/*
 * Decompiled with CFR 0.152.
 */
package conifer.largemove;

import conifer.fastpf.FastPriorPrior;
import conifer.largemove.LargeMoveOperator;
import conifer.largemove.LargeMoveOperatorSelection;
import conifer.largemove.LargeMoveUtils;
import conifer.largemove.MicroGibbsResult;
import conifer.ml.tests.TestRealData;
import conifer.multicategories.PhylogeneticFactorGraph;
import conifer.multicategories.PhylogenyPotentials;
import conifer.spr.SPROperator;
import fenchel.factor.UnaryFactor;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import goblin.Taxon;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.math.MutableGraph;
import nuts.math.Sampling;
import nuts.util.Arbre;
import pty.Observations;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.smc.ParticleKernel;

public class LargeMoveKernel
implements ParticleKernel<LargeMoveParticle> {
    private final LargeMoveKernelOptions kernelOptions;
    private final FastPriorPrior.SimplePriorOptions priorOptions;
    private final LargeMoveParticle initial;
    private final double[] categoryPriors;
    private final double[][][] rateMatrices;
    private final Observations observations;
    private final double[][] stationaryDistributions;
    private final double observationErrorPr;
    private final Map<Taxon, UnaryFactor> leafFactors;
    private final List<Taxon> graftOrder;

    public LargeMoveKernel(LargeMoveKernelOptions kernelOptions, FastPriorPrior.SimplePriorOptions priorOptions, LargeMoveParticle initial, double[] categoryPriors, double[][][] rateMatrices, Observations observations, double[][] stationaryDistributions, double observationErrorPr, List<Taxon> graftOrder) {
        this.kernelOptions = kernelOptions;
        this.priorOptions = priorOptions;
        this.initial = initial;
        this.categoryPriors = categoryPriors;
        this.rateMatrices = rateMatrices;
        this.observations = observations;
        this.stationaryDistributions = stationaryDistributions;
        this.observationErrorPr = observationErrorPr;
        if (kernelOptions.useRegraftMode) {
            RootedTree randomTree = RootedTree.Util.random(new Random(1L), observations.observations().keySet());
            PhylogeneticFactorGraph starFG = new PhylogeneticFactorGraph(randomTree, new PhylogenyPotentials(categoryPriors, rateMatrices, stationaryDistributions, observationErrorPr), observations);
            this.leafFactors = SPROperator.leafFactors(starFG);
        } else {
            this.leafFactors = null;
        }
        this.graftOrder = graftOrder;
    }

    @Override
    public Pair<LargeMoveParticle, Double> next(Random rand, LargeMoveParticle current) {
        if (this.kernelOptions.useRegraftMode) {
            try {
                double ratioLogProduct = 0.0;
                LargeMoveParticle init = current;
                if (current.tree.leaves().size() == 3) {
                    for (int i = 0; i < this.kernelOptions.regraftNBranchResamplingRounds; ++i) {
                        Pair<LargeMoveParticle, Double> newParticle = this.next(rand, init, true, false);
                        init = newParticle.getFirst();
                        ratioLogProduct += newParticle.getSecond().doubleValue();
                    }
                }
                Pair<LargeMoveParticle, Double> newParticle = this.nextSPR(rand, init);
                ratioLogProduct += newParticle.getSecond().doubleValue();
                for (int i = 0; i < this.kernelOptions.regraftNBranchResamplingRounds; ++i) {
                    boolean onlyBranches = i > 0 || rand.nextBoolean();
                    newParticle = this.next(rand, newParticle.getFirst(), onlyBranches, false);
                    ratioLogProduct += newParticle.getSecond().doubleValue();
                }
                LargeMoveParticle successor = new LargeMoveParticle(newParticle.getFirst().tree, current.nIterationLeft - 1, Double.NaN, Double.NaN, newParticle.getFirst().logPrior, newParticle.getFirst().logLikelihood);
                return Pair.makePair(successor, ratioLogProduct);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        boolean onlyBranches = current.nIterationLeft % 2 == 0;
        return this.next(rand, current, onlyBranches, false);
    }

    public Pair<LargeMoveParticle, Double> nextSPR(Random rand, LargeMoveParticle current) {
        int nCurrentLeaves = current.tree.nTaxa();
        Taxon toGraft = this.graftOrder.get(nCurrentLeaves);
        return this.nextSPR(rand, current, toGraft, this.kernelOptions.regraftNStemLens, this.kernelOptions.regraftBasicRangerUpperBound);
    }

    public Pair<LargeMoveParticle, Double> nextSPR(Random rand, LargeMoveParticle current, Taxon leafToAdd, int nStemLens, double basicRangeUpperBound) {
        Set<Taxon> leaves = current.tree.leavesSet();
        if (leaves.contains(leafToAdd)) {
            throw new RuntimeException();
        }
        TestRealData.SimpleObservations restricted = SPROperator.restrictedObservations(this.observations, leaves);
        RootedTree rt = LargeMoveUtils.nextRooting(current.tree, rand);
        PhylogeneticFactorGraph beforeGraftFC = new PhylogeneticFactorGraph(rt, new PhylogenyPotentials(this.categoryPriors, this.rateMatrices, this.stationaryDistributions, this.observationErrorPr), restricted);
        double basicRange = rand.nextDouble() * basicRangeUpperBound;
        double[] stemLens = new double[nStemLens];
        for (int i = 0; i < nStemLens; ++i) {
            stemLens[i] = basicRange;
            basicRange *= 2.0;
        }
        SPROperator spr = new SPROperator(beforeGraftFC, this.leafFactors.get(leafToAdd), leafToAdd, stemLens);
        spr.addRegrafts(rand.nextDouble());
        int graftIdx = spr.sample(rand);
        double proposalPr = spr.getNormalizedProbabilities()[graftIdx];
        UnrootedTree newTree = spr.getRegraft(graftIdx).treeAfterRegraft();
        RootedTree newTreeRandomRooted = LargeMoveUtils.nextRooting(newTree, rand);
        double newLogPrior = this.priorOptions.logDensity(newTreeRandomRooted);
        double newLogLikelihood = spr.getUnnormalizedPrs().get(graftIdx);
        LargeMoveParticle newParticle = new LargeMoveParticle(newTree, current.nIterationLeft - 1, Double.NaN, Double.NaN, newLogPrior, newLogLikelihood);
        double logRatio = newParticle.logPrior - current.logPrior + newParticle.logLikelihood - current.logLikelihood - Math.log(proposalPr);
        return Pair.makePair(newParticle, logRatio);
    }

    public Pair<LargeMoveParticle, Double> next(Random rand, LargeMoveParticle current, boolean onlyBranches, boolean forceBreak) {
        UnrootedTree currentUnrooted = current.tree;
        RootedTree currentRooted = LargeMoveUtils.nextRooting(currentUnrooted, rand);
        List<Pair<Arbre<Taxon>, Arbre<Taxon>>> edges = onlyBranches ? LargeMoveOperatorSelection.allEdges(currentRooted.topology()) : LargeMoveOperatorSelection.randomNonTerminalEdgesInPreorder(rand, currentRooted.topology());
        MutableGraph<Taxon> newTopology = new MutableGraph<Taxon>(currentRooted.getUnrooted().getTopology());
        HashMap<UnorderedPair<Taxon, Taxon>, Double> newBranchLengths = new HashMap<UnorderedPair<Taxon, Taxon>, Double>(currentUnrooted.branchLengths);
        Observations observations = this.observations;
        if (observations.observations().keySet().size() != current.tree.nTaxa()) {
            observations = SPROperator.restrictedObservations(observations, current.tree.leavesSet());
        }
        PhylogeneticFactorGraph factorGraph = new PhylogeneticFactorGraph(currentRooted, new PhylogenyPotentials(this.categoryPriors, this.rateMatrices, this.stationaryDistributions, this.observationErrorPr), observations);
        double nTopologyMoves = 0.0;
        double nBLMoves = 0.0;
        double proposalLogPr = 0.0;
        HashMap<Pair<Arbre<Taxon>, Arbre<Taxon>>, Double> sampledModifiers = new HashMap<Pair<Arbre<Taxon>, Arbre<Taxon>>, Double>();
        HashMap<Pair<Arbre<Taxon>, Arbre<Taxon>>, Integer> sampledBL = new HashMap<Pair<Arbre<Taxon>, Arbre<Taxon>>, Integer>();
        HashMap<Pair<Arbre<Taxon>, Arbre<Taxon>>, Taxon> originalC0I1 = new HashMap<Pair<Arbre<Taxon>, Arbre<Taxon>>, Taxon>();
        for (Pair<Arbre<Taxon>, Arbre<Taxon>> orderedEdge : edges) {
            double modifier = Sampling.nextDouble(rand, 1.0, this.kernelOptions.modifierBound);
            sampledModifiers.put(orderedEdge, modifier);
            double[] bls = LargeMoveOperatorSelection.branchLengthPerturbations(this.kernelOptions.nBLExpansions, modifier, currentRooted.branchLengths().get(orderedEdge.getSecond().getContents()));
            MicroGibbsResult currentMicroGibbs = LargeMoveOperator.efficientMicroGibbs(factorGraph, orderedEdge, bls, onlyBranches, forceBreak);
            originalC0I1.put(orderedEdge, currentMicroGibbs.getTaxon(0, 0, 1));
            int sampled = currentMicroGibbs.sampleIndex(rand);
            double[] prs = currentMicroGibbs.samplingProbabilities();
            proposalLogPr += Math.log(prs[sampled]);
            sampledBL.put(orderedEdge, currentMicroGibbs.branchLengthIndex(sampled));
            if (currentMicroGibbs.branchLength(sampled) != currentRooted.branchLengths().get(orderedEdge.getSecond().getContents()).doubleValue()) {
                nBLMoves += 1.0;
            }
            if (!onlyBranches && currentMicroGibbs.configuration(sampled) != 0) {
                nTopologyMoves += 1.0;
            }
            currentMicroGibbs.alter(newTopology, newBranchLengths, sampled);
        }
        double branchChangeFraction = nBLMoves / (double)edges.size();
        double topoChangeFraction = nTopologyMoves / (double)edges.size();
        UnrootedTree newUnrooted = new UnrootedTree(newTopology, newBranchLengths);
        RootedTree newWithSameRooting = newUnrooted.reRootAtNode(currentRooted.topology().getContents());
        double newLogPrior = this.priorOptions.logDensity(newWithSameRooting);
        PhylogeneticFactorGraph newFactorGraph = factorGraph.createWithNewTopology(newWithSameRooting);
        LargeMoveParticle newParticle = new LargeMoveParticle(newUnrooted, current.nIterationLeft - 1, branchChangeFraction, topoChangeFraction, newLogPrior, newFactorGraph.getSumProductPosteriorCalculator().logZ());
        double backLogPr = 0.0;
        for (Pair<Arbre<Taxon>, Arbre<Taxon>> orderedEdge : edges) {
            double modifier = (Double)sampledModifiers.get(orderedEdge);
            Double bl = (Double)newBranchLengths.get(new UnorderedPair<Taxon, Taxon>(orderedEdge.getFirst().getContents(), orderedEdge.getSecond().getContents()));
            if (bl == null) {
                bl = currentRooted.branchLengths().get(orderedEdge.getSecond().getContents());
            }
            double[] bls = LargeMoveOperatorSelection.branchLengthPerturbations(this.kernelOptions.nBLExpansions, modifier, bl);
            MicroGibbsResult currentMicroGibbs = LargeMoveOperator.efficientMicroGibbs(newFactorGraph, orderedEdge, bls, onlyBranches, forceBreak);
            int backConfig = currentMicroGibbs.findConfig((Taxon)originalC0I1.get(orderedEdge));
            int backBL = this.invertBLIndex((Integer)sampledBL.get(orderedEdge), this.kernelOptions.nBLExpansions);
            int reversedIdx = currentMicroGibbs.prsIndex(backConfig, backBL);
            backLogPr += currentMicroGibbs.samplingProbabilities()[reversedIdx];
        }
        double logRatio = newParticle.logPrior - current.logPrior + newParticle.logLikelihood() - current.logLikelihood() - proposalLogPr + backLogPr;
        return Pair.makePair(newParticle, logRatio);
    }

    private int invertBLIndex(int position, int nExp) {
        int middlePos = nExp;
        int delta = position - middlePos;
        return middlePos - delta;
    }

    @Override
    public int nIterationsLeft(LargeMoveParticle partialState) {
        return partialState.nIterationLeft;
    }

    @Override
    public LargeMoveParticle getInitial() {
        return this.initial;
    }

    public static class LargeMoveKernelOptions {
        @Option
        public double modifierBound = 1.3;
        @Option
        public int nBLExpansions = 1;
        @Option
        public double regraftBasicRangerUpperBound = 0.2;
        @Option
        public int regraftNStemLens = 5;
        @Option
        public boolean useRegraftMode = true;
        @Option
        public int regraftNBranchResamplingRounds = 2;
        @Option
        public Random regraftRandomOrderRand = new Random(1L);
    }

    public static class LargeMoveParticle {
        public final UnrootedTree tree;
        public final int nIterationLeft;
        public final double branchChangeFraction;
        public final double topoChangeFraction;
        public final double logPrior;
        public final double logLikelihood;

        public LargeMoveParticle(UnrootedTree tree, int nIterationLeft, double branchChangeFraction, double topoChangeFraction, double logPrior, double logLikelihood) {
            this.tree = tree;
            this.nIterationLeft = nIterationLeft;
            this.branchChangeFraction = branchChangeFraction;
            this.topoChangeFraction = topoChangeFraction;
            this.logPrior = logPrior;
            this.logLikelihood = logLikelihood;
        }

        public double logLikelihood() {
            return this.logLikelihood;
        }
    }
}

