/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml.tests;

import conifer.data.PhylogeneticData;
import conifer.evol.EvolutionaryModel;
import conifer.evol.EvolutionaryOptions;
import conifer.evol.SimpleSubstitutionLikelihoodModel;
import conifer.fastmetrics.CladeMetrics;
import conifer.fastpf.FastPriorPrior;
import conifer.largemove.LargeMoveKernel;
import conifer.ml.tests.TestRealData;
import conifer.multicategories.PhylogeneticFactorGraph;
import fenchel.factor.multisitecat.MSCUnaryScaledFactor;
import fig.basic.LogInfo;
import fig.basic.Option;
import gep.util.OutputManager;
import goblin.Taxon;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import ma.SequenceType;
import nuts.io.IO;
import pty.UnrootedTree;
import pty.io.Dataset;

public class TestHoldOut
implements Runnable {
    private double[][] rateMtx;
    private Dataset dataset;
    private double[][] unobserved;
    private UnrootedTree trueTree;
    private OutputManager out = new OutputManager();
    @Option
    public int nPaths = 1000;
    @Option
    public Random rand = new Random(1L);
    @Option
    public int nIterPerPath = 100;
    @Option
    public boolean largeMoves = false;

    public static void main(String[] args) {
        IO.run(args, new TestHoldOut());
    }

    @Override
    public void run() {
        this.trueTree = UnrootedTree.fromNewick(new File("/Users/bouchard/Documents/data/utcs/23S.E.raxml.nwk"));
        EvolutionaryOptions evoOptions = new EvolutionaryOptions();
        evoOptions.sequenceType = SequenceType.PROTEIN;
        evoOptions.nSites = 100;
        evoOptions.model = EvolutionaryModel.SIMPLE_SUBSTITUTION;
        SimpleSubstitutionLikelihoodModel params = (SimpleSubstitutionLikelihoodModel)evoOptions.model.instantiateParameters(evoOptions, null);
        this.rateMtx = params.rateMatrix;
        PhylogeneticData data = params.generate(this.trueTree);
        this.dataset = (Dataset)((Object)data.getObservedTaxonIndexData());
        int nChars = this.rateMtx.length;
        this.unobserved = new double[evoOptions.nSites][];
        double[] unif = new double[nChars];
        for (int i = 0; i < nChars; ++i) {
            unif[i] = 1.0;
        }
        for (int s = 0; s < evoOptions.nSites; ++s) {
            this.unobserved[s] = unif;
        }
        this.process(this.trueTree, 0);
        LargeMoveKernel.LargeMoveKernelOptions ko = new LargeMoveKernel.LargeMoveKernelOptions();
        ko.useRegraftMode = false;
        PhylogeneticFactorGraph _dummy = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(this.trueTree.reRootAtArbitraryInternalNode(), this.rateMtx, this.dataset);
        for (int i = 0; i < this.nPaths; ++i) {
            LogInfo.track("Processing path " + (i + 1) + "/" + this.nPaths);
            LargeMoveKernel.LargeMoveParticle lmp = new LargeMoveKernel.LargeMoveParticle(this.trueTree, 10000, Double.NaN, Double.NaN, Double.NaN, Double.NaN);
            for (int j = 0; j < this.nIterPerPath; ++j) {
                LargeMoveKernel lmk = new LargeMoveKernel(ko, new FastPriorPrior.SimplePriorOptions(), lmp, _dummy.potentials.categoryPriors, _dummy.potentials.rateMatrices, this.dataset, _dummy.potentials.stationaryDistributions, _dummy.potentials.observationErrorProbability, null);
                lmp = lmk.next(this.rand, lmp, false, this.largeMoves).getFirst();
                this.process(lmp.tree, i);
            }
            LogInfo.end_track();
        }
    }

    public void process(UnrootedTree t2, int path) {
        Map<CladeMetrics.TreeMetric, Double> treeMetrics = CladeMetrics.computeTreeMetrics(this.trueTree, t2);
        for (Taxon leaf : this.dataset.observations().keySet()) {
            Map<ErrorMetric, Double> meanPred = this.predictives(t2, leaf);
            for (CladeMetrics.TreeMetric tm : treeMetrics.keySet()) {
                for (ErrorMetric em : meanPred.keySet()) {
                    this.out.write("points", new Object[]{"path", path, "treeMetricType", tm, "errorMetricType", em, "treeMetric", treeMetrics.get((Object)tm), "errorMetric", meanPred.get((Object)em)});
                }
            }
        }
    }

    public Map<ErrorMetric, Double> predictives(UnrootedTree t2, Taxon leaf) {
        HashMap<ErrorMetric, Double> result = new HashMap<ErrorMetric, Double>();
        double[][] truth = this.dataset.observations().get(leaf);
        HashMap<Taxon, double[][]> map = new HashMap<Taxon, double[][]>();
        for (Taxon t : this.dataset.observations().keySet()) {
            if (!t.equals(leaf)) {
                map.put(t, this.dataset.observations().get(t));
                continue;
            }
            map.put(t, this.unobserved);
        }
        TestRealData.SimpleObservations restriced = new TestRealData.SimpleObservations(map);
        PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(t2.reRootAtArbitraryInternalNode(), this.rateMtx, restriced);
        MSCUnaryScaledFactor posterior = pfg.getNodePosterior(leaf);
        double resultKL = 0.0;
        double resultPred = 0.0;
        for (int s = 0; s < posterior.nSites; ++s) {
            double[] sitePost = pfg.posteriorOverCharacters(leaf, s);
            for (int x = 0; x < posterior.nCharacters; ++x) {
                resultKL += truth[s][x] * Math.log(sitePost[x]);
                resultPred += truth[s][x] * sitePost[x];
            }
        }
        result.put(ErrorMetric.NEG_CROSS_ENTROPY, resultKL / (double)posterior.nSites);
        result.put(ErrorMetric.PRED_PROB, resultPred / (double)posterior.nSites);
        return result;
    }

    public static enum ErrorMetric {
        NEG_CROSS_ENTROPY,
        PRED_PROB;

    }
}

