/*
 * Decompiled with CFR 0.152.
 */
package conifer.ssm;

import conifer.clock.ClockTree;
import conifer.clock.ClockTreeUtils;
import conifer.clock.proposals.ClockTreeProposal;
import conifer.clock.proposals.ClockTreeProposals;
import conifer.ssm.InformedProposal2;
import conifer.ssm.PMCMCSample;
import conifer.ssm.SSMDataGenerator;
import conifer.ssm.SSMKernel;
import conifer.ssm.StringMutationModel;
import ev.ex.TreeGenerators;
import ev.poi.processors.TreeDistancesProcessor;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.OptionSet;
import fig.basic.Pair;
import gep.util.OutputManager;
import goblin.Taxon;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import ma.MSAPoset;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.io.TreeEvaluator;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;

public class TreeInferenceTest
implements Runnable {
    @Option
    public int nTaxa = 10;
    @Option
    public Random treeGenRand = new Random(1L);
    @Option
    public double fixTreeHeight = -1.0;
    @Option
    public int nBlocks = 10;
    @Option
    public int nPmcmcIters = 10000;
    @Option
    public Random pmcmcRand = new Random(1L);
    @Option
    public int pmcmcEvalPeriod = 10;
    @OptionSet(name="prop")
    public InformedProposal2 proposal = new InformedProposal2();
    @OptionSet(name="model")
    public StringMutationModel model = new StringMutationModel();
    @OptionSet(name="datagen")
    public SSMDataGenerator generator = new SSMDataGenerator();
    @OptionSet(name="pf")
    public ParticleFilter<PartialCoalescentState> particleFilter = new ParticleFilter();
    @OptionSet(name="treeprop")
    public ClockTreeProposals treeProposal = new ClockTreeProposals();
    private RootedTree trueTree;
    private Map<Taxon, List<String>> data;
    private OutputManager outputManager = new OutputManager();
    private long startTimeMilli;

    public static void main(String[] args) {
        IO.run(args, new TreeInferenceTest());
    }

    @Override
    public void run() {
        this.trueTree = TreeGenerators.sampleCoalescent(this.treeGenRand, this.nTaxa, false);
        if (this.fixTreeHeight > 0.0) {
            double h = RootedTree.Util.height(this.trueTree);
            this.trueTree = TreeGenerators.scaleRootedTree(this.trueTree, this.fixTreeHeight / h);
        }
        LogInfo.logsForce("GeneratedTreeHeight=" + RootedTree.Util.height(this.trueTree));
        LogInfo.track((Object)"Generating data", true);
        this.data = new HashMap<Taxon, List<String>>();
        for (int b = 0; b < this.nBlocks; ++b) {
            LogInfo.logs("Block " + b);
            this.generator.generateData(this.model, this.trueTree);
            MSAPoset msa = this.generator.getSpeciationPointsAlignment();
            LogInfo.logs(msa);
            for (Taxon t : this.trueTree.topology().leaveContents()) {
                CollUtils.getNoNullList(this.data, t).add(msa.sequences().get(t));
            }
        }
        LogInfo.end_track();
        this.startTimeMilli = System.currentTimeMillis();
        LogInfo.track("SMC inference");
        this.proposal.setModel(this.model);
        SSMKernel kernel = new SSMKernel(this.proposal, this.data);
        TreeDistancesProcessor tdp = new TreeDistancesProcessor();
        this.particleFilter.sample(kernel, tdp);
        UnrootedTree ut = tdp.getConsensus(true);
        this.eval("consensus", 0, ut);
        this.eval("mode", 0, tdp.getMode().getUnrooted());
        LogInfo.end_track();
    }

    private void testTruth() {
        PMCMCSample truth = new PMCMCSample(ClockTreeUtils.fromRooted(this.trueTree), Double.NEGATIVE_INFINITY);
        LogInfo.track((Object)"Trying truth", true);
        this.next(this.pmcmcRand, truth);
        LogInfo.end_track();
    }

    private PMCMCSample next(Random rand, PMCMCSample current) {
        ClockTreeProposal currentTreeProposal = this.treeProposal.nextProposal(rand, current.clockTree);
        Pair<ClockTree, Double> proposedTree = currentTreeProposal.propose(rand, current.clockTree);
        double logNum = 0.0;
        for (int block = 0; block < this.nBlocks; ++block) {
            Map<Taxon, List<String>> datum = TreeInferenceTest.getDatum(this.data, block);
            SSMKernel kernel = new SSMKernel(this.proposal, datum);
            kernel.setGuideTree(proposedTree.getFirst());
            this.particleFilter.sample(kernel, new ParticleFilter.DoNothingProcessor());
            logNum += this.particleFilter.estimateNormalizer();
        }
        LogInfo.logs("CurrentLogP=" + logNum);
        double logRatio = logNum - current.logL + proposedTree.getSecond();
        double acceptPr = Math.min(1.0, Math.exp(logRatio));
        boolean accept = Sampling.sampleBern(acceptPr, rand);
        this.outputManager.write("acceptancePrs", "proposal", currentTreeProposal.description(), "acceptPr", acceptPr, "accepted", accept);
        return accept ? new PMCMCSample(proposedTree.getFirst(), logNum) : current;
    }

    private static Map<Taxon, List<String>> getDatum(Map<Taxon, List<String>> data, int block) {
        HashMap<Taxon, List<String>> result = new HashMap<Taxon, List<String>>();
        for (Taxon t : data.keySet()) {
            result.put(t, data.get(t).subList(block, block + 1));
        }
        return result;
    }

    private void eval(String method, int iter, UnrootedTree ut) {
        double time = (double)(System.currentTimeMillis() - this.startTimeMilli) / 1000.0;
        for (TreeEvaluator.TreeMetric metric : TreeEvaluator.coreTreeMetrics) {
            this.outputManager.printWrite(method, "iter", iter, "metric", metric, "value", metric.score(ut, this.trueTree.getUnrooted()), "time(s)", time);
        }
        this.outputManager.flush();
    }
}

