/*
 * Decompiled with CFR 0.152.
 */
package scratch;

import fig.basic.IOUtils;
import fig.basic.UnorderedPair;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import ma.newick.NewickParser;
import ma.newick.ParseException;
import nuts.io.IO;
import nuts.util.Tree;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;

public class WalsEval {
    public static void main(String[] args) throws IOException, ParseException {
        if (args.length != 2) {
            System.err.println("<newick-format-tree> <wals-languages-file>");
            return;
        }
        NewickParser np = new NewickParser(IOUtils.openIn(args[0]));
        Tree<String> tree = np.parse();
        HashMap<String, String> lang2genus = new HashMap<String, String>();
        for (String line : IO.i(args[1])) {
            if (line.startsWith("wals")) continue;
            String[] fields = line.split("\\t");
            lang2genus.put(fields[0], fields[4]);
        }
        System.out.println("Purity: " + WalsEval.purity(tree, lang2genus));
    }

    public static <S, T> double purity(Tree<S> tree, Map<S, T> labels) {
        Map<UnorderedPair<S, S>, Tree<S>> lcas = WalsEval.lcas(tree);
        SummaryStatistics mainAvg = new SummaryStatistics();
        for (S leave : tree.getYield()) {
            SummaryStatistics subAvg = new SummaryStatistics();
            T cLabel = labels.get(leave);
            for (S other : tree.getYield()) {
                if (other.equals(leave) || !cLabel.equals(labels.get(other))) continue;
                subAvg.addValue(WalsEval.ratio(lcas.get(new UnorderedPair<S, S>(leave, other)), labels, cLabel));
            }
            if (subAvg.getN() <= 0L) continue;
            mainAvg.addValue(subAvg.getMean());
        }
        return mainAvg.getMean();
    }

    private static <S, T> double ratio(Tree<S> subtree, Map<S, T> labels, T label) {
        SummaryStatistics ratio = new SummaryStatistics();
        for (S leave : subtree.getYield()) {
            ratio.addValue(labels.get(leave).equals(label) ? 1.0 : 0.0);
        }
        if (Double.isNaN(ratio.getMean())) {
            System.out.println("error");
        }
        return ratio.getMean();
    }

    public static <S> Map<UnorderedPair<S, S>, Tree<S>> lcas(Tree<S> tree) {
        HashMap<UnorderedPair<S, S>, Tree<S>> result = new HashMap<UnorderedPair<S, S>, Tree<S>>();
        WalsEval._lcas(tree, result);
        return result;
    }

    private static <S> void _lcas(Tree<S> tree, Map<UnorderedPair<S, S>, Tree<S>> result) {
        for (int i = 0; i < tree.getChildren().size(); ++i) {
            for (int j = i + 1; j < tree.getChildren().size(); ++j) {
                for (S first : tree.getChildren().get(i).getYield()) {
                    for (S second : tree.getChildren().get(j).getYield()) {
                        result.put(new UnorderedPair<S, S>(first, second), tree);
                    }
                }
            }
        }
        for (Tree<S> child : tree.getChildren()) {
            WalsEval._lcas(child, result);
        }
    }
}

