/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml.tests;

import com.google.common.collect.Lists;
import conifer.Phylogeny;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import conifer.multicategories.PhylogeneticFactorGraph;
import conifer.multicategories.PhylogenyPotentials;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.OptionSet;
import fig.basic.UnorderedPair;
import fig.prob.Gamma;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.io.IO;
import nuts.math.PositiveParamResampler;
import nuts.math.RateMtxUtils;
import nuts.util.Arbre;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.smc.models.CTMC;
import pty.smc.models.DiscreteModelCalculator;

public class TestBranchSampler
implements Runnable {
    @OptionSet(name="phylo")
    public PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions phyloOptions = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
    @Option
    public Random rand = new Random(1L);
    @OptionSet(name="prior")
    public PositiveParamResampler.GammaParameters gammaParams = new PositiveParamResampler.GammaParameters(){

        @Override
        public Gamma getDistrib() {
            return new Gamma(Double.NaN, Double.NaN){

                @Override
                public double logProb(double x) {
                    return 0.0;
                }
            };
        }
    };
    @Option
    public double tuning = 1.2;
    private PhylogeneticHeldoutDataset phyloData;
    private PhylogenyPotentials potentials;
    private double branchProposalVariance = 2.0;
    private double branchSplitParameter = 5.0;
    private double[][] rateMatrix = TestBranchSampler.jc();
    private RootedTree centroidRooted;

    public static double[][] jc() {
        int size = 4;
        double value = 1.0 / ((double)size - 1.0);
        double[][] result = new double[size][size];
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                if (i == j) continue;
                result[i][j] = value;
            }
        }
        RateMtxUtils.fillRateMatrixDiagonalEntries(result);
        return result;
    }

    public static void main(String[] args) {
        IO.run(args, new TestBranchSampler());
    }

    public PhyloPartPop mpsmc(Arbre<Taxon> node, Random rand, int nParticles) {
        if (node.isLeaf()) {
            PhyloPartPop result = new PhyloPartPop(1);
            result.normalizedWeights[0] = 1.0;
            Node n = new Node();
            n.calc = DiscreteModelCalculator.observation(new CTMC.SimpleCTMC(this.rateMatrix, this.phyloData.obs.nSites()), this.phyloData.obs.observations().get(node.getContents()));
            n.t = node.getContents();
            result.particles.add(n);
            return result;
        }
        if (node.getChildren().size() != 2) {
            throw new RuntimeException();
        }
        PhyloPartPop p0 = this.mpsmc(node.getChildren().get(0), rand, nParticles);
        PhyloPartPop p1 = this.mpsmc(node.getChildren().get(1), rand, nParticles);
        PhyloPartPop result = new PhyloPartPop(nParticles);
        for (int k = 0; k < nParticles; ++k) {
            double logRatio;
            Node no0 = p0.sample(rand);
            Node no1 = p1.sample(rand);
            DiscreteModelCalculator n0 = no0.calc;
            DiscreteModelCalculator n1 = no1.calc;
            double bl0 = rand.nextDouble() * 1.0;
            double bl1 = rand.nextDouble() * 1.0;
            double q = 0.0;
            DiscreteModelCalculator newDMC = (DiscreteModelCalculator)n0.combine(n0, n1, bl0, bl1, false);
            double newTreeLogLikelihood = newDMC.logLikelihood();
            Node newNode = new Node();
            newNode.b0 = bl0;
            newNode.b1 = bl1;
            newNode.calc = newDMC;
            newNode.left = no0;
            newNode.right = no1;
            newNode.t = node.getContents();
            result.particles.add(newNode);
            double oldTreesLogLikelihoods = n0.logLikelihood() + n1.logLikelihood();
            result.normalizedWeights[k] = logRatio = newTreeLogLikelihood - oldTreesLogLikelihoods - q;
        }
        NumUtils.expNormalize(result.normalizedWeights);
        return result;
    }

    public void reconstruct(Node n, Map<Taxon, Double> bls) {
        if (n.left != null) {
            bls.put(n.left.t, n.b0);
            bls.put(n.right.t, n.b1);
            this.reconstruct(n.left, bls);
            this.reconstruct(n.right, bls);
        }
    }

    @Override
    public void run() {
        System.out.println("Rate mtx = " + Arrays.deepToString((Object[])this.rateMatrix));
        DiscreteModelCalculator.allowDiscreteModelCalculator = true;
        this.phyloData = PhylogeneticHeldoutDataset.loadData(this.phyloOptions);
        this.potentials = PhylogenyPotentials.createSingleCategoryErrorFreeFromStationaryProcess(this.rateMatrix);
        this.centroidRooted = RootedTree.Util.centroidRooting(this.phyloData.rootedTree.getUnrooted());
        Phylogeny sampledRT = null;
        for (int k : new int[]{1000}) {
            System.out.println("starting " + k);
            long time = System.currentTimeMillis();
            Node n = this.mpsmc(this.centroidRooted.topology(), this.rand, k).sample(this.rand);
            System.out.println("time = " + (System.currentTimeMillis() - time));
            double smcLL = n.calc.logLikelihood();
            HashMap<Taxon, Double> bls = new HashMap<Taxon, Double>();
            this.reconstruct(n, bls);
            sampledRT = RootedTree.Util.create(this.centroidRooted.topology(), bls);
            System.out.println("MP-SMC LL REF = " + this.density(sampledRT.getUnrooted()));
        }
        UnrootedTree current = sampledRT.getUnrooted();
        double lastLogDensity = Double.NEGATIVE_INFINITY;
        System.out.println("Comparing to custom MCMC code");
        for (int mcmcIter = 0; mcmcIter < 1000000; ++mcmcIter) {
            final UnorderedPair<Taxon, Taxon> edge = current.randomEdge(this.rand);
            PositiveParamResampler.PositiveParameterization<UnrootedTree> parameterization = new PositiveParamResampler.PositiveParameterization<UnrootedTree>(){

                @Override
                public UnrootedTree state(UnrootedTree current, double p) {
                    current.branchLengths.put(edge, p);
                    return current;
                }

                @Override
                public double param(UnrootedTree state) {
                    return state.branchLength(edge);
                }

                @Override
                public double unormLogDensity(UnrootedTree x) {
                    return TestBranchSampler.this.density(x);
                }
            };
            PositiveParamResampler.Step<UnrootedTree> step = PositiveParamResampler.resample(current, parameterization, this.gammaParams, this.rand, this.tuning);
            current = (UnrootedTree)step.state;
            if (mcmcIter % 1 != 0) continue;
            System.out.println(this.density(current));
        }
    }

    private double density(UnrootedTree current) {
        PhylogeneticFactorGraph factorGraph = new PhylogeneticFactorGraph(current.reRootAtArbitraryInternalNode(), this.potentials, this.phyloData.obs);
        double density = factorGraph.dataLogLikelihoodGivenTreeAndParameters();
        Gamma distrib = this.gammaParams.getDistrib();
        for (UnorderedPair<Taxon, Taxon> e : current.edges()) {
            density += distrib.logProb(current.branchLength(e));
        }
        return density;
    }

    private UnrootedTree randomize(RootedTree rootedTree, Random rand) {
        UnrootedTree result = rootedTree.getUnrooted();
        for (UnorderedPair<Taxon, Taxon> edge : result.edges()) {
            result.branchLengths.put(edge, rand.nextDouble() * 0.1);
        }
        return result;
    }

    public static class PhyloPartPop {
        public double[] normalizedWeights;
        public List<Node> particles;

        public PhyloPartPop(int nParticles) {
            this.normalizedWeights = new double[nParticles];
            this.particles = new ArrayList<Node>();
        }

        public Node sample(Random rand) {
            int index = SampleUtils.sampleMultinomial(rand, this.normalizedWeights);
            return this.particles.get(index);
        }

        public String toString() {
            ArrayList ws = Lists.newArrayList();
            for (double w : this.normalizedWeights) {
                ws.add(w);
            }
            Collections.sort(ws);
            Collections.reverse(ws);
            return ((Object)ws).toString();
        }
    }

    public static class Node {
        public DiscreteModelCalculator calc;
        public Node left = null;
        public Node right = null;
        public double b0 = Double.NaN;
        public double b1 = Double.NaN;
        public Taxon t;
    }
}

