/*
 * Decompiled with CFR 0.152.
 */
package conifer.spr;

import conifer.data.DataModel;
import conifer.data.DataOptions;
import conifer.data.PhylogeneticData;
import conifer.ml.tests.TestRealData;
import conifer.multicategories.PhylogeneticFactorGraph;
import fenchel.factor.UnaryFactor;
import fenchel.factor.multisitecat.MSCBinaryScaledFactor;
import fig.basic.NumUtils;
import fig.basic.UnorderedPair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.RateMatrixLoader;
import nuts.math.MutableGraph;
import nuts.util.Arbre;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import pty.Observations;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.io.Dataset;

public class SPROperator {
    public final Indexer<Regraft> indexer = new Indexer();
    private final List<Double> unnormalizedLogPrs = new ArrayList<Double>();
    private double[] normalizedPrs = null;
    private final double[] stemLens;
    private final List<UnaryFactor> preConvolvedStems;
    public final PhylogeneticFactorGraph factorGraphBeforeRegraft;
    private final Taxon additionalLeafName;

    public SPROperator(PhylogeneticFactorGraph factorGraphBeforeRegraft, UnaryFactor additionalLeaf, Taxon additionalLeafName, double[] stemLens) {
        this.additionalLeafName = additionalLeafName;
        this.factorGraphBeforeRegraft = factorGraphBeforeRegraft;
        this.stemLens = stemLens;
        this.preConvolvedStems = SPROperator.preConvolve(additionalLeaf, stemLens, factorGraphBeforeRegraft);
        if (factorGraphBeforeRegraft.observations.observations().containsKey(additionalLeafName)) {
            throw new RuntimeException();
        }
    }

    public void addRegrafts(double fractionFromBottom) {
        if (this.normalizedPrs != null) {
            throw new RuntimeException();
        }
        PhylogeneticFactorGraph factorGraphBeforeRegraft = this.factorGraphBeforeRegraft.createWithStemFraction(fractionFromBottom);
        for (Arbre<Taxon> subtree : factorGraphBeforeRegraft.rootedTree.topology().nodes()) {
            if (subtree.isRoot()) continue;
            Taxon top = subtree.getParent().getContents();
            Taxon bot = subtree.getContents();
            for (int i = 0; i < this.stemLens.length; ++i) {
                double stemLen = this.stemLens[i];
                UnaryFactor factor = this.preConvolvedStems.get(i);
                double unnormalizedLogPr = SPROperator.evaluateRegraftLogProbability(factorGraphBeforeRegraft, factor, top, bot);
                this.unnormalizedLogPrs.add(unnormalizedLogPr);
                this.indexer.addToIndex((Regraft[])new Regraft[]{new Regraft(top, bot, fractionFromBottom, stemLen, factorGraphBeforeRegraft.rootedTree, this.additionalLeafName)});
                if (this.indexer.size() == this.unnormalizedLogPrs.size()) continue;
                throw new RuntimeException();
            }
        }
    }

    public double[] getNormalizedProbabilities() {
        if (this.normalizedPrs != null) {
            return this.normalizedPrs;
        }
        this.normalizedPrs = new double[this.unnormalizedLogPrs.size()];
        for (int i = 0; i < this.unnormalizedLogPrs.size(); ++i) {
            this.normalizedPrs[i] = this.unnormalizedLogPrs.get(i);
        }
        NumUtils.expNormalize(this.normalizedPrs);
        return this.normalizedPrs;
    }

    public Counter<Regraft> getNormalizedPrCounter() {
        double[] prs = this.getNormalizedProbabilities();
        Counter<Regraft> result = new Counter<Regraft>();
        for (int i = 0; i < prs.length; ++i) {
            result.setCount(this.getRegraft(i), prs[i]);
        }
        return result;
    }

    public int sample(Random rand) {
        return SampleUtils.sampleMultinomial(rand, this.getNormalizedProbabilities());
    }

    public int argmax() {
        return MathUtils.argmax(this.getNormalizedProbabilities());
    }

    public Regraft getRegraft(int i) {
        return this.indexer.i2o(i);
    }

    private static List<UnaryFactor> preConvolve(UnaryFactor additionalLeaf, double[] stemLens, PhylogeneticFactorGraph fg) {
        ArrayList<UnaryFactor> result = new ArrayList<UnaryFactor>();
        for (int i = 0; i < stemLens.length; ++i) {
            MSCBinaryScaledFactor binary = fg.getBinaryFactor(stemLens[i], true);
            result.add(binary.marginalize(Collections.singletonList(additionalLeaf)));
        }
        return result;
    }

    public static final double evaluateRegraftLogProbability(PhylogeneticFactorGraph factorGraph, UnaryFactor toRegraft, Taxon top, Taxon bot) {
        Taxon intermediate = PhylogeneticFactorGraph.intermediateTaxon(top, bot);
        return factorGraph.getSumProductPosteriorCalculator().moment(intermediate).multiply(Collections.singletonList(toRegraft)).logNorm();
    }

    public static void addEdge(Taxon t1, Taxon t2, double len, MutableGraph<Taxon> topo, Map<UnorderedPair<Taxon, Taxon>, Double> branchLengths) {
        UnorderedPair<Taxon, Taxon> edge = new UnorderedPair<Taxon, Taxon>(t1, t2);
        topo.addEdge(edge);
        branchLengths.put(edge, len);
    }

