/*
 * Decompiled with CFR 0.152.
 */
package conifer.largemove;

import conifer.data.DataModel;
import conifer.data.DataOptions;
import conifer.data.PhylogeneticData;
import conifer.fastmetrics.CladeMetrics;
import conifer.fastpf.FastParticle;
import conifer.fastpf.FastPriorPrior;
import conifer.fastpf.TaxaOrderHeuristic;
import conifer.largemove.LargeMoveKernel;
import conifer.largemove.LargeMoveOperator;
import conifer.largemove.LargeMoveOperatorSelection;
import conifer.largemove.LargeMoveUtils;
import conifer.largemove.MicroGibbsResult;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import conifer.multicategories.PhylogeneticFactorGraph;
import conifer.spr.SPROperator;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import gep.util.OutputManager;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import ma.RateMatrixLoader;
import monaco.process.ProcessSchedule;
import monaco.process.ProcessScheduleContext;
import nuts.io.IO;
import nuts.math.MutableGraph;
import nuts.math.Sampling;
import nuts.util.Arbre;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.io.Dataset;
import pty.smc.LazyParticleFilter;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.PriorPriorKernel;
import pty.smc.models.CTMC;

public class TestClimbLikelihood
implements Runnable {
    private static final PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions phyloOptions = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
    private static final LazyParticleFilter.ParticleFilterOptions pfOptions = new LazyParticleFilter.ParticleFilterOptions();
    private static final FastPriorPrior.FastPriorPriorOptions pfKernelOptions = new FastPriorPrior.FastPriorPriorOptions();
    private static final DataOptions syntheticOptions = new DataOptions();
    private static final FastPriorPrior.SimplePriorOptions priorOptions = new FastPriorPrior.SimplePriorOptions();
    private static final LargeMoveKernel.LargeMoveKernelOptions largeKernelOptions = new LargeMoveKernel.LargeMoveKernelOptions();
    public static final ParticleFilter<LargeMoveKernel.LargeMoveParticle> pf2 = new ParticleFilter();
    @Option
    public Random rand = new Random(1L);
    @Option
    public boolean useStandardMove = false;
    @Option
    public boolean optimize = false;
    @Option
    public String taxaOrderFile = "";
    @Option
    public boolean useNewLargeMovePF = true;
    @Option
    public boolean generate = false;
    @Option
    public boolean useNewSMC = true;
    private OutputManager output;
    private double[][] rateMtx;
    private PhylogeneticHeldoutDataset phyloData = null;
    private RootedTree truth;
    @Option
    public int nPhaseTwoIterations = 10000;
    @Option
    public boolean useMCMC = true;
    private double observationError;
    private int iter = 0;
    public final ProcessSchedule always = new ProcessSchedule(){

        @Override
        public void monitor(ProcessScheduleContext context) {
            TestClimbLikelihood.this.output.printWrite("evidence", "iteration", TestClimbLikelihood.this.iter++, "evidence", pf2.estimateNormalizer());
        }

        @Override
        public boolean shouldProcess(ProcessScheduleContext context) {
            return true;
        }
    };

    public static void main(String[] args) {
        IO.run(args, new TestClimbLikelihood(), "phylo", phyloOptions, "pf", pfOptions, "pfk", pfKernelOptions, "synth", syntheticOptions, "prior", priorOptions, "pfk2", largeKernelOptions, "pf2", pf2);
    }

    @Override
    public void run() {
        TestClimbLikelihood.phyloOptions.holdOutFre = 0.0;
        if (this.generate) {
            this.phyloData = new PhylogeneticHeldoutDataset();
            PhylogeneticData data2 = DataModel.GENERATED.loadDataset(syntheticOptions);
            this.phyloData.rootedTree = data2.getPhylogeny().getRooted();
            this.phyloData.obs = (Dataset)((Object)data2.getObservedTaxonIndexData());
        } else {
            this.phyloData = PhylogeneticHeldoutDataset.loadData(phyloOptions);
        }
        this.truth = this.phyloData.rootedTree;
        this.rateMtx = RateMatrixLoader.k2p();
        this.observationError = 0.0;
        this.output = new OutputManager();
        if (this.observationError != 0.0) {
            throw new RuntimeException();
        }
        if (this.phyloData.rootedTree != null) {
            PhylogeneticFactorGraph goldTreeFG = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(this.phyloData.rootedTree, this.rateMtx, this.phyloData.obs);
            LogInfo.logsForce("true tree log ll: " + goldTreeFG.getSumProductPosteriorCalculator().logZ());
        }
        RootedTree randomTree = RootedTree.Util.random(this.rand, this.phyloData.rootedTree.getUnrooted().leaves());
        PhylogeneticFactorGraph randomTreeFG = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(randomTree, this.rateMtx, this.phyloData.obs);
        LogInfo.logsForce("random tree log ll: " + randomTreeFG.getSumProductPosteriorCalculator().logZ());
        LogInfo.logsForce("random tree metrics: " + CladeMetrics.computeTreeMetrics(randomTree, this.truth));
        if (TestClimbLikelihood.largeKernelOptions.useRegraftMode) {
            List<Taxon> order = this.taxaOrder(this.phyloData.obs.observations());
            UnrootedTree argmax = this.regraftSMC(order, this.rand);
            this.regraftSMC(order, this.rand);
            if (this.useMCMC) {
                this.mcmc(LargeMoveUtils.nextRooting(argmax, this.rand));
            }
        } else {
            RootedTree original = null;
            original = this.useNewSMC ? this.newPFInit(this.phyloData) : this.pfInit(this.phyloData);
            Map<CladeMetrics.TreeMetric, Double> metrics = CladeMetrics.computeTreeMetrics(this.truth, original);
            for (CladeMetrics.TreeMetric t : metrics.keySet()) {
                this.output.printWrite("metrics", new Object[]{"iteration", -1, "metric", t, "value", metrics.get((Object)t)});
            }
            if (this.useMCMC) {
                this.mcmc(LargeMoveUtils.nextRooting(original.getUnrooted(), this.rand));
            }
            if (this.useNewLargeMovePF) {
                this.newLargeMovePF(LargeMoveUtils.nextRooting(original.getUnrooted(), this.rand));
            } else {
                this.useOld(original.getUnrooted());
            }
        }
    }

    private UnrootedTree regraftSMC(List<Taxon> order, Random rand) {
        LargeMoveKernel.LargeMoveParticle init = TestClimbLikelihood.starInit(order.subList(0, 3), rand, this.phyloData.obs.observations().keySet().size());
        PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(SPROperator.starInitialization(order.subList(0, 3), 0.2), this.rateMtx, this.phyloData.obs);
        LargeMoveKernel lmk = new LargeMoveKernel(largeKernelOptions, priorOptions, init, pfg.potentials.categoryPriors, pfg.potentials.rateMatrices, pfg.observations, pfg.potentials.stationaryDistributions, this.observationError, order);
        Phase2ParticleProcessor processor = new Phase2ParticleProcessor();
        pf2.sample(lmk, processor);
        LogInfo.logsForce("logZ = " + pf2.estimateNormalizer());
        return processor.argmax.tree;
    }

    public static LargeMoveKernel.LargeMoveParticle starInit(List<Taxon> subList, Random rand, int nTaxa) {
        int nPreprocessBLMoves = 100;
        if (subList.size() != 3) {
            throw new RuntimeException();
        }
        RootedTree tree = SPROperator.starInitialization(subList, 0.2);
        return new LargeMoveKernel.LargeMoveParticle(tree.getUnrooted(), nTaxa - 3, Double.NaN, Double.NaN, 0.0, 0.0);
    }

    private void mcmc(RootedTree original) {
        PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(original, this.rateMtx, this.phyloData.obs);
        LargeMoveKernel.LargeMoveParticle current = new LargeMoveKernel.LargeMoveParticle(original.getUnrooted(), this.nPhaseTwoIterations, 0.0, 0.0, priorOptions.logDensity(original), pfg.getSumProductPosteriorCalculator().logZ());
        LargeMoveKernel lmk = new LargeMoveKernel(largeKernelOptions, priorOptions, current, pfg.potentials.categoryPriors, pfg.potentials.rateMatrices, pfg.observations, pfg.potentials.stationaryDistributions, this.observationError, null);
        for (int i = 0; i < this.nPhaseTwoIterations; ++i) {
            Pair<LargeMoveKernel.LargeMoveParticle, Double> pair = lmk.next(this.rand, current, false, false);
            double sum = pair.getSecond();
            System.out.println("===>" + sum);
            pair = lmk.next(this.rand, pair.getFirst(), true, false);
            System.out.println("===>" + (sum += pair.getSecond().doubleValue()));
            System.out.println("----");
            double ratio = Math.exp(sum);
            if (!(this.rand.nextDouble() < ratio)) continue;
            current = pair.getFirst();
            Map<CladeMetrics.TreeMetric, Double> metrics = CladeMetrics.computeTreeMetrics(this.truth, current.tree);
            for (CladeMetrics.TreeMetric t : metrics.keySet()) {
                this.output.printWrite("metrics", new Object[]{"iteration", i, "metric", t, "value", metrics.get((Object)t)});
            }
            this.output.printWrite("ll", "iteration", i, "ll", current.logLikelihood);
        }
    }

    private void newLargeMovePF(RootedTree original) {
        PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(original, this.rateMtx, this.phyloData.obs);
        LargeMoveKernel.LargeMoveParticle lmp = new LargeMoveKernel.LargeMoveParticle(original.getUnrooted(), this.nPhaseTwoIterations, 0.0, 0.0, priorOptions.logDensity(original), pfg.getSumProductPosteriorCalculator().logZ());
        LargeMoveKernel lmk = new LargeMoveKernel(largeKernelOptions, priorOptions, lmp, pfg.potentials.categoryPriors, pfg.potentials.rateMatrices, pfg.observations, pfg.potentials.stationaryDistributions, this.observationError, null);
        pf2.setProcessSchedule(this.always);
        Phase2ParticleProcessor processor = new Phase2ParticleProcessor();
        pf2.sample(lmk, processor);
    }

    private void useOld(UnrootedTree ut) {
        boolean standard = this.useStandardMove;
        LogInfo.track("Sampling");
        double previousLogLL = Double.NaN;
        double previousPropLog = Double.NaN;
        for (int i = 0; i < this.nPhaseTwoIterations; ++i) {
            List<Pair<Arbre<Taxon>, Arbre<Taxon>>> edges;
            boolean onlyBranches = i % 2 == 0;
            LogInfo.logsForce("Iteration " + i);
            RootedTree rt = LargeMoveUtils.nextRooting(ut, this.rand);
            List<Pair<Arbre<Taxon>, Arbre<Taxon>>> list = edges = onlyBranches ? LargeMoveOperatorSelection.allEdges(rt.topology()) : LargeMoveOperatorSelection.randomNonTerminalEdgesInPreorder(this.rand, rt.topology());
            if (standard) {
                edges = Collections.singletonList(edges.get(this.rand.nextInt(edges.size())));
            }
            PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(rt, this.rateMtx, this.phyloData.obs);
            this.output.printWrite("ll", "iteration", i, "standard", standard, "ll", pfg.getSumProductPosteriorCalculator().logZ());
            if (!Double.isNaN(previousLogLL)) {
                double newLogLL = pfg.getSumProductPosteriorCalculator().logZ();
                System.out.println();
                System.out.println("newLogLL = " + newLogLL);
                System.out.println("oldLogLL = " + previousLogLL);
                System.out.println("propLogP = " + previousPropLog);
                System.out.println("naiveLogWeight = " + (newLogLL - previousLogLL - previousPropLog));
                System.out.println();
            }
            previousLogLL = pfg.getSumProductPosteriorCalculator().logZ();
            Map<CladeMetrics.TreeMetric, Double> metrics = CladeMetrics.computeTreeMetrics(this.truth, rt);
            for (CladeMetrics.TreeMetric t : metrics.keySet()) {
                this.output.printWrite("metrics", new Object[]{"iteration", i, "standard", standard, "metric", t, "value", metrics.get((Object)t)});
            }
            MutableGraph<Taxon> newTopology = new MutableGraph<Taxon>(rt.getUnrooted().getTopology());
            HashMap<UnorderedPair<Taxon, Taxon>, Double> newBranchLengths = new HashMap<UnorderedPair<Taxon, Taxon>, Double>(ut.branchLengths);
            LogInfo.track("Processing micro kernels");
            int dummy = 1;
            double nTopologyMoves = 0.0;
            double nBLMoves = 0.0;
            previousPropLog = 0.0;
            for (Pair<Arbre<Taxon>, Arbre<Taxon>> orderedEdge : edges) {
                LogInfo.logs("" + dummy++ + "/" + edges.size());
                double modifier = Sampling.nextDouble(this.rand, 1.0, TestClimbLikelihood.largeKernelOptions.modifierBound);
                double[] bls = LargeMoveOperatorSelection.branchLengthPerturbations(TestClimbLikelihood.largeKernelOptions.nBLExpansions, modifier, rt.branchLengths().get(orderedEdge.getSecond().getContents()));
                MicroGibbsResult currentMicroGibbs = LargeMoveOperator.efficientMicroGibbs(pfg, orderedEdge, bls, onlyBranches, false);
                int sampled = this.optimize ? currentMicroGibbs.argmaxIndex() : currentMicroGibbs.sampleIndex(this.rand);
                double[] prs = currentMicroGibbs.samplingProbabilities();
                previousPropLog += Math.log(prs[sampled]);
                if (currentMicroGibbs.branchLength(sampled) != rt.branchLengths().get(orderedEdge.getSecond().getContents()).doubleValue()) {
                    nBLMoves += 1.0;
                }
                if (!onlyBranches && currentMicroGibbs.configuration(sampled) != 0) {
                    nTopologyMoves += 1.0;
                }
                currentMicroGibbs.alter(newTopology, newBranchLengths, sampled);
            }
            LogInfo.end_track();
            this.output.printWrite("branchStats", "iteration", i, "fraction", nBLMoves / (double)edges.size());
            if (!onlyBranches) {
                this.output.printWrite("topoStats", "iteration", i, "fraction", nTopologyMoves / (double)edges.size());
            }
            UnrootedTree newUt = new UnrootedTree(newTopology, newBranchLengths);
            Map<CladeMetrics.TreeMetric, Double> metrics2 = CladeMetrics.computeTreeMetrics(ut, newUt);
            for (CladeMetrics.TreeMetric t : metrics2.keySet()) {
                this.output.printWrite("delta-metrics", new Object[]{"iteration", i, "standard", standard, "metric", t, "value", metrics2.get((Object)t)});
            }
            ut = newUt;
        }
        LogInfo.end_track();
    }

    private RootedTree newPFInit(PhylogeneticHeldoutDataset phyloData) {
        LogInfo.track("First Fast-SMC phase");
        CTMC.SimpleCTMC ctmc = CTMC.SimpleCTMC.dnaCTMC(phyloData.obs.nSites());
        List<Taxon> taxaOrder = this.taxaOrder(phyloData.obs.observations());
        FastParticle fp = FastParticle.initFastParticle(phyloData.obs.observations(), ctmc, taxaOrder);
        FastPriorPrior fpp = new FastPriorPrior(fp, pfKernelOptions, priorOptions);
        ParticleFilter.StoreProcessor storeP = new ParticleFilter.StoreProcessor();
        LazyParticleFilter<FastParticle> filter = new LazyParticleFilter<FastParticle>(fpp, pfOptions);
        filter.sample(storeP);
        LogInfo.end_track();
        return ((FastParticle)storeP.argmax()).getTree();
    }

    private List<Taxon> taxaOrder(Map<Taxon, double[][]> observations) {
        if (this.taxaOrderFile != null && !"".equals(this.taxaOrderFile)) {
            return TaxaOrderHeuristic.fromFile(new File(this.taxaOrderFile));
        }
        return TaxaOrderHeuristic.heuristicOrder(observations);
    }

    private RootedTree pfInit(PhylogeneticHeldoutDataset phyloData) {
        System.out.println("WARNING: OLD INIT");
        LogInfo.track("First SMC phase");
        PartialCoalescentState pcs = PartialCoalescentState.initFastState(phyloData.obs, CTMC.SimpleCTMC.dnaCTMC(phyloData.obs.nSites()));
        PriorPriorKernel ppk = new PriorPriorKernel(pcs);
        LazyParticleFilter<PartialCoalescentState> pf = new LazyParticleFilter<PartialCoalescentState>(ppk, pfOptions);
        ParticleFilter.StoreProcessor sp = new ParticleFilter.StoreProcessor();
        pf.sample(sp);
        LogInfo.end_track();
        return ((PartialCoalescentState)sp.argmax()).getFullCoalescentState();
    }

    public void test1() {
        TestClimbLikelihood.phyloOptions.maxNSites = 100;
        TestClimbLikelihood.phyloOptions.alignmentFile = "/Users/bouchard/Documents/data/small-test-data/synthetic-phylo-10/alignment.fasta";
        TestClimbLikelihood.phyloOptions.treeFile = "/Users/bouchard/Documents/data/small-test-data/synthetic-phylo-10/tree.newick";
        PhylogeneticHeldoutDataset phyloData = PhylogeneticHeldoutDataset.loadData(phyloOptions);
        UnrootedTree ut = phyloData.rootedTree.getUnrooted();
        Random rand = new Random(1L);
        RootedTree rt = LargeMoveUtils.nextRooting(ut, rand);
        System.out.println(LargeMoveOperatorSelection.randomNonTerminalEdgesInPreorder(rand, rt.topology()));
        LogInfo.logsForce(rt);
        Pair<Arbre<Taxon>, Arbre<Taxon>> orderedEdge = LargeMoveOperatorSelection.randomNonTerminalEdge(rand, rt.topology());
        LogInfo.logsForce(orderedEdge);
        UnorderedPair<Arbre<Taxon>, Arbre<Taxon>> edge = new UnorderedPair<Arbre<Taxon>, Arbre<Taxon>>(orderedEdge.getFirst(), orderedEdge.getSecond());
        ArrayList<UnrootedTree> nbrs = new ArrayList<UnrootedTree>();
        UnorderedPair<Taxon, Taxon> edge2 = new UnorderedPair<Taxon, Taxon>(edge.getFirst().getContents(), edge.getSecond().getContents());
        nbrs.addAll(ut.topologicalNeighbors(edge2));
        nbrs.add(ut);
        double[][] rateMtx = RateMatrixLoader.k2p();
        LogInfo.track("Method 1");
        for (UnrootedTree nei : nbrs) {
            RootedTree rtnei = RootedTree.Util.centroidRooting(nei);
            PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(rtnei, rateMtx, phyloData.obs);
            LogInfo.logsForce(pfg.getSumProductPosteriorCalculator().logZ());
        }
        LogInfo.end_track();
        PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(rt, rateMtx, phyloData.obs);
        double[] bls = new double[]{rt.branchLengths().get(orderedEdge.getSecond())};
        MicroGibbsResult result = LargeMoveOperator.efficientMicroGibbs(pfg, orderedEdge, bls, false, false);
    }

    private class Phase2ParticleProcessor
    implements ParticleFilter.ParticleProcessor<LargeMoveKernel.LargeMoveParticle> {
        LargeMoveKernel.LargeMoveParticle argmax = null;
        double max = Double.NEGATIVE_INFINITY;

        private Phase2ParticleProcessor() {
        }

        @Override
        public void process(LargeMoveKernel.LargeMoveParticle state, double weight) {
            double currentLL = state.logLikelihood;
            if (currentLL > this.max) {
                this.max = currentLL;
                this.argmax = state;
            }
            int iteration = TestClimbLikelihood.this.nPhaseTwoIterations - state.nIterationLeft;
            Map<CladeMetrics.TreeMetric, Double> metrics = CladeMetrics.computeTreeMetrics(TestClimbLikelihood.this.truth, state.tree);
            for (CladeMetrics.TreeMetric t : metrics.keySet()) {
                TestClimbLikelihood.this.output.printWrite("metrics", new Object[]{"iteration", iteration, "metric", t, "value", metrics.get((Object)t)});
            }
            TestClimbLikelihood.this.output.printWrite("ll", "iteration", iteration, "ll", currentLL);
        }
    }
}

