/*
 * Decompiled with CFR 0.152.
 */
package pty;

import fig.basic.IOUtils;
import fig.basic.Interner;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.UnorderedPair;
import fig.exec.Execution;
import goblin.Taxon;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import nuts.io.IO;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.EasyFormat;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.UnrootedTree;
import pty.eval.Purity;
import pty.io.WalsAnn;
import pty.io.WalsDataset;

public class SumT
implements Runnable {
    @Option
    public String treeFile = null;
    @Option
    public boolean evaluatePurity = true;
    @Option
    public String purityGroupsFile = "";
    @Option
    public String fileFilter = "sample";
    @Option
    public int maxFileIndex = Integer.MAX_VALUE;
    @Option(gloss="fraction of files to be skipped after sorting and truncation")
    public double burnin = 0.1;
    @Option(gloss="number to skip between each tree processed")
    public int thinning = 0;
    @Option
    public int evalFrequency = 0;
    private List<UnrootedTree> trees = new ArrayList<UnrootedTree>();
    private Map<Taxon, String> reference = null;
    private Counter<Set<Set<Taxon>>> topoCounter = new Counter();
    private Counter<Set<Taxon>> cladeCounter = new Counter();
    int loadedTreeCount = 0;

    public void loadTrees(File file) {
        LogInfo.track("Reading samples under " + file);
        LinkedList<File> subFiles = new LinkedList<File>(IO.locate(file, IO.suffixFilter("gz", "newick")));
        Iterator iter = subFiles.iterator();
        while (iter.hasNext()) {
            File cur = (File)iter.next();
            if (cur.getName().contains(this.fileFilter) && !SumT.indexExceedLimit(cur, this.maxFileIndex)) continue;
            iter.remove();
        }
        Collections.sort(subFiles, new Comparator<File>(){

            @Override
            public int compare(File o1, File o2) {
                return new Integer(SumT.getIndex(o1)).compareTo(SumT.getIndex(o2));
            }
        });
        int nFile = subFiles.size();
        int nToBurnIn = (int)(this.burnin * (double)nFile);
        LogInfo.logsForce("Skipping " + nToBurnIn + " files (burned-in)");
        iter = subFiles.iterator();
        for (int i = 0; i < nToBurnIn; ++i) {
            iter.next();
            iter.remove();
        }
        LogInfo.logsForce("Processing files with thinning: " + this.thinning);
        int nTreeReadSoFar = 1;
        for (File subFile : subFiles) {
            BufferedReader reader = IOUtils.openInHard(subFile);
            String line = null;
            try {
                while ((line = reader.readLine()) != null) {
                    if (line.charAt(0) == '(') {
                        if (this.thinning != 0 && this.loadedTreeCount++ % this.thinning != 0) continue;
                        this.trees.add(UnrootedTree.fromNewick(line));
                        continue;
                    }
                    if (line.equals("")) continue;
                    LogInfo.warning("Warning skipped line in " + file.getAbsolutePath() + "\n\t" + line);
                }
                reader.close();
                LogInfo.logs("" + nTreeReadSoFar++ + "/" + subFiles.size() + ", " + this.trees.size() + " trees read so far");
            }
            catch (IOException e) {
                LogInfo.warning("Problems with file:" + subFile);
            }
        }
        LogInfo.end_track();
    }

    private static boolean indexExceedLimit(File cur, int maxFileIndex) {
        return SumT.getIndex(cur) > maxFileIndex;
    }

    public static void main(String[] args) {
        Execution.monitor = true;
        Execution.makeThunk = false;
        Execution.create = true;
        Execution.useStandardExecPoolDirStrategy = true;
        Execution.run(args, new SumT(), "wals", WalsDataset.class);
    }

    @Override
    public void run() {
        this.consensus();
    }

    public UnrootedTree consensus() {
        LogInfo.logs("");
        if (this.evaluatePurity) {
            if (this.purityGroupsFile.equals("")) {
                WalsDataset ds = WalsDataset.getPreprocessedCorpus();
                this.reference = ds.getReferenceClusters();
            } else {
                this.reference = SumT.parseReferences(this.purityGroupsFile);
            }
        }
        this.loadTrees(new File(this.treeFile));
        LogInfo.logsForce("" + this.trees.size() + " trees read");
        LogInfo.track("Creating clade counters");
        CladeInterner interner = new CladeInterner();
        Counter<UnorderedPair<Taxon, Taxon>> pairwiseDistances = new Counter<UnorderedPair<Taxon, Taxon>>();
        SummaryStatistics hstat = new SummaryStatistics();
        int i = 1;
        for (UnrootedTree t : this.trees) {
            hstat.addValue(this.height(t));
            pairwiseDistances.incrementAll(t.pairwiseDistances());
            Set<Set<Taxon>> clades = t.clades();
            if (!this.topoCounter.containsKey(clades)) {
                clades = interner.internClades(clades);
            }
            this.topoCounter.incrementCount(clades, 1.0);
            for (Set<Taxon> clade : clades) {
                this.cladeCounter.incrementCount(clade, 1.0);
            }
            LogInfo.logs("" + i++ + "/" + this.trees.size() + " [" + this.cladeCounter.size() + " unique clades]");
            if (!this.evaluatePurity || this.evalFrequency == 0 || i % this.evalFrequency != 0) continue;
            LogInfo.logsForce("Purity of current sample (" + i + "): " + Purity.purity(t.clades(), this.reference));
            LogInfo.logsForce("MBR Purity so far: " + Purity.purity(SumT.findMin(this.topoCounter, this.cladeCounter), this.reference));
        }
        LogInfo.end_track();
        LogInfo.logs("Approx reconstructed dist:" + hstat.getMean());
        double norm = this.trees.size();
        for (UnorderedPair key : pairwiseDistances.keySet()) {
            pairwiseDistances.setCount(key, pairwiseDistances.getCount(key) / norm);
        }
        String distanceFile = Execution.getFile("distances");
        IO.writeToDisk(distanceFile, SumT.phylipDistanceMatrix(pairwiseDistances));
        Set<Set<Taxon>> recontructedClades = null;
        if (this.trees.size() == 1) {
            recontructedClades = this.trees.get(0).clades();
        } else {
            LogInfo.track((Object)"MBR reconstruction", false);
            LogInfo.logs("Number of unique topologies: " + this.topoCounter.size());
            recontructedClades = SumT.findMin(this.topoCounter, this.cladeCounter);
            if (this.trees.size() != (int)this.topoCounter.totalCount()) {
                throw new RuntimeException(this.trees.size() + " vs " + (int)this.topoCounter.totalCount());
            }
            LogInfo.end_track();
        }
        LogInfo.logsForce("MBR clades: " + recontructedClades);
        int rk = 1;
        LogInfo.track((Object)"Clades:", true);
        for (Set<Taxon> clade : this.cladeCounter) {
            if (!recontructedClades.contains(clade) || clade.size() <= 1 || clade.size() >= this.nLeaves(recontructedClades) - 1) continue;
            LogInfo.logs("" + rk++ + "\t" + this.cladeCounter.getCount(clade) / (double)this.trees.size() + "\t" + clade);
        }
        LogInfo.end_track();
        if (this.evaluatePurity) {
            LogInfo.track((Object)"Evaluation", true);
            LogInfo.logs("Purity: " + Purity.purity(recontructedClades, this.reference));
            LogInfo.end_track();
        }
        UnrootedTree result = null;
        for (UnrootedTree nct : this.trees) {
            if (!nct.clades().equals(recontructedClades)) continue;
            result = nct;
            IO.writeToDisk(new File(Execution.getFile("consensus.newick")), nct.toNewick());
            break;
        }
        if (this.evaluatePurity) {
            String recon = IO.call("/bin/bash cmds/phylipnj.bash " + distanceFile);
            UnrootedTree pairRecon = UnrootedTree.fromNewick(recon);
            LogInfo.logs("Purity (from distances): " + Purity.purity(pairRecon.clades(), this.reference));
        }
        return result;
    }

    private double height(UnrootedTree t) {
        double min = Double.POSITIVE_INFINITY;
        for (Taxon lang : t.getTopology().vertexSet()) {
            double cur;
            if (t.leavesSet().contains(lang) || !((cur = this.meanD(lang, t)) < min)) continue;
            min = cur;
        }
        return min;
    }

    private double meanD(Taxon lang, UnrootedTree t) {
        double mean = 0.0;
        for (Taxon leaf : t.leaves()) {
            mean += t.totalBranchLengthDistance(leaf, lang);
        }
        return mean / (double)t.leaves().size();
    }

    public static String phylipDistanceMatrix(Counter<UnorderedPair<Taxon, Taxon>> pairwiseDistances) {
        StringBuilder result = new StringBuilder();
        HashSet langsSet = CollUtils.set();
        for (UnorderedPair<Taxon, Taxon> key : pairwiseDistances.keySet()) {
            langsSet.add(key.getFirst());
            langsSet.add(key.getSecond());
        }
        ArrayList langs = new ArrayList(langsSet);
        result.append(langs.size() + "\n");
        for (Taxon lang : langs) {
            result.append(WalsAnn.cleanForPhylip(lang.toString()));
            for (Taxon lang2 : langs) {
                result.append("  " + EasyFormat.fmt(pairwiseDistances.getCount(new UnorderedPair<Taxon, Taxon>(lang, lang2))));
            }
            result.append('\n');
        }
        return result.toString();
    }

    public static Map<Taxon, String> parseReferences(String purityGroupsFile2) {
        HashMap<Taxon, String> result = CollUtils.map();
        for (String line : IO.i(purityGroupsFile2)) {
            String[] fields = line.split("\\s+");
            result.put(new Taxon(fields[0]), fields[1]);
        }
        return result;
    }

    private int nLeaves(Set<Set<Taxon>> clades) {
        int max = 0;
        for (Set<Taxon> clade : clades) {
            if (clade.size() <= max) continue;
            max = clade.size();
        }
        return max;
    }

    public static Set<Set<Taxon>> findMin(Counter<Set<Set<Taxon>>> topoCounter, Counter<Set<Taxon>> cladeCounter) {
        Set<Set<Taxon>> argmin = null;
        double minValue = Double.POSITIVE_INFINITY;
        boolean i = true;
        for (Set<Set<Taxon>> currentElt : topoCounter.keySet()) {
            int currentValue = SumT.sumOfSymmLosses(currentElt, cladeCounter, (int)topoCounter.totalCount());
            if (!((double)currentValue < minValue)) continue;
            minValue = currentValue;
            argmin = currentElt;
        }
        return argmin;
    }

    private static int sumOfSymmLosses(Set<Set<Taxon>> currentElt, Counter<Set<Taxon>> cladeCounts, int nTrees) {
        if (nTrees == 1) {
            throw new RuntimeException();
        }
        int nCladesPerTree = currentElt.size();
        int result = (nTrees - 1) * nCladesPerTree;
        for (Set<Taxon> clade : currentElt) {
            result -= (int)cladeCounts.getCount(clade) - 1;
        }
        return 2 * result;
    }

    public static int getIndex(File f) {
        try {
            return Integer.parseInt(f.getName().replaceAll("[^0-9]*", ""));
        }
        catch (Exception e) {
            return -1;
        }
    }

    private static class CladeInterner {
        private Interner<Taxon> lInterner = new Interner();
        private Interner<Set<Taxon>> sInterner = new Interner();

        private CladeInterner() {
        }

        public Set<Set<Taxon>> internClades(Set<Set<Taxon>> clades) {
            HashSet<Set<Taxon>> result = new HashSet<Set<Taxon>>(clades.size());
            for (Set<Taxon> clade : clades) {
                result.add(this.internClade(clade));
            }
            return result;
        }

        private Set<Taxon> internClade(Set<Taxon> clade) {
            if (this.sInterner.isCanonical(clade)) {
                return clade;
            }
            HashSet<Taxon> result = new HashSet<Taxon>(clade.size());
            for (Taxon lang : clade) {
                result.add(this.lInterner.intern(lang));
            }
            return this.sInterner.intern(result);
        }
    }
}

