/*
 * Decompiled with CFR 0.152.
 */
package conifer.fastpf;

import conifer.fastpf.FastBranchEstimate;
import conifer.fastpf.FastParticle;
import conifer.fastpf.TaxaOrderHeuristic;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.prob.Beta;
import goblin.Taxon;
import java.util.List;
import java.util.Random;
import nuts.math.ProposalRandom;
import nuts.math.Sampling;
import nuts.util.Counter;
import pty.RootedTree;
import pty.smc.LazyParticleFilter;
import pty.smc.ParticleFilter;
import pty.smc.models.CTMC;
import pty.smc.models.FastDiscreteModelCalculator;

public class FastPriorPrior
implements LazyParticleFilter.LazyParticleKernel<FastParticle> {
    private final FastParticle initial;
    private final FastPriorPriorOptions options;
    private final SimplePriorOptions priorOptions;

    public FastPriorPrior(FastParticle initial, FastPriorPriorOptions options, SimplePriorOptions priorOptions) {
        this.priorOptions = priorOptions;
        this.initial = initial;
        this.options = options;
    }

    private Object compute(Random rand, FastParticle current, boolean isPeek) {
        ProposalRandom pr = new ProposalRandom(rand);
        int p0 = pr.sampleDiscreteUniform(current.nRoots());
        int p1 = this.pickOtherPoint(p0, current.nRoots(), pr);
        FastParticle.CacheNode c0 = current.cacheNodes[p0];
        FastParticle.CacheNode c1 = current.cacheNodes[p1];
        double expectedFractionOfDifferentChars = FastBranchEstimate.fractionThatDiffers(c0.calculator.cache, c1.calculator.cache);
        double sampledBL = FastBranchEstimate.fastBranchEstimateSample(expectedFractionOfDifferentChars, this.options.branchProposalVariance, rand);
        pr.logProbability += FastBranchEstimate.fastBranchEstimateLogDensity(expectedFractionOfDifferentChars, this.options.branchProposalVariance, sampledBL);
        double fraction = Beta.sample(rand, this.options.branchSplitParameter, this.options.branchSplitParameter);
        pr.logProbability += Beta.logProb(this.options.branchSplitParameter, this.options.branchSplitParameter, fraction);
        double bl0 = fraction * sampledBL;
        double bl1 = (1.0 - fraction) * sampledBL;
        double logPrior = this.priorOptions.logDensity(bl0) + this.priorOptions.logDensity(bl1);
        double newTreeLogLikelihood = Double.NaN;
        FastParticle newPart = null;
        int nNonTrivialRoots = current.nNonTrivialRoots - (c0.isLeaf() ? 0 : 1) - (c1.isLeaf() ? 0 : 1) + 1;
        if (isPeek) {
            newTreeLogLikelihood = c0.calculator.peekCoalescedLogLikelihood(c0.calculator, c1.calculator, bl0, bl1);
        } else {
            int nIterLeft = this.nIterationsLeft(current);
            FastDiscreteModelCalculator newDMC = (FastDiscreteModelCalculator)c0.calculator.combine(c0.calculator, c1.calculator, bl0, bl1, nIterLeft == 1);
            newTreeLogLikelihood = newDMC.logLikelihood();
            newPart = current.next(p0, p1, bl0, bl1, newDMC, nNonTrivialRoots);
        }
        double oldTreesLogLikelihoods = c0.calculator.logLikelihood() + c1.calculator.logLikelihood();
        double logRatio = logPrior + newTreeLogLikelihood - oldTreesLogLikelihoods - pr.logProbability - Math.log(nNonTrivialRoots);
        return isPeek ? Double.valueOf(logRatio) : Pair.makePair(newPart, logRatio);
    }

    private int pickOtherPoint(int point, int nRoots, ProposalRandom rand) {
        int radius = (int)Math.sqrt(nRoots);
        double closePr = nRoots >= 25 ? this.options.probabilityToPickClosebyNode : 0.0;
        boolean pickCloseOne = rand.rand.nextDouble() < closePr;
        int pickedPosition = pickCloseOne ? FastPriorPrior.pickClose(point, rand.rand, nRoots, radius) : FastPriorPrior.pickUnif(point, rand.rand, nRoots);
        double unifPr = 1.0 - closePr;
        double pr = unifPr * 1.0 / (double)(nRoots - 1);
        if (Math.abs(pickedPosition - point) <= radius) {
            pr += closePr * 1.0 / (double)(2 * radius);
        }
        rand.logProbability += Math.log(pr);
        return pickedPosition;
    }

    private static int pickUnif(int point, Random rand, int nRoots) {
        int picked = rand.nextInt(nRoots - 1);
        if (picked >= point) {
            ++picked;
        }
        return picked;
    }

    private static int pickClose(int point, Random rand, int size, int radius) {
        int offset = rand.nextInt(radius) + 1;
        int raw = point + offset * (rand.nextBoolean() ? 1 : -1);
        return (raw + size) % size;
    }

    @Override
    public Pair<FastParticle, Double> next(Random rand, FastParticle current) {
        return (Pair)this.compute(rand, current, false);
    }

    @Override
    public double peekNext(Random rand, FastParticle current) {
        return (Double)this.compute(rand, current, true);
    }

    @Override
    public int nIterationsLeft(FastParticle partialState) {
        return partialState.nRoots() - 1;
    }

    @Override
    public FastParticle getInitial() {
        return this.initial;
    }

    public static void main(String[] args) {
        PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions phyloOptions = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
        phyloOptions.holdOutFre = 0.0;
        phyloOptions.minFractionObserved = 0.5;
        phyloOptions.alignmentFile = "/Users/bouchard/Documents/data/small-test-data/synthetic-phylo-10/alignment.fasta";
        phyloOptions.treeFile = "/Users/bouchard/Documents/data/small-test-data/synthetic-phylo-10/tree.newick";
        phyloOptions.maxNSites = 200;
        PhylogeneticHeldoutDataset phyloData = PhylogeneticHeldoutDataset.loadData(phyloOptions);
        CTMC.SimpleCTMC ctmc = CTMC.SimpleCTMC.dnaCTMC(phyloData.obs.nSites());
        List<Taxon> taxaOrder = TaxaOrderHeuristic.heuristicOrder(phyloData.obs.observations());
        FastParticle fp = FastParticle.initFastParticle(phyloData.obs.observations(), ctmc, taxaOrder);
        FastPriorPriorOptions options = new FastPriorPriorOptions();
        SimplePriorOptions pOp = new SimplePriorOptions();
        FastPriorPrior fpp = new FastPriorPrior(fp, options, pOp);
        ParticleFilter.StoreProcessor storeP = new ParticleFilter.StoreProcessor();
        LazyParticleFilter.ParticleFilterOptions pfOptions = new LazyParticleFilter.ParticleFilterOptions();
        pfOptions.verbose = true;
        pfOptions.nThreads = 1;
        LazyParticleFilter<FastParticle> filter = new LazyParticleFilter<FastParticle>(fpp, pfOptions);
        LogInfo.track("Sampling");
        filter.sample(storeP);
        LogInfo.end_track();
    }

    public static void testPairSampling() {
        int i;
        Random rand = new Random(1L);
        double num = 0.0;
        double denom = 1000.0;
        int i2 = 0;
        while ((double)i2 < denom) {
            System.out.println(Beta.sample(rand, 5.0, 5.0));
            ++i2;
        }
        System.out.println(num / denom);
        SimplePriorOptions pOp = new SimplePriorOptions();
        FastPriorPrior fpp = new FastPriorPrior(null, new FastPriorPriorOptions(), pOp);
        Counter<Integer> theory = new Counter<Integer>();
        Counter<Integer> sampled = new Counter<Integer>();
        int point = 2;
        int size = 300;
        int radius = 5;
        for (i = 0; i < 1000000; ++i) {
            ProposalRandom pRand = new ProposalRandom(rand);
            int picked = fpp.pickOtherPoint(2, 300, pRand);
            sampled.incrementCount(picked, 1.0);
            theory.setCount(picked, Math.exp(pRand.logProbability));
        }
        sampled.normalize();
        for (i = 0; i < 300; ++i) {
            System.out.println("" + i + "\t" + sampled.getCount(i) + "\t" + theory.getCount(i));
        }
    }

    public static final class SimplePriorOptions {
        @Option
        public double meanBranchLengthPriorParam = 3.0;

        private double logDensity(double length) {
            return Sampling.exponentialLogDensity(this.meanBranchLengthPriorParam, length);
        }

        public double logDensity(RootedTree t) {
            double result = 0.0;
            for (double d : t.branchLengths().values()) {
                result += this.logDensity(d);
            }
            return result;
        }
    }

    public static final class FastPriorPriorOptions {
        @Option
        public double probabilityToPickClosebyNode = 0.5;
        @Option
        public double branchProposalVariance = 2.0;
        @Option
        public double branchSplitParameter = 5.0;
    }
}

