/*
 * Decompiled with CFR 0.152.
 */
package ev;

import ev.PairAlignCognateSet;
import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.Parallelizer;
import fig.basic.UnorderedPair;
import fig.exec.Execution;
import goblin.CognateId;
import goblin.CognateSet;
import goblin.DerivationTree;
import goblin.HLFeatureExtractor;
import goblin.HLParams;
import goblin.HLParamsUpdater;
import goblin.Taxon;
import java.io.File;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import ma.MSAPoset;
import nuts.io.IO;
import nuts.math.GMFct;
import nuts.math.Graph;
import nuts.math.Graphs;
import nuts.math.HashGraph;
import nuts.math.MutableGraph;
import nuts.math.TreeSumProd;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import pepper.Encodings;

public class GreedyCognateInference
implements Runnable {
    @Option
    public String pathToCognates = "/Users/bouchard/w/pepper/state/execs/537.exec/snapshot0001.CognateSet";
    @Option
    public String pathToMSADir = "/Users/bouchard/Documents/workspace/evolvere/state/execs/494.exec/alignments-0";
    @Option
    public String pathToParams = "/Users/bouchard/w/pepper/state/execs/537.exec/reest-proposal0000.weights";
    @Option
    public int saveInterval = 100;
    @Option
    public int nThreads = 1;
    public static HLFeatureExtractor featureExtractor = new HLFeatureExtractor();
    private PairAlignCognateSet.CognateDataset data;
    private Arbre<Taxon> globalTopo;
    private Graph<Taxon> globalTopoGraph;
    private Map<PairAlignCognateSet.Concept, MSAPoset> globalAligns;
    private HLParams params;
    private Map<Taxon, Taxon> parentPointers;
    private Set<Taxon> leaves;
    private Map<PairAlignCognateSet.Concept, Map<MSAPoset.Column, Site>> sites;
    private Encodings enc;
    private Map<PairAlignCognateSet.Concept, Map<Taxon, Integer>> currentPartitions = CollUtils.map();
    private Counter<Merger> mergers = new Counter();
    private Map<PairAlignCognateSet.Concept, Set<Merger>> mergerMap = CollUtils.map();

    public static void main(String[] args) {
        IO.run(args, new GreedyCognateInference(), "Feat", featureExtractor, "enc", Encodings.class);
    }

    public String toString(Merger m) {
        MSAPoset currentAlign = this.globalAligns.get(m.c);
        return "" + m.blocks.getFirst() + " and " + m.blocks.getSecond() + " @ " + m.c + "\n" + "Result:\n" + currentAlign.toString(m.union());
    }

    private double logLikelihoodRatio(Merger merger) {
        Set<Taxon> merged = merger.union();
        return this.logLikelihood(merger.c, merged) - this.logLikelihood(merger.c, (Set)merger.blocks.getFirst()) - this.logLikelihood(merger.c, (Set)merger.blocks.getSecond());
    }

    private Set<Taxon> coverage(PairAlignCognateSet.Concept c) {
        HashSet<Taxon> result = CollUtils.set(this.data.coverage(c));
        result.retainAll(this.leaves);
        return result;
    }

    @Override
    public void run() {
        try {
            this.enc = Encodings.realEncoding();
            CognateSet cs = CognateSet.restoreCognateSet(this.pathToCognates);
            this.data = new PairAlignCognateSet.CognateDataset(cs);
            this.globalTopo = GreedyCognateInference.unionTree(cs);
            this.globalTopoGraph = new HashGraph<Taxon>(Arbre.arbre2Tree(this.globalTopo));
            this.globalAligns = GreedyCognateInference.readMSAPosets(new File(this.pathToMSADir));
            this.params = this.loadParams(this.pathToParams);
            this.parentPointers = Arbre.parents(this.globalTopo);
            this.leaves = CollUtils.set(this.globalTopo.leaveContents());
            this.sites = CollUtils.map();
            for (PairAlignCognateSet.Concept c : this.data.concepts()) {
                this.sites.put(c, this.createSites(this.globalAligns.get(c)));
            }
            for (PairAlignCognateSet.Concept c : this.data.concepts()) {
                this.currentPartitions.put(c, this.singletons(this.coverage(c)));
            }
            LogInfo.track("Initialization");
            Parallelizer<PairAlignCognateSet.Concept> parallelizer = new Parallelizer<PairAlignCognateSet.Concept>(this.nThreads);
            parallelizer.setPrimaryThread();
            parallelizer.process(CollUtils.list(this.data.concepts()), new Parallelizer.Processor<PairAlignCognateSet.Concept>(){

                @Override
                public void process(PairAlignCognateSet.Concept c, int _i, int _n, boolean log) {
                    if (_i % 10 == 0) {
                        LogInfo.logs("Processing concept " + _i + "/" + _n);
                    }
                    ArrayList langs = CollUtils.list(GreedyCognateInference.this.coverage(c));
                    for (int i = 0; i < langs.size(); ++i) {
                        for (int j = i + 1; j < langs.size(); ++j) {
                            GreedyCognateInference.this.add(c, Collections.singleton(langs.get(i)), Collections.singleton(langs.get(j)));
                        }
                    }
                }
            });
            LogInfo.end_track();
            this.eval();
            int i = 0;
            int nMergers = this.nMergers();
            while (this.mergers.size() > 0) {
                LogInfo.track((Object)("Merge operation " + i++ + "/" + nMergers), true);
                Merger best = this.mergers.argMax();
                LogInfo.logs("Best merger: " + this.toString(best));
                LogInfo.logs("Best merger log likelihood ratio: " + this.mergers.max());
                this.apply(best);
                this.updateMergers(best);
                this.eval();
                if (i % this.saveInterval == 0) {
                    this.saveState(i);
                }
                LogInfo.end_track();
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private HLParams loadParams(String pathToParams) throws Exception {
        Counter ws = HLParamsUpdater.restoreCounter(pathToParams);
        return HLParams.createHLParamsFromWeights(this.enc, this.globalTopoGraph.vertexSet(), featureExtractor, ws, false, this.nThreads);
    }

    private void saveState(int i) {
        try {
            ObjectOutputStream out = IOUtils.openBinOutHard(Execution.getFile("iteration" + i + ".currentPartition"));
            out.writeObject(this.currentPartitions);
            out.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private int nMergers() {
        int total = 0;
        for (PairAlignCognateSet.Concept c : this.data.concepts()) {
            total += this.coverage(c).size() - 1;
        }
        return total;
    }

    private void updateMergers(final Merger best) {
        this.remove(best);
        HashSet remainingBlocks = CollUtils.set();
        for (Merger m : this.mergerMap.get(best.c)) {
            remainingBlocks.add(m.blocks.getFirst());
            remainingBlocks.add(m.blocks.getSecond());
        }
        remainingBlocks.remove(best.blocks.getFirst());
        remainingBlocks.remove(best.blocks.getSecond());
        final Set<Taxon> union = best.union();
        Parallelizer<Set<Taxon>> parallelizer = new Parallelizer<Set<Taxon>>(this.nThreads);
        parallelizer.setPrimaryThread();
        parallelizer.process(CollUtils.list(remainingBlocks), new Parallelizer.Processor<Set<Taxon>>(){

            @Override
            public void process(Set<Taxon> rem, int _i, int _n, boolean log) {
                GreedyCognateInference.this.add(best.c, union, rem);
            }
        });
        for (Merger m : CollUtils.set((Collection)this.mergerMap.get(best.c))) {
            if (!m.hasBlock((Set)best.blocks.getFirst()) && !m.hasBlock((Set)best.blocks.getSecond())) continue;
            this.remove(m);
        }
    }

    private void apply(Merger best) {
        Map<Taxon, Integer> partitions = this.currentPartitions.get(best.c);
        int newId = partitions.get(CollUtils.pick((Collection)best.blocks.getFirst()));
        for (Taxon lang : (Set)best.blocks.getFirst()) {
            if (partitions.get(lang) == newId) continue;
            throw new RuntimeException();
        }
        for (Taxon lang : (Set)best.blocks.getSecond()) {
            partitions.put(lang, newId);
        }
    }

    private void remove(Merger m) {
        this.mergers.removeKey(m);
        this.mergerMap.get(m.c).remove(m);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void add(PairAlignCognateSet.Concept c, Set<Taxon> s1, Set<Taxon> s2) {
        Merger m = new Merger(c, s1, s2);
        double logLikelihoodRatio = this.logLikelihoodRatio(m);
        GreedyCognateInference greedyCognateInference = this;
        synchronized (greedyCognateInference) {
            this.mergers.setCount(m, logLikelihoodRatio);
            Set<Merger> current = this.mergerMap.get(c);
            if (current == null) {
                current = CollUtils.set();
                this.mergerMap.put(c, current);
            }
            current.add(m);
        }
    }

    private void eval() {
        LogInfo.logs("Rand index: " + this.randIndex());
    }

    private double randIndex() {
        double agree = 0.0;
        double disagree = 0.0;
        for (PairAlignCognateSet.Concept c : this.data.concepts()) {
            ArrayList<Taxon> langs = CollUtils.list(this.coverage(c));
            for (int i = 0; i < langs.size(); ++i) {
                for (int j = i + 1; j < langs.size(); ++j) {
                    boolean sameInGuess;
                    boolean sameInGold = this.data.getCognateGroup(c, (Taxon)langs.get(i)) == this.data.getCognateGroup(c, (Taxon)langs.get(j));
                    boolean bl = sameInGuess = this.getCognateGroup(c, (Taxon)langs.get(i)) == this.getCognateGroup(c, (Taxon)langs.get(j));
                    if (sameInGold == sameInGuess) {
                        agree += 1.0;
                        continue;
                    }
                    disagree += 1.0;
                }
            }
        }
        return agree / (agree + disagree);
    }

    public int getCognateGroup(PairAlignCognateSet.Concept c, Taxon lang) {
        return this.currentPartitions.get(c).get(lang);
    }

    private Map<Taxon, Integer> singletons(Set<Taxon> coverage) {
        HashMap<Taxon, Integer> result = CollUtils.map();
        int i = 0;
        for (Taxon lang : coverage) {
            result.put(lang, i++);
        }
        return result;
    }

    private Map<MSAPoset.Column, Site> createSites(MSAPoset align) {
        HashMap<MSAPoset.Column, Site> result = CollUtils.map();
        for (MSAPoset.Column col : align.columns()) {
            HashMap<Taxon, Integer> valuesAtLeaves = CollUtils.map();
            Indexer<Integer> siteIndexer = new Indexer<Integer>();
            for (Taxon lang : col.getPoints().keySet()) {
                char currentCharacter = align.charAt(col, lang);
                int globalCode = this.enc.char2PhoneId(currentCharacter);
                if (!siteIndexer.containsObject(globalCode)) {
                    siteIndexer.addToIndex((Integer[])new Integer[]{globalCode});
                }
                int localCode = siteIndexer.o2i(globalCode);
                valuesAtLeaves.put(lang, localCode);
            }
            result.put(col, new Site(valuesAtLeaves, siteIndexer));
        }
        return result;
    }

    private char previousCharacter(MSAPoset align, MSAPoset.Column col, Taxon lang) {
        int currentPos = col.getPoints().get(lang);
        if (currentPos == 0) {
            return this.enc.boundChar();
        }
        return align.sequences().get(lang).charAt(col.getPoints().get(lang) - 1);
    }

    public static Map<PairAlignCognateSet.Concept, MSAPoset> readMSAPosets(File directory) {
        HashMap<PairAlignCognateSet.Concept, MSAPoset> result = CollUtils.map();
        for (File f : IO.ls(directory)) {
            try {
                result.put(PairAlignCognateSet.file2Concept(f), MSAPoset.restore(f));
            }
            catch (Exception e) {
                LogInfo.warning("Skipping: " + f);
            }
        }
        return result;
    }

    public static Arbre<Taxon> unionTree(CognateSet cs) {
        MutableGraph<Taxon> mg = new MutableGraph<Taxon>();
        Taxon root = null;
        for (CognateId id : cs.getCognateIds()) {
            boolean prevRootVisited = root == null;
            Arbre<Taxon> current = cs.getTree(id).preOrderMap(new Arbre.ArbreMap<DerivationTree.DerivationNode, Taxon>(){

                @Override
                public Taxon map(Arbre<DerivationTree.DerivationNode> currentDomainNode) {
                    return currentDomainNode.getContents().getLanguage();
                }
            });
            for (Arbre<Taxon> subt : current.nodes()) {
                if (!subt.isRoot()) {
                    mg.addEdge(subt.getContents(), subt.getParent().getContents());
                    if (!subt.getContents().equals(root)) continue;
                    prevRootVisited = true;
                    continue;
                }
                if (!prevRootVisited) continue;
                root = subt.getContents();
            }
        }
        return Arbre.tree2Arbre(Graphs.toTree(mg, root));
    }

    private double logLikelihood(PairAlignCognateSet.Concept concept, Set<Taxon> partition) {
        Pair<Taxon, Set<Taxon>> p = this.relevantLanguages(partition);
        Graphs.SubGraph<Taxon> graphicalModelTopology = new Graphs.SubGraph<Taxon>(this.globalTopoGraph, p.getSecond());
        Set<MSAPoset.Column> relevantColumns = this.globalAligns.get(concept).relevantColumns(partition);
        double logLikelihood = 0.0;
        Map<MSAPoset.Column, Site> currentSites = this.sites.get(concept);
        for (MSAPoset.Column c : relevantColumns) {
            Site s = currentSites.get(c);
            Counter<Integer> prevCounts = new Counter<Integer>();
            for (Taxon lang : c.getPoints().keySet()) {
                if (!partition.contains(lang)) continue;
                char prevCharacter = this.previousCharacter(this.globalAligns.get(concept), c, lang);
                int prevCharGlobalCode = this.enc.char2PhoneId(prevCharacter);
                prevCounts.incrementCount(prevCharGlobalCode, 1.0);
            }
            int prevMaj = (Integer)prevCounts.argMax();
            GMFct<Taxon> pots = this.potentials(graphicalModelTopology, s, p.getFirst(), prevMaj);
            logLikelihood += new TreeSumProd<Taxon>(pots).logZ();
        }
        return logLikelihood;
    }

    private GMFct<Taxon> potentials(final Graph<Taxon> graph, final Site s, final Taxon root, final int previousSiteMajority) {
        final int GAP_CODE = s.gapCode();
        return new GMFct<Taxon>(){

            @Override
            public double get(Taxon curTax, Taxon parTax, int curSt, int parSt) {
                if (curSt == GAP_CODE && parSt == GAP_CODE) {
                    return 1.0;
                }
                if (GreedyCognateInference.this.parentPointers.get(curTax) == null || !((Taxon)GreedyCognateInference.this.parentPointers.get(curTax)).equals(parTax)) {
                    Taxon tmp = curTax;
                    curTax = parTax;
                    parTax = tmp;
                    int tmpSt = curSt;
                    curSt = parSt;
                    parSt = tmpSt;
                }
                if (curSt == GAP_CODE) {
                    return GreedyCognateInference.this.params.getBranchParams().get(curTax).death(s.local2global(parSt), previousSiteMajority);
                }
                if (parSt == GAP_CODE) {
                    return GreedyCognateInference.this.params.getBranchParams().get(curTax).ins(previousSiteMajority, previousSiteMajority, s.local2global(curSt));
                }
                return GreedyCognateInference.this.params.getBranchParams().get(curTax).sub(s.local2global(parSt), previousSiteMajority, s.local2global(curSt));
            }

            @Override
            public double get(Taxon tax, int st) {
                if (s.valueAtLeaves.keySet().contains(tax) && st != s.valueAtLeaves.get(tax)) {
                    return 0.0;
                }
                if (tax.equals(root)) {
                    if (st == s.gapCode()) {
                        return 1.0;
                    }
                    return GreedyCognateInference.this.params.rootPr(tax, previousSiteMajority, s.local2global(st));
                }
                return 1.0;
            }

            @Override
            public Graph<Taxon> graph() {
                return graph;
            }

            @Override
            public int nStates(Taxon node) {
                return s.stateSize();
            }
        };
    }

    private Pair<Taxon, Set<Taxon>> relevantLanguages(Set<Taxon> selectedLeaves) {
        Arbre<Taxon> next;
        HashSet<Taxon> result = CollUtils.set(selectedLeaves);
        result.add(this.globalTopo.getContents());
        block0: for (Arbre<Taxon> current : this.globalTopo.leaves()) {
            if (!selectedLeaves.contains(current.getContents())) continue;
            do {
                Taxon curLang = current.getContents();
                if (!current.isLeaf() && result.contains(curLang)) continue block0;
                result.add(curLang);
            } while (!(current = current.getParent()).isRoot());
        }
        Arbre<Taxon> current = this.globalTopo;
        while ((next = this.onlyMarkedDesc(current, result)) != null) {
            result.remove(current.getContents());
            current = next;
        }
        return Pair.makePair(current.getContents(), result);
    }

    private Arbre<Taxon> onlyMarkedDesc(Arbre<Taxon> current, Set<Taxon> relCandidates) {
        if (current.getChildren().size() == 0) {
            return null;
        }
        Arbre<Taxon> result = null;
        for (Arbre<Taxon> child : current.getChildren()) {
            if (!relCandidates.contains(child.getContents())) continue;
            if (result == null) {
                result = child;
                continue;
            }
            return null;
        }
        if (result == null) {
            throw new RuntimeException();
        }
        return result;
    }

    public static final class Site {
        public final Map<Taxon, Integer> valueAtLeaves;
        public final Indexer<Integer> siteIndexer;

        public int stateSize() {
            return this.siteIndexer.size() + 1;
        }

        public int gapCode() {
            return this.siteIndexer.size();
        }

        public int local2global(int localIndex) {
            return this.siteIndexer.i2o(localIndex);
        }

        public int global2local(int globalIndex) {
            return this.siteIndexer.o2i(globalIndex);
        }

        public Site(Map<Taxon, Integer> valueAtLeaves, Indexer<Integer> siteIndexer) {
            this.valueAtLeaves = valueAtLeaves;
            this.siteIndexer = siteIndexer;
        }
    }

    public static class Merger {
        private final PairAlignCognateSet.Concept c;
        private final UnorderedPair<Set<Taxon>, Set<Taxon>> blocks;

        public Merger(PairAlignCognateSet.Concept c, Set<Taxon> set1, Set<Taxon> set2) {
            this.c = c;
            this.blocks = new UnorderedPair<Set<Taxon>, Set<Taxon>>(set1, set2);
        }

        public Set<Taxon> union() {
            return CollUtils.union(this.blocks.getFirst(), this.blocks.getSecond());
        }

        public boolean hasBlock(Set<Taxon> block) {
            if (this.blocks.getFirst().equals(block)) {
                return true;
            }
            return this.blocks.getSecond().equals(block);
        }
    }
}

