/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.Random;
import java.util.Set;
import nuts.math.Sampling;
import nuts.maxent.SloppyMath;
import nuts.util.CollUtils;
import nuts.util.CoordinatesPacker;
import pty.eval.SymmetricDiff;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleKernel;
import pty.smc.PriorPriorKernel;

public class ConstrainedKernel
implements ParticleKernel<PartialCoalescentState> {
    @Option
    public static boolean usesPriorPost = false;
    @Option
    public static double constraintViolationLogPenalty = -100.0;
    private final PartialCoalescentState initial;
    private final Set<Set<Taxon>> constraints;

    public ConstrainedKernel(PartialCoalescentState initial, Set<Set<Taxon>> constraints) {
        this.constraints = constraints;
        this.initial = initial;
        if (!initial.isClock()) {
            throw new RuntimeException("Not yet supported");
        }
    }

    public static void checkCompatible(PartialCoalescentState initial, Set<Set<Taxon>> constraints) {
        Set s2;
        Set<Taxon> s1 = initial.getObservations().observations().keySet();
        if (!s1.equals(s2 = SymmetricDiff.allLeaves(constraints))) {
            LogInfo.warning("Leaves of constraints:" + s1 + "\nLeaves of PCS:" + s2 + "\nSymmetric diff:" + CollUtils.symmetricDifference(s1, s2));
        }
    }

    @Override
    public Pair<PartialCoalescentState, Double> next(Random rand, PartialCoalescentState state) {
        int idx;
        int nRoots = state.nRoots();
        double delta = Sampling.sampleExponential(rand, 1.0 / PriorPriorKernel.nChoose2(nRoots));
        double[] prs = new double[nRoots * nRoots];
        double lognorm = Double.NaN;
        CoordinatesPacker cp = new CoordinatesPacker(nRoots);
        for (int trial = 0; trial < 2; ++trial) {
            lognorm = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < nRoots; ++i) {
                for (int j = 0; j < nRoots; ++j) {
                    double priorPostTerm;
                    int idx2 = cp.coord2int(i, j);
                    if (i >= j) {
                        prs[idx2] = Double.NEGATIVE_INFINITY;
                        continue;
                    }
                    boolean violation = SymmetricDiff.violatesOne(state.mergedClade(i, j), this.constraints);
                    if (violation) {
                        double d = prs[idx2] = trial == 0 ? Double.NEGATIVE_INFINITY : constraintViolationLogPenalty;
                    }
                    if (trial == 0 && !violation) {
                        priorPostTerm = 0.0;
                        if (usesPriorPost) {
                            priorPostTerm = state.peekLogLikelihoodRatio(i, j, delta, 0.0, 0.0);
                            lognorm = SloppyMath.logAdd(lognorm, priorPostTerm);
                        }
                        prs[idx2] = priorPostTerm;
                    }
                    if (trial != 1) continue;
                    priorPostTerm = 0.0;
                    if (usesPriorPost) {
                        priorPostTerm = state.peekLogLikelihoodRatio(i, j, delta, 0.0, 0.0);
                        lognorm = SloppyMath.logAdd(lognorm, priorPostTerm);
                    }
                    int n = idx2;
                    prs[n] = prs[n] + priorPostTerm;
                }
            }
            if (lognorm > Double.NEGATIVE_INFINITY) break;
        }
        NumUtils.expNormalize(prs);
        try {
            idx = SampleUtils.sampleMultinomial(rand, prs);
        }
        catch (RuntimeException e) {
            return null;
        }
        int[] pair = cp.int2coord(idx);
        PartialCoalescentState result = state.coalesce(pair[0], pair[1], delta, 0.0, 0.0);
        double w = usesPriorPost ? lognorm : result.logLikelihoodRatio();
        return Pair.makePair(result, w);
    }

    @Override
    public PartialCoalescentState getInitial() {
        return this.initial;
    }

    @Override
    public int nIterationsLeft(PartialCoalescentState partialState) {
        return partialState.nIterationsLeft();
    }
}

