/*
 * Decompiled with CFR 0.152.
 */
package ev;

import ev.ex.PairAlign;
import ev.hmm.HetPairHMM;
import ev.par.ExponentialFamily;
import ev.par.FeatureExtractor;
import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Parallelizer;
import fig.exec.Execution;
import goblin.DerivationTree;
import goblin.Taxon;
import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import ma.BalibaseCorpus;
import ma.MultiAlignment;
import ma.SequenceType;
import nuts.io.IO;
import nuts.maxent.MaxentClassifier;
import nuts.util.CollUtils;
import nuts.util.EasyFormat;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;

public class Train
implements Runnable {
    @Option
    public double distLowerBound = Double.NEGATIVE_INFINITY;
    @Option
    public double distUpperBound = Double.POSITIVE_INFINITY;
    @Option
    public int emIters = 10;
    @Option(gloss="k=1 means unsup, o.w. k-fold supervised")
    public int k = 2;
    @Option
    public int maxNPairs = Integer.MAX_VALUE;
    @Option
    public Random pairShuffleRandom = new Random(1L);
    public static ExponentialFamily.ExponentialFamilyOptions expFamOptions = new ExponentialFamily.ExponentialFamilyOptions();
    public static FeatureExtractor.FeatureOptions featureOptions = new FeatureExtractor.FeatureOptions();
    public static MaxentClassifier.MaxentOptions<Object> learningOptions = new MaxentClassifier.MaxentOptions();

    public static void main(String[] args) {
        IO.run(args, new Train(), "bali", PairAlign.options, "efo", expFamOptions, "featop", featureOptions, "lo", learningOptions);
    }

    @Override
    public void run() {
        PairAlign.PairAlignMain dataLoader = new PairAlign.PairAlignMain();
        dataLoader.distLowerBound = this.distLowerBound;
        dataLoader.distUpperBound = this.distUpperBound;
        final BalibaseCorpus bc = new BalibaseCorpus(PairAlign.options);
        final List<List<MultiAlignment>> data = dataLoader.extractPairwiseAlignments(bc, Integer.MAX_VALUE, false, Train.expFamOptions.encodingType == SequenceType.PROTEIN);
        if (this.maxNPairs != Integer.MAX_VALUE) {
            for (int i = 0; i < data.size(); ++i) {
                List datum = data.get(i);
                Collections.shuffle(datum, this.pairShuffleRandom);
                data.set(i, datum.subList(0, this.maxNPairs));
            }
        }
        final int nPairwiseAligns = PairAlign.PairAlignMain.flattenList(data).size();
        final boolean unsup = this.k == 1;
        final List<Set<Integer>> dataPartitions = CollUtils.partition(data.size(), this.k);
        final ArrayList<SummaryStatistics> stats = new ArrayList<SummaryStatistics>();
        for (int i = 0; i < this.emIters; ++i) {
            stats.add(new SummaryStatistics());
        }
        this.storeTrainingPartitions(dataPartitions, data);
        Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(this.k);
        parallelizer.setPrimaryThread();
        parallelizer.process(CollUtils.ints(this.k), new Parallelizer.Processor<Integer>(){

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            public void process(Integer evaluationBlockIndex, int _i, int _n, boolean log) {
                ExponentialFamily expFam = ExponentialFamily.createExpfam(learningOptions, expFamOptions, featureOptions, bc.getDistances());
                if (evaluationBlockIndex == 0) {
                    expFam.saveWeightsInExec("init.weights");
                }
                if (log) {
                    LogInfo.track("Validation block " + evaluationBlockIndex + "/" + Train.this.k);
                }
                Set evaluationBlock = (Set)dataPartitions.get(evaluationBlockIndex);
                for (int curEmIter = 0; curEmIter < Train.this.emIters; ++curEmIter) {
                    if (log) {
                        LogInfo.track((Object)("EM iteration " + curEmIter + "/" + Train.this.emIters), true);
                    }
                    for (int msaIndex = 0; msaIndex < data.size(); ++msaIndex) {
                        if (log) {
                            LogInfo.track("MSA " + msaIndex + "/" + data.size() + "(thread-" + evaluationBlockIndex + ",em" + curEmIter + ")");
                        }
                        List pairwiseAligns = (List)data.get(msaIndex);
                        boolean isValidation = evaluationBlock.contains(msaIndex);
                        int pairN = 0;
                        for (MultiAlignment gold : pairwiseAligns) {
                            double curSP;
                            MultiAlignment guessMSA;
                            DerivationTree.Derivation guess;
                            HetPairHMM hmm;
                            if (log) {
                                LogInfo.logs("Pair " + pairN++ + "/" + pairwiseAligns.size());
                            }
                            Taxon topL = gold.nodes().get(0);
                            Taxon botL = gold.nodes().get(1);
                            String top = gold.getSequences().get(topL);
                            String bot = gold.getSequences().get(botL);
                            HetPairHMM hetPairHMM = hmm = isValidation || unsup ? expFam.getHMM(top, bot, topL, botL) : expFam.getSupervisedHMM(gold, topL, botL);
                            if (Double.isNaN(hmm.logSumProduct()) || Double.isInfinite(hmm.logSumProduct())) {
                                throw new RuntimeException();
                            }
                            expFam.addSufficientStatistics(hmm, topL, botL);
                            if (isValidation || unsup) {
                                guess = HetPairHMM.removeBoundary(hmm.viterbi(null), expFam.model.enc.boundChar());
                                guessMSA = MultiAlignment.inducedMultiAlignment(topL, botL, guess);
                                curSP = gold.sumOfPairsScore(guessMSA);
                                SummaryStatistics summaryStatistics = (SummaryStatistics)stats.get(curEmIter);
                                synchronized (summaryStatistics) {
                                    SummaryStatistics current = (SummaryStatistics)stats.get(curEmIter);
                                    current.addValue(curSP);
                                    if (current.getN() == (long)nPairwiseAligns) {
                                        LogInfo.logsForce("EM-" + curEmIter + " SP score :" + EasyFormat.fmt2(current.getMean()));
                                        IO.appendLine("SP", curEmIter, current);
                                    }
                                    continue;
                                }
                            }
                            if (isValidation || unsup || !((curSP = gold.sumOfPairsScore(guessMSA = MultiAlignment.inducedMultiAlignment(topL, botL, guess = HetPairHMM.removeBoundary(hmm.viterbi(null), expFam.model.enc.boundChar())))) < 1.0)) continue;
                            LogInfo.error("SP on supervised less than one:" + curSP);
                            long time = System.currentTimeMillis();
                            File f = new File(Execution.getFile("Bug_" + time));
                            f.mkdir();
                            gold.saveToMSF(new File(f, "goldMSA"));
                            expFam.saveWeights(new File(f, "current.weights"));
                        }
                        if (!log) continue;
                        LogInfo.end_track();
                    }
                    if (log) {
                        LogInfo.end_track();
                    }
                    expFam.updateParameters();
                    expFam.saveWeightsInExec("reest-block" + evaluationBlockIndex + "-iter" + curEmIter + ".weights");
                }
                if (log) {
                    LogInfo.end_track();
                }
            }
        });
    }

    private void storeTrainingPartitions(List<Set<Integer>> dataPartitions, List<List<MultiAlignment>> data) {
        for (int block = 0; block < this.k; ++block) {
            PrintWriter out = IOUtils.openOutHard(Execution.getFile("trainingBlock-" + block));
            for (int other = 0; other < this.k; ++other) {
                if (other == block) continue;
                for (int alignIndex : dataPartitions.get(other)) {
                    for (Taxon l : this.allSeqIds(data.get(alignIndex))) {
                        out.println(l);
                    }
                }
            }
            out.close();
        }
    }

    private Set<Taxon> allSeqIds(List<MultiAlignment> list) {
        HashSet<Taxon> all = new HashSet<Taxon>();
        for (MultiAlignment msa : list) {
            all.addAll(msa.getSequences().keySet());
        }
        return all;
    }
}

