/*
 * Decompiled with CFR 0.152.
 */
package ev;

import ev.ex.PairAlign;
import ev.hmm.HetPairHMM;
import ev.hmm.SPMinBayesRiskDecoder;
import ev.multi.MessageComputations;
import ev.par.ExponentialFamily;
import ev.par.FeatureExtractor;
import fig.basic.LogInfo;
import fig.basic.Option;
import goblin.CognateId;
import goblin.DerivationTree;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import ma.BalibaseCorpus;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.MultiAlignment;
import ma.SequenceType;
import nuts.io.IO;
import nuts.maxent.MaxentClassifier;
import nuts.util.Counter;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;

public class Test
implements Runnable {
    public static ExponentialFamily.ExponentialFamilyOptions expFamOptions = new ExponentialFamily.ExponentialFamilyOptions();
    public static FeatureExtractor.FeatureOptions featureOptions = new FeatureExtractor.FeatureOptions();
    public static MaxentClassifier.MaxentOptions<Object> learningOptions = new MaxentClassifier.MaxentOptions();
    @Option
    public int nIterations = 3;

    public static void main(String[] args) {
        IO.run(args, new Test(), "anno", MessageComputations.class, "bali", PairAlign.options, "efo", expFamOptions, "featop", featureOptions, "lo", learningOptions);
    }

    @Override
    public void run() {
        BalibaseCorpus bc = new BalibaseCorpus(PairAlign.options);
        Test.expFamOptions.encodingType = SequenceType.PROTEIN;
        ExponentialFamily expFam = ExponentialFamily.createExpfam(learningOptions, expFamOptions, featureOptions, bc.getDistances());
        DescriptiveStatistics converged = new DescriptiveStatistics();
        DescriptiveStatistics baseline = new DescriptiveStatistics();
        for (CognateId id : bc.getMultiAlignments().keySet()) {
            MultiAlignment msa = bc.getMultiAlignments().get(id);
            LogInfo.track("Current file:" + id);
            List<MultiAlignment> pairs = Test.extractPairwiseAlignments(msa, true);
            MessageComputations.QMessages previousQs = null;
            Counter<GreedyDecoder.Edge> edgePosteriors = new Counter<GreedyDecoder.Edge>();
            for (int i = 0; i < this.nIterations; ++i) {
                DescriptiveStatistics stats = new DescriptiveStatistics();
                MessageComputations.QMessages newQs = MessageComputations.blank(msa);
                for (MultiAlignment pairAlign : pairs) {
                    Taxon l1 = pairAlign.nodes().get(0);
                    Taxon l2 = pairAlign.nodes().get(1);
                    String s1 = pairAlign.getSequences().get(l1);
                    String s2 = pairAlign.getSequences().get(l2);
                    double[][][] rMessages = previousQs == null ? MessageComputations.initRMessages(s1.length(), s2.length()) : MessageComputations.rMessages(l1, l2, previousQs);
                    HetPairHMM hmm = expFam.getReweightedHMM(rMessages, s1, s2, l1, l2);
                    MessageComputations.qMessages(hmm, newQs, l1, l2, rMessages);
                    DerivationTree.Derivation guess = HetPairHMM.removeBoundary(SPMinBayesRiskDecoder.decode(hmm), expFam.model.enc.boundChar());
                    if (i == this.nIterations - 1) {
                        Taxon botLang = l2;
                        Taxon topLang = l1;
                        for (int botPos = 0; botPos < guess.getCurrentWord().length(); ++botPos) {
                            for (int topPos = 0; topPos < guess.getAncestorWord().length(); ++topPos) {
                                GreedyDecoder.Edge current = new GreedyDecoder.Edge(topPos, botPos, topLang, botLang);
                                edgePosteriors.setCount(current, Math.exp(hmm.logPosteriorAlignment(topPos, botPos)));
                            }
                        }
                    }
                    MultiAlignment guessMSA = MultiAlignment.inducedMultiAlignment(l1, l2, guess);
                    double curSP = pairAlign.sumOfPairsScore(guessMSA);
                    stats.addValue(curSP);
                }
                previousQs = newQs;
                LogInfo.logsForce("Iteration " + i + ":" + stats.getMean());
                if (i == this.nIterations - 1) {
                    for (Object v : (Object)stats.getValues()) {
                        converged.addValue((double)v);
                    }
                }
                if (i != 0) continue;
                for (Object v : (Object)stats.getValues()) {
                    baseline.addValue((double)v);
                }
            }
            MSAPoset msaPoset = new MSAPoset(msa.getSequences());
            for (GreedyDecoder.Edge e : edgePosteriors) {
                msaPoset.tryAdding(e);
            }
            LogInfo.logsForce("Global score:" + msa.sumOfPairsScore(msaPoset.toMultiAlignmentObject()));
            LogInfo.end_track();
        }
        LogInfo.logsForce("Baseline SP score:" + baseline.getMean());
        LogInfo.logsForce("Global SP score:" + converged.getMean());
    }

    public static List<MultiAlignment> extractPairwiseAlignments(MultiAlignment msa, boolean fix) {
        ArrayList<MultiAlignment> result = new ArrayList<MultiAlignment>();
        for (int i = 0; i < msa.nodes().size(); ++i) {
            Taxon l1 = msa.nodes().get(i);
            for (int j = i + 1; j < msa.nodes().size(); ++j) {
                Taxon l2 = msa.nodes().get(j);
                MultiAlignment curPair = msa.restrict(Arrays.asList(l1, l2));
                if (fix) {
                    curPair = PairAlign.PairAlignMain.fixMSA(curPair);
                }
                result.add(curPair);
            }
        }
        return result;
    }
}

