/*
 * Decompiled with CFR 0.152.
 */
package conifer.ssm;

import conifer.clock.ClockTree;
import conifer.clock.ClockTreeUtils;
import conifer.ssm.InformedProposal2;
import conifer.ssm.SSMModelCalculator;
import fig.basic.Pair;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.math.Sampling;
import nuts.util.Arbre;
import nuts.util.MathUtils;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleKernel;

public class SSMKernel
implements ParticleKernel<PartialCoalescentState> {
    private final InformedProposal2 proposal;
    private final PartialCoalescentState initial;
    private ClockTree guideTree = null;

    public SSMKernel(InformedProposal2 proposal, PartialCoalescentState initial) {
        this.proposal = proposal;
        this.initial = initial;
    }

    public SSMKernel(InformedProposal2 proposal, Map<Taxon, List<String>> sequences) {
        this.initial = PartialCoalescentState.initSSMState(sequences);
        this.proposal = proposal;
    }

    @Override
    public Pair<PartialCoalescentState, Double> next(Random rand, PartialCoalescentState current) {
        double delta;
        PartialCoalescentState next = null;
        double logProp = 0.0;
        if (this.guideTree == null) {
            delta = Sampling.sampleExponential(rand, 1.0 / (double)MathUtils.nChoose2(current.nRoots()));
            List<Integer> sampledIndices = Sampling.sampleWithoutReplacement(rand, current.nRoots(), 2);
            next = current.coalesce(sampledIndices.get(0), sampledIndices.get(1), delta, 0.0, 0.0);
        } else {
            next = ClockTreeUtils.constrainedNext(current, this.guideTree);
            delta = next.topHeight() - current.topHeight();
            logProp += Sampling.exponentialLogDensity(1.0 / (double)MathUtils.nChoose2(current.nRoots()), delta);
        }
        List<Arbre<PartialCoalescentState.CoalescentNode>> roots = next.getRoots();
        Arbre<PartialCoalescentState.CoalescentNode> latest = roots.get(roots.size() - 1);
        SSMModelCalculator topCalc = (SSMModelCalculator)latest.getContents().likelihoodModelCache;
        if (topCalc.getTopSequences() != null) {
            throw new RuntimeException();
        }
        Pair<Double, List<String>> c0 = this.childrenInfo(latest, 0);
        Pair<Double, List<String>> c1 = this.childrenInfo(latest, 1);
        int nSeq = c0.getSecond().size();
        if (c1.getSecond().size() != nSeq) {
            throw new RuntimeException();
        }
        ArrayList<String> newSeq = new ArrayList<String>();
        for (int i = 0; i < nSeq; ++i) {
            Pair<String, Double> cur = this.proposal.proposeAndGetLogProb(rand, c0.getSecond().get(i), c0.getFirst(), c1.getSecond().get(i), c1.getFirst());
            logProp += cur.getSecond().doubleValue();
            newSeq.add(cur.getFirst());
        }
        topCalc.setTopSequences(newSeq);
        return Pair.makePair(next, logProp);
    }

    private Pair<Double, List<String>> childrenInfo(Arbre<PartialCoalescentState.CoalescentNode> latest, int i) {
        double bl = i == 0 ? latest.getContents().leftBranchLength : latest.getContents().rightBranchLength;
        List<String> strs = ((SSMModelCalculator)latest.getChildren().get((int)i).getContents().likelihoodModelCache).getTopSequences();
        return Pair.makePair(bl, strs);
    }

    @Override
    public int nIterationsLeft(PartialCoalescentState partialState) {
        return partialState.nIterationsLeft();
    }

    @Override
    public PartialCoalescentState getInitial() {
        return this.initial;
    }

    public ClockTree getGuideTree() {
        return this.guideTree;
    }

    public void setGuideTree(ClockTree guideTree) {
        this.guideTree = guideTree;
    }
}

