/*
 * Decompiled with CFR 0.152.
 */
package ev;

import ev.hmm.HetPairHMM;
import ev.par.ExponentialFamily;
import ev.par.FeatureExtractor;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.Parallelizer;
import fig.exec.Execution;
import goblin.CognateId;
import goblin.CognateSet;
import goblin.Taxon;
import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.SequenceType;
import nuts.io.IO;
import nuts.lang.StringUtils;
import nuts.maxent.MaxentClassifier;
import nuts.util.CollUtils;
import nuts.util.Counter;
import pepper.Encodings;

public class PairAlignCognateSet
implements Runnable {
    @Option
    public int emIters = 10;
    @Option
    public String pathToCognates = "/Users/bouchard/w/pepper/state/execs/538.exec/snapshot0000.CognateSet";
    @Option
    public int nThreads = 2;
    public static MaxentClassifier.MaxentOptions<Object> learningOptions = new MaxentClassifier.MaxentOptions();
    public static ExponentialFamily.ExponentialFamilyOptions expFamOptions = new ExponentialFamily.ExponentialFamilyOptions();
    public static FeatureExtractor.FeatureOptions featureOptions = new FeatureExtractor.FeatureOptions();

    public static void main(String[] args) {
        IO.run(args, new PairAlignCognateSet(), "feat", featureOptions, "enc", Encodings.class);
    }

    @Override
    public void run() {
        CognateSet cs = null;
        try {
            cs = CognateSet.restoreCognateSet(this.pathToCognates);
        }
        catch (Exception e) {
            throw new RuntimeException();
        }
        List<Pair<String, String>> data = this.goodData(cs);
        final CognateDataset cogData = new CognateDataset(cs);
        PairAlignCognateSet.expFamOptions.encodingType = SequenceType.REAL;
        final ExponentialFamily expFam = ExponentialFamily.createExpfam(learningOptions, expFamOptions, featureOptions, null);
        for (int curEmIter = 0; curEmIter < this.emIters; ++curEmIter) {
            LogInfo.track("EM iter " + curEmIter + "/" + this.emIters);
            double logLikelihood = 0.0;
            int i = 0;
            for (Pair<String, String> words : data) {
                String bot;
                Taxon topL = new Taxon("l1");
                Taxon botL = new Taxon("l2");
                String top = words.getFirst();
                HetPairHMM hmm = expFam.getHMM(top, bot = words.getSecond(), topL, botL);
                if (Double.isNaN(hmm.logSumProduct()) || Double.isInfinite(hmm.logSumProduct())) {
                    throw new RuntimeException();
                }
                logLikelihood += hmm.logSumProduct();
                expFam.addSufficientStatistics(hmm, topL, botL);
                LogInfo.logs("Extracting suff stats [" + i++ + "/" + data.size() + "]");
            }
            expFam.updateParameters();
            expFam.saveWeightsInExec("reest-" + curEmIter + ".weights");
            LogInfo.logsForce("Loglikelihood at iteration " + curEmIter + ":" + logLikelihood);
            final File dir = new File(Execution.getFile("alignments-" + curEmIter));
            dir.mkdir();
            Parallelizer<Concept> parallelizer = new Parallelizer<Concept>(this.nThreads);
            parallelizer.setPrimaryThread();
            parallelizer.process(CollUtils.list(cogData.concepts()), new Parallelizer.Processor<Concept>(){

                @Override
                public void process(Concept c, int _i, int _n, boolean log) {
                    MSAPoset current = PairAlignCognateSet.posteriorAlignment(cogData.words(c), expFam);
                    MSAPoset.save(current, new File(dir, PairAlignCognateSet.concept2file(c)));
                    if (log) {
                        LogInfo.logs("Global align [" + _i + "/" + _n + "]");
                    }
                }
            });
            LogInfo.end_track();
        }
    }

    public static String concept2file(Concept c) {
        return c.toString().replaceAll("/", "___");
    }

    public static Concept file2Concept(File f) {
        return new Concept(f.getName().replaceAll("___", "/"));
    }

    public static MSAPoset posteriorAlignment(Map<Taxon, String> data, ExponentialFamily parameters) {
        Counter<GreedyDecoder.Edge> edgePosteriors = new Counter<GreedyDecoder.Edge>();
        ArrayList<Taxon> orderedLang = CollUtils.list(data.keySet());
        for (int l1 = 0; l1 < orderedLang.size(); ++l1) {
            for (int l2 = l1 + 1; l2 < orderedLang.size(); ++l2) {
                String bot;
                Taxon topL = (Taxon)orderedLang.get(l1);
                Taxon botL = (Taxon)orderedLang.get(l2);
                String top = data.get(topL);
                HetPairHMM hmm = parameters.getHMM(top, bot = data.get(botL), topL, botL);
                if (Double.isNaN(hmm.logSumProduct()) || Double.isInfinite(hmm.logSumProduct())) {
                    throw new RuntimeException();
                }
                for (int topPos = 0; topPos < top.length(); ++topPos) {
                    for (int botPos = 0; botPos < bot.length(); ++botPos) {
                        GreedyDecoder.Edge e = new GreedyDecoder.Edge(topPos, botPos, topL, botL);
                        edgePosteriors.setCount(e, Math.exp(hmm.logPosteriorAlignment(topPos, botPos)));
                    }
                }
            }
        }
        MSAPoset result = new MSAPoset(data);
        for (GreedyDecoder.Edge e : edgePosteriors) {
            result.tryAdding(e);
        }
        return result;
    }

    private List<Pair<String, String>> goodData(CognateSet cs) {
        ArrayList<Pair<String, String>> result = CollUtils.list();
        for (CognateId id : cs.getCognateIds()) {
            List<String> words = cs.allWords(id);
            for (int i = 0; i < words.size(); ++i) {
                for (int j = i + 1; j < words.size(); ++j) {
                    result.add(Pair.makePair(words.get(i), words.get(j)));
                }
            }
        }
        return result;
    }

    public static class Concept
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private final String str;

        public Concept(String str) {
            this.str = str;
        }

        public String toString() {
            return this.str;
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.str == null ? 0 : this.str.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            Concept other = (Concept)obj;
            return !(this.str == null ? other.str != null : !this.str.equals(other.str));
        }
    }

    public static class CognateDataset
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private final Map<Concept, Map<Taxon, Pair<Integer, String>>> data = CollUtils.map();

        public CognateDataset(CognateSet cs) {
            for (CognateId id : cs.getCognateIds()) {
                String cAndGp = id.toString();
                Concept c = new Concept(cAndGp.replaceAll("\\([0-9]*\\)", ""));
                int cognateGp = Integer.parseInt(StringUtils.selectFirstRegex("\\(([0-9]*)\\)", cAndGp));
                for (Taxon lang : cs.getObs(id).observedLanguages()) {
                    ((Map)CollUtils.getNoNull(this.data, c, new HashMap())).put(lang, Pair.makePair(cognateGp, cs.getWord(id, lang)));
                }
            }
        }

        public Map<Taxon, String> words(Concept c) {
            HashMap<Taxon, String> result = CollUtils.map();
            for (Taxon lang : this.coverage(c)) {
                result.put(lang, this.getWord(c, lang));
            }
            return result;
        }

        public Set<Concept> concepts() {
            return this.data.keySet();
        }

        public Set<Taxon> coverage(Concept c) {
            return this.data.get(c).keySet();
        }

        public int getCognateGroup(Concept c, Taxon lang) {
            return this.data.get(c).get(lang).getFirst();
        }

        public String getWord(Concept c, Taxon lang) {
            return this.data.get(c).get(lang).getSecond();
        }

        public Map<Taxon, Integer> cognateGroups(Concept c) {
            HashMap<Taxon, Integer> result = CollUtils.map();
            for (Taxon lang : this.coverage(c)) {
                result.put(lang, this.getCognateGroup(c, lang));
            }
            return result;
        }
    }
}

