/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.test;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.exec.Execution;
import goblin.Taxon;
import java.util.HashSet;
import java.util.Set;
import nuts.util.CollUtils;
import pty.Train;
import pty.eval.SymmetricDiff;
import pty.io.Dataset;
import pty.io.WalsDataset;
import pty.learn.CTMCLoader;
import pty.smc.ConstrainedKernel;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.PriorPriorKernel;
import pty.smc.models.CTMC;
import pty.smc.models.ForestModelCalculator;

public class TestForest
implements Runnable {
    @Option
    public double langInvRate = 10.0;
    @Option
    public double rootHeight = 10.0;
    private static ParticleFilter<PartialCoalescentState> pf = new ParticleFilter();
    private static CTMCLoader langParamLoader = new CTMCLoader();
    private Dataset langData;

    public static void main(String[] args) {
        Execution.monitor = true;
        Execution.makeThunk = false;
        Execution.create = true;
        Execution.useStandardExecPoolDirStrategy = true;
        Execution.run(args, new TestForest(), "wals", WalsDataset.class, "kernel", ConstrainedKernel.class, "filter", pf, "pcs", PartialCoalescentState.class, "ppk", PriorPriorKernel.class);
    }

    private PartialCoalescentState initLanguageState() {
        this.langData = WalsDataset.getPreprocessedCorpus();
        langParamLoader.setData(this.langData);
        CTMC langParam = langParamLoader.load();
        return PartialCoalescentState.initForestState(this.langData, langParam, this.rootHeight, this.langInvRate);
    }

    @Override
    public void run() {
        PartialCoalescentState initState = this.initLanguageState();
        ConstrainedKernel ppk = new ConstrainedKernel(initState, this.getConstraints());
        SingleTreePosteriorDecoder dec = new SingleTreePosteriorDecoder();
        ParticleFilter.ParticleMapperProcessor<PartialCoalescentState, Set<Set<Taxon>>> mbr = SymmetricDiff.createCladeProcessor();
        ParticleFilter.ForkedProcessor processors = new ParticleFilter.ForkedProcessor(mbr, dec);
        pf.sample(ppk, processors);
        Train.outputTree(SymmetricDiff.clades2arbre(mbr.centroid(SymmetricDiff.CLADE_SYMMETRIC_DIFFERENCE)), "consensusTree");
        LogInfo.logs("Single tree posterior:" + dec.getSingleTreePosterior());
    }

    private Set<Set<Taxon>> getConstraints() {
        HashSet<Set<Taxon>> constraints = new HashSet<Set<Taxon>>();
        constraints.addAll(CollUtils.inducedPartition(WalsDataset.langDB.genusMap()));
        constraints.addAll(CollUtils.inducedPartition(WalsDataset.langDB.familyMap()));
        LogInfo.logs("Constraints:" + constraints);
        return constraints;
    }

    public static class SingleTreePosteriorDecoder
    implements ParticleFilter.ParticleProcessor<PartialCoalescentState> {
        private double singleTreePosterior = 0.0;

        @Override
        public void process(PartialCoalescentState state, double weight) {
            ForestModelCalculator node = (ForestModelCalculator)state.getLikelihoodModelCalculator(0);
            this.singleTreePosterior += weight * (1.0 - node.posteriorNoLanguagePr());
        }

        public double getSingleTreePosterior() {
            return this.singleTreePosterior;
        }
    }
}

