/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.Random;
import java.util.Set;
import nuts.math.Sampling;
import nuts.maxent.SloppyMath;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.CoordinatesPacker;
import nuts.util.MathUtils;
import pty.eval.SymmetricDiff;
import pty.smc.NCPriorPriorKernel;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleKernel;

public class NCPriorPriorCondKernel
implements ParticleKernel<PartialCoalescentState> {
    private final PartialCoalescentState initial;
    private Set<Set<Taxon>> conditioningClades = CollUtils.set();

    public int getNMasks() {
        return this.conditioningClades.size();
    }

    public void setConditioning(Set<Set<Taxon>> conditioningClades) {
        this.conditioningClades = conditioningClades;
    }

    public NCPriorPriorCondKernel(PartialCoalescentState initial) {
        this.initial = initial;
    }

    @Override
    public PartialCoalescentState getInitial() {
        return this.initial;
    }

    @Override
    public int nIterationsLeft(PartialCoalescentState partialState) {
        return partialState.nIterationsLeft();
    }

    @Override
    public Pair<PartialCoalescentState, Double> next(Random rand, PartialCoalescentState current) {
        double delta0 = Sampling.sampleExponential(rand, 1.0 / NCPriorPriorKernel.deltaProposalRate) / (current.nRoots() == 2 ? 2.0 : 1.0);
        double delta1 = Sampling.sampleExponential(rand, 1.0 / NCPriorPriorKernel.deltaProposalRate) / (current.nRoots() == 2 ? 2.0 : 1.0);
        int nPossibleMergeInT = 0;
        int nRoots = current.nRoots();
        double[] prs = new double[nRoots * nRoots];
        CoordinatesPacker cp = new CoordinatesPacker(nRoots);
        for (int i = 0; i < nRoots; ++i) {
            for (int j = i + 1; j < nRoots; ++j) {
                if (SymmetricDiff.violatesOne(CollUtils.union(current.rootedClade(i), current.rootedClade(j)), this.conditioningClades)) continue;
                prs[cp.coord2int((int[])new int[]{i, j})] = 1.0;
                ++nPossibleMergeInT;
            }
        }
        NumUtils.normalize(prs);
        int idx = -1;
        try {
            idx = SampleUtils.sampleMultinomial(rand, prs);
        }
        catch (Exception e) {
            return Pair.makePair(null, Double.NEGATIVE_INFINITY);
        }
        int[] pair = cp.int2coord(idx);
        int i0 = pair[0];
        int i1 = pair[1];
        PartialCoalescentState next = current.coalesce(i0, i1, 0.0, delta0, delta1);
        double logW = this.nonClockLogWeight(next);
        return Pair.makePair(next, logW);
    }

    public double nonClockLogWeight(PartialCoalescentState next) {
        if (next.isClock()) {
            throw new RuntimeException();
        }
        double logSum = Double.NEGATIVE_INFINITY;
        for (Arbre<PartialCoalescentState.CoalescentNode> root : next.roots) {
            if (root.isLeaf()) continue;
            if (root.getChildren().size() != 2) {
                throw new RuntimeException();
            }
            double logCur = root.getChildren().get((int)0).getContents().likelihoodModelCache.logLikelihood() + root.getChildren().get((int)0).getContents().likelihoodModelCache.logLikelihood() - root.getContents().likelihoodModelCache.logLikelihood() - this.topoNorm(next, root);
            logSum = SloppyMath.logAdd(logSum, logCur);
        }
        return -logSum;
    }

    private double topoNorm(PartialCoalescentState next, Arbre<PartialCoalescentState.CoalescentNode> root) {
        int result = MathUtils.safeIntValue(MathUtils.nChoose2(next.nRoots() + 1));
        if (!root.isLeaf()) {
            result -= this.nViolations(next, root.getChildren().get(0), root);
            result -= this.nViolations(next, root.getChildren().get(1), root);
        }
        return result;
    }

    private int nViolations(PartialCoalescentState next, Arbre<PartialCoalescentState.CoalescentNode> child, Arbre<PartialCoalescentState.CoalescentNode> parent) {
        int nViol = 0;
        for (Arbre<PartialCoalescentState.CoalescentNode> cRoot : next.roots) {
            if (cRoot == parent || !SymmetricDiff.violates(child.getContents().rootedClade, cRoot.getContents().rootedClade)) continue;
            ++nViol;
        }
        return nViol;
    }
}

