/*
 * Decompiled with CFR 0.152.
 */
package pty.mcmc;

import conifer.Phylogeny;
import conifer.particle.PhyloParticle;
import conifer.particle.PhyloParticleInitContext;
import fig.basic.NumUtils;
import fig.basic.Pair;
import goblin.Taxon;
import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ma.SequenceType;
import monaco.Density;
import nuts.util.CollUtils;
import pty.UnrootedTree;
import pty.io.Dataset;
import pty.io.HGDPDataset;
import pty.smc.PartialCoalescentState;
import pty.smc.models.BrownianModel;
import pty.smc.models.BrownianModelCalculator;
import pty.smc.models.CTMC;
import pty.smc.models.FastDiscreteModelCalculator;
import pty.smc.models.LikelihoodModelCalculator;

public final class UnrootedTreeState
implements PhyloParticle,
Serializable {
    private static final long serialVersionUID = 1L;
    private Density<PhyloParticle> prior;
    private double logPrior;
    private final int generationIndex;
    private UnrootedTree t;
    private Map<Taxon, LikelihoodModelCalculator> likelihoodModels;
    private double loglikelihood = Double.NaN;

    public UnrootedTreeState() {
        this.generationIndex = 0;
    }

    @Override
    public void init(List<Pair<Taxon, LikelihoodModelCalculator>> calculators, Density<PhyloParticle> prior, PhyloParticleInitContext context) {
        if (this.likelihoodModels != null) {
            throw new RuntimeException("Looks like trying to init a UTS constructed from legacy constructors");
        }
        this.likelihoodModels = CollUtils.map();
        for (int i = 0; i < calculators.size(); ++i) {
            this.likelihoodModels.put(calculators.get(i).getFirst(), calculators.get(i).getSecond());
        }
        this.t = context.getTree().getUnrooted();
        this.prior = prior;
        this.logPrior = prior.logDensity(this);
    }

    @Override
    public Phylogeny getPhylogeny() {
        return this.t;
    }

    @Override
    public double getLogLikelihood() {
        return this.logLikelihood();
    }

    @Override
    public double getLogPrior() {
        return this.logPrior;
    }

    @Override
    public int generationIndex() {
        return this.generationIndex;
    }

    public UnrootedTree getUnrootedTree() {
        return this.t;
    }

    public UnrootedTree getNonClockTree() {
        return this.t;
    }

    public UnrootedTreeState(UnrootedTree t, Map<Taxon, LikelihoodModelCalculator> likelihoodModels, Density<PhyloParticle> prior, int generationIndex) {
        this.generationIndex = generationIndex;
        this.t = t;
        this.likelihoodModels = likelihoodModels;
        this.prior = prior;
        if (prior != null) {
            this.logPrior = prior.logDensity(this);
        }
    }

    @Deprecated
    public static UnrootedTreeState initFastState(UnrootedTree t, Dataset data, CTMC ctmc) {
        ArrayList<Taxon> leafNames = new ArrayList<Taxon>();
        HashMap<Taxon, LikelihoodModelCalculator> leaves = CollUtils.map();
        Map<Taxon, double[][]> observations = data.observations();
        for (Taxon lang : observations.keySet()) {
            leafNames.add(lang);
            leaves.put(lang, FastDiscreteModelCalculator.observation(ctmc, observations.get(lang), false));
        }
        return new UnrootedTreeState(t, leaves, null, 0);
    }

    @Deprecated
    public static UnrootedTreeState initFastState(UnrootedTree t, Dataset data, CTMC ctmc, Density<PhyloParticle> prior) {
        ArrayList<Taxon> leafNames = new ArrayList<Taxon>();
        HashMap<Taxon, LikelihoodModelCalculator> leaves = CollUtils.map();
        Map<Taxon, double[][]> observations = data.observations();
        for (Taxon lang : observations.keySet()) {
            leafNames.add(lang);
            leaves.put(lang, FastDiscreteModelCalculator.observation(ctmc, observations.get(lang), false));
        }
        return new UnrootedTreeState(t, leaves, prior, 0);
    }

    @Deprecated
    public static UnrootedTreeState initFastState(UnrootedTree t, Map<Taxon, double[][]> observations, CTMC ctmc) {
        ArrayList<Taxon> leafNames = new ArrayList<Taxon>();
        HashMap<Taxon, LikelihoodModelCalculator> leaves = CollUtils.map();
        for (Taxon lang : observations.keySet()) {
            leafNames.add(lang);
            leaves.put(lang, FastDiscreteModelCalculator.observation(ctmc, observations.get(lang), false));
        }
        return new UnrootedTreeState(t, leaves, null, 0);
    }

    public UnrootedTreeState deepClone() {
        UnrootedTree nct = new UnrootedTree(this.t);
        return new UnrootedTreeState(nct, this.likelihoodModels, this.prior, this.generationIndex);
    }

    public UnrootedTreeState copyAndChange(UnrootedTree newTree) {
        return new UnrootedTreeState(newTree, this.likelihoodModels, this.prior, this.generationIndex + 1);
    }

    public static UnrootedTreeState fromAlignment(UnrootedTree tree, File align, SequenceType st) {
        Dataset data = Dataset.DatasetUtils.fromAlignment(align, st);
        CTMC.SimpleCTMC ctmc = null;
        if (st == SequenceType.RNA || st == SequenceType.DNA) {
            ctmc = CTMC.SimpleCTMC.dnaCTMC(data.nSites());
        } else if (st == SequenceType.PROTEIN) {
            ctmc = CTMC.SimpleCTMC.proteinCTMC(data.nSites());
        } else {
            throw new RuntimeException();
        }
        ArrayList<Taxon> leafNames = new ArrayList<Taxon>();
        HashMap leaves = CollUtils.map();
        Map<Taxon, double[][]> observations = data.observations();
        for (Taxon lang : observations.keySet()) {
            leafNames.add(lang);
            leaves.put(lang, FastDiscreteModelCalculator.observation(ctmc, observations.get(lang), false));
        }
        return UnrootedTreeState.initFastState(tree, data, (CTMC)ctmc);
    }

    public static UnrootedTreeState fromPartialCoalescentState(PartialCoalescentState pcs) {
        return UnrootedTreeState.fromPartialCoalescentState(pcs, 0);
    }

    public static UnrootedTreeState fromPartialCoalescentState(PartialCoalescentState pcs, int generationIndex) {
        return new UnrootedTreeState(UnrootedTree.fromRooted(pcs.getFullCoalescentState()), pcs.getLeafLikelihoodModels(), pcs.getPriorDensity(), generationIndex);
    }

    @Deprecated
    public static UnrootedTreeState fromBrownianMotion(UnrootedTree nct, Dataset data, BrownianModel bm) {
        PartialCoalescentState pcs = PartialCoalescentState.initState(data, bm, false);
        return new UnrootedTreeState(nct, pcs.getLeafLikelihoodModels(), null, 0);
    }

    @Deprecated
    public static UnrootedTreeState fromCTMC(UnrootedTree nct, Dataset data, CTMC ctmc) {
        PartialCoalescentState pcs = PartialCoalescentState.initState(data, ctmc);
        return new UnrootedTreeState(nct, pcs.getLeafLikelihoodModels(), null, 0);
    }

    public double logLikelihood() {
        if (!Double.isNaN(this.loglikelihood)) {
            return this.loglikelihood;
        }
        this.loglikelihood = UnrootedTreeState.computeLogLikelihood(this.t, this.likelihoodModels);
        return this.loglikelihood;
    }

    public static double computeLogLikelihood(UnrootedTree t, Map<Taxon, LikelihoodModelCalculator> likelihoodModels) {
        LinkedHashMap<Taxon, LikelihoodModelCalculator> fringe = new LinkedHashMap<Taxon, LikelihoodModelCalculator>();
        ArrayList<Taxon> listOfTaxa = CollUtils.list(likelihoodModels.keySet());
        Collections.sort(listOfTaxa);
        for (Taxon taxon : listOfTaxa) {
            fringe.put(taxon, likelihoodModels.get(taxon));
        }
        Taxon aLeaf = listOfTaxa.get(0);
        HashSet<Taxon> closedList = new HashSet<Taxon>();
        fringe.remove(aLeaf);
        while (fringe.size() > 1) {
            Pair<Taxon, Taxon> pair = UnrootedTreeState.peek(t, fringe, closedList);
            Taxon l1 = pair.getFirst();
            Taxon l2 = pair.getSecond();
            Taxon parent = UnrootedTreeState.parent(t, l1, l2);
            double BL1 = t.branchLength(l1, parent);
            double BL2 = t.branchLength(l2, parent);
            LikelihoodModelCalculator LM1 = (LikelihoodModelCalculator)fringe.get(l1);
            LikelihoodModelCalculator LM2 = (LikelihoodModelCalculator)fringe.get(l2);
            LikelihoodModelCalculator combined = LM1.combine(LM1, LM2, BL1, BL2, false);
            fringe.remove(l1);
            fringe.remove(l2);
            fringe.put(parent, combined);
            closedList.add(l1);
            closedList.add(l2);
        }
        Taxon last = (Taxon)fringe.keySet().iterator().next();
        LikelihoodModelCalculator lastLM = (LikelihoodModelCalculator)fringe.get(last);
        LikelihoodModelCalculator aLeafLM = likelihoodModels.get(aLeaf);
        double halfBL = t.branchLength(last, aLeaf) / 2.0;
        double value = lastLM.combine(lastLM, aLeafLM, halfBL, halfBL, true).logLikelihood();
        return value;
    }

    private static Taxon parent(UnrootedTree t, Taxon l1, Taxon l2) {
        Set inter = CollUtils.inter(t.getTopology().nbrs(l1), t.getTopology().nbrs(l2));
        if (inter.size() != 1) {
            throw new RuntimeException();
        }
        return (Taxon)inter.iterator().next();
    }

    private static Pair<Taxon, Taxon> peek(UnrootedTree t, Map<Taxon, LikelihoodModelCalculator> fringe, Set<Taxon> closedList) {
        for (Taxon l1 : fringe.keySet()) {
            Taxon parent = null;
            ArrayList<Taxon> nbrs1 = CollUtils.list(t.getTopology().nbrs(l1));
            Collections.sort(nbrs1);
            for (Taxon taxon : nbrs1) {
                if (closedList.contains(taxon)) continue;
                parent = taxon;
                break;
            }
            ArrayList<Object> nbrs2 = CollUtils.list(t.getTopology().nbrs(parent));
            Collections.sort(nbrs2);
            for (Taxon taxon : nbrs2) {
                if (taxon.equals(l1) || !fringe.keySet().contains(taxon)) continue;
                return Pair.makePair(l1, taxon);
            }
        }
        throw new RuntimeException();
    }

    public String toString() {
        String result = this.t.toString() + "\n";
        result = result + "LogLikelihood: " + this.logLikelihood();
        return result;
    }

    private static double harmonicCombine(double v1, double v2, double x1, double x2) {
        return (x1 / v1 + x2 / v2) / (1.0 / v1 + 1.0 / v2);
    }

    private static double harmonic(double v1, double v2) {
        return 1.0 / (1.0 / v1 + 1.0 / v2);
    }

    public static double explicit4LeavesLikelihoodFormula() {
        int i;
        double[] x1 = new double[]{0.0, 0.9, 0.5};
        double[] x2 = new double[]{0.9, 1.0, 0.0};
        double[] x3 = new double[]{0.5, 0.0, 1.0};
        double[] x4 = new double[]{1.0, 0.5, 0.9};
        double v1 = 0.10341776;
        double v2 = 0.10207381;
        double v3 = 0.03380679;
        double v4 = 0.02367421;
        double vcenter = 0.09358231;
        double sum = 0.0;
        double bigv = UnrootedTreeState.harmonic(v1, v2) + UnrootedTreeState.harmonic(v3, v4) + vcenter;
        for (int i2 = 0; i2 < x1.length; ++i2) {
            sum += BrownianModelCalculator.logNormalDensity(0.0, UnrootedTreeState.harmonicCombine(v1, v2, x1[i2], x2[i2]) - UnrootedTreeState.harmonicCombine(v3, v4, x3[i2], x4[i2]), bigv);
        }
        double smallv = v1 + v2;
        for (i = 0; i < x1.length; ++i) {
            sum += BrownianModelCalculator.logNormalDensity(0.0, x1[i] - x2[i], smallv);
        }
        smallv = v3 + v4;
        for (i = 0; i < x1.length; ++i) {
            sum += BrownianModelCalculator.logNormalDensity(0.0, x3[i] - x4[i], smallv);
        }
        return sum;
    }

    public static double explicit3LeavesLikelihoodFormula(UnrootedTreeState state) {
        if (state.likelihoodModels.keySet().size() != 3) {
            throw new RuntimeException();
        }
        ArrayList<Taxon> leaves = new ArrayList<Taxon>();
        Taxon c = null;
        ArrayList<double[]> xs = new ArrayList<double[]>();
        for (Taxon l : state.getNonClockTree().getTopology().vertexSet()) {
            if (!state.likelihoodModels.keySet().contains(l)) {
                c = l;
                continue;
            }
            BrownianModelCalculator bm = (BrownianModelCalculator)state.likelihoodModels.get(l);
            leaves.add(l);
            xs.add(bm.message);
        }
        Taxon l1 = (Taxon)leaves.get(0);
        Taxon l2 = (Taxon)leaves.get(1);
        Taxon l3 = (Taxon)leaves.get(2);
        double v1 = state.getNonClockTree().branchLength(l1, c);
        double v2 = state.getNonClockTree().branchLength(l2, c);
        double v3 = state.getNonClockTree().branchLength(l3, c);
        double[] x1 = (double[])xs.get(0);
        double[] x2 = (double[])xs.get(1);
        double[] x3 = (double[])xs.get(2);
        double d23 = NumUtils.l2DistSquared(x2, x3);
        double d13 = NumUtils.l2DistSquared(x1, x3);
        double d12 = NumUtils.l2DistSquared(x1, x2);
        double sumvs = v1 * v2 + v1 * v3 + v2 * v3;
        double sumvds = v1 * d23 + v2 * d13 + v3 * d12;
        double p = x1.length;
        return -p * Math.log(Math.PI * 2) - p * Math.log(sumvs) / 2.0 - sumvds / sumvs / 2.0;
    }

    public static void main(String[] args) {
        HGDPDataset.path = "data/hgdp/hgdp.ie.phylip";
        BrownianModelCalculator.useVarianceTransform = true;
        HGDPDataset data = new HGDPDataset();
        UnrootedTreeState infileState = UnrootedTreeState.fromBrownianMotion(UnrootedTree.fromNewick(new File("data/hgdp/hgdp.ie.tree")), data, new BrownianModel(data.nSites(), 1.0));
        System.out.println("Code:" + infileState.logLikelihood());
    }

    public Map<Taxon, LikelihoodModelCalculator> getLikelihoodModels() {
        return Collections.unmodifiableMap(this.likelihoodModels);
    }
}

