/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import ev.poi.processors.TreeDistancesProcessor;
import gep.util.OutputManager;
import goblin.Taxon;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.eval.SymmetricDiff;
import pty.io.TreeEvaluator;
import pty.smc.NCPriorPriorCondKernel;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;

public class PMCMC {
    private final ParticleFilter<PartialCoalescentState> pf;
    private final NCPriorPriorCondKernel kernel;
    private final TreeDistancesProcessor tdp;
    private double previousLogLLEstimate = Double.NEGATIVE_INFINITY;
    private RootedTree currentSample = null;
    public static OutputManager outMan = new OutputManager();

    public PMCMC(ParticleFilter<PartialCoalescentState> pf, NCPriorPriorCondKernel kernel, TreeDistancesProcessor tdp) {
        this.pf = pf;
        this.kernel = kernel;
        this.tdp = tdp;
    }

    public void next(Random rand) {
        RootedTree previousSample = this.currentSample;
        double currentSparsity = this.currentSample == null ? 0.0 : (double)this.kernel.getNMasks() / ((double)this.currentSample.topology().nLeaves() - 3.0);
        ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
        double acceptPr = 0.0;
        try {
            this.pf.sample(this.kernel, pro);
            double logRatio = this.pf.estimateNormalizer() - this.previousLogLLEstimate;
            acceptPr = Math.min(1.0, Math.exp(logRatio));
            if (this.currentSample != null && Double.isInfinite(acceptPr)) {
                throw new RuntimeException();
            }
            boolean accept = Sampling.sampleBern(acceptPr, rand);
            if (accept) {
                PartialCoalescentState sampled = (PartialCoalescentState)pro.sample(rand);
                this.currentSample = sampled.getFullCoalescentState();
                this.previousLogLLEstimate = this.pf.estimateNormalizer();
            }
        }
        catch (Exception logRatio) {
            // empty catch block
        }
        this.tdp.process(this.currentSample);
        int tSize = this.currentSample.topology().nLeaves();
        outMan.write("PMCMC", "treeSize", tSize, "acceptPr", acceptPr, "maskSparsity", currentSparsity, "rfDist", previousSample == null ? 0.0 : new TreeEvaluator.RobinsonFouldsMetric().score(this.currentSample, previousSample));
        HashSet<Set<Taxon>> newMask = CollUtils.set();
        double currentRetProb = rand.nextDouble();
        Set clades = UnrootedTree.fromRooted(this.currentSample).unRootedClades();
        Set allTaxa = SymmetricDiff.allLeaves(clades);
        HashSet simplified = CollUtils.set(clades);
        for (Set<Object> clade : clades) {
            HashSet complement = CollUtils.set(allTaxa);
            complement.removeAll(clade);
            if (!simplified.contains(complement) || !simplified.contains(clade)) continue;
            simplified.remove(complement);
        }
        if (simplified.size() != clades.size() / 2) {
            throw new RuntimeException();
        }
        for (Set<Object> clade : simplified) {
            if (!this.nonTrivial(clade, allTaxa.size()) || !Sampling.sampleBern(currentRetProb, rand)) continue;
            newMask.add(clade);
        }
        this.kernel.setConditioning(newMask);
    }

    private boolean nonTrivial(Set<Taxon> clade, int nLeaves) {
        return clade.size() > 1 && clade.size() < nLeaves - 1;
    }
}

