/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml.data;

import conifer.ml.data.CharacterReconstructionMethod;
import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Pair;
import gep.util.OutputManager;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import nuts.lang.ArrayUtils;
import nuts.math.RateMtxUtils;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.MathUtils;

public class HeldoutData {
    public final Map<Pair<Taxon, Integer>, Integer> heldout = CollUtils.map();
    private final Indexer<Character> indexer;

    public EvalResult evaluate(CharacterReconstructionMethod reconstructionMethod) {
        EvalResult result = new EvalResult();
        EvalResult.access$002(result, new double[this.indexer.size()][this.indexer.size()]);
        for (Pair<Taxon, Integer> key : this.heldout.keySet()) {
            int truth = this.heldout.get(key);
            double[] recontruction = reconstructionMethod.posteriorOverCharacters(key.getFirst(), key.getSecond());
            if (recontruction.length != this.indexer.size()) {
                throw new RuntimeException();
            }
            MathUtils.checkIsProb(recontruction);
            result.sumOfLogPredictives += Math.log(recontruction[truth]);
            result.sumOfPredictives += recontruction[truth];
            result.n += 1.0;
            for (int other = 0; other < this.indexer.size(); ++other) {
                if (other == truth) continue;
                double[] dArray = result.confusions[truth];
                int n = other;
                dArray[n] = dArray[n] + recontruction[other];
            }
            if (ArrayUtils.argmax(recontruction) != truth) continue;
            result.nCorrect += 1.0;
        }
        result.indexer = this.indexer;
        for (Object row : (Object)result.confusions) {
            NumUtils.normalize((double[])row);
        }
        return result;
    }

    public HeldoutData(Indexer<Character> indexer) {
        this.indexer = indexer;
    }

    public Map<Taxon, String> holdout(Map<Taxon, double[][]> data, double holdPr, Random rand) {
        Counter<Character> heldoutDist = new Counter<Character>();
        int nSites = CollUtils.pick(data.values()).length;
        int nChars = CollUtils.pick(data.values())[0].length;
        double[] unk = new double[nChars];
        for (int i = 0; i < nChars; ++i) {
            unk[i] = 1.0;
        }
        HashMap<Taxon, String> result = CollUtils.map();
        ArrayList<Taxon> sorted = CollUtils.list();
        sorted.addAll(data.keySet());
        Collections.sort(sorted);
        LogInfo.track("Holding out (only odd column indices)");
        int nHeld = 0;
        for (Taxon t : sorted) {
            double[][] current = data.get(t);
            for (int i = 0; i < nSites; ++i) {
                Character c = null;
                if (i % 2 != 1 || !(rand.nextDouble() < holdPr) || (c = this.findChar(current[i])) == null) continue;
                heldoutDist.incrementCount(c, 1.0);
                int index = this.indexer.o2i(c);
                this.heldout.put(Pair.makePair(t, i), index);
                ++nHeld;
                current[i] = unk;
            }
        }
        LogInfo.logsForce("" + nHeld + " characters held out");
        heldoutDist.normalize();
        LogInfo.logsForce("Distribution of held-out characters: " + heldoutDist);
        LogInfo.end_track();
        return result;
    }

    private Character findChar(double[] ds) {
        int idx = -1;
        for (int i = 0; i < ds.length; ++i) {
            if (ds[i] == 1.0) {
                if (idx != -1) {
                    return null;
                }
                idx = i;
                continue;
            }
            if (ds[i] == 0.0) continue;
            throw new RuntimeException();
        }
        if (idx == -1) {
            throw new RuntimeException();
        }
        return this.indexer.i2o(idx);
    }

    public static class EvalResult {
        private Indexer<Character> indexer;
        private double[][] confusions;
        public double sumOfLogPredictives = 0.0;
        public double nCorrect = 0.0;
        public double n = 0.0;
        public double sumOfPredictives = 0.0;

        public String toString() {
            return "perplexity = " + this.perplexity() + "\n" + this.accuracy();
        }

        public double accuracy() {
            return this.nCorrect / this.n;
        }

        public double perplexity() {
            return Math.pow(2.0, -this.sumOfLogPredictives / Math.log(2.0) / this.n);
        }

        public double meanPredictive() {
            return this.sumOfPredictives / this.n;
        }

        public void report(OutputManager output) {
            output.printWrite("perplexity", "perplexity", this.perplexity());
            output.printWrite("sumLogPred", "sumLogPred", this.sumOfLogPredictives);
            output.printWrite("accuracy", "accuracy", this.accuracy());
            output.printWrite("meanPred", "meanPred", this.meanPredictive());
            LogInfo.logsForce("confusions:\n" + RateMtxUtils.toString(this.confusions, this.indexer));
        }

        static /* synthetic */ double[][] access$002(EvalResult x0, double[][] x1) {
            x0.confusions = x1;
            return x1;
        }
    }
}

