/*
 * Decompiled with CFR 0.152.
 */
package pty.learn;

import goblin.Taxon;
import java.util.HashMap;
import java.util.Map;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.Graph;
import nuts.math.HashGraph;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Tree;
import pty.Observations;
import pty.RootedTree;
import pty.smc.PartialCoalescentState;
import pty.smc.models.CTMC;

public class DiscreteBP {
    public static GMFct<Taxon> posteriorMarginalTransitions(PartialCoalescentState pcs, int site) {
        return TreeSumProd.computeMoments(DiscreteBP.toGraphicalModel(pcs.getFullCoalescentState(), pcs.getCTMC(), pcs.getObservations(), site));
    }

    public static double dataLogLikelihood(RootedTree state, CTMC ctmc, Observations observations, int site) {
        return new TreeSumProd<Taxon>(DiscreteBP.toGraphicalModel(state, ctmc, observations, site)).logZ();
    }

    public static double dataLogLikelihood(RootedTree state, CTMC ctmc, Observations observations) {
        double sum = 0.0;
        for (int s = 0; s < ctmc.nSites(); ++s) {
            sum += DiscreteBP.dataLogLikelihood(state, ctmc, observations, s);
        }
        return sum;
    }

    public static GMFct<Taxon> posteriorMarginalTransitions(RootedTree state, CTMC ctmc, Observations observations, int site) {
        return DiscreteBP.posteriorMarginalTransitions(state, ctmc, observations, site, null);
    }

    public static GMFct<Taxon> posteriorMarginalTransitions(RootedTree state, CTMC ctmc, Observations observations, int site, Taxon languageToHeldout) {
        return TreeSumProd.computeMoments(DiscreteBP.toGraphicalModel(state, ctmc, observations, site, languageToHeldout));
    }

    public static SiteHomogeneousPhylogeneticGMFcts getPhylogeneticGMFcts(Map<Taxon, double[][]> observations, RootedTree tree, double[][] rateMatrix) {
        int nSites = CollUtils.pick(observations.values()).length;
        CTMC.SimpleCTMC ctmc = new CTMC.SimpleCTMC(rateMatrix, nSites);
        GMFct<Taxon> pairPotentials = DiscreteBP.toGraphicalModel(tree, ctmc, null, nSites);
        return new SiteHomogeneousPhylogeneticGMFcts(pairPotentials, observations);
    }

    public static GMFct<Taxon> toGraphicalModel(RootedTree state, CTMC ctmc, Observations observations, int site) {
        return DiscreteBP.toGraphicalModel(state, ctmc, observations, site, null);
    }

    public static GMFct<Taxon> toGraphicalModel(RootedTree state, CTMC ctmc, Observations observations, int site, Taxon languageToHeldout) {
        int nCharacters = ctmc.nCharacter(site);
        Tree<Taxon> t = Arbre.arbre2Tree(state.topology());
        HashGraph<Taxon> g = new HashGraph<Taxon>(t);
        HashMap<Taxon, Integer> rvRange = new HashMap<Taxon, Integer>();
        for (Taxon taxon : state.topology().nodeContents()) {
            rvRange.put(taxon, nCharacters);
        }
        TabularGMFct<Taxon> pot = GMFctUtils.ones(new TabularGMFct<Taxon>(g, rvRange));
        for (Arbre<Taxon> arbre : state.topology().nodes()) {
            if (arbre.isRoot()) continue;
            Taxon cur = arbre.getContents();
            Taxon par = arbre.getParent().getContents();
            double[][] prs = ctmc.getTransitionPr(site, state.branchLengths().get(arbre.getContents()));
            for (int ts = 0; ts < nCharacters; ++ts) {
                for (int bs = 0; bs < nCharacters; ++bs) {
                    pot.set(par, cur, ts, bs, prs[ts][bs]);
                }
            }
        }
        Taxon taxon = state.topology().getContents();
        double[] dArray = ctmc.getInitialDistribution(site);
        for (int c = 0; c < nCharacters; ++c) {
            pot.set(taxon, c, dArray[c]);
        }
        if (observations != null) {
            Map<Taxon, double[][]> obs = observations.observations();
            for (Taxon lang : obs.keySet()) {
                for (int c = 0; c < nCharacters; ++c) {
                    if (lang.equals(languageToHeldout)) {
                        pot.set(lang, c, 1.0);
                        continue;
                    }
                    pot.set(lang, c, obs.get(lang)[site][c]);
                }
            }
        }
        return pot;
    }

    public static void main(String[] args) {
    }

    public static class SiteHomogeneousPhylogeneticGMFcts {
        private final GMFct<Taxon> pairPotentials;
        private final Map<Taxon, double[][]> observations;

        public SiteHomogeneousPhylogeneticGMFcts(GMFct<Taxon> pairPotentials, Map<Taxon, double[][]> observations) {
            this.pairPotentials = pairPotentials;
            this.observations = observations;
        }

        public GMFct<Taxon> getGM(final int site) {
            return new GMFct<Taxon>(){

                @Override
                public double get(Taxon n1, Taxon n2, int s1, int s2) {
                    return pairPotentials.get(n1, n2, s1, s2);
                }

                @Override
                public Graph<Taxon> graph() {
                    return pairPotentials.graph();
                }

                @Override
                public int nStates(Taxon node) {
                    return pairPotentials.nStates(node);
                }

                @Override
                public double get(Taxon n, int s) {
                    return ((double[][])observations.get(n))[site][s];
                }
            };
        }
    }
}

