/*
 * Decompiled with CFR 0.152.
 */
package ev.poi;

import ev.poi.MSAMarginalLikelihoodCalculator;
import ev.poi.PoissonParameters;
import goblin.CognateId;
import goblin.Taxon;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import ma.MSAPoset;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import pty.RootedTree;
import pty.UnrootedTree;

public final class PoissonModel {
    public final Map<CognateId, MSAPoset> alignments;
    private MSAMarginalLikelihoodCalculator calculator;
    public final double treeExpBranchHyperPrior;
    public final double insertHyperPrior;
    public final double deleteHyperPrior;
    private double _treeLogPrior = 0.0;
    private double _parametersLogPrior = 0.0;
    private double _msaLogLikelihood = 0.0;
    private double _msaEmptyColumnLogLikelihood = 0.0;
    private final Map<CognateId, Double> _msaLogLikelihoods = CollUtils.map();

    public MSAMarginalLikelihoodCalculator getCalculator() {
        return this.calculator;
    }

    public PoissonModel(Map<CognateId, MSAPoset> alignments, MSAMarginalLikelihoodCalculator calculator, double treeExpBranchHyperPrior, double insertHyperPrior, double deleteHyperPrior) {
        this.insertHyperPrior = insertHyperPrior;
        this.deleteHyperPrior = deleteHyperPrior;
        this.treeExpBranchHyperPrior = treeExpBranchHyperPrior;
        this.calculator = calculator;
        this.alignments = alignments;
        this._treeLogPrior = this.logPrior(this.currentRooted());
        this._parametersLogPrior = this.paramLogPrior(calculator.params.deleteRate, calculator.params.insertRate);
        this.recomputeAllAlignmentsLogLikelihoods();
    }

    public double jointLogProbability() {
        return this._treeLogPrior + this._parametersLogPrior + this._msaLogLikelihood;
    }

    public void setPhylogeneticTree(RootedTree newTree) {
        this.calculator = MSAMarginalLikelihoodCalculator.copyWithNewTree(this.calculator, newTree);
        this._treeLogPrior = this.logPrior(this.currentRooted());
        this.recomputeAllAlignmentsLogLikelihoods();
    }

    public void setDeleteRate(double newRate) {
        this.calculator = MSAMarginalLikelihoodCalculator.copyWithNewParams(this.calculator, PoissonParameters.copyWithNewDeleteRate(this.calculator.params, newRate));
        this._parametersLogPrior = this.paramLogPrior(this.calculator.params.deleteRate, this.calculator.params.insertRate);
        this.recomputeAllAlignmentsLogLikelihoods();
    }

    public void setInsertRate(double newRate) {
        this.calculator = MSAMarginalLikelihoodCalculator.copyWithNewParams(this.calculator, PoissonParameters.copyWithNewInsertRate(this.calculator.params, newRate));
        this._parametersLogPrior = this.paramLogPrior(this.calculator.params.deleteRate, this.calculator.params.insertRate);
        this.recomputeAllAlignmentsLogLikelihoods();
    }

    public double paramLogPrior(double del, double ins) {
        return this.deleteLogPrior(del) + this.insertLogPrior(ins);
    }

    private void recomputeAllAlignmentsLogLikelihoods() {
        this.recomputeAlignmentsLogLikelihoods(this.alignments.keySet());
    }

    public void recomputeAlignmentLogLikelihood(CognateId key) {
        this.recomputeAlignmentsLogLikelihoods(Collections.singleton(key));
    }

    private void recomputeAlignmentsLogLikelihoods(Collection<CognateId> keys) {
        for (CognateId id : keys) {
            MSAPoset current = this.alignments.get(id);
            double oldValue = this._msaLogLikelihoods.containsKey(id) ? this._msaLogLikelihoods.get(id) : 0.0;
            double newValue = this.calculator._subDel_marginalLogLikelihood(current);
            this._msaLogLikelihoods.put(id, newValue);
            this._msaLogLikelihood += newValue - oldValue;
        }
        double oldValue = this._msaEmptyColumnLogLikelihood;
        double newValue = this.calculator._insert_marginalLogLikelihood(this.totalNColumns());
        this._msaLogLikelihood += newValue - oldValue;
        this._msaEmptyColumnLogLikelihood = newValue;
    }

    private int totalNColumns() {
        int n = 0;
        for (MSAPoset msa : this.alignments.values()) {
            n += msa.columns().size();
        }
        return n;
    }

    private double insertLogPrior(double value) {
        return Sampling.exponentialLogDensity(this.insertHyperPrior, value);
    }

    private double deleteLogPrior(double value) {
        return Sampling.exponentialLogDensity(this.deleteHyperPrior, value);
    }

    public double logPrior(RootedTree t) {
        double sum = 0.0;
        for (Taxon lang : t.branchLengths().keySet()) {
            sum += Sampling.exponentialLogDensity(this.treeExpBranchHyperPrior, t.branchLengths().get(lang));
        }
        return sum;
    }

    public UnrootedTree currentUnrooted() {
        return UnrootedTree.fromRooted(this.currentRooted());
    }

    public RootedTree currentRooted() {
        return this.calculator.rootedTree;
    }
}

