/*
 * Decompiled with CFR 0.152.
 */
package conifer.spr;

import conifer.data.DataModel;
import conifer.data.DataOptions;
import conifer.data.PhylogeneticData;
import conifer.fastmetrics.CladeMetrics;
import conifer.fastpf.TaxaOrderHeuristic;
import conifer.largemove.LargeMoveUtils;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import conifer.ml.tests.TestRealData;
import conifer.multicategories.PhylogeneticFactorGraph;
import conifer.multicategories.PhylogenyPotentials;
import conifer.spr.SPROperator;
import fenchel.factor.UnaryFactor;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.OptionSet;
import gep.util.OutputManager;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.RateMatrixLoader;
import nuts.io.IO;
import nuts.util.Counter;
import pty.Observations;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.io.Dataset;

public class TestSPR
implements Runnable {
    @Option
    public boolean generate = true;
    @OptionSet(name="synt")
    public DataOptions syntheticOptions = new DataOptions();
    @OptionSet(name="real")
    public PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions phyloOptions = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
    @Option
    public boolean useRandomOrder = false;
    private Map<Taxon, UnaryFactor> leafFactors;
    private Observations observations;
    private double[] fractionsFromBottom;
    private double[] stemLens;
    private PhylogenyPotentials potentials;
    private PhylogeneticHeldoutDataset phyloData;
    private OutputManager outm = new OutputManager();
    @Option
    public boolean optimistic = true;
    double nTimesFirstIsCorrectTopo = 0.0;
    double nTimesTested = 0.0;

    public static void main(String[] args) {
        IO.run(args, new TestSPR());
    }

    public static double[] enumerateDoubles(double min, double max, double increment) {
        ArrayList<Double> result = new ArrayList<Double>();
        for (double current = min; current <= max; current += increment) {
            result.add(current);
        }
        double[] resultAr = new double[result.size()];
        for (int i = 0; i < result.size(); ++i) {
            resultAr[i] = (Double)result.get(i);
        }
        return resultAr;
    }

    @Override
    public void run() {
        if (this.generate) {
            this.phyloData = new PhylogeneticHeldoutDataset();
            PhylogeneticData data2 = DataModel.GENERATED.loadDataset(this.syntheticOptions);
            this.phyloData.rootedTree = data2.getPhylogeny().getRooted();
            this.phyloData.obs = (Dataset)((Object)data2.getObservedTaxonIndexData());
        } else {
            this.phyloData = PhylogeneticHeldoutDataset.loadData(this.phyloOptions);
        }
        UnrootedTree guide = this.phyloData.rootedTree.getUnrooted();
        this.observations = this.phyloData.obs;
        this.potentials = PhylogenyPotentials.createSingleCategoryErrorFreeFromStationaryProcess(RateMatrixLoader.k2p());
        this.fractionsFromBottom = TestSPR.enumerateDoubles(0.1, 0.9, 0.1);
        this.stemLens = TestSPR.enumerateDoubles(0.1, 1.5, 0.1);
        Random rand = new Random(1L);
        RootedTree rt = LargeMoveUtils.nextRooting(guide, rand);
        PhylogeneticFactorGraph fullTree = new PhylogeneticFactorGraph(rt, this.potentials, this.observations);
        LogInfo.logsForce("guideLL = " + fullTree.dataLogLikelihoodGivenTreeAndParameters());
        this.leafFactors = SPROperator.leafFactors(fullTree);
        List<Taxon> order = TaxaOrderHeuristic.heuristicOrder(this.observations.observations());
        if (this.useRandomOrder) {
            Collections.shuffle(order, rand);
        }
        UnrootedTree previous = null;
        for (int baseNTaxa = 3; baseNTaxa < guide.nTaxa(); ++baseNTaxa) {
            int taxonToAddIndex = baseNTaxa;
            Taxon taxonToAdd = order.get(taxonToAddIndex);
            UnrootedTree treeBefore = this.optimistic || previous == null ? UnrootedTree.restrict(guide, new HashSet<Taxon>(order.subList(0, taxonToAddIndex))) : previous;
            SPROperator op = this.buildSPROperator(taxonToAdd, treeBefore, rand);
            UnrootedTree goldTreeAfter = UnrootedTree.restrict(guide, new HashSet<Taxon>(order.subList(0, taxonToAddIndex + 1)));
            previous = this.checkSPR(goldTreeAfter, op);
        }
        rt = LargeMoveUtils.nextRooting(previous, rand);
        PhylogeneticFactorGraph finalTree = new PhylogeneticFactorGraph(rt, this.potentials, this.observations);
        LogInfo.logsForce("guideLL = " + finalTree.dataLogLikelihoodGivenTreeAndParameters());
    }

    public SPROperator buildSPROperator(Taxon leafToAdd, UnrootedTree treeBeforeSPR, Random rand) {
        Set<Taxon> leaves = treeBeforeSPR.leavesSet();
        if (leaves.contains(leafToAdd)) {
            throw new RuntimeException();
        }
        TestRealData.SimpleObservations restricted = SPROperator.restrictedObservations(this.observations, leaves);
        RootedTree rt = LargeMoveUtils.nextRooting(treeBeforeSPR, rand);
        PhylogeneticFactorGraph beforeGraftFC = new PhylogeneticFactorGraph(rt, this.potentials, restricted);
        SPROperator result = new SPROperator(beforeGraftFC, this.leafFactors.get(leafToAdd), leafToAdd, this.stemLens);
        for (double fractionFromBottom : this.fractionsFromBottom) {
            result.addRegrafts(fractionFromBottom);
        }
        return result;
    }

    public UnrootedTree checkSPR(UnrootedTree truth, SPROperator operator) {
        UnrootedTree firstChoice = null;
        Counter<SPROperator.Regraft> prs = operator.getNormalizedPrCounter();
        int i = 0;
        double massCorrectTopo = 0.0;
        double attachedBL = Double.NaN;
        for (SPROperator.Regraft r : prs) {
            UnrootedTree guess = r.treeAfterRegraft();
            Map<CladeMetrics.TreeMetric, Double> metrics = CladeMetrics.computeTreeMetrics(truth, guess);
            double l0Metric = metrics.get((Object)CladeMetrics.TreeMetric.l0);
            if (i == 0) {
                firstChoice = guess;
                this.outm.printWrite("l1First", "nTaxaBefore", truth.nTaxa(), "value", metrics.get((Object)CladeMetrics.TreeMetric.l1));
            }
            if (i == 0 && l0Metric == 0.0) {
                this.nTimesFirstIsCorrectTopo += 1.0;
            }
            if (l0Metric == 0.0) {
                massCorrectTopo += prs.getCount(r);
                attachedBL = operator.factorGraphBeforeRegraft.rootedTree.branchLengths().get(r.bot);
            }
            ++i;
        }
        this.nTimesTested += 1.0;
        this.outm.printWrite("massCorrect", "nTaxaBefore", truth.nTaxa(), "attachedBL", attachedBL, "value", massCorrectTopo);
        this.outm.printWrite("averageFirstCorrect", "nTaxaBefore", truth.nTaxa(), "value", this.nTimesFirstIsCorrectTopo / this.nTimesTested);
        return firstChoice;
    }
}

