/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.models;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import nuts.maxent.SloppyMath;
import nuts.util.MathUtils;
import pty.smc.models.CTMC;
import pty.smc.models.DiscreteModelCalculator;
import pty.smc.models.LikelihoodModelCalculator;

public class ForestModelCalculator
implements LikelihoodModelCalculator {
    private final DiscreteModelCalculator dmc;
    private final double rootHeight;
    private final double currentHeight;
    private final double languageInventionRate;
    private final double noLangLogLikelihood;
    private final double withLangLogLikelihood;
    private transient double ll = Double.NaN;

    public double posteriorNoLanguagePr() {
        double priorWithLanguage = this.priorWithLanguage(this.currentHeight);
        double[] prs = new double[]{this.noLangLogLikelihood + Math.log(1.0 - priorWithLanguage), this.withLangLogLikelihood + Math.log(priorWithLanguage)};
        NumUtils.expNormalize(prs);
        return prs[0];
    }

    public static ForestModelCalculator observation(CTMC ctmc, double[][] initCache, double rootHeight, double languageInventionRate) {
        LogInfo.logs("Currently assumes reversibility!");
        DiscreteModelCalculator dmc = DiscreteModelCalculator.observation(ctmc, initCache);
        return new ForestModelCalculator(dmc, rootHeight, 0.0, languageInventionRate, Double.NEGATIVE_INFINITY, 0.0);
    }

    public ForestModelCalculator(DiscreteModelCalculator dmc, double rootHeight, double currentHeight, double languageInventionRate, double noLangLogLikelihood, double withLangLogLikelihood) {
        this.dmc = dmc;
        this.rootHeight = rootHeight;
        this.currentHeight = currentHeight;
        this.languageInventionRate = languageInventionRate;
        this.noLangLogLikelihood = noLangLogLikelihood;
        this.withLangLogLikelihood = withLangLogLikelihood;
    }

    private double priorWithLanguage(double height) {
        if (height >= this.rootHeight) {
            throw new RuntimeException("Reached height of " + height + " but max is " + this.rootHeight);
        }
        return 1.0 - Math.exp(-this.languageInventionRate * (this.rootHeight - height));
    }

    @Override
    public double extendLogLikelihood(double delta) {
        double prLangInvEvent = 1.0 - Math.exp(-this.languageInventionRate * delta);
        double priorWithLanguage = this.priorWithLanguage(this.currentHeight + delta);
        return SloppyMath.logAdd(Math.log(priorWithLanguage) + this.withLangLogLikelihood, Math.log(1.0 - priorWithLanguage) + SloppyMath.logAdd(Math.log(prLangInvEvent) + this.withLangLogLikelihood, Math.log(1.0 - prLangInvEvent) + this.noLangLogLikelihood));
    }

    public Object calculate(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean isPeek, boolean doNotBuildCache) {
        ForestModelCalculator n1 = (ForestModelCalculator)node1;
        ForestModelCalculator n2 = (ForestModelCalculator)node2;
        MathUtils.checkClose(v1 + n1.currentHeight, v2 + n2.currentHeight);
        double newHeight = v1 + n1.currentHeight;
        double resultNoLang = Double.NEGATIVE_INFINITY;
        double prLangInvEvent1 = 1.0 - Math.exp(-this.languageInventionRate * v1);
        double prLangInvEvent2 = 1.0 - Math.exp(-this.languageInventionRate * v2);
        resultNoLang = SloppyMath.logAdd(resultNoLang, Math.log(1.0 - prLangInvEvent1) + n1.noLangLogLikelihood + Math.log(1.0 - prLangInvEvent2) + n2.noLangLogLikelihood);
        resultNoLang = SloppyMath.logAdd(resultNoLang, Math.log(prLangInvEvent1) + n1.withLangLogLikelihood + Math.log(1.0 - prLangInvEvent2) + n2.noLangLogLikelihood);
        resultNoLang = SloppyMath.logAdd(resultNoLang, Math.log(1.0 - prLangInvEvent1) + n1.noLangLogLikelihood + Math.log(prLangInvEvent2) + n2.withLangLogLikelihood);
        resultNoLang = SloppyMath.logAdd(resultNoLang, Math.log(prLangInvEvent1) + n1.withLangLogLikelihood + Math.log(prLangInvEvent2) + n2.withLangLogLikelihood);
        if (isPeek) {
            double resultWithLang = this.dmc.peekCoalescedLogLikelihood(n1.dmc, n2.dmc, v1, v2);
            return this.logLikelihood(newHeight, resultWithLang, resultNoLang);
        }
        DiscreteModelCalculator resultDMC = (DiscreteModelCalculator)this.dmc.combine(n1.dmc, n2.dmc, v1, v2, doNotBuildCache);
        double resultWithLang = resultDMC.logLikelihood();
        return new ForestModelCalculator(resultDMC, this.rootHeight, newHeight, this.languageInventionRate, resultNoLang, resultWithLang);
    }

    @Override
    public double logLikelihood() {
        if (!Double.isNaN(this.ll)) {
            return this.ll;
        }
        this.ll = this.logLikelihood(this.currentHeight, this.withLangLogLikelihood, this.noLangLogLikelihood);
        return this.ll;
    }

    private double logLikelihood(double currentHeight, double withLangLogLikelihood, double noLangLogLikelihood) {
        double priorWithLanguage = this.priorWithLanguage(currentHeight);
        return SloppyMath.logAdd(Math.log(priorWithLanguage) + withLangLogLikelihood, Math.log(1.0 - priorWithLanguage) + noLangLogLikelihood);
    }

    @Override
    public boolean isReversible() {
        return true;
    }

    @Override
    public double peekCoalescedLogLikelihood(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double delta1, double delta2) {
        return (Double)this.calculate(node1, node2, delta1, delta2, true, true);
    }

    @Override
    public LikelihoodModelCalculator combine(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean doNotBuildCache) {
        return (LikelihoodModelCalculator)this.calculate(node1, node2, v1, v2, false, doNotBuildCache);
    }
}

