/*
 * Decompiled with CFR 0.152.
 */
package pty.eval;

import fig.basic.IOUtils;
import fig.basic.UnorderedPair;
import goblin.Taxon;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import ma.newick.NewickParser;
import ma.newick.ParseException;
import nuts.io.IO;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Tree;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.UnrootedTree;
import pty.eval.SymmetricDiff;

public class Purity {
    public static <S, T> double purity(Tree<S> tree, Map<S, T> labels) {
        return Purity.purity(SymmetricDiff.clades(Arbre.tree2Arbre(tree)), labels);
    }

    public static <T> double purity(UnrootedTree t, Map<Taxon, T> labels) {
        return Purity.purity(t.clades(), labels);
    }

    public static <T> double recall(Set<T> reference, Set<T> guess) {
        return CollUtils.inter(reference, guess).size() / reference.size();
    }

    public static <S, T> double purity(Set<Set<S>> clades, Map<S, T> labels) {
        Set<S> leaves = Purity.leaves(clades);
        SummaryStatistics mainAvg = new SummaryStatistics();
        for (S leave : leaves) {
            SummaryStatistics subAvg = new SummaryStatistics();
            T cLabel = labels.get(leave);
            if (cLabel != null) {
                for (S other : leaves) {
                    T otherLabel = labels.get(other);
                    if (other.equals(leave) || otherLabel == null || !cLabel.equals(otherLabel)) continue;
                    subAvg.addValue(Purity.ratio(Purity.lca(leave, other, clades), labels, cLabel));
                }
            }
            if (subAvg.getN() <= 0L) continue;
            mainAvg.addValue(subAvg.getMean());
        }
        return mainAvg.getMean();
    }

    private static <S> Set<S> lca(S leave1, S leave2, Set<Set<S>> clades) {
        Set<S> result = null;
        for (Set<S> clade : clades) {
            if (result != null && clade.size() >= result.size() || !clade.contains(leave1) || !clade.contains(leave2)) continue;
            result = clade;
        }
        return result;
    }

    private static <S> Set<S> leaves(Set<Set<S>> clades) {
        HashSet<S> result = new HashSet<S>();
        for (Set<S> clade : clades) {
            if (clade.size() != 1) continue;
            result.addAll(clade);
        }
        return result;
    }

    public static <S, T> Map<T, Set<S>> partitionsUsedForEval(Tree<S> t, Map<S, T> labels) {
        HashSet<S> leaves = new HashSet<S>(t.getYield());
        HashMap<S, T> labelsCopy = new HashMap<S, T>(labels);
        labelsCopy.keySet().retainAll(leaves);
        Map<T, Set<S>> inverse = CollUtils.invert(labelsCopy);
        Iterator<T> iter = inverse.keySet().iterator();
        while (iter.hasNext()) {
            if (inverse.get(iter.next()).size() != 1) continue;
            iter.remove();
        }
        return inverse;
    }

    public static void main(String[] args) throws IOException, ParseException {
        if (args.length != 2) {
            System.err.println("Evaluates the purity of a newick linguistic tree against");
            System.err.println("genus annotation in the file language.tab in wals_data");
            System.err.println("<newick-format-tree> <path-to-wals-languages-file>");
            return;
        }
        NewickParser np = new NewickParser(IOUtils.openIn(args[0]));
        Tree<String> tree = np.parse();
        HashMap<String, String> lang2genus = new HashMap<String, String>();
        for (String line : IO.i(args[1])) {
            if (line.startsWith("wals")) continue;
            String[] fields = line.split("\\t");
            lang2genus.put(fields[0], fields[4]);
        }
        System.out.println("Purity: " + Purity.purity(tree, lang2genus));
    }

    private static <S, T> double ratio(Set<S> clade, Map<S, T> labels, T label) {
        SummaryStatistics ratio = new SummaryStatistics();
        for (S leaf : clade) {
            T leafLabel = labels.get(leaf);
            if (leafLabel == null) {
                ratio.addValue(0.0);
                continue;
            }
            ratio.addValue(leafLabel.equals(label) ? 1.0 : 0.0);
        }
        if (Double.isNaN(ratio.getMean())) {
            throw new RuntimeException();
        }
        return ratio.getMean();
    }

    public static <S> Map<UnorderedPair<S, S>, Tree<S>> lcas(Tree<S> tree) {
        HashMap<UnorderedPair<S, S>, Tree<S>> result = new HashMap<UnorderedPair<S, S>, Tree<S>>();
        Purity._lcas(tree, result);
        return result;
    }

    private static <S> void _lcas(Tree<S> tree, Map<UnorderedPair<S, S>, Tree<S>> result) {
        for (int i = 0; i < tree.getChildren().size(); ++i) {
            for (int j = i + 1; j < tree.getChildren().size(); ++j) {
                for (S first : tree.getChildren().get(i).getYield()) {
                    for (S second : tree.getChildren().get(j).getYield()) {
                        result.put(new UnorderedPair<S, S>(first, second), tree);
                    }
                }
            }
        }
        for (Tree<S> child : tree.getChildren()) {
            Purity._lcas(child, result);
        }
    }
}

