/*
 * Decompiled with CFR 0.152.
 */
package ev.ex;

import ev.ex.NJPState;
import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Map;
import java.util.Random;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.MathUtils;
import pty.smc.ParticleKernel;

public class NJStateKernel
implements ParticleKernel<NJPState> {
    private final NJPState initState;

    public NJStateKernel(NJPState initState) {
        this.initState = initState;
    }

    @Override
    public NJPState getInitial() {
        return this.initState;
    }

    @Override
    public int nIterationsLeft(NJPState partialState) {
        return partialState.pcs.nIterationsLeft();
    }

    @Override
    public Pair<NJPState, Double> next(Random rand, NJPState current) {
        ArrayList list = CollUtils.list();
        double[] prs = new double[current.pairwiseDistances.size()];
        int i = 0;
        Counter<UnorderedPair<Taxon, Taxon>> temp = new Counter<UnorderedPair<Taxon, Taxon>>();
        for (UnorderedPair<Taxon, Taxon> key : current.pairwiseDistances.keySet()) {
            temp.setCount(key, -1.0 * current.pairwiseDistances.getCount(key));
        }
        double cur = 0.5;
        for (UnorderedPair unorderedPair : temp) {
            list.add(unorderedPair);
            prs[i++] = cur;
            cur /= 2.0;
        }
        NumUtils.normalize(prs);
        int sampledIndex = SampleUtils.sampleMultinomial(rand, prs);
        UnorderedPair unorderedPair = (UnorderedPair)list.get(sampledIndex);
        double param = 1.0 / (double)MathUtils.nChoose2(current.pcs.nRoots());
        double delta = Sampling.sampleExponential(rand, param);
        double logPropDensity = Sampling.exponentialLogDensity(param, delta);
        NJPState next = current.coalesce((Taxon)unorderedPair.getFirst(), (Taxon)unorderedPair.getSecond(), delta);
        if (Double.isNaN(next.gamma() - current.gamma() - logPropDensity)) {
            throw new RuntimeException();
        }
        return Pair.makePair(next, next.gamma() - current.gamma() - logPropDensity - Math.log(prs[sampledIndex]));
    }

    private double computeDeltaParam(NJPState current, UnorderedPair<Taxon, Taxon> selectedPair) {
        double currentHeight = current.pcs.topHeight();
        Map<Taxon, Integer> taxa = current.pcs.rootsTaxa();
        double leftH = current.pcs.getHeight(taxa.get(selectedPair.getFirst()));
        double rightH = current.pcs.getHeight(taxa.get(selectedPair.getSecond()));
        double min = currentHeight - leftH - rightH;
        double pairwiseD = current.pairwiseDistances.getCount(selectedPair);
        return (pairwiseD - min) / 2.0;
    }
}

