/*
 * Decompiled with CFR 0.152.
 */
package conifer.ssm;

import conifer.pip.LinearizedAlignment;
import conifer.pip.PIPLikelihoodCalculator;
import conifer.pip.simple.PIPProcess;
import conifer.pip.simple.PIPString;
import conifer.ssm.ImportanceSampler;
import conifer.ssm.SimplePotentialExperiments;
import ev.poi.PoissonParameters;
import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.exec.Execution;
import goblin.Taxon;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Random;
import ma.MSAPoset;
import nuts.io.CSV;
import nuts.io.IO;
import nuts.util.CollUtils;
import nuts.util.Counter;
import org.apache.commons.math.distribution.PoissonDistributionImpl;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.RootedTree;

public class TestPIPPrs
implements Runnable {
    @Option
    public Random rand = new Random(1L);
    @Option
    public Random genRand = new Random(1L);
    @Option
    public double lambda = 2.0;
    @Option
    public double mu = 0.5;
    @Option
    public double bl = 0.3;
    @Option
    public boolean useSimplePMCMC = false;
    @Option
    public int minNParticles = 2;
    @Option
    public int maxNParticles = 10000000;
    @Option
    public int particleGrowthFactor = 2;
    @Option
    public int nBranchLens = 1;
    @Option
    public double branchExpansionFactor = 1.2;

    public static void main(String[] args) {
        IO.run(args, new TestPIPPrs());
    }

    @Override
    public void run() {
        PoissonParameters pip = PoissonParameters.simplePIP(this.lambda, this.mu);
        PIPProcess process = new PIPProcess(this.lambda, this.mu);
        Taxon ta = new Taxon("A");
        Taxon tb = new Taxon("B");
        LogInfo.logsForce("Data generate with branch = " + this.bl);
        MSAPoset msa = PIPProcess.keepOnlyEndPts(process.sample(this.genRand, this.bl), ta, tb);
        LogInfo.logsForce("alignment = \n" + msa + "\n");
        PrintWriter out = IOUtils.openOutHard(Execution.getFile("result.csv"));
        PrintWriter varout = IOUtils.openOutHard(Execution.getFile("variances.csv"));
        out.println(CSV.header("nPart", "error", "estimate", "time", "absLogRatio", "truth", "bl"));
        varout.println(CSV.header("nPart", "var", "bl"));
        for (int i = 0; i < this.nBranchLens; ++i) {
            LogInfo.logsForce("branchLen = " + this.bl);
            RootedTree rt = RootedTree.Util.fromNewickString("(A:" + this.bl / 2.0 + ",B:" + this.bl / 2.0 + ");");
            LinearizedAlignment lia = new LinearizedAlignment(msa);
            PIPLikelihoodCalculator mf = new PIPLikelihoodCalculator(pip, lia, rt);
            double logJoint = mf.computeDataLogProbabilityGivenTree();
            double statRate = this.lambda / this.mu;
            PoissonDistributionImpl pd = new PoissonDistributionImpl(statRate);
            double logPrior = Math.log(pd.probability(msa.sequences().get(ta).length()));
            double logConditional = logJoint - logPrior;
            double truth = Math.exp(logConditional);
            LogInfo.logsForce("exact = " + truth);
            double theoreticalSimplePMCMCVariance = truth * (1.0 - truth);
            LogInfo.logsForce("theoreticalSimplePMCMCVariance = " + theoreticalSimplePMCMCVariance);
            Pair<PIPString, PIPString> endPoints = TestPIPPrs.getEndPoints(msa, ta, tb);
            SimplePotentialExperiments.PIPPotential pot = new SimplePotentialExperiments.PIPPotential();
            SimplePotentialExperiments.PotProposal<PIPString> prop = new SimplePotentialExperiments.PotProposal<PIPString>(process, pot, new SimplePotentialExperiments.PotPropOptions());
            ImportanceSampler<PIPString> is = new ImportanceSampler<PIPString>(prop, process);
            SummaryStatistics weightVariance = new SummaryStatistics();
            for (int nPart = this.minNParticles; nPart < this.maxNParticles; nPart *= this.particleGrowthFactor) {
                long start = System.currentTimeMillis();
                double estimate = Double.NaN;
                if (this.useSimplePMCMC) {
                    estimate = TestPIPPrs.stdPMCMC(msa, this.bl, process, nPart, this.rand, weightVariance);
                } else {
                    is.nParticles = nPart;
                    Counter samples = is.sample(endPoints.getFirst(), endPoints.getSecond(), this.bl, weightVariance);
                    estimate = is.estimateZ(samples);
                }
                LogInfo.logsForce("approx(" + nPart + ") = " + estimate);
                LogInfo.logsForce("weightVarianceEstimate = " + (this.useSimplePMCMC ? theoreticalSimplePMCMCVariance : weightVariance.getVariance()));
                varout.println(CSV.body(nPart, this.useSimplePMCMC ? theoreticalSimplePMCMCVariance : weightVariance.getVariance(), this.bl));
                double error = Math.abs(truth - estimate);
                double absLogRatio = Math.abs(Math.log(truth / estimate));
                long delta = System.currentTimeMillis() - start;
                out.println(CSV.body(nPart, error, estimate, delta, absLogRatio, truth, this.bl));
                out.flush();
                varout.flush();
                this.bl *= this.branchExpansionFactor;
            }
        }
        out.close();
        varout.close();
    }

    public static double stdPMCMC(MSAPoset ref, double bl2, PIPProcess process, int nPart, Random rand, SummaryStatistics weightStats) {
        double num = 0.0;
        for (int i = 0; i < 1000; ++i) {
            rand.nextDouble();
        }
        Taxon ta = new Taxon("A");
        Taxon tb = new Taxon("B");
        for (int i = 0; i < nPart; ++i) {
            MSAPoset temp = process.createInitMSA(ref.sequences().get(ta));
            MSAPoset proposed = PIPProcess.keepOnlyEndPts(process.sample(rand, temp, bl2), ta, tb);
            if (CollUtils.set(ref.edges()).equals(CollUtils.set(proposed.edges())) && ref.sequences().equals(proposed.sequences())) {
                num += 1.0;
                if (weightStats == null) continue;
                weightStats.addValue(1.0);
                continue;
            }
            if (weightStats == null) continue;
            weightStats.addValue(0.0);
        }
        return (num + 1.0) / (double)(nPart + 2);
    }

    public static Pair<PIPString, PIPString> getEndPoints(MSAPoset msa, Taxon ta, Taxon tb) {
        int i;
        if (msa.sequences().size() != 2) {
            throw new RuntimeException();
        }
        HashSet<Integer> aligned1 = new HashSet<Integer>();
        HashSet<Integer> aligned2 = new HashSet<Integer>();
        for (MSAPoset.Column c : msa.columns()) {
            if (c.getPoints().size() != 2) continue;
            if (c.getPoints().containsKey(ta)) {
                aligned1.add(c.getPoints().get(ta));
            }
            if (!c.getPoints().containsKey(tb)) continue;
            aligned2.add(c.getPoints().get(tb));
        }
        ArrayList<Integer> cur = new ArrayList<Integer>();
        for (i = 0; i < msa.sequences().get(ta).length(); ++i) {
            cur.add(aligned1.contains(i) ? 0 : -1);
        }
        PIPString pips1 = new PIPString(cur);
        cur = new ArrayList();
        for (i = 0; i < msa.sequences().get(tb).length(); ++i) {
            cur.add(aligned2.contains(i) ? 0 : 1);
        }
        PIPString pips2 = new PIPString(cur);
        return Pair.makePair(pips1, pips2);
    }
}

