/*
 * Decompiled with CFR 0.152.
 */
package conifer.multicategories;

import conifer.ml.data.CharacterReconstructionMethod;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import conifer.multicategories.PhylogenyPotentials;
import fenchel.algo.FactorGraphSumProduct;
import fenchel.factor.multisitecat.MSCBinaryScaledFactor;
import fenchel.factor.multisitecat.MSCFactorGraph;
import fenchel.factor.multisitecat.MSCUnaryScaledFactor;
import fenchel.factor.multisitecat.MSCUtils;
import goblin.Taxon;
import java.util.Collection;
import ma.RateMatrixLoader;
import nuts.math.RateMtxUtils;
import nuts.util.Arbre;
import pty.Observations;
import pty.RootedTree;

public class PhylogeneticFactorGraph
implements CharacterReconstructionMethod {
    public final PhylogenyPotentials potentials;
    public final RootedTree rootedTree;
    public final Observations observations;
    private MSCFactorGraph<Taxon> factorGraph = null;
    private double fraction = 1.0;
    private FactorGraphSumProduct<Taxon> sumProd = null;
    public final int nCategories;
    public final int nCharacters;
    public final int nSites;

    public FactorGraphSumProduct<Taxon> getSumProductPosteriorCalculator() {
        if (this.sumProd != null) {
            return this.sumProd;
        }
        this.sumProd = new FactorGraphSumProduct();
        this.sumProd.init(this.getFactorGraph());
        return this.sumProd;
    }

    @Override
    public double[] posteriorOverCharacters(Taxon t, int site) {
        MSCUnaryScaledFactor fullNodePost = this.getNodePosterior(t);
        double[] result = new double[this.nCharacters];
        for (int cat = 0; cat < this.nCategories; ++cat) {
            for (int character = 0; character < this.nCharacters; ++character) {
                int n = character;
                result[n] = result[n] + fullNodePost.get(site, cat, character);
            }
        }
        return result;
    }

    public double[][][] getPairwisePosterior(Taxon parent, Taxon node) {
        return MSCUtils.pairwiseExpectations(this.getSumProductPosteriorCalculator(), parent, node);
    }

    public double[][] getNodePosteriorSummingOverSites(Taxon node) {
        double[][] result = new double[this.nCategories][this.nCharacters];
        MSCUnaryScaledFactor factor = this.getNodePosterior(node);
        for (int s = 0; s < this.nSites; ++s) {
            for (int cat = 0; cat < this.nCategories; ++cat) {
                for (int x = 0; x < this.nCharacters; ++x) {
                    double[] dArray = result[cat];
                    int n = x;
                    dArray[n] = dArray[n] + factor.get(s, cat, x);
                }
            }
        }
        return result;
    }

    public MSCUnaryScaledFactor getNodePosterior(Taxon taxon) {
        return (MSCUnaryScaledFactor)this.getSumProductPosteriorCalculator().moment(taxon);
    }

    public static PhylogeneticFactorGraph createSingleCategoryFromStationaryProcess(RootedTree rt, double[][] rateMatrix, Observations observations) {
        return new PhylogeneticFactorGraph(rt, PhylogenyPotentials.createSingleCategoryErrorFreeFromStationaryProcess(rateMatrix), observations);
    }

    public static PhylogeneticFactorGraph createFromStationaryProcess(RootedTree rt, double[] categoryPriors, double[][][] rateMatrices, Observations observations, double observationErrorPr) {
        return new PhylogeneticFactorGraph(rt, PhylogenyPotentials.createFromStationaryProcess(categoryPriors, rateMatrices, observationErrorPr), observations);
    }

    public PhylogeneticFactorGraph(RootedTree rt, PhylogenyPotentials potentials, Observations observations) {
        this.rootedTree = rt;
        this.potentials = potentials;
        this.observations = observations;
        this.nCategories = potentials.categoryPriors.length;
        this.nCharacters = potentials.stationaryDistributions[0].length;
        this.nSites = observations.nSites();
    }

    public void setInitialFactor(MSCUnaryScaledFactor factor) {
        for (int s = 0; s < this.nSites; ++s) {
            for (int c = 0; c < this.nCategories; ++c) {
                double catPr = this.potentials.categoryPriors[c];
                for (int x = 0; x < this.nCharacters; ++x) {
                    factor.set(s, c, x, catPr * this.potentials.stationaryDistributions[c][x]);
                }
            }
        }
    }

    public double[][] getMarginalMtx(Taxon child, int category) {
        double T = this.rootedTree.branchLengths().get(child);
        return RateMtxUtils.marginalTransitionMtx(this.potentials.rateMatrices[category], T);
    }

    public static Taxon intermediateTaxon(Taxon top, Taxon bot) {
        return new Taxon("intermediate(" + top.toString() + "," + bot.toString() + ")");
    }

    public MSCFactorGraph<Taxon> getFactorGraph() {
        if (this.factorGraph != null) {
            return this.factorGraph;
        }
        this.factorGraph = new MSCFactorGraph(this.nSites, this.nCategories, this.nCharacters);
        MSCUnaryScaledFactor root = this.factorGraph.getMSCUnary(this.rootedTree.topology().getContents());
        this.setInitialFactor(root);
        for (int c = 0; c < this.nCategories; ++c) {
            for (Arbre<Taxon> node : this.rootedTree.topology().nodes()) {
                if (node.isRoot()) continue;
                Taxon currentTax = node.getContents();
                Taxon parentTax = node.getParent().getContents();
                if (this.fraction == 0.0 || this.fraction == 1.0) {
                    double[][] currentMargMtx = this.getMarginalMtx(currentTax, c);
                    this.factorGraph.setBinary(c, parentTax, currentTax, currentMargMtx);
                    continue;
                }
                Taxon intermediate = PhylogeneticFactorGraph.intermediateTaxon(parentTax, currentTax);
                double T = this.rootedTree.branchLengths().get(currentTax);
                double top2intT = T * (1.0 - this.fraction);
                double int2botT = T * this.fraction;
                double[][] top2intM = RateMtxUtils.marginalTransitionMtx(this.potentials.rateMatrices[c], top2intT);
                double[][] int2botM = RateMtxUtils.marginalTransitionMtx(this.potentials.rateMatrices[c], int2botT);
                this.factorGraph.setBinary(c, parentTax, intermediate, top2intM);
                this.factorGraph.setBinary(c, intermediate, currentTax, int2botM);
            }
        }
        double[][] noiseTransition = PhylogeneticFactorGraph.noiseTransition(this.nCharacters, this.potentials.observationErrorProbability);
        for (Taxon observedTaxon : this.observedTaxa()) {
            if (this.rootedTree.topology().getContents().equals(observedTaxon)) {
                throw new RuntimeException("This implementation does not support the root to be observed");
            }
            double[][] currentObservation = this.observations.observations().get(observedTaxon);
            TaxonObservation auxiliary = PhylogeneticFactorGraph.auxiliaryTaxonObservationNode(observedTaxon);
            MSCUnaryScaledFactor currentTaxonFactor = this.factorGraph.getMSCUnary(auxiliary);
            for (int s = 0; s < this.nSites; ++s) {
                for (int c = 0; c < this.nCategories; ++c) {
                    for (int x = 0; x < this.nCharacters; ++x) {
                        currentTaxonFactor.set(s, c, x, currentObservation[s][x]);
                    }
                }
            }
            for (int c = 0; c < this.nCategories; ++c) {
                this.factorGraph.setBinary(c, observedTaxon, auxiliary, noiseTransition);
            }
        }
        return this.factorGraph;
    }

    public Collection<Taxon> observedTaxa() {
        return this.observations.observations().keySet();
    }

    public static TaxonObservation auxiliaryTaxonObservationNode(Taxon t) {
        return new TaxonObservation(t.toString());
    }

    public static double[][] noiseTransition(int nCharacters, double observationErrorProbability) {
        double[][] result = new double[nCharacters][nCharacters];
        double correctReadPr = 1.0 - observationErrorProbability;
        double individualErrorPr = observationErrorProbability / ((double)nCharacters - 1.0);
        for (int x = 0; x < nCharacters; ++x) {
            for (int y = 0; y < nCharacters; ++y) {
                result[x][y] = x == y ? correctReadPr : individualErrorPr;
            }
        }
        return result;
    }

    public PhylogeneticFactorGraph createWithNewTopology(RootedTree newRooted) {
        return new PhylogeneticFactorGraph(newRooted, this.potentials, this.observations);
    }

    public PhylogeneticFactorGraph createWithStemFraction(double fraction) {
        if (fraction < 0.0 || fraction > 1.0) {
            throw new RuntimeException();
        }
        PhylogeneticFactorGraph result = new PhylogeneticFactorGraph(this.rootedTree, this.potentials, this.observations);
        result.fraction = fraction;
        return result;
    }

    public MSCBinaryScaledFactor getBinaryFactor(double d, boolean reverse) {
        MSCBinaryScaledFactor currentFactor = new MSCBinaryScaledFactor(this.nSites, this.nCategories, this.nCharacters);
        for (int c = 0; c < this.nCategories; ++c) {
            double[][] marginal = RateMtxUtils.marginalTransitionMtx(this.potentials.rateMatrices[c], d);
            for (int c0 = 0; c0 < this.nCharacters; ++c0) {
                for (int c1 = 0; c1 < this.nCharacters; ++c1) {
                    currentFactor.set(c, c0, c1, reverse ? marginal[c1][c0] : marginal[c0][c1]);
                }
            }
        }
        return currentFactor;
    }

    public double getStemFraction() {
        return this.fraction;
    }

    public double dataLogLikelihoodGivenTreeAndParameters() {
        return this.getSumProductPosteriorCalculator().logZ();
    }

    public double observationErrorCount() {
        double result = 0.0;
        for (Taxon observedTaxon : this.observedTaxa()) {
            TaxonObservation auxiliary = PhylogeneticFactorGraph.auxiliaryTaxonObservationNode(observedTaxon);
            double[][][] posteriors = this.getPairwisePosterior(observedTaxon, auxiliary);
            for (int c = 0; c < this.nCategories; ++c) {
                for (int x = 0; x < this.nCharacters; ++x) {
                    for (int y = 0; y < this.nCharacters; ++y) {
                        if (x == y) continue;
                        result += posteriors[c][x][y];
                    }
                }
            }
        }
        return result;
    }

    public static void main(String[] args) {
        PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions opts = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
        opts.alignmentFile = "/Users/bouchard/Documents/data/utcs/23S.E/R0/cleaned.alignment.fasta";
        opts.treeFile = "/Users/bouchard/Documents/data/utcs/23S.E.raxml.nwk";
        PhylogeneticHeldoutDataset phyloData = PhylogeneticHeldoutDataset.loadData(opts);
        PhylogenyPotentials potentials = PhylogenyPotentials.createSingleCategoryErrorFreeFromStationaryProcess(RateMatrixLoader.k2p());
        long sum = 0L;
        for (int i = 0; i < 5000; ++i) {
            long start = System.currentTimeMillis();
            PhylogeneticFactorGraph factorGraph = new PhylogeneticFactorGraph(phyloData.rootedTree, potentials, phyloData.obs);
            factorGraph.dataLogLikelihoodGivenTreeAndParameters();
            long delta = System.currentTimeMillis() - start;
            sum += delta;
            System.out.println(delta);
        }
        System.out.println("us: " + sum);
    }

    public static class TaxonObservation
    extends Taxon {
        private static final long serialVersionUID = 1L;

        public TaxonObservation(String string) {
            super("OBSERVED_SEQUENCE:" + string);
        }
    }
}

