/*
 * Decompiled with CFR 0.152.
 */
package conifer.ssm;

import conifer.pip.LinearizedAlignment;
import conifer.pip.PIPLikelihoodCalculator;
import conifer.pip.simple.PIPProcess;
import conifer.pip.simple.PIPString;
import conifer.ssm.ImportanceSampler;
import conifer.ssm.SimplePotentialExperiments;
import conifer.ssm.TestPIPPrs;
import ev.poi.PoissonParameters;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.exec.Execution;
import gep.util.OutputManager;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import ma.MSAPoset;
import nuts.io.IO;
import nuts.math.Plot2D;
import nuts.math.Sampling;
import nuts.math.StatisticsMap;
import nuts.util.Counter;
import org.apache.commons.math.distribution.PoissonDistributionImpl;
import pty.RootedTree;

public class TestPIPParam
implements Runnable {
    @Option
    public Random rand = new Random(1L);
    @Option
    public Random genRand = new Random(1L);
    @Option
    public double lambda = 2.0;
    @Option
    public double mu = 0.5;
    @Option
    public double bl = 0.3;
    @Option
    public double initMu = 1.0;
    @Option
    public double initLambda = 1.0;
    @Option
    public int nObservations = 10;
    @Option
    public int nMCMCIters = 10000000;
    @Option
    public int printStatFreq = 100;
    @Option
    public int nParticles = 100;
    @Option
    public double proposalWidth = 2.0;
    @Option
    public double priorMean = 1.0;
    @Option
    public boolean useAnalytic = false;
    @Option
    public boolean computeNumericalIntegral = false;
    @Option
    public boolean useSimplePMCMC = false;
    private ArrayList<MSAPoset> msas;
    public static double[] qs = new double[]{25.0, 50.0, 75.0};

    public static void main(String[] args) {
        IO.run(args, new TestPIPParam());
    }

    @Override
    public void run() {
        PIPProcess process = new PIPProcess(this.lambda, this.mu);
        Taxon ta = new Taxon("A");
        Taxon tb = new Taxon("B");
        OutputManager out = new OutputManager();
        this.msas = new ArrayList();
        ArrayList<Pair<PIPString, PIPString>> observations = new ArrayList<Pair<PIPString, PIPString>>();
        LogInfo.track("Generating " + this.nObservations + " datapoints");
        for (int i = 0; i < this.nObservations; ++i) {
            MSAPoset msa = PIPProcess.keepOnlyEndPts(process.sample(this.genRand, this.bl), ta, tb);
            LogInfo.logs(msa);
            this.msas.add(msa);
            Pair<PIPString, PIPString> endPoints = TestPIPPrs.getEndPoints(msa, ta, tb);
            observations.add(endPoints);
        }
        LogInfo.end_track();
        LogInfo.logsForce("brachLen = " + this.bl);
        PIPProcess currentParams = new PIPProcess(this.initLambda, this.initMu);
        StatisticsMap.DescriptiveStatisticsMap<String> stats = new StatisticsMap.DescriptiveStatisticsMap<String>();
        if (this.computeNumericalIntegral) {
            this.computeNumericalIntegral(this.msas);
        }
        double currentLogPi = Double.NEGATIVE_INFINITY;
        LogInfo.track("PMCMC is running");
        long start = System.currentTimeMillis();
        for (int mcmcIter = 0; mcmcIter < this.nMCMCIters; ++mcmcIter) {
            LogInfo.logs("Iteration " + mcmcIter + "/" + this.nMCMCIters);
            HashMap<String, Double> curQts = new HashMap<String, Double>();
            Pair<PIPProcess, Double> proposedParams_logRatio = this.proposeParam_logRatio(currentParams, this.rand);
            PIPProcess proposedParam = proposedParams_logRatio.getFirst();
            double logPi = (this.useAnalytic ? this.computeLLAnalytic(proposedParams_logRatio.getFirst(), this.msas) : this.computeLL(proposedParams_logRatio.getFirst(), observations, this.rand)) + this.logPrior(proposedParam);
            double logRatio = logPi - currentLogPi + proposedParams_logRatio.getSecond();
            double acceptPr = Sampling.min1exp(logRatio);
            curQts.put("acceptPr", acceptPr);
            boolean accepted = Sampling.sampleBern(acceptPr, this.rand);
            if (accepted) {
                currentParams = proposedParam;
                currentLogPi = logPi;
            }
            curQts.put("mu", currentParams.mu);
            curQts.put("lambda", currentParams.lambda);
            for (String qt : curQts.keySet()) {
                stats.addValue(qt, (Double)curQts.get(qt));
            }
            if (mcmcIter % this.printStatFreq != 0) continue;
            for (String qt : curQts.keySet()) {
                long time = System.currentTimeMillis() - start;
                for (double q : qs) {
                    out.write("cumulatives", "time", time, "iter", mcmcIter, "qt", qt, "mean", stats.getDescriptiveStat(qt).getMean(), "sd", stats.getDescriptiveStat(qt).getStandardDeviation(), "q", q, "qvalue", stats.getDescriptiveStat(qt).getPercentile(q));
                }
                out.write("samples", "time", time, "iter", mcmcIter, "qt", qt, "value", curQts.get(qt));
            }
        }
        LogInfo.end_track();
        out.close();
    }

    private void computeNumericalIntegral(List<MSAPoset> msas) {
        int nParts = 100;
        double norm = 0.0;
        double minMu = 0.4;
        double maxMu = 0.7;
        double minLa = 1.0;
        double maxLa = 3.0;
        double muLen = maxMu - minMu;
        double laLen = maxLa - minLa;
        double gridSizeMu = muLen / (double)nParts;
        double gridSizeLa = laLen / (double)nParts;
        double area = gridSizeMu * gridSizeLa;
        double[][] values = new double[nParts][nParts];
        LogInfo.track("Analytic integration...");
        for (int i = 0; i < nParts; ++i) {
            LogInfo.logs("" + i + "/" + nParts);
            double lambda = minLa + gridSizeLa * (double)i + gridSizeLa / 2.0;
            for (int j = 0; j < nParts; ++j) {
                double value;
                double mu = minMu + gridSizeMu * (double)j + gridSizeMu / 2.0;
                PIPProcess proposedParam = new PIPProcess(lambda, mu);
                values[i][j] = value = Math.exp(this.computeLLAnalytic(proposedParam, msas) + this.logPrior(proposedParam));
                norm += value * area;
            }
        }
        LogInfo.end_track();
        for (boolean useLambda : new boolean[]{true, false}) {
            ArrayList<Pair<Double, Double>> points = new ArrayList<Pair<Double, Double>>();
            for (int i = 0; i < nParts; ++i) {
                double cur = useLambda ? minLa + gridSizeLa * (double)i + gridSizeLa / 2.0 : minMu + gridSizeMu * (double)i + gridSizeMu / 2.0;
                double sum = 0.0;
                for (int j = 0; j < nParts; ++j) {
                    sum += (useLambda ? values[i][j] : values[j][i]) * (useLambda ? gridSizeMu : gridSizeLa) / norm;
                }
                points.add(Pair.makePair(cur, sum));
            }
            Plot2D plot = new Plot2D();
            plot.addSeries(points, true, "");
            plot.savePlot(new File(Execution.getFile(useLambda ? "lambda.pdf" : "mu.pdf")));
        }
    }

    private double computeLLAnalytic(PIPProcess first, List<MSAPoset> observations) {
        double lambda = first.lambda;
        double mu = first.mu;
        double result = 0.0;
        PoissonParameters pip = PoissonParameters.simplePIP(lambda, mu);
        Taxon ta = new Taxon("A");
        Taxon tb = new Taxon("B");
        RootedTree rt = RootedTree.Util.fromNewickString("(A:" + this.bl / 2.0 + ",B:" + this.bl / 2.0 + ");");
        for (MSAPoset msa : observations) {
            LinearizedAlignment lia = new LinearizedAlignment(msa);
            PIPLikelihoodCalculator mf = new PIPLikelihoodCalculator(pip, lia, rt);
            double logJoint = mf.computeDataLogProbabilityGivenTree();
            double statRate = lambda / mu;
            PoissonDistributionImpl pd = new PoissonDistributionImpl(statRate);
            double logPrior = Math.log(pd.probability(msa.sequences().get(ta).length()));
            double logConditional = logJoint - logPrior;
            result += logConditional;
        }
        return result;
    }

    private double logPrior(PIPProcess proposedParam) {
        return Sampling.exponentialLogDensity(this.priorMean, proposedParam.lambda) + Sampling.exponentialLogDensity(this.priorMean, proposedParam.mu);
    }

    private Pair<PIPProcess, Double> proposeParam_logRatio(PIPProcess currentParams, Random rand2) {
        boolean useMu = rand2.nextBoolean();
        double m = Sampling.nextDouble(rand2, 1.0 / this.proposalWidth, this.proposalWidth);
        double oldMu = currentParams.mu;
        double oldLa = currentParams.lambda;
        double newMu = useMu ? m * oldMu : oldMu;
        double newLa = useMu ? oldLa : m * oldLa;
        PIPProcess proposed = new PIPProcess(newLa, newMu);
        return Pair.makePair(proposed, Math.log(m));
    }

    private double computeLL(PIPProcess proposedParams, List<Pair<PIPString, PIPString>> observations, Random rand) {
        double result = 0.0;
        System.out.println("l=" + proposedParams.lambda + ",m=" + proposedParams.mu);
        SimplePotentialExperiments.PIPPotential pot = new SimplePotentialExperiments.PIPPotential();
        SimplePotentialExperiments.PotProposal<PIPString> prop = new SimplePotentialExperiments.PotProposal<PIPString>(proposedParams, pot, new SimplePotentialExperiments.PotPropOptions());
        ImportanceSampler<PIPString> is = new ImportanceSampler<PIPString>(prop, proposedParams);
        is.rand = rand;
        for (int i = 0; i < observations.size(); ++i) {
            double estimate;
            is.nParticles = this.nParticles;
            if (this.useSimplePMCMC) {
                MSAPoset msa = this.msas.get(i);
                estimate = TestPIPPrs.stdPMCMC(msa, this.bl, proposedParams, this.nParticles, rand, null);
            } else {
                Pair<PIPString, PIPString> obs = observations.get(i);
                Counter samples = is.sample(obs.getFirst(), obs.getSecond(), this.bl);
                estimate = is.estimateZ(samples);
            }
            result += Math.log(estimate);
        }
        return result;
    }
}

