/*
 * Decompiled with CFR 0.152.
 */
package pty;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.exec.Execution;
import goblin.DataPrepUtils;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import nuts.io.IO;
import nuts.util.Arbre;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.eval.Purity;
import pty.eval.SymmetricDiff;
import pty.io.Dataset;
import pty.io.GeneratedDataset;
import pty.io.HGDPDataset;
import pty.io.WalsDataset;
import pty.learn.CTMCLoader;
import pty.learn.DiscreteBP;
import pty.learn.LearningProcessor;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import pty.smc.PriorPostKernel;
import pty.smc.PriorPriorKernel;
import pty.smc.models.CTMC;
import pty.smc.models.CTMCUtils;

public class Train
implements Runnable {
    @Option
    public int nEMIters = 20;
    @Option
    public Dataset.DatasetType datasetType = Dataset.DatasetType.WALS;
    @Option
    public int increaseNSamplesPerEMIteration = 0;
    @Option
    public EstimationMethod estimationMethod = EstimationMethod.UNSUPERVISED;
    @Option
    public String guideTreePath = "AUTO";
    public static final String autoGuideTree = "AUTO";
    @Option
    public double guideTreeScaling = 100.0;
    @Option
    public boolean usePriorPost = false;
    private static ParticleFilter<PartialCoalescentState> pf;
    private static CTMCLoader loader;
    private Dataset data;

    @Override
    public void run() {
        this.data = this.datasetType.loadDataset();
        loader.setData(this.data);
        CTMC ctmc = loader.load();
        CTMCUtils.saveInExec(ctmc, "init");
        for (int emIter = 0; emIter < this.nEMIters; ++emIter) {
            LogInfo.track((Object)("EM Iteration " + (emIter + 1) + "/" + this.nEMIters), true);
            PartialCoalescentState initState = PartialCoalescentState.initState(this.data, ctmc);
            ParticleFilter.ParticleMapperProcessor<PartialCoalescentState, Set<Set<Taxon>>> mbr = SymmetricDiff.createCladeProcessor();
            LearningProcessor ssp = new LearningProcessor(ctmc);
            ParticleFilter.ForkedProcessor processors = new ParticleFilter.ForkedProcessor(mbr);
            if (this.estimationMethod == EstimationMethod.UNSUPERVISED) {
                processors.processors.add(ssp);
            }
            ParticleKernel<PartialCoalescentState> kernel = this.usePriorPost ? new PriorPostKernel(initState) : new PriorPriorKernel(initState);
            pf.sample(kernel, processors);
            Train.pf.N += this.increaseNSamplesPerEMIteration;
            Arbre<Taxon> reconstruction = Train.outputTree(SymmetricDiff.clades2arbre(mbr.centroid(SymmetricDiff.CLADE_SYMMETRIC_DIFFERENCE)), "consensusTree-" + emIter);
            if (this.data.hasReferenceClusters()) {
                LogInfo.logs("Purity:" + Purity.purity(Arbre.arbre2Tree(reconstruction), this.data.getReferenceClusters()));
            }
            if (this.estimationMethod == EstimationMethod.UNSUPERVISED) {
                LogInfo.track("Unsupervised reestimation of parameters");
                ctmc = ssp.reestimate(ctmc);
                CTMCUtils.saveInExec(ctmc, "unsup-reest-" + (emIter + 1));
                LogInfo.end_track();
            } else if (this.estimationMethod == EstimationMethod.SUPERVISED) {
                LogInfo.track("Supervised reestimation of parameters");
                ssp.process(Train.getGuideCoalescent(this.guideTreePath, this.data, this.guideTreeScaling), ctmc, this.data, 1.0);
                LogInfo.logs("Data likelihood before reestimation: " + DiscreteBP.dataLogLikelihood(Train.getGuideCoalescent(this.guideTreePath, this.data, this.guideTreeScaling), ctmc, this.data));
                ctmc = ssp.reestimate(ctmc);
                LogInfo.logs("Data likelihood after reestimation: " + DiscreteBP.dataLogLikelihood(Train.getGuideCoalescent(this.guideTreePath, this.data, this.guideTreeScaling), ctmc, this.data));
                CTMCUtils.saveInExec(ctmc, "sup-reest-" + (emIter + 1));
                LogInfo.end_track();
            }
            LogInfo.end_track();
        }
    }

    public static RootedTree getGuideCoalescent(String guideTreePath) {
        return Train.getGuideCoalescent(guideTreePath, null, 1.0);
    }

    public static RootedTree getGuideCoalescent(String guideTreePath, Dataset data, double guideTreeScaling) {
        RootedTree _guideCoalescent;
        if (guideTreePath.equals(autoGuideTree)) {
            Map<Taxon, String> clusters = data.getReferenceClusters();
            final HashMap<Taxon, Double> bl = new HashMap<Taxon, Double>();
            ArrayList children = new ArrayList();
            Set<Taxon> obsLang = data.observations().keySet();
            for (String cluster : new HashSet<String>(clusters.values())) {
                ArrayList children2 = new ArrayList();
                for (Taxon lang : clusters.keySet()) {
                    if (!obsLang.contains(lang) || !clusters.get(lang).equals(cluster)) continue;
                    Arbre<Taxon> leaf = Arbre.arbre(lang);
                    children2.add(leaf);
                    bl.put(lang, 1.0);
                }
                if (children2.size() <= 0) continue;
                Taxon curLang = new Taxon(cluster);
                children.add(Arbre.arbre(curLang, children2));
                bl.put(curLang, 2.0);
            }
            final Arbre<Taxon> root = Arbre.arbre(new Taxon("root"), children);
            _guideCoalescent = new RootedTree(){

                @Override
                public Arbre<Taxon> topology() {
                    return root;
                }

                @Override
                public Map<Taxon, Double> branchLengths() {
                    return bl;
                }

                @Override
                public int nTaxa() {
                    return root.nLeaves();
                }

                @Override
                public RootedTree getRooted() {
                    return this;
                }

                @Override
                public UnrootedTree getUnrooted() {
                    return UnrootedTree.fromRooted(this);
                }
            };
        } else {
            try {
                final RootedTree c = RootedTree.Util.load(new File(guideTreePath));
                final HashMap<Taxon, Double> bl = new HashMap<Taxon, Double>();
                for (Taxon lang : c.branchLengths().keySet()) {
                    bl.put(lang, c.branchLengths().get(lang) * guideTreeScaling);
                }
                _guideCoalescent = new RootedTree(){

                    @Override
                    public Arbre<Taxon> topology() {
                        return c.topology();
                    }

                    @Override
                    public Map<Taxon, Double> branchLengths() {
                        return bl;
                    }

                    @Override
                    public int nTaxa() {
                        return c.topology().nLeaves();
                    }

                    @Override
                    public RootedTree getRooted() {
                        return this;
                    }

                    @Override
                    public UnrootedTree getUnrooted() {
                        return UnrootedTree.fromRooted(this);
                    }
                };
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        Train.outputTree(_guideCoalescent.topology(), "guide-tree");
        return _guideCoalescent;
    }

    public static Arbre<Taxon> outputTree(Arbre<Taxon> reconstruction, String prefix) {
        return Train.outputTree(reconstruction, prefix, null);
    }

    public static Arbre<Taxon> outputTree(Arbre<Taxon> reconstruction, String prefix, Map<Taxon, Double> bl) {
        IO.writeToDisk(Execution.getFile(prefix + ".newick"), DataPrepUtils.newick(Arbre.arbre2Tree(reconstruction), bl, true));
        IO.writeToDisk(Execution.getFile(prefix + ".txt"), reconstruction.deepToString());
        LogInfo.logs("Reconstructed tree (" + prefix + "):\n" + reconstruction.deepToString());
        return reconstruction;
    }

    public static void main(String[] args) {
        Execution.monitor = true;
        Execution.makeThunk = false;
        Execution.create = true;
        Execution.useStandardExecPoolDirStrategy = true;
        if (!Arrays.asList(args).contains("NOJARS")) {
            Execution.jarFiles = new ArrayList<String>(Arrays.asList("/home/eecs/bouchard/jars/ptychodus.jar", "/home/eecs/bouchard/jars/nuts.jar", "/home/eecs/bouchard/jars/pepper.jar", "/home/eecs/bouchard/jars/fig.jar"));
        }
        pf = new ParticleFilter();
        loader = new CTMCLoader();
        Execution.run(args, new Train(), "wals", WalsDataset.class, "hgdb", HGDPDataset.class, "dataGen", GeneratedDataset.class, "paramGen", GeneratedDataset.genCTMCLoader, "filter", pf, "init", loader, "suffstat", LearningProcessor.class);
    }

    public static enum EstimationMethod {
        NONE,
        UNSUPERVISED,
        SUPERVISED;

    }
}

