/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.test;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.exec.Execution;
import goblin.BayesRiskMinimizer;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.newick.NewickParser;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Tree;
import pty.RootedTree;
import pty.Train;
import pty.eval.Purity;
import pty.eval.SymmetricDiff;
import pty.io.Dataset;
import pty.io.HGDPDataset;
import pty.io.LeaveOneOut;
import pty.io.WalsDataset;
import pty.learn.CTMCLoader;
import pty.smc.ConditionalPriorPriorKernel;
import pty.smc.MapLeaves;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import pty.smc.models.BrownianModel;
import pty.smc.models.BrownianModelCalculator;
import pty.smc.models.CTMC;
import pty.smc.models.DiscreteModelCalculator;
import pty.smc.models.LikelihoodModelCalculator;
import pty.smc.models.ProductModel;
import pty.smc.test.TestBrownianModel;

public class TestJointModel
implements Runnable {
    @Option
    public String mapfile = "data/language-gene-map.txt";
    @Option
    public int gibbsIterations = 100;
    @Option
    public int increaseNSamplesPerGibbsIteration = 0;
    @Option
    public double variance = 0.1;
    @Option
    public boolean testAgainstFixedTree = false;
    @Option
    public String fixedTreePath = "data/hgdp/contml.all.newick";
    @Option
    public double agreementWeight = 1.0;
    @Option
    public boolean softAgreement = true;
    private static ParticleFilter<PartialCoalescentState> pf = new ParticleFilter();
    private static CTMCLoader langParamLoader = new CTMCLoader();
    private Dataset langData;
    private MapLeaves ml;
    private CTMC langParam;
    public static final Random rand = new Random(1L);
    PartialCoalescentState genepcs;
    PartialCoalescentState langpcs;
    private Set<Set<Taxon>> _fixedGS = null;

    public static void main(String[] args) {
        Execution.monitor = true;
        Execution.makeThunk = false;
        Execution.create = true;
        Execution.useStandardExecPoolDirStrategy = true;
        Execution.run(args, new TestJointModel(), "wals", WalsDataset.class, "hddp", HGDPDataset.class, "filter", pf, "langparam", langParamLoader, "cppk", ConditionalPriorPriorKernel.class);
    }

    private PartialCoalescentState initLanguageState() {
        this.langData = WalsDataset.getPreprocessedCorpus();
        langParamLoader.setData(this.langData);
        this.langParam = langParamLoader.load();
        return PartialCoalescentState.initState(this.langData, this.langParam);
    }

    public void gibbsSampler(PartialCoalescentState initGeneState, PartialCoalescentState initLangState) {
        ParticleKernel<PartialCoalescentState> ppk = ConditionalPriorPriorKernel.usesPriorPost ? TestBrownianModel.KernelType.PRIOR_POST2.load(initLangState, null) : TestBrownianModel.KernelType.PRIOR_PRIOR.load(initLangState, null);
        ParticleFilter.ParticleMapperProcessor<PartialCoalescentState, Set<Set<Taxon>>> processor = SymmetricDiff.createCladeProcessor();
        pf.sample(ppk, processor);
        Set<Set<Taxon>> currentGeneState = null;
        Set<Set<Taxon>> currentLangState = processor.map();
        Arbre<Taxon> reconstruction = Train.outputTree(SymmetricDiff.clades2arbre(processor.centroid(SymmetricDiff.CLADE_SYMMETRIC_DIFFERENCE)), "ling-consensusTree-init");
        if (this.langData.hasReferenceClusters()) {
            Tree<Taxon> recon = Arbre.arbre2Tree(reconstruction);
            Map<Taxon, String> allLabels = this.langData.getReferenceClusters();
            LogInfo.logs("Labels used for evaluation:" + Purity.partitionsUsedForEval(recon, allLabels));
            LogInfo.logs("Purity-init:" + Purity.purity(recon, allLabels));
        }
        Counter<Set<Set<Taxon>>> allLinguisticSamples = new Counter<Set<Set<Taxon>>>();
        Counter<Set<Set<Taxon>>> allGeneSamples = new Counter<Set<Set<Taxon>>>();
        for (int i = 0; i < this.gibbsIterations; ++i) {
            currentGeneState = this.testAgainstFixedTree ? this.loadFixedGeneState(this.fixedTreePath) : this.sampleBlock(initGeneState, currentLangState, allGeneSamples, false, i);
            currentLangState = this.sampleBlock(initLangState, currentGeneState, allLinguisticSamples, true, i);
            TestJointModel.pf.N += this.increaseNSamplesPerGibbsIteration;
        }
    }

    private Set<Set<Taxon>> loadFixedGeneState(String fixedTreePath) {
        if (this._fixedGS != null) {
            return this._fixedGS;
        }
        try {
            NewickParser np = new NewickParser(IOUtils.openIn(fixedTreePath));
            Tree<String> tree = np.parse();
            this._fixedGS = SymmetricDiff.cladesFromUnrooted(Arbre.tree2Arbre(tree).postOrderMap(new Arbre.ArbreMap<String, Taxon>(){

                @Override
                public Taxon map(Arbre<String> d) {
                    return new Taxon(d.getContents());
                }
            }));
            LogInfo.logs("Fixed constraints:" + this._fixedGS);
            return this._fixedGS;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private Set<Set<Taxon>> sampleBlock(PartialCoalescentState initCoalescentState, Set<Set<Taxon>> otherNodeState, Counter<Set<Set<Taxon>>> allSampleForCurrentNode, boolean isLang, int i) {
        ParticleFilter.ParticleMapperProcessor<PartialCoalescentState, Set<Set<Taxon>>> processor = SymmetricDiff.createCladeProcessor();
        ParticleFilter.PCSHash hashProcessor = new ParticleFilter.PCSHash();
        ParticleFilter.MAPDecoder mapDecoder = new ParticleFilter.MAPDecoder();
        ParticleFilter.ForkedProcessor processors = new ParticleFilter.ForkedProcessor(processor, hashProcessor, mapDecoder);
        ConditionalPriorPriorKernel pk = new ConditionalPriorPriorKernel(initCoalescentState, otherNodeState, this.ml, this.agreementWeight);
        pf.sample(pk, processors);
        allSampleForCurrentNode.incrementAll(processor.getCounter());
        RootedTree map = ((PartialCoalescentState)mapDecoder.map()).getFullCoalescentState();
        String prefix = isLang ? "ling" : "bio";
        Train.outputTree(map.topology(), prefix + "-mapTree-" + i, map.branchLengths());
        Arbre<Taxon> reconstruction = Train.outputTree(SymmetricDiff.clades2arbre(new BayesRiskMinimizer<Set<Taxon>>(SymmetricDiff.CLADE_SYMMETRIC_DIFFERENCE).findMin(allSampleForCurrentNode)), prefix + "-consensusTree-" + i);
        if (isLang && this.langData.hasReferenceClusters()) {
            LogInfo.logs("Purity-" + i + ":" + Purity.purity(Arbre.arbre2Tree(reconstruction), this.langData.getReferenceClusters()));
        }
        if (isLang) {
            LogInfo.logs("LOO-" + i + ":" + LeaveOneOut.loo((PartialCoalescentState)mapDecoder.map()));
        }
        LogInfo.logs("Hash-" + i + (isLang ? "-lang" : "-bio") + "=" + hashProcessor.getHash());
        LogInfo.logs("TopologyDistributionLargestPrt=" + processor.getCounter().max());
        return processor.sample(rand);
    }

    @Override
    public void run() {
        PartialCoalescentState geneState = TestBrownianModel.initGeneState(this.variance);
        PartialCoalescentState languageState = this.initLanguageState();
        this.ml = MapLeaves.parse(this.mapfile);
        if (this.softAgreement) {
            this.gibbsSampler(geneState, languageState);
        } else {
            this.hardAgreement();
        }
    }

    private void hardAgreement() {
        PartialCoalescentState jointInit = this.initHardState();
        ParticleKernel<PartialCoalescentState> ppk = ConditionalPriorPriorKernel.usesPriorPost ? TestBrownianModel.KernelType.PRIOR_POST2.load(jointInit, null) : TestBrownianModel.KernelType.PRIOR_PRIOR.load(jointInit, null);
        ParticleFilter.ParticleMapperProcessor<PartialCoalescentState, Set<Set<Taxon>>> processor = SymmetricDiff.createCladeProcessor();
        ParticleFilter.PCSHash hashProcessor = new ParticleFilter.PCSHash();
        ParticleFilter.MAPDecoder mapDecoder = new ParticleFilter.MAPDecoder();
        ParticleFilter.ForkedProcessor processors = new ParticleFilter.ForkedProcessor(processor, hashProcessor, mapDecoder);
        pf.sample(ppk, processors);
        RootedTree map = ((PartialCoalescentState)mapDecoder.map()).getFullCoalescentState();
        Train.outputTree(map.topology(), "mapTree", map.branchLengths());
        Arbre<Taxon> reconstruction = Train.outputTree(SymmetricDiff.clades2arbre(processor.centroid(SymmetricDiff.CLADE_SYMMETRIC_DIFFERENCE)), "consensusTree");
        if (this.langData.hasReferenceClusters()) {
            LogInfo.logs("Purity:" + Purity.purity(Arbre.arbre2Tree(reconstruction), this.langData.getReferenceClusters()));
        }
        LogInfo.logs("Hash=" + hashProcessor.getHash());
    }

    @Deprecated
    private PartialCoalescentState initHardState() {
        ArrayList<Taxon> leafNames = new ArrayList<Taxon>();
        ArrayList<LikelihoodModelCalculator> leaves = new ArrayList<LikelihoodModelCalculator>();
        Map<Taxon, double[][]> langObservations = this.langData.observations();
        Map<Taxon, double[][]> bioObservations = Dataset.DatasetType.HGDP.loadDataset().observations();
        if (langObservations.keySet().size() != bioObservations.keySet().size()) {
            throw new RuntimeException();
        }
        for (Taxon lang : langObservations.keySet()) {
            Taxon bioEq = this.ml.translate(lang);
            leafNames.add(lang);
            ArrayList<LikelihoodModelCalculator> models = CollUtils.list();
            models.add(DiscreteModelCalculator.observation(this.langParam, langObservations.get(lang)));
            double[][] cObs = bioObservations.get(bioEq);
            double[] converted = new double[cObs.length];
            for (int i = 0; i < converted.length; ++i) {
                converted[i] = cObs[i][0];
            }
            BrownianModel bm = new BrownianModel(converted.length, this.variance);
            models.add(BrownianModelCalculator.observation(converted, bm, false));
            leaves.add(new ProductModel(models));
        }
        return PartialCoalescentState.initialState(leaves, leafNames, null, true);
    }
}

