/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.util.Random;
import nuts.math.Sampling;
import nuts.math.TrapezoidLogSpaceIntegrator;
import nuts.util.CoordinatesPacker;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleKernel;
import pty.smc.PriorPriorKernel;
import pty.smc.models.BrownianModelCalculator;
import pty.smc.models.LikelihoodModelCalculator;
import pty.smc.models.VarianceMarginalUtils;

public class PostPostKernelA
implements ParticleKernel<PartialCoalescentState> {
    private static final double MAX_DELTA = 3.0;
    private static final int MAX_TRIALS = 1000000;
    private final PartialCoalescentState initial;

    public PostPostKernelA(PartialCoalescentState initial) {
        this.initial = initial;
    }

    @Override
    public Pair<PartialCoalescentState, Double> next(Random rand, PartialCoalescentState state) {
        int nRoots = state.nRoots();
        double[] prs = new double[nRoots * nRoots];
        double lognorm = 0.0;
        CoordinatesPacker cp = new CoordinatesPacker(nRoots);
        for (int i = 0; i < nRoots; ++i) {
            for (int j = 0; j < nRoots; ++j) {
                int idx = cp.coord2int(i, j);
                if (i >= j) {
                    prs[idx] = Double.NEGATIVE_INFINITY;
                    continue;
                }
                prs[idx] = PostPostKernelA.posteriorCoalesceLogPr(i, j, state);
                lognorm += prs[idx];
            }
        }
        NumUtils.expNormalize(prs);
        int idx = SampleUtils.sampleMultinomial(rand, prs);
        int[] pair = cp.int2coord(idx);
        double delta = Sampling.sampleExponential(rand, (this.initial.isClock() ? 1.0 : 0.5) / PriorPriorKernel.nChoose2(state.nRoots()));
        PartialCoalescentState result = state.coalesce(pair[0], pair[1], delta, 0.0, 0.0);
        double weightUpdate = prs[idx] + lognorm - result.logLikelihoodRatio();
        return Pair.makePair(result, weightUpdate);
    }

    private double samplePosteriorDelta(Random rand, PartialCoalescentState state, int left, int right) {
        if (!state.isBrownianMotion()) {
            throw new RuntimeException();
        }
        Pair<UnivariateRealFunction, Double> pair = PostPostKernelA.getProposal(left, right, state);
        double M = pair.getSecond();
        double param = 1.0 / PriorPriorKernel.nChoose2(state.nRoots());
        LogInfo.track("Sampling");
        for (int i = 0; i < 1000000; ++i) {
            LogInfo.logs("Attempt " + i);
            double x = Sampling.sampleExponential(rand, param);
            double u = rand.nextDouble() * M;
            try {
                if (!(u <= pair.getFirst().value(x))) continue;
                LogInfo.end_track();
                return x;
            }
            catch (FunctionEvaluationException e) {
                throw new RuntimeException();
            }
        }
        throw new RuntimeException();
    }

    private static Pair<UnivariateRealFunction, Double> getProposal(int left, int right, final PartialCoalescentState state) {
        LikelihoodModelCalculator leftCalculator = state.getLikelihoodModelCalculator(left);
        LikelihoodModelCalculator rightCalculator = state.getLikelihoodModelCalculator(right);
        final double node1H = state.getHeight(left);
        final double node2H = state.getHeight(right);
        final double baseV1 = Math.max(0.0, node2H - node1H);
        final double baseV2 = Math.max(0.0, node1H - node2H);
        final BrownianModelCalculator l1 = (BrownianModelCalculator)state.getLikelihoodModelCalculator(left);
        final BrownianModelCalculator l2 = (BrownianModelCalculator)state.getLikelihoodModelCalculator(right);
        final double halfSumSquareDiff = PostPostKernelA.halfSumSquareDiff(l1.message, l2.message);
        final double n = state.getObservations().nSites();
        final double prefix = n / 2.0 * Math.log(Math.PI * 2);
        UnivariateRealFunction fct = new UnivariateRealFunction(){

            public double value(double deltaT) throws FunctionEvaluationException {
                if (baseV1 + deltaT + baseV2 + deltaT == 0.0) {
                    deltaT = 1.0E-10;
                }
                double sigma = 2.0 * deltaT + Math.abs(node1H - node2H) + l1.messageVariance + l2.messageVariance;
                double secondTerm = -prefix - n / 2.0 * Math.log(sigma) - halfSumSquareDiff / sigma;
                double result = PostPostKernelA.logPrior(deltaT, state.nRoots()) + secondTerm;
                return result;
            }
        };
        return Pair.makePair(fct, leftCalculator.peekCoalescedLogLikelihood(leftCalculator, rightCalculator, baseV1, baseV2));
    }

    private static double posteriorCoalesceLogPr(int left, int right, PartialCoalescentState state) {
        TrapezoidLogSpaceIntegrator tlsi;
        if (!state.isReversible() || !state.isClock()) {
            throw new RuntimeException("Not yet supported");
        }
        double sum = 0.0;
        try {
            LogInfo.track((Object)"Integrating...", true);
            tlsi = new TrapezoidLogSpaceIntegrator(PostPostKernelA.getProposal(left, right, state).getFirst());
            LogInfo.end_track();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        System.out.println("Numerical approx.  :" + (sum += tlsi.integrate(0.0, 3.0)));
        BrownianModelCalculator bmc1 = (BrownianModelCalculator)state.getLikelihoodModelCalculator(left);
        BrownianModelCalculator bmc2 = (BrownianModelCalculator)state.getLikelihoodModelCalculator(right);
        double h1 = state.getHeight(left);
        double h2 = state.getHeight(right);
        double lambda = 1.0 / PriorPriorKernel.nChoose2(state.nRoots());
        double k = Math.abs(h1 - h2) + bmc1.bm.varianceScale * (bmc1.messageVariance + bmc2.messageVariance);
        double kPrime = PostPostKernelA.halfSumSquareDiff(bmc1.message, bmc2.message);
        double n = state.getObservations().nSites();
        double p = 1.0 - n / 2.0;
        double a = lambda / 2.0 / bmc1.bm.varianceScale;
        double truncIntegral = VarianceMarginalUtils.truncatedGIGLogNormalizationApprox(p, a, kPrime, k);
        double estimate = Math.log(lambda) - Math.log(2.0 * bmc1.bm.varianceScale) + k / 2.0 / bmc1.bm.varianceScale - n / 2.0 * Math.log(Math.PI * 2) + truncIntegral;
        System.out.println("Gamma approximation:" + estimate);
        return sum;
    }

    private static double halfSumSquareDiff(double[] message1, double[] message2) {
        if (message1.length != message2.length) {
            throw new RuntimeException();
        }
        double sum = 0.0;
        for (int i = 0; i < message1.length; ++i) {
            if (message1[i] < 0.0 || message1[i] > 1.5707963267948966) {
                throw new RuntimeException();
            }
            if (message2[i] < 0.0 || message2[i] > 1.5707963267948966) {
                throw new RuntimeException();
            }
            double term = message1[i] - message2[i];
            sum += term * term;
        }
        return sum / 2.0;
    }

    public static double logPrior(double x, int nRoots) {
        double param = 1.0 / PriorPriorKernel.nChoose2(nRoots);
        return Sampling.exponentialDensity(param, x);
    }

    @Override
    public PartialCoalescentState getInitial() {
        return this.initial;
    }

    @Override
    public int nIterationsLeft(PartialCoalescentState partialState) {
        return partialState.nIterationsLeft();
    }
}

