/*
 * Decompiled with CFR 0.152.
 */
package conifer.clock;

import conifer.clock.ClockTree;
import ev.ex.TreeGenerators;
import fig.basic.Pair;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.util.Arbre;
import nuts.util.MathUtils;
import pty.RootedTree;
import pty.smc.PartialCoalescentState;

public class ClockTreeUtils {
    public static ClockTree fromRooted(RootedTree t) {
        Map<Taxon, Set<Taxon>> rootedClades = Arbre.leavesMap_Efficient(t.topology());
        Map<Taxon, Double> heights = ClockTreeUtils.heights(t);
        HashMap<Set<Taxon>, Double> convertedHeights = new HashMap<Set<Taxon>, Double>();
        for (Taxon taxon : heights.keySet()) {
            convertedHeights.put(rootedClades.get(taxon), heights.get(taxon));
        }
        return new ClockTree(convertedHeights);
    }

    public static Map<Taxon, Double> heights(RootedTree t) {
        HashMap<Taxon, Double> result = new HashMap<Taxon, Double>();
        for (Arbre<Taxon> leaf : t.topology().leaves()) {
            double curHeight = 0.0;
            Arbre<Taxon> cur = leaf;
            do {
                Taxon curTax;
                if (result.get(curTax = cur.getContents()) != null && !MathUtils.close((Double)result.get(curTax), curHeight)) {
                    throw new RuntimeException("Seems like the argument is not a clock tree");
                }
                result.put(curTax, curHeight);
                Double curBL = t.branchLengths().get(curTax);
                if (curBL == null) continue;
                curHeight += curBL.doubleValue();
            } while ((cur = cur.getParentEasy()) != null);
        }
        return result;
    }

    public static RootedTree toRooted(ClockTree rct) {
        PartialCoalescentState curPCS = PartialCoalescentState.initState(new ArrayList<Taxon>(rct.leaves()), true);
        while (curPCS.nRoots() > 1) {
            curPCS = ClockTreeUtils.constrainedNext(curPCS, rct);
        }
        return curPCS.getFullCoalescentState();
    }

    public static PartialCoalescentState constrainedNext(PartialCoalescentState current, ClockTree rct) {
        int nForests = current.nRoots();
        if (nForests == 1) {
            return null;
        }
        double min = Double.POSITIVE_INFINITY;
        Pair<Integer, Integer> argmin = null;
        for (int i = 0; i < nForests; ++i) {
            for (int j = i + 1; j < nForests; ++j) {
                HashSet<Taxon> merged = new HashSet<Taxon>();
                merged.addAll(current.rootedClade(i));
                merged.addAll(current.rootedClade(j));
                Double curH = rct.height(merged);
                if (curH == null || !(curH < min)) continue;
                min = curH;
                argmin = Pair.makePair(i, j);
            }
        }
        double delta = min - current.topHeight();
        return current.coalesce((Integer)argmin.getFirst(), (Integer)argmin.getSecond(), delta, 0.0, 0.0);
    }

    public static void main(String[] args) {
        Random rand = new Random(1L);
        RootedTree rt = TreeGenerators.sampleCoalescent(rand, 100, false);
        ClockTree rct = ClockTreeUtils.fromRooted(rt);
        RootedTree rt2 = ClockTreeUtils.toRooted(rct);
        ClockTree rct2 = ClockTreeUtils.fromRooted(rt2);
        if (!rct.getHeights().keySet().equals(rct2.getHeights().keySet())) {
            throw new RuntimeException();
        }
        for (Set<Taxon> key : rct.getHeights().keySet()) {
            if (MathUtils.close(rct.height(key), rct2.height(key))) continue;
            throw new RuntimeException("" + rct.height(key) + " vs " + rct2.height(key));
        }
        Random pr = new Random(1L);
        System.out.println("Test 1 passed");
        System.out.println(rct);
        for (int i = 0; i < 100; ++i) {
        }
    }
}

