/*
 * Decompiled with CFR 0.152.
 */
package ev.ex;

import ev.to.Clust;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.UnorderedPair;
import fig.exec.Execution;
import goblin.CognateId;
import goblin.DerivationTree;
import goblin.HLFeatureExtractor;
import goblin.HLParams;
import goblin.HLParamsUpdater;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.AffineGapAlignmentSampler;
import ma.BalibaseCorpus;
import ma.MultiAlignment;
import nuts.io.IO;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.Counter;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pepper.Encodings;

public class PairAlign {
    public static BalibaseCorpus.BalibaseCorpusOptions options = new BalibaseCorpus.BalibaseCorpusOptions();
    public static HLFeatureExtractor extractor = new HLFeatureExtractor();
    public static MaxentClassifier.MaxentOptions maxentOptions = new MaxentClassifier.MaxentOptions();

    public static void main(String[] args) {
        IO.run(args, new PairAlignMain(), "bali", options, "extract", extractor, "maxent", maxentOptions);
    }

    public static class PairAlignMain
    implements Runnable {
        @Option
        public String initWeightsPath = "weights/wDna";
        @Option
        public boolean checkInitFeatures = true;
        @Option
        public Random rand = new Random(1L);
        @Option
        public int nIters = 10;
        @Option
        public int nSamples = 1000;
        @Option
        public int maxNAlign = Integer.MAX_VALUE;
        @Option
        public double distLowerBound = Double.NEGATIVE_INFINITY;
        @Option
        public double distUpperBound = Double.POSITIVE_INFINITY;
        private Set<Taxon> langs = new HashSet<Taxon>();

        @Override
        public void run() {
            try {
                BalibaseCorpus bc = new BalibaseCorpus(options);
                List<MultiAlignment> pairwise = this.extractPairwiseAlignments(bc, this.maxNAlign, true);
                LogInfo.track("Baseline: clustalw");
                try {
                    SummaryStatistics clustSP = new SummaryStatistics();
                    for (MultiAlignment gold : pairwise) {
                        MultiAlignment clustGuess = Clust.clustal(gold);
                        LogInfo.logsForce("Clust:\n" + clustGuess);
                        LogInfo.logsForce("Truth:\n" + gold);
                        double curSP = gold.sumOfPairsScore(clustGuess);
                        clustSP.addValue(curSP);
                        LogInfo.logsForce("SP: " + curSP);
                    }
                    LogInfo.logsForce("Clustal Mean SP: " + clustSP.getMean());
                }
                catch (Exception e) {
                    LogInfo.warning("Clustal failed.. skipping");
                }
                LogInfo.end_track();
                Encodings enc = Encodings.proteinEncodings(true);
                Counter initWeights = HLParamsUpdater.restoreCounter(this.initWeightsPath);
                HLParamsUpdater updater = new HLParamsUpdater(enc, this.langs, extractor, maxentOptions, initWeights, 1);
                HLParams cur = HLParams.createHLParamsFromWeights(enc, this.langs, extractor, initWeights, this.checkInitFeatures, 1);
                for (int curEmIter = 0; curEmIter < this.nIters; ++curEmIter) {
                    LogInfo.track((Object)("EM iteration " + curEmIter + "/" + this.nIters), true);
                    SummaryStatistics maxDecSP = new SummaryStatistics();
                    SummaryStatistics variSP = new SummaryStatistics();
                    Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> globalSuffStat = new Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>>();
                    int msaIndex = 0;
                    for (MultiAlignment gold : pairwise) {
                        LogInfo.track((Object)("Sequences " + gold.nodes() + " (" + msaIndex++ + "/" + pairwise.size() + ")"), true);
                        Taxon topL = gold.nodes().get(0);
                        Taxon botL = gold.nodes().get(1);
                        String top = gold.getSequences().get(topL);
                        String bot = gold.getSequences().get(botL);
                        AffineGapAlignmentSampler sampler = AffineGapAlignmentSampler.createHLAlignmentSampler(top, bot, cur.getBranchParams().get(botL));
                        Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> current = new Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>>();
                        double[][] alignmentFreq = new double[top.length()][bot.length()];
                        LogInfo.track("Sampling");
                        for (int s = 0; s < this.nSamples; ++s) {
                            LogInfo.logs("" + s + "/" + this.nSamples);
                            DerivationTree.Derivation d = sampler.sample(this.rand);
                            this.addToFreq(d, alignmentFreq);
                            HLParams.addBranchSuffStats(current, d, botL, enc, new DerivationTree.Window(0, top.length()), new DerivationTree.Window(0, bot.length()));
                        }
                        LogInfo.end_track();
                        for (LabeledInstance key : current.keySet()) {
                            current.setCount(key, current.getCount(key) / (double)this.nSamples);
                        }
                        globalSuffStat.incrementAll(current);
                        AffineGapAlignmentSampler maxDec = AffineGapAlignmentSampler.createHLAlignmentSampler(top, bot, cur.getBranchParams().get(botL));
                        DerivationTree.Derivation guess = maxDec.mode();
                        MultiAlignment guessMSA = MultiAlignment.inducedMultiAlignment(topL, botL, guess);
                        LogInfo.logsForce("Truth:\n" + gold);
                        LogInfo.logsForce("Max derivation:\n" + guessMSA);
                        double curSP = gold.sumOfPairsScore(guessMSA);
                        maxDecSP.addValue(curSP);
                        LogInfo.logsForce("Max derivation SP: " + curSP);
                        for (int ti = 0; ti < top.length(); ++ti) {
                            for (int bi = 0; bi < bot.length(); ++bi) {
                                alignmentFreq[ti][bi] = Math.exp(alignmentFreq[ti][bi] / (double)this.nSamples);
                            }
                        }
                        AffineGapAlignmentSampler varDec = AffineGapAlignmentSampler.createContextSensitiveAlignmentSampler(top, bot, alignmentFreq);
                        guess = varDec.mode();
                        guessMSA = MultiAlignment.inducedMultiAlignment(topL, botL, guess);
                        LogInfo.logsForce("Vari derivation:\n" + guessMSA);
                        curSP = gold.sumOfPairsScore(guessMSA);
                        variSP.addValue(curSP);
                        LogInfo.logsForce("Vari derivation SP: " + curSP);
                        LogInfo.end_track();
                    }
                    LogInfo.logsForce("Mean max dec SP: " + maxDecSP.getMean());
                    LogInfo.logsForce("Mean vari dec SP: " + variSP.getMean());
                    cur = updater.update(globalSuffStat);
                    updater.saveWeightsInExec("reest-proposal", curEmIter);
                    LogInfo.end_track();
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        private void addToFreq(DerivationTree.Derivation d, double[][] alignmentFreq) {
            for (int botIndex = 0; botIndex < d.getCurrentWord().length(); ++botIndex) {
                if (!d.hasAncestor(botIndex)) continue;
                double[] dArray = alignmentFreq[d.ancestor(botIndex)];
                int n = botIndex;
                dArray[n] = dArray[n] + 1.0;
            }
        }

        public List<MultiAlignment> extractPairwiseAlignments(BalibaseCorpus bc, int maxN, boolean fix) {
            return PairAlignMain.flattenList(this.extractPairwiseAlignments(bc, maxN, false, fix));
        }

        public static <T> List<T> flattenList(List<List<T>> nestedList) {
            ArrayList<T> result = new ArrayList<T>();
            for (List<T> subList : nestedList) {
                result.addAll(subList);
            }
            return result;
        }

        public List<List<MultiAlignment>> extractPairwiseAlignments(BalibaseCorpus bc, int maxN, boolean output, boolean fix) {
            File outDir;
            File file = outDir = output ? new File(Execution.getFile("pairwiseCorpus")) : null;
            if (output) {
                outDir.mkdir();
            }
            ArrayList<List<MultiAlignment>> finalresult = new ArrayList<List<MultiAlignment>>();
            int total = 0;
            for (CognateId id : bc.getMultiAlignments().keySet()) {
                ArrayList<MultiAlignment> result = new ArrayList<MultiAlignment>();
                finalresult.add(result);
                int pairIndex = 0;
                MultiAlignment full = bc.getMultiAlignments().get(id);
                if (fix) {
                    full = PairAlignMain.fixMSA(full);
                }
                if (output) {
                    LogInfo.logs("Current MSA:\n" + full);
                }
                if (output) {
                    LogInfo.track((Object)"Extracting pairwise aligns", true);
                }
                for (int i = 0; i < full.nodes().size(); ++i) {
                    Taxon l1 = full.nodes().get(i);
                    for (int j = i + 1; j < full.nodes().size(); ++j) {
                        Taxon l2 = full.nodes().get(j);
                        if (!this.distanceValid(bc.getDistances(), l1, l2)) continue;
                        if (total >= maxN) {
                            return finalresult;
                        }
                        ++total;
                        this.langs.add(l1);
                        this.langs.add(l2);
                        MultiAlignment curPair = full.restrict(Arrays.asList(l1, l2));
                        result.add(curPair);
                        if (output) {
                            curPair.saveToMSF(new File(outDir, "" + id + "-" + pairIndex++ + ".msf"));
                        }
                        if (!output) continue;
                        LogInfo.logs(curPair);
                    }
                }
                if (!output) continue;
                LogInfo.end_track();
            }
            return finalresult;
        }

        private boolean distanceValid(Map<UnorderedPair<Taxon, Taxon>, Double> distances, Taxon l1, Taxon l2) {
            if (this.distLowerBound == Double.NEGATIVE_INFINITY && this.distUpperBound == Double.POSITIVE_INFINITY) {
                return true;
            }
            double dist = distances.get(new UnorderedPair<Taxon, Taxon>(l1, l2));
            if (dist < this.distLowerBound) {
                return false;
            }
            return !(dist > this.distUpperBound);
        }

        public static MultiAlignment fixMSA(MultiAlignment msa) {
            Encodings enc = Encodings.proteinEncodings(true);
            return PairAlignMain.fixMSA(enc, msa);
        }

        public static MultiAlignment fixMSA(Encodings enc, MultiAlignment msa) {
            msa = msa.copy();
            for (Taxon l : new ArrayList<Taxon>(msa.getSequences().keySet())) {
                StringBuilder repaired = new StringBuilder();
                String cur = msa.getSequences().get(l);
                for (char c : cur.toCharArray()) {
                    if (enc.char2PhoneId(c) == -1) {
                        LogInfo.warning("Warning: encountered an unk base (" + c + ").  Hack: replacing it arbitrarily");
                        repaired.append(enc.phoneId2Char(0));
                        continue;
                    }
                    repaired.append(c);
                }
                msa.changeString(l, repaired.toString());
            }
            return msa;
        }
    }
}

