/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import nuts.math.Sampling;
import nuts.maxent.SloppyMath;
import nuts.util.CoordinatesPacker;
import pty.eval.SymmetricDiff;
import pty.smc.MapLeaves;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleKernel;
import pty.smc.PriorPriorKernel;

public class ConditionalPriorPriorKernel
implements ParticleKernel<PartialCoalescentState> {
    @Option
    public static boolean usesPriorPost = false;
    private final PartialCoalescentState initial;
    private final Set<Set<Taxon>> conditionalClades;
    private double agreementWeight;
    private MapLeaves ml;

    public ConditionalPriorPriorKernel(PartialCoalescentState initial, Set<Set<Taxon>> conditionalClades, MapLeaves ml, double agreementWeight) {
        this.agreementWeight = agreementWeight;
        this.conditionalClades = ml.mapClades(conditionalClades);
        this.ml = ml;
        this.initial = initial;
        if (!initial.isClock()) {
            throw new RuntimeException("Not yet supported");
        }
    }

    @Override
    public Pair<PartialCoalescentState, Double> next(Random rand, PartialCoalescentState state) {
        int nRoots = state.nRoots();
        double delta = Sampling.sampleExponential(rand, 1.0 / PriorPriorKernel.nChoose2(nRoots));
        double[] prs = new double[nRoots * nRoots];
        double lognorm = usesPriorPost ? Double.NEGATIVE_INFINITY : Double.NaN;
        CoordinatesPacker cp = new CoordinatesPacker(nRoots);
        Set<Set<Taxon>> allCladesInState = state.mapped != null ? state.mapped : this.ml.filterClades(state.allClades());
        for (int i = 0; i < nRoots; ++i) {
            for (int j = 0; j < nRoots; ++j) {
                int idx = cp.coord2int(i, j);
                if (i >= j) {
                    prs[idx] = Double.NEGATIVE_INFINITY;
                    continue;
                }
                double priorPostTerm = 0.0;
                if (usesPriorPost) {
                    priorPostTerm = state.peekLogLikelihoodRatio(i, j, delta, 0.0, 0.0);
                    lognorm = SloppyMath.logAdd(lognorm, priorPostTerm);
                }
                prs[idx] = priorPostTerm - this.agreementWeight * (double)SymmetricDiff.deltaSymmetricDiff(state, allCladesInState, i, j, this.conditionalClades, this.ml);
            }
        }
        NumUtils.expNormalize(prs);
        int idx = SampleUtils.sampleMultinomial(rand, prs);
        int[] pair = cp.int2coord(idx);
        PartialCoalescentState result = state.coalesce(pair[0], pair[1], delta, 0.0, 0.0);
        double w = usesPriorPost ? lognorm : result.logLikelihoodRatio();
        HashSet<Set<Taxon>> newAllClades = new HashSet<Set<Taxon>>(allCladesInState);
        newAllClades.add(this.ml.restrict(state.mergedClade(pair[0], pair[1])));
        result.mapped = newAllClades;
        return Pair.makePair(result, w);
    }

    @Override
    public PartialCoalescentState getInitial() {
        return this.initial;
    }

    @Override
    public int nIterationsLeft(PartialCoalescentState partialState) {
        return partialState.nIterationsLeft();
    }
}

