/*
 * Decompiled with CFR 0.152.
 */
package conifer.ssm;

import conifer.ssm.ImportanceSampler;
import conifer.ssm.SimplePotentialExperiments;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.exec.Execution;
import fig.prob.SampleUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.io.IO;
import nuts.math.Plot3D;
import nuts.math.Sampling;
import nuts.math.StatisticsMap;
import nuts.maxent.Function;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import org.apache.commons.math.random.MersenneTwister;
import org.apache.commons.math.random.RandomAdaptor;
import org.apache.commons.math.random.RandomGenerator;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class BirthDeathExperiment
implements Runnable {
    @Option
    public Random trueParamGenRandom = new Random(1L);
    @Option
    public Random dataGenRandom = new Random(1L);
    @Option
    public int nSeries = 1000;
    @Option
    public double timeWindow = 0.4;
    @Option
    public int initWindow = 5;
    @Option
    public int nGrids = 10;
    @Option
    public Random samplingRandom = new Random(1L);
    @Option
    public int nParticles = 1000;
    @Option
    public int maxValue = 20;
    @Option
    public boolean useNewProposal = true;
    Pair<Double, Double> maxFound = null;
    double maxFoundValue = Double.NEGATIVE_INFINITY;
    double trueLikelihoodValue = 0.0;
    private static final int MAX_SAMPLE_TRIES = 1000000;

    @Override
    public void run() {
        SimpleBD trueParam = BirthDeathExperiment.priorSampleSimpleBD(this.trueParamGenRandom).getFirst();
        trueParam = new SimpleBD(trueParam.lambda, trueParam.mu, this.maxValue);
        LogInfo.logsForce("trueParam = " + trueParam);
        final List<Datum<Integer>> data = BirthDeathExperiment.generateData(trueParam, this.dataGenRandom, this.nSeries, this.timeWindow, this.initWindow);
        double[][] Q = SimplePotentialExperiments.getRateMatrix(trueParam, trueParam.indexer());
        double[][] exp = MatrixFunctions.expm((DoubleMatrix)new DoubleMatrix(Q).mul(this.timeWindow)).toArray2();
        for (Datum<Integer> datum : data) {
            MathUtils.checkClose(this.timeWindow, datum.length);
            this.trueLikelihoodValue += Math.log(exp[(Integer)datum.startPoint][(Integer)datum.endPoint]);
        }
        LogInfo.logsForce("trueLLValue = " + this.trueLikelihoodValue);
        boolean[] blArray = new boolean[]{true, false};
        int n = blArray.length;
        for (int i = 0; i < n; ++i) {
            boolean _exact;
            final boolean exact = _exact = blArray[i];
            LogInfo.track("Creating likelihood plot, exact = " + exact);
            Function f = new Function(){

                @Override
                public double valueAt(double[] x) {
                    double logl;
                    double curMu = x[0];
                    double curLambda = x[1];
                    SimpleBD currentBD = new SimpleBD(curLambda, curMu, BirthDeathExperiment.this.maxValue);
                    if (exact) {
                        double[][] Q = SimplePotentialExperiments.getRateMatrix(currentBD, currentBD.indexer());
                        double[][] exp = MatrixFunctions.expm((DoubleMatrix)new DoubleMatrix(Q).mul(BirthDeathExperiment.this.timeWindow)).toArray2();
                        logl = 0.0;
                        for (Datum datum : data) {
                            MathUtils.checkClose(BirthDeathExperiment.this.timeWindow, datum.length);
                            logl += Math.log(exp[(Integer)datum.startPoint][(Integer)datum.endPoint]);
                        }
                        LogInfo.logsForce("(" + curMu + "," + curLambda + ") = " + logl);
                    } else {
                        Proposal<Integer> proposal = BirthDeathExperiment.this.useNewProposal ? new SimplePotentialExperiments.PotProposal<Integer>(currentBD, SimplePotentialExperiments.INTEGER_POTENTIAL, new SimplePotentialExperiments.PotPropOptions()) : new BDProposer();
                        ImportanceSampler<Integer> sampler = new ImportanceSampler<Integer>(proposal, currentBD);
                        sampler.rand = BirthDeathExperiment.this.samplingRandom;
                        sampler.nParticles = BirthDeathExperiment.this.nParticles;
                        logl = BirthDeathExperiment.estimateZ(data, sampler);
                    }
                    if (logl > BirthDeathExperiment.this.maxFoundValue) {
                        BirthDeathExperiment.this.maxFoundValue = logl;
                        BirthDeathExperiment.this.maxFound = Pair.makePair(curMu, curLambda);
                    }
                    return logl;
                }

                @Override
                public int dimension() {
                    return 2;
                }
            };
            Plot3D plot = new Plot3D(f);
            plot.setMax_x(3.0);
            plot.setMin_x(1.0);
            plot.setMax_y(3.0);
            plot.setMin_y(1.0);
            plot.setRes_x(this.nGrids);
            plot.setRes_y(this.nGrids);
            plot.savePlot(new File(Execution.getFile("output(" + exact + "," + trueParam.mu + "," + trueParam.lambda + ".pdf")));
            LogInfo.logsForce("maxLL = " + this.maxFound);
            LogInfo.logsForce("maxLLValue = " + this.maxFoundValue);
            LogInfo.end_track();
        }
    }

    private double get(int pointIdx) {
        return 1.0 + 2.0 * (double)pointIdx / (double)this.nGrids;
    }

    public static <S> Pair<S, Double> sampleJumpWaitingTime(Random rand, S x, Process<S> p) {
        double time = Sampling.sampleExponential(rand, 1.0 / p.holdRate(x));
        return Pair.makePair(p.sample(rand, x), time);
    }

    public static void main(String[] args) {
        IO.run(args, new BirthDeathExperiment());
    }

    private static void test2() {
        Random rand = new Random(4329087L);
        SimpleBD bd = BirthDeathExperiment.priorSampleSimpleBD(rand).getFirst();
        System.out.println("trueParm = " + bd);
        List<Datum<Integer>> data = BirthDeathExperiment.generateData(bd, rand, 1000, 0.5, 5);
        PMCMC pmcmc = new PMCMC();
        pmcmc.sample(data);
    }

    private static List<Datum<Integer>> generateData(SimpleBD bd, Random rand, int nPoints, double timeWindow, int initWindow) {
        ArrayList<Datum<Integer>> result = new ArrayList<Datum<Integer>>();
        for (int i = 0; i < nPoints; ++i) {
            double t = timeWindow;
            int x = rand.nextInt(initWindow);
            List<Integer> sample = BirthDeathExperiment.forwardSample(rand, bd, x, t);
            result.add(new Datum(sample, t));
        }
        return result;
    }

    private static void test1() {
        Random[] rands = new Random[]{new Random(1L), new RandomAdaptor((RandomGenerator)new MersenneTwister(10000))};
        SimpleBD bd = new SimpleBD(3.0, 2.0);
        BDProposer bdp = new BDProposer();
        for (int i = 0; i < 100; ++i) {
            double t = 0.2;
            List<Integer> currentPath = BirthDeathExperiment.forwardSample(rands[0], bd, rands[0].nextInt(20), t);
            System.out.println("forwardSample = " + currentPath);
            int x = currentPath.get(0);
            int y = currentPath.get(currentPath.size() - 1);
            for (int nPart = 1000; nPart < 10000000; nPart *= 10) {
                for (int j = 0; j < 2; ++j) {
                    Random rand = rands[j];
                    Counter<List<Integer>> sample = BirthDeathExperiment.importanceSample(rand, nPart, bd, bdp, x, y, t);
                    System.out.println("approxNomr(" + nPart + ") = " + sample.totalCount() / (double)nPart);
                    sample.normalize();
                    System.out.println("importanceSample(" + nPart + ") = " + sample);
                }
            }
            System.out.println();
        }
    }

    public static <S> double estimateZ(List<Datum<S>> series, ImportanceSampler<S> sampler) {
        double sumOfLogs = 0.0;
        for (Datum<S> d : series) {
            Counter<List<S>> samples = sampler.sample(d.startPoint, d.endPoint, d.length);
            sumOfLogs = sumOfLogs + Math.log(samples.totalCount()) - Math.log(sampler.nParticles);
        }
        return sumOfLogs;
    }

    public static Pair<SimpleBD, Double> priorSampleSimpleBD(Random rand) {
        double mu = 2.0 * rand.nextDouble() + 1.0;
        double lambda = 2.0 * rand.nextDouble() + 1.0;
        return Pair.makePair(new SimpleBD(lambda, mu), 0.25);
    }

    public static <S> List<S> forwardSample(Random rand, Process<S> p, S x, double t) {
        ArrayList<S> result = new ArrayList<S>();
        result.add(x);
        double s = 0.0;
        boolean success = false;
        S current = x;
        for (int i = 0; i < 1000000; ++i) {
            Pair<S, Double> sample = BirthDeathExperiment.sampleJumpWaitingTime(rand, current, p);
            if ((s += sample.getSecond().doubleValue()) > t) {
                success = true;
                break;
            }
            current = sample.getFirst();
            result.add(current);
        }
        if (!success) {
            throw new RuntimeException();
        }
        return result;
    }

    public static <S> Counter<List<S>> importanceSample(Random rand, int nParticles, Process<S> process, Proposal<S> proposal, S x, S y, double t) {
        return BirthDeathExperiment.importanceSample(rand, nParticles, process, proposal, x, y, t, null);
    }

    public static <S> Counter<List<S>> importanceSample(Random rand, int nParticles, Process<S> process, Proposal<S> proposal, S x, S y, double t, SummaryStatistics weightVariance) {
        Counter<List<S>> result = new Counter<List<S>>();
        for (int pIdx = 0; pIdx < nParticles; ++pIdx) {
            Pair<List<S>, Double> proposed = proposal.propose(rand, x, y, t);
            double integral = BirthDeathExperiment.integral(process, proposed.getFirst(), t);
            List<S> proposedJumps = proposed.getFirst();
            for (int jIdx = 0; jIdx < proposedJumps.size() - 1; ++jIdx) {
                integral *= process.transitionProbability(proposedJumps.get(jIdx), proposedJumps.get(jIdx + 1));
            }
            if (weightVariance != null) {
                weightVariance.addValue(integral / proposed.getSecond());
            }
            result.incrementCount(proposed.getFirst(), integral / proposed.getSecond());
        }
        return result;
    }

    public static <S> double integral(Process<S> process, List<S> proposed, double t) {
        int size = proposed.size() + 1;
        double[][] mtx = new double[size][size];
        for (int i = 0; i < proposed.size(); ++i) {
            double curRate = process.holdRate(proposed.get(i));
            mtx[i][i] = -curRate * t;
            mtx[i][i + 1] = curRate * t;
        }
        return MatrixFunctions.expm((DoubleMatrix)new DoubleMatrix(mtx)).get(0, size - 2);
    }

    public static class BDProposer
    implements Proposal<Integer> {
        public double stopPr = 0.95;
        public double greed = 0.95;

        @Override
        public Pair<List<Integer>, Double> propose(Random rand, Integer _x, Integer _y, double t) {
            double proposalPr = 1.0;
            int x = _x;
            int y = _y;
            ArrayList<Integer> result = new ArrayList<Integer>();
            result.add(x);
            int current = x;
            boolean success = false;
            for (int i = 0; i < 1000000; ++i) {
                if (current == y) {
                    boolean stop = Sampling.sampleBern(this.stopPr, rand);
                    if (stop) {
                        proposalPr *= this.stopPr;
                        success = true;
                        break;
                    }
                    proposalPr *= 1.0 - this.stopPr;
                }
                if (current == 0) {
                    current = 1;
                } else {
                    boolean flip = !Sampling.sampleBern(this.greed, rand);
                    proposalPr *= flip ? 1.0 - this.greed : this.greed;
                    current = current <= y && !flip || current > y && flip ? ++current : --current;
                }
                result.add(current);
            }
            if (!success) {
                throw new RuntimeException();
            }
            return Pair.makePair(result, proposalPr);
        }
    }

    public static interface Proposal<S> {
        public Pair<List<S>, Double> propose(Random var1, S var2, S var3, double var4);
    }

    public static class SimpleBD
    implements Process<Integer>,
    SimplePotentialExperiments.SparseProcess<Integer> {
        public final double lambda;
        public final double mu;
        public final int nValues;

        public SimpleBD(double lambda, double mu) {
            this.lambda = lambda;
            this.mu = mu;
            this.nValues = Integer.MAX_VALUE;
        }

        public SimpleBD(double lambda, double mu, int nValues) {
            this.lambda = lambda;
            this.mu = mu;
            this.nValues = nValues;
        }

        @Override
        public double holdRate(Integer _x) {
            int x = _x;
            return (x == this.nValues - 1 ? 0.0 : this.lambda) + (x == 0 ? 0.0 : this.mu * (double)x);
        }

        @Override
        public double transitionProbability(Integer _x, Integer _y) {
            int x = _x;
            int y = _y;
            if (x < 0) {
                throw new RuntimeException();
            }
            if (y < 0 || Math.abs(x - y) != 1) {
                return 0.0;
            }
            if (y >= this.nValues) {
                return 0.0;
            }
            if (x == 0 && y == 1) {
                return 1.0;
            }
            if (x == this.nValues - 1 && y == this.nValues - 2) {
                return 1.0;
            }
            if (y == x + 1) {
                return this.lambda / this.holdRate(x);
            }
            if (y == x - 1) {
                return this.mu * (double)x / this.holdRate(x);
            }
            throw new RuntimeException();
        }

        @Override
        public Integer sample(Random rand, Integer x) {
            return x + (Sampling.sampleBern(this.transitionProbability(x, x + 1), rand) ? 1 : -1);
        }

        public String toString() {
            return "SimpleBD [lambda=" + this.lambda + ", mu=" + this.mu + ", nValues=" + this.nValues + "]";
        }

        public Indexer<Integer> indexer() {
            if (this.nValues == Integer.MAX_VALUE) {
                throw new RuntimeException();
            }
            Indexer<Integer> result = new Indexer<Integer>();
            for (int i = 0; i < this.nValues; ++i) {
                result.addToIndex((Integer[])new Integer[]{i});
            }
            return result;
        }

        @Override
        public Counter<Integer> rates(Integer point) {
            Counter<Integer> result = new Counter<Integer>();
            double totalRate = this.holdRate(point);
            for (int delta = -1; delta <= 1; delta += 2) {
                if (this.transitionProbability(point, point + delta) == 0.0) continue;
                result.setCount(point + delta, totalRate * this.transitionProbability(point, point + delta));
            }
            return result;
        }
    }

    public static class FiniteBD
    implements Process<Integer> {
        public double[][] rateMtx;
        public double[] holdRates;
        public double[][] normalizedJumpPrs;

        @Override
        public double holdRate(Integer x) {
            return this.holdRates[x];
        }

        @Override
        public double transitionProbability(Integer x, Integer y) {
            return this.normalizedJumpPrs[x][y];
        }

        @Override
        public Integer sample(Random rand, Integer x) {
            return SampleUtils.sampleMultinomial(rand, this.normalizedJumpPrs[x]);
        }
    }

    public static class Datum<S> {
        public final S startPoint;
        public final S endPoint;
        public final double length;

        public Datum(S x, S y, double t) {
            this.startPoint = x;
            this.endPoint = y;
            this.length = t;
        }

        private Datum(List<S> fullSeqn, double t) {
            this.startPoint = fullSeqn.get(0);
            this.endPoint = fullSeqn.get(fullSeqn.size() - 1);
            this.length = t;
        }
    }

    public static class PMCMC {
        public int iters = 10000;
        public int summaryPrintPeriod = 10;
        public Random rand = new Random(1L);
        public int nParticles = 1000;

        public void sample(List<Datum<Integer>> data) {
            SimpleBD init = BirthDeathExperiment.priorSampleSimpleBD(this.rand).getFirst();
            StatisticsMap<String> summaries = new StatisticsMap<String>();
            double currentDensity = Double.NEGATIVE_INFINITY;
            SimpleBD currentParam = null;
            for (int i = 0; i < this.iters; ++i) {
                Pair<SimpleBD, Double> sbd = BirthDeathExperiment.priorSampleSimpleBD(this.rand);
                ImportanceSampler<Integer> sampler = new ImportanceSampler<Integer>(new BDProposer(), sbd.getFirst());
                sampler.rand = this.rand;
                sampler.nParticles = this.nParticles;
                double logl = BirthDeathExperiment.estimateZ(data, sampler);
                double logRatio = logl + sbd.getSecond() - currentDensity;
                double acceptPr = Math.min(1.0, Math.exp(logRatio));
                summaries.addValue("acceptPr", acceptPr);
                boolean accept = Sampling.sampleBern(acceptPr, this.rand);
                if (accept) {
                    currentDensity = logl + sbd.getSecond();
                    currentParam = sbd.getFirst();
                }
                summaries.addValue("lambda", currentParam.lambda);
                summaries.addValue("mu", currentParam.mu);
                if (i % this.summaryPrintPeriod != 0 && i != this.iters - 1) continue;
                System.out.println(summaries.printAll());
            }
        }
    }

    public static interface Process<S> {
        public double holdRate(S var1);

        public double transitionProbability(S var1, S var2);

        public S sample(Random var1, S var2);
    }
}

