/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.models;

import fig.basic.Option;
import fig.basic.Pair;
import fig.prob.Gaussian;
import java.util.Random;
import org.apache.commons.math.distribution.GammaDistributionImpl;
import org.apache.commons.math.special.Gamma;
import pty.smc.PartialCoalescentState;
import pty.smc.models.BrownianModel;
import pty.smc.models.GIG;
import pty.smc.models.LikelihoodModelCalculator;
import pty.smc.models.VarianceMarginalUtils;

public class BrownianModelCalculator
implements LikelihoodModelCalculator {
    @Option
    public static boolean useVarianceTransform = true;
    public final boolean resampleRoot;
    public final double[] message;
    public final double messageVariance;
    private final double loglikelihood;
    public double lognorm;
    public final BrownianModel bm;
    private final int nsites;
    private static Random rand = new Random(1L);

    private BrownianModelCalculator(double[] message, double messageVariance, BrownianModel bm, double loglikelihood, boolean resampleRoot) {
        this.message = message;
        this.messageVariance = messageVariance;
        this.loglikelihood = loglikelihood;
        this.bm = bm;
        this.nsites = bm.nsites;
        this.lognorm = 0.0;
        this.resampleRoot = resampleRoot;
    }

    private BrownianModelCalculator(double[] message, double messageVariance, BrownianModel bm, double loglikelihood, double lognorm, boolean resampleRoot) {
        this.message = message;
        this.messageVariance = messageVariance;
        this.loglikelihood = loglikelihood;
        this.bm = bm;
        this.nsites = bm.nsites;
        this.lognorm = lognorm;
        this.resampleRoot = resampleRoot;
    }

    public static BrownianModelCalculator observation(double[] observed, BrownianModel bm, boolean resampleRoot) {
        return BrownianModelCalculator.observation(observed, bm, useVarianceTransform, resampleRoot);
    }

    private static BrownianModelCalculator observation(double[] observed, BrownianModel bm, boolean useVarianceTransform, boolean resampleRoot) {
        if (observed.length != bm.nsites) {
            throw new RuntimeException();
        }
        double[] state = BrownianModelCalculator.getNewMessage(bm.nsites);
        for (int i = 0; i < bm.nsites; ++i) {
            state[i] = useVarianceTransform ? Math.asin(Math.sqrt(observed[i])) : observed[i];
        }
        return new BrownianModelCalculator(state, 0.0, bm, 0.0, resampleRoot);
    }

    private final Object calculate(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean peek) {
        BrownianModelCalculator l1 = (BrownianModelCalculator)node1;
        BrownianModelCalculator l2 = (BrownianModelCalculator)node2;
        double[] message = peek ? null : BrownianModelCalculator.getNewMessage(this.nsites);
        double logl = 0.0;
        double var1 = l1.messageVariance;
        double var2 = l2.messageVariance;
        double var = 1.0 / (var1 + v1) + 1.0 / (var2 + v2);
        double newMessageVariance = 1.0 / var;
        for (int i = 0; i < this.nsites; ++i) {
            double mean1 = l1.message[i];
            double mean2 = l2.message[i];
            if (!peek) {
                message[i] = (mean1 / (var1 + v1) + mean2 / (var2 + v2)) / var;
            }
            double cur = BrownianModelCalculator.logNormalDensity(mean1 - mean2, 0.0, this.bm.varianceScale * (v1 + var1 + v2 + var2));
            logl += cur;
        }
        double lognorm = logl;
        logl += l1.loglikelihood + l2.loglikelihood;
        if (peek) {
            return logl;
        }
        return this.resampleRoot ? BrownianModelCalculator.resampleRoot(message, newMessageVariance, this.bm) : new BrownianModelCalculator(message, newMessageVariance, this.bm, logl, lognorm, this.resampleRoot);
    }

    private static BrownianModelCalculator resampleRoot(double[] message, double newMessageVariance, BrownianModel bm) {
        for (int i = 0; i < message.length; ++i) {
            message[i] = Gaussian.sample(rand, message[i], newMessageVariance);
        }
        return BrownianModelCalculator.observation(message, bm, false, true);
    }

    @Override
    public LikelihoodModelCalculator combine(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double delta1, double delta2, boolean doNotBuildCache) {
        return (LikelihoodModelCalculator)this.calculate(node1, node2, delta1, delta2, false);
    }

    public final Pair<Double, Double> sampleBranchLength(PartialCoalescentState.CoalescentNode node1, PartialCoalescentState.CoalescentNode node2, double topHeight, int nroots, double particleweight, Random rand) {
        BrownianModelCalculator l1 = (BrownianModelCalculator)node1.likelihoodModelCache;
        BrownianModelCalculator l2 = (BrownianModelCalculator)node2.likelihoodModelCache;
        double var1 = l1.messageVariance;
        double var2 = l2.messageVariance;
        double c = var1 + var2 + 2.0 * topHeight - node1.height - node2.height;
        double b = 0.0;
        double a = 0.5 * (double)nroots * (double)(nroots - 1);
        int m = this.nsites;
        double p = 0.5 * (double)(m - 2);
        for (int i = 0; i < this.nsites; ++i) {
            double mean1 = l1.message[i];
            double mean2 = l2.message[i];
            b += Math.pow(mean1 - mean2, 2.0) / this.bm.varianceScale;
        }
        try {
            double mode = (double)(m - 4) / b;
            double result = 0.0;
            double correct = 0.0;
            if (mode < 1.0 / c) {
                int maxiter;
                GammaDistributionImpl gamma;
                do {
                    maxiter = (int)(0.5 * (double)m - 2.0);
                    double sum = 0.0;
                    for (int i = 0; i < maxiter; ++i) {
                        sum -= Math.log(rand.nextDouble());
                    }
                } while (!((result = (sum += (gamma = new GammaDistributionImpl(0.5 * (double)m - 1.0 - (double)maxiter, 1.0)).inverseCumulativeProbability(rand.nextDouble())) * (2.0 / b)) <= 1.0 / c));
                result = 0.5 * (1.0 / result - c);
                correct = -0.5 * (double)m * (Math.log(Math.PI * 2) + Math.log(this.bm.varianceScale));
            } else {
                result = rand.nextDouble() * (1.0 / c);
                VarianceMarginalUtils.GIGlogDensity gig = new VarianceMarginalUtils.GIGlogDensity(p - 1.0, b, a);
                BrownianModelCalculator bmc1 = (BrownianModelCalculator)node1.likelihoodModelCache;
                BrownianModelCalculator bmc2 = (BrownianModelCalculator)node2.likelihoodModelCache;
                correct = gig.value(result) + 0.5 * a * c - particleweight - Math.log(2.0 * c);
                correct += bmc1.lognorm + bmc2.lognorm - bmc1.logLikelihood() - bmc2.logLikelihood();
                correct -= 0.5 * (double)m * (Math.log(Math.PI * 2) + Math.log(this.bm.varianceScale));
                result = 0.5 * (1.0 / result - c);
            }
            return Pair.makePair(result, correct);
        }
        catch (Exception e) {
            double correct = 0.0;
            double result = rand.nextGaussian();
            result = (result + (double)(m - 2) / b) * (Math.sqrt(2 * (m - 2)) / b);
            result = 0.5 * (1.0 / result - c);
            return Pair.makePair(result, correct);
        }
    }

    public final double evaluatePair(PartialCoalescentState.CoalescentNode node1, PartialCoalescentState.CoalescentNode node2, int nroots, double topHeight) {
        double weight;
        BrownianModelCalculator l1 = (BrownianModelCalculator)node1.likelihoodModelCache;
        BrownianModelCalculator l2 = (BrownianModelCalculator)node2.likelihoodModelCache;
        double var1 = l1.messageVariance;
        double var2 = l2.messageVariance;
        double c = var1 + var2 + 2.0 * topHeight - node1.height - node2.height;
        double b = 0.0;
        double a = 0.5 * (double)nroots * (double)(nroots - 1);
        int m = this.nsites;
        double p = 0.5 * (double)(m - 2);
        for (int i = 0; i < this.nsites; ++i) {
            double mean1 = l1.message[i];
            double mean2 = l2.message[i];
            b += Math.pow(mean1 - mean2, 2.0) / this.bm.varianceScale;
        }
        if (node1.isLeaf() && node2.isLeaf()) {
            weight = GIG.GIGapproxLognorm(b, a, p);
        } else {
            weight = GIG.GIGapproxLognorm(b, a, p);
            if (Double.isNaN(weight)) {
                System.exit(1);
            }
            try {
                GammaDistributionImpl gamma = new GammaDistributionImpl(p, 1.0 / (0.5 * b));
                double result = Math.log(gamma.cumulativeProbability(0.0, 1.0 / c));
                if (Double.isInfinite(result) || Double.isNaN(result)) {
                    result = this.logGammaDensity(1.0 / c, p, 0.5 * b) + 1.0 / c;
                }
                if (Double.isNaN(weight += result)) {
                    System.exit(1);
                }
            }
            catch (Exception e) {
                e.printStackTrace();
                System.exit(1);
            }
        }
        weight += a * c / 2.0;
        BrownianModelCalculator bmc1 = (BrownianModelCalculator)node1.likelihoodModelCache;
        BrownianModelCalculator bmc2 = (BrownianModelCalculator)node2.likelihoodModelCache;
        if (Double.isNaN(weight += bmc1.lognorm + bmc2.lognorm - bmc1.logLikelihood() - bmc2.logLikelihood())) {
            System.exit(1);
        }
        return weight;
    }

    public double logGammaDensity(double x, double a, double b) {
        double y = (a - 1.0) * Math.log(x) - b * x + a * Math.log(b) - Gamma.logGamma((double)a);
        return y;
    }

    @Override
    public double peekCoalescedLogLikelihood(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double delta1, double delta2) {
        return (Double)this.calculate(node1, node2, delta1, delta2, true);
    }

    public static final double logNormalDensity(double x, double mean, double var) {
        return -0.5 * (x - mean) * (x - mean) / var - 0.5 * Math.log(Math.PI * 2 * var);
    }

    @Override
    public double extendLogLikelihood(double delta) {
        return this.loglikelihood;
    }

    @Override
    public double logLikelihood() {
        return this.loglikelihood;
    }

    private static double[] getNewMessage(int nsites) {
        return new double[nsites];
    }

    @Override
    public boolean isReversible() {
        return true;
    }
}

