/*
 * Decompiled with CFR 0.152.
 */
package ev.ex;

import fig.basic.Pair;
import fig.basic.UnorderedPair;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import nuts.util.Counter;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.mcmc.UnrootedTreeState;
import pty.smc.PartialCoalescentState;
import pty.smc.models.DiscreteModelCalculator;
import pty.smc.models.LikelihoodModelCalculator;

public class NJPState {
    public final PartialCoalescentState pcs;
    public final Counter<UnorderedPair<Taxon, Taxon>> pairwiseDistances;
    private UnrootedTree _njTree = null;
    private UnrootedTree _fullNJTree = null;
    private double _gamma = Double.NaN;

    public static NJPState init(PartialCoalescentState initPCS) {
        return new NJPState(initPCS, NJPState.refreshPairwiseDistance(new Counter<UnorderedPair<Taxon, Taxon>>(), initPCS));
    }

    private NJPState(PartialCoalescentState pcs, Counter<UnorderedPair<Taxon, Taxon>> pairwiseDistances) {
        this.pcs = pcs;
        this.pairwiseDistances = pairwiseDistances;
    }

    public UnrootedTree getNJTree() {
        if (this._njTree != null) {
            return this._njTree;
        }
        this._njTree = null;
        double h = this.pcs.topHeight();
        for (int i = 0; i < this.pcs.nRoots(); ++i) {
            double min;
            Taxon cur = this.pcs.getRoots().get((int)i).getContents().nodeIdentifier;
            Set<Taxon> nbs = this._njTree.getTopology().nbrs(cur);
            if (nbs.size() != 1) {
                throw new RuntimeException();
            }
            UnorderedPair<Taxon, Taxon> edge = new UnorderedPair<Taxon, Taxon>(cur, CollUtils.pick(nbs));
            double curBL = this._njTree.branchLength(edge);
            if (!(curBL < (min = h - this.pcs.getHeight(i)))) continue;
            this._njTree.changeBranchLength(edge, min);
        }
        return this._njTree;
    }

    public UnrootedTree getFullNJTree() {
        if (this._fullNJTree != null) {
            return this._fullNJTree;
        }
        if (this.pcs.isFinalState()) {
            this._fullNJTree = UnrootedTree.fromRooted(this.pcs.getFullCoalescentState());
        } else {
            UnrootedTree result = this.getNJTree();
            HashSet toSub = CollUtils.set();
            for (int r = 0; r < this.pcs.nRoots(); ++r) {
                toSub.add(this.pcs.getSubtree(r).topology().getContents());
            }
            Taxon reroot = null;
            for (Taxon cur : result.getTopology().vertexSet()) {
                if (toSub.contains(cur)) continue;
                reroot = cur;
                break;
            }
            String curStr = result.toNewick(reroot);
            for (int r = 0; r < this.pcs.nRoots(); ++r) {
                RootedTree curRooted = this.pcs.getSubtree(r);
                if (curRooted.topology().isLeaf()) continue;
                String rootTaxon = curRooted.topology().getContents().toString();
                String subTreeStr = RootedTree.Util.toNewick(curRooted);
                curStr = curStr.replace(rootTaxon + ":", subTreeStr.replace(";", "") + ":");
            }
            this._fullNJTree = UnrootedTree.fromNewick(curStr);
        }
        if (!this._fullNJTree.leavesSet().equals(this.pcs.getObservations().observations().keySet())) {
            throw new RuntimeException();
        }
        return this._fullNJTree;
    }

    public String toString() {
        return this.getFullNJTree().toString();
    }

    public double gamma() {
        if (!Double.isNaN(this._gamma)) {
            return this._gamma;
        }
        if (this.pcs.isFinalState()) {
            this._gamma = this.pcs.logLikelihood();
        } else {
            UnrootedTree urt = this.getNJTree();
            UnrootedTreeState ncts = new UnrootedTreeState(urt, this.pcs.getTopLevelModelCalculators(), null, 0);
            this._gamma = ncts.logLikelihood() * this.prior(urt);
        }
        return this._gamma;
    }

    private double prior(UnrootedTree urt) {
        IO.warnOnce("WARNING: assuming a rate of one in NJPState");
        double result = 0.0;
        for (UnorderedPair<Taxon, Taxon> edge : urt.edges()) {
            result += Sampling.exponentialLogDensity(1.0, urt.branchLength(edge));
        }
        return result;
    }

    public NJPState coalesce(Taxon left, Taxon right, double delta) {
        Map<Taxon, Integer> taxa = this.pcs.rootsTaxa();
        return this.coalesce(taxa.get(left), taxa.get(right), delta);
    }

    public NJPState coalesce(int left, int right, double delta) {
        PartialCoalescentState newPCSS = this.pcs.coalesce(left, right, delta, 0.0, 0.0);
        return new NJPState(newPCSS, NJPState.refreshPairwiseDistance(this.pairwiseDistances, newPCSS));
    }

    public static Counter<UnorderedPair<Taxon, Taxon>> refreshPairwiseDistance(Counter<UnorderedPair<Taxon, Taxon>> model, PartialCoalescentState pcs) {
        Counter<UnorderedPair<Taxon, Taxon>> result = new Counter<UnorderedPair<Taxon, Taxon>>();
        Map<Taxon, LikelihoodModelCalculator> topLevelMCs = pcs.getTopLevelModelCalculators();
        ArrayList<Taxon> taxa = CollUtils.list(topLevelMCs.keySet());
        for (int i = 0; i < taxa.size(); ++i) {
            for (int j = i + 1; j < taxa.size(); ++j) {
                Taxon t1 = (Taxon)taxa.get(i);
                Taxon t2 = (Taxon)taxa.get(j);
                UnorderedPair<Taxon, Taxon> key = new UnorderedPair<Taxon, Taxon>(t1, t2);
                if (model.keySet().contains(key)) {
                    result.setCount(key, model.getCount(key));
                    continue;
                }
                Pair<Double, Double> ss = DiscreteModelCalculator.k2pDistanceSuffStat((DiscreteModelCalculator)topLevelMCs.get(t1), (DiscreteModelCalculator)topLevelMCs.get(t2));
            }
        }
        return result;
    }
}