    public static void main(String[] args) {
        DataOptions syntheticOptions = new DataOptions();
        syntheticOptions.generatingTreeOptions.nTaxa = 4;
        PhylogeneticData data = DataModel.GENERATED.loadDataset(syntheticOptions);
        RootedTree trueTree = data.getPhylogeny().getRooted();
        System.out.println("true tree:\n" + trueTree);
        Dataset obs = (Dataset)((Object)data.getObservedTaxonIndexData());
        double[][] rateMatrix = RateMatrixLoader.k2p();
        Random rand = new Random(1L);
        PhylogeneticFactorGraph pfg = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(trueTree, rateMatrix, obs);
        System.out.println("true tree ll: " + pfg.getSumProductPosteriorCalculator().logZ());
        RootedTree bigStar = SPROperator.starInitialization(obs.observations().keySet(), 1.0);
        PhylogeneticFactorGraph starFG = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(bigStar, rateMatrix, obs);
        Map<Taxon, UnaryFactor> leafFactors = SPROperator.leafFactors(starFG);
        pfg = null;
        List<Taxon> taxa = new ArrayList<Taxon>(obs.observations().keySet());
        Collections.sort(taxa);
        Collections.shuffle(taxa, rand);
        Taxon outside = (Taxon)taxa.get(3);
        taxa = taxa.subList(0, 3);
        RootedTree rt = SPROperator.starInitialization(taxa, 0.5);
        System.out.println("star init:\n" + rt);
        TestRealData.SimpleObservations restricted = SPROperator.restrictedObservations(obs, new HashSet<Taxon>(taxa));
        PhylogeneticFactorGraph beforeGraftFC = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(rt, rateMatrix, restricted);
        double[] stemLens = new double[]{0.05, 0.1, 0.2, 0.4, 0.8};
        SPROperator spr = new SPROperator(beforeGraftFC, leafFactors.get(outside), outside, stemLens);
        spr.addRegrafts(0.5);
        UnrootedTree ut = spr.getRegraft(spr.argmax()).treeAfterRegraft();
        System.out.println(ut);
        System.out.println(spr.unnormalizedLogPrs);
        PhylogeneticFactorGraph regraftedFG = PhylogeneticFactorGraph.createSingleCategoryFromStationaryProcess(ut.reRootAtNode(ut.randomNonTerminalEdge(new Random()).getFirst()), rateMatrix, obs);
        System.out.println(regraftedFG.getSumProductPosteriorCalculator().logZ());
    }

    public static RootedTree starInitialization(Collection<Taxon> leaves, double bl) {
        ArrayList children = new ArrayList();
        for (Taxon t : leaves) {
            children.add(Arbre.arbre(t));
        }
        Arbre<Taxon> topo = new Arbre<Taxon>(new Taxon("internal_0"), children);
        HashMap<Taxon, Double> bls = new HashMap<Taxon, Double>();
        for (Taxon t : leaves) {
            bls.put(t, bl);
        }
        return RootedTree.Util.create(topo, bls);
    }

    public static Map<Taxon, UnaryFactor> leafFactors(PhylogeneticFactorGraph pfg) {
        HashMap<Taxon, UnaryFactor> result = new HashMap<Taxon, UnaryFactor>();
        Set<Taxon> taxa = pfg.observations.observations().keySet();
        for (Taxon t : taxa) {
            result.put(t, pfg.getSumProductPosteriorCalculator().getMessage(PhylogeneticFactorGraph.auxiliaryTaxonObservationNode(t), t));
        }
        return result;
    }

    public static TestRealData.SimpleObservations restrictedObservations(Observations obs, Set<Taxon> restriction) {
        HashMap<Taxon, double[][]> restrictedMap = new HashMap<Taxon, double[][]>();
        for (Taxon t : obs.observations().keySet()) {
            if (!restriction.contains(t)) continue;
            restrictedMap.put(t, obs.observations().get(t));
        }
        return new TestRealData.SimpleObservations(restrictedMap);
    }

    public List<Double> getUnnormalizedPrs() {
        return this.unnormalizedLogPrs;
    }

    public static class Regraft {
        public final Taxon top;
        public final Taxon bot;
        public final double fractionFromBot;
        public final double stemLen;
        public final RootedTree beforeRG;
        public final Taxon graftedLeafName;

        public Regraft(Taxon top, Taxon bot, double fractionFromBot, double stemLen, RootedTree beforeRG, Taxon graftedTaxon) {
            this.top = top;
            this.bot = bot;
            this.fractionFromBot = fractionFromBot;
            this.stemLen = stemLen;
            this.beforeRG = beforeRG;
            this.graftedLeafName = graftedTaxon;
        }

        public String toString() {
            return "Regraft [top=" + this.top + ", bot=" + this.bot + ", fractionFromBot=" + this.fractionFromBot + ", stemLen=" + this.stemLen + "]";
        }

        public UnrootedTree treeAfterRegraft() {
            Taxon extra = new Taxon("internal_" + this.beforeRG.topology().nInternalNodes());
            MutableGraph<Taxon> topo = new MutableGraph<Taxon>();
            HashMap<UnorderedPair<Taxon, Taxon>, Double> branchLengths = new HashMap<UnorderedPair<Taxon, Taxon>, Double>();
            for (Arbre<Taxon> subtree : this.beforeRG.topology().nodes()) {
                Taxon bot = subtree.getContents();
                if (extra.equals(bot)) {
                    throw new RuntimeException();
                }
                if (subtree.isRoot()) continue;
                Taxon top = subtree.getParent().getContents();
                double oldBL = this.beforeRG.branchLengths().get(bot);
                if (bot.equals(this.bot)) {
                    SPROperator.addEdge(bot, extra, oldBL * this.fractionFromBot, topo, branchLengths);
                    SPROperator.addEdge(extra, top, oldBL * (1.0 - this.fractionFromBot), topo, branchLengths);
                    SPROperator.addEdge(extra, this.graftedLeafName, this.stemLen, topo, branchLengths);
                    continue;
                }
                SPROperator.addEdge(top, bot, oldBL, topo, branchLengths);
            }
            return new UnrootedTree(topo, branchLengths);
        }
    }
}

