/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml.tests;

import conifer.ml.CTMCExpFam;
import conifer.ml.ExpectedStatistics;
import conifer.ml.OptimizationOptions;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import conifer.ml.extractors.IdentityExtractor;
import conifer.multicategories.PhylogeneticFactorGraph;
import fenchel.algo.FactorGraphSumProduct;
import fenchel.factor.multisitecat.MSCUnaryScaledFactor;
import fenchel.factor.multisitecat.MSCUtils;
import fig.basic.LogInfo;
import fig.basic.Pair;
import goblin.Taxon;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import ma.RateMatrixLoader;
import nuts.io.IO;
import nuts.math.GMFct;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.tui.Table;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.CounterMap;
import pty.io.Dataset;
import pty.learn.DiscreteBP;
import pty.smc.models.CTMC;

public class TestRealData
implements Runnable {
    public static PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions phyloDataOptions = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
    private CTMCExpFam<Character> frrm;
    private PhylogeneticHeldoutDataset phyloData;

    public static void main(String[] args) {
        IO.run(args, new TestRealData(), "phyloData", phyloDataOptions);
    }

    @Override
    public void run() {
        this.phyloData = PhylogeneticHeldoutDataset.loadData(phyloDataOptions);
        this.runEM();
    }

    private ExpectedStatistics<Character> oldEStep(double[][] currentRateMtx) {
        double[] perplexity = new double[1];
        double[] baselinePerplexity = new double[1];
        HashMap counts = CollUtils.map();
        LogInfo.track("old method");
        ExpectedStatistics<Character> currentStats = new ExpectedStatistics<Character>(this.frrm);
        CTMC.SimpleCTMC ctmc = new CTMC.SimpleCTMC(currentRateMtx, this.phyloData.obs.nSites());
        double logNorm = 0.0;
        for (int site = 0; site < this.phyloData.obs.nSites(); ++site) {
            try {
                GMFct<Taxon> pots = DiscreteBP.toGraphicalModel(this.phyloData.rootedTree, ctmc, this.phyloData.obs, site);
                LogInfo.logs("starting site " + (site + 1) + "/" + this.phyloData.obs.nSites());
                TreeSumProd<Taxon> tsp = new TreeSumProd<Taxon>(pots);
                logNorm += tsp.logZ();
                TabularGMFct<Taxon> moments = tsp.moments();
                Iterator<Serializable> iterator = this.phyloData.indexer.objects().iterator();
                while (iterator.hasNext()) {
                    char c = iterator.next().charValue();
                    int i = this.phyloData.indexer.o2i(Character.valueOf(c));
                    currentStats.addInitialValue(Character.valueOf(c), moments.get(this.phyloData.rootedTree.topology().getContents(), i));
                }
                for (Arbre arbre : this.phyloData.rootedTree.topology().nodes()) {
                    if (arbre.isRoot()) continue;
                    CounterMap<Character, Character> endPointCounts = new CounterMap<Character, Character>();
                    for (int pindex = 0; pindex < 4; ++pindex) {
                        for (int cindex = 0; cindex < 4; ++cindex) {
                            endPointCounts.incrementCount(this.phyloData.indexer.i2o(pindex), this.phyloData.indexer.i2o(cindex), moments.get((Taxon)arbre.getParent().getContents(), (Taxon)arbre.getContents(), pindex, cindex));
                        }
                    }
                    CollUtils.getNoNull(counts, arbre.getContents(), new CounterMap()).incrementAll(endPointCounts);
                }
                for (Taxon taxon : this.phyloData.rootedTree.topology().leaveContents()) {
                    Pair<Taxon, Integer> key = Pair.makePair(taxon, site);
                    if (!this.phyloData.heldOut.heldout.containsKey(key)) continue;
                    int correct = this.phyloData.heldOut.heldout.get(key);
                    perplexity[0] = perplexity[0] + Math.log(moments.get(taxon, correct));
                    baselinePerplexity[0] = baselinePerplexity[0] + Math.log(0.25);
                }
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        LogInfo.end_track();
        LogInfo.logsForce("old method logz = " + logNorm);
        LogInfo.end_track();
        LogInfo.track("Expected counts");
        int brIndex = 0;
        for (Taxon tax : counts.keySet()) {
            LogInfo.logs("starting branch " + (brIndex++ + 1) + "/" + counts.size());
            currentStats.addMarginalizedPath((CounterMap)counts.get(tax), currentRateMtx, (double)this.phyloData.rootedTree.branchLengths().get(tax));
        }
        LogInfo.end_track();
        LogInfo.logsForce("baseline = " + -baselinePerplexity[0]);
        LogInfo.logsForce("perplexity = " + -perplexity[0]);
        return currentStats;
    }

    public void runEM() {
        OptimizationOptions optimizationOptions = new OptimizationOptions();
        this.frrm = null;
        Set c1 = Collections.singleton(new IdentityExtractor());
        this.frrm.extractUnivariateFeatures(c1);
        this.frrm.extractReversibleBivariateFeatures(c1);
        double[][] currentRateMtx = RateMatrixLoader.hky85(1.0, 1.0);
        LogInfo.logsForce(Table.toString(currentRateMtx));
        for (int emIter = 0; emIter < 100; ++emIter) {
            LogInfo.track("E steps");
            ExpectedStatistics<Character> oldStat = this.oldEStep(currentRateMtx);
            ExpectedStatistics<Character> currentStats = this.newEStep(currentRateMtx);
            CTMCExpFam.LearnedReversibleModel currentModel = this.frrm.fitReversibleModel(optimizationOptions, oldStat, null);
            currentRateMtx = currentModel.getRateMatrix();
            LogInfo.logsForce(Table.toString(currentRateMtx));
        }
    }

    private ExpectedStatistics<Character> newEStep(double[][] currentRateMtx) {
        ExpectedStatistics<Character> currentStats = new ExpectedStatistics<Character>(this.frrm);
        LogInfo.track("Starting new method");
        PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(this.phyloData.rootedTree, currentRateMtx, this.phyloData.obs);
        FactorGraphSumProduct<Taxon> fgsp = pfg.getSumProductPosteriorCalculator();
        LogInfo.logsForce("new method: " + fgsp.logZ());
        Taxon root = this.phyloData.rootedTree.topology().getContents();
        MSCUnaryScaledFactor rFactor = (MSCUnaryScaledFactor)fgsp.moment(root);
        if (rFactor.nCategories != 1) {
            throw new RuntimeException();
        }
        for (int site = 0; site < pfg.nSites; ++site) {
            for (int i = 0; i < pfg.nCharacters; ++i) {
                currentStats.addInitialValue((Character)currentStats.model.stateIndexer.i2o(i), rFactor.get(site, 0, i));
            }
        }
        for (Arbre<Taxon> tax : this.phyloData.rootedTree.topology().nodes()) {
            if (tax.isRoot()) continue;
            Taxon node = tax.getContents();
            Taxon parent = tax.getParent().getContents();
            double[][][] marginal = MSCUtils.pairwiseExpectations(fgsp, parent, node);
            currentStats.addMarginalizedPath(marginal[0], currentRateMtx, (double)this.phyloData.rootedTree.branchLengths().get(node));
        }
        double perplexity = 0.0;
        for (Taxon taxon : this.phyloData.rootedTree.topology().leaveContents()) {
            MSCUnaryScaledFactor factor = (MSCUnaryScaledFactor)fgsp.moment(taxon);
            for (int site = 0; site < this.phyloData.obs.nSites(); ++site) {
                Pair<Taxon, Integer> key = Pair.makePair(taxon, site);
                if (!this.phyloData.heldOut.heldout.containsKey(key)) continue;
                int correct = this.phyloData.heldOut.heldout.get(key);
                perplexity += Math.log(factor.get(site, 0, correct));
            }
        }
        LogInfo.logsForce("perp = " + perplexity);
        LogInfo.end_track();
        return currentStats;
    }

    public static class SimpleObservations
    implements Dataset {
        private final Map<Taxon, double[][]> observations;

        public SimpleObservations(Map<Taxon, double[][]> observations) {
            this.observations = observations;
        }

        @Override
        public int nSites() {
            return CollUtils.pick(this.observations.values()).length;
        }

        @Override
        public int nCharacter(int site) {
            return CollUtils.pick(this.observations.values())[site].length;
        }

        @Override
        public Map<Taxon, double[][]> observations() {
            return this.observations;
        }

        @Override
        public boolean hasReferenceClusters() {
            return false;
        }

        @Override
        public Map<Taxon, String> getReferenceClusters() {
            return null;
        }
    }
}

