/*
 * Decompiled with CFR 0.152.
 */
package ev.ex;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.StrUtils;
import fig.prob.Dirichlet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.io.CmlSequenceReader;
import nuts.io.IO;
import nuts.math.DynamicBinHistogram;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.Graphs;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.util.CollUtils;
import nuts.util.CounterMap;
import nuts.util.Indexer;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;

public class HmmQuestion
implements Runnable {
    private static final String BEG = "<BEGINNING>";
    private static final String END = "<END>";
    @Option
    public int nTest = 10;
    @Option
    public int nTrain = 10;
    @Option
    public Random rand = new Random(1L);
    @Option
    public double alpha0 = 0.5;
    @Option
    public double threshold = 0.5;
    @Option
    public int nGibbsIters = 100;
    private List<String> targets = Arrays.asList("H", "B", "E");
    private Indexer<String> hiddenIdx;
    private Indexer<String> obsIdx;
    private CounterMap<String, String> backboneSuffStats;
    private CounterMap<String, String> observationSuffStats;
    private CounterMap<String, String> backboneMLE;
    private CounterMap<String, String> observationMLE;
    private int stspsize;
    private List<List<Map<Integer, Integer>>> samples;
    private List<Pair<List<String>, List<String>>> train = CollUtils.list();
    private List<Pair<List<String>, List<String>>> test = CollUtils.list();

    public static void main(String[] args) {
        IO.run(args, new HmmQuestion());
    }

    private boolean isTarget(String str) {
        return this.targets.contains(str);
    }

    private void loadData() {
        this.samples = CollUtils.list();
        int nData = 0;
        for (Pair<List<String>, List<String>> datum : CmlSequenceReader.i("/Users/bouchard/w/evolvere/data/secondaryStructure/train")) {
            if (++nData > this.nTrain) break;
            this.train.add(datum);
        }
        nData = 0;
        SummaryStatistics propBeta = new SummaryStatistics();
        for (Pair<List<String>, List<String>> datum : CmlSequenceReader.i("/Users/bouchard/w/evolvere/data/secondaryStructure/test")) {
            if (++nData > this.nTest) break;
            this.test.add(datum);
            for (String cur : datum.getSecond()) {
                propBeta.addValue(this.isTarget(cur) ? 1.0 : 0.0);
            }
        }
        LogInfo.logsForce("Target structures proportion:" + propBeta.getMean());
        this.hiddenIdx = new Indexer();
        this.hiddenIdx.addToIndex((String[])new String[]{END});
        this.hiddenIdx.addToIndex((String[])new String[]{BEG});
        this.obsIdx = new Indexer();
        this.backboneSuffStats = new CounterMap();
        this.observationSuffStats = new CounterMap();
        for (Pair<List<String>, List<String>> datum : this.train) {
            for (String str : datum.getFirst()) {
                this.obsIdx.addToIndex((String[])new String[]{str});
            }
            for (String str : datum.getSecond()) {
                this.hiddenIdx.addToIndex((String[])new String[]{str});
            }
            this.addBackBoneStats(datum.getSecond(), this.backboneSuffStats);
            this.addObsStats(datum, this.observationSuffStats);
        }
        for (Pair<List<String>, List<String>> datum : this.test) {
            for (String str : datum.getFirst()) {
                this.obsIdx.addToIndex((String[])new String[]{str});
            }
            for (String str : datum.getSecond()) {
                this.hiddenIdx.addToIndex((String[])new String[]{str});
            }
        }
        this.stspsize = this.hiddenIdx.size();
        this.backboneMLE = new CounterMap();
        this.observationMLE = new CounterMap();
        this.backboneMLE.incrementAll(this.backboneSuffStats);
        this.observationMLE.incrementAll(this.observationSuffStats);
        this.backboneMLE.normalize();
        this.observationMLE.normalize();
    }

    private GMFct<Integer> getGraphicalModel(List<String> observations, CounterMap<String, String> backboneStats, CounterMap<String, String> observationStats) {
        int s;
        int t;
        int size = observations.size();
        HashMap sizes = CollUtils.map();
        for (int j = 0; j < size; ++j) {
            sizes.put(j, this.stspsize);
        }
        TabularGMFct<Integer> pots = GMFctUtils.zeroes(Graphs.chainGraph(size), sizes);
        for (t = 1; t < size; ++t) {
            for (int prev = 0; prev < this.stspsize; ++prev) {
                for (int next = 0; next < this.stspsize; ++next) {
                    pots.set(t - 1, t, prev, next, backboneStats.getCount(this.hiddenIdx.i2o(prev), this.hiddenIdx.i2o(next)));
                }
            }
        }
        for (t = 0; t < size; ++t) {
            for (int s2 = 0; s2 < this.stspsize; ++s2) {
                pots.set(t, s2, observationStats.getCount(this.hiddenIdx.i2o(s2), observations.get(t)));
            }
        }
        for (s = 0; s < this.stspsize; ++s) {
            pots.scale((Integer)0, s, backboneStats.getCount(BEG, this.hiddenIdx.i2o(s)));
        }
        for (s = 0; s < this.stspsize; ++s) {
            pots.scale((Integer)(size - 1), s, backboneStats.getCount(this.hiddenIdx.i2o(s), END));
        }
        return pots;
    }

    private List<Integer> predictionMAP(GMFct<Integer> post) {
        List<Double> predictionPost = this.prediction(post);
        ArrayList<Integer> predicteds = CollUtils.list();
        for (double curPost : predictionPost) {
            boolean predIsEB = curPost > this.threshold;
            predicteds.add(predIsEB ? 1 : 0);
        }
        return predicteds;
    }

    private List<Double> prediction(GMFct<Integer> post) {
        int size = post.graph().vertexSet().size();
        ArrayList<Double> predicteds = CollUtils.list();
        for (int t = 0; t < size; ++t) {
            double curPost = 0.0;
            for (String target : this.targets) {
                curPost += post.get(t, this.hiddenIdx.o2i(target));
            }
            predicteds.add(curPost);
        }
        return predicteds;
    }

    private List<Integer> truth(List<String> states) {
        int size = states.size();
        ArrayList<Integer> truths = CollUtils.list();
        for (int t = 0; t < size; ++t) {
            String truth = states.get(t);
            boolean trueIsEB = this.isTarget(truth);
            truths.add(trueIsEB ? 1 : 0);
        }
        return truths;
    }

    private int nMatches(List<Integer> l1, List<Integer> l2) {
        int result = 0;
        for (int i = 0; i < l1.size(); ++i) {
            result += l1.get(i).equals(l2.get(i)) ? 1 : 0;
        }
        return result;
    }

    private int nPositives(List<Integer> l) {
        int result = 0;
        for (int i : l) {
            if (i != 1) continue;
            ++result;
        }
        return result;
    }

    public Pair<CounterMap<String, String>, CounterMap<String, String>> doInference(CounterMap<String, String> bbParams, CounterMap<String, String> obsParams) {
        SummaryStatistics globalAccuracyStats = new SummaryStatistics();
        DynamicBinHistogram hist = new DynamicBinHistogram();
        SummaryStatistics globalAccuracyStats2 = new SummaryStatistics();
        DynamicBinHistogram hist2 = new DynamicBinHistogram();
        CounterMap<String, String> currentBackSS = new CounterMap<String, String>();
        CounterMap<String, String> currentObsSS = new CounterMap<String, String>();
        LogInfo.track("Doing inference");
        int k = 0;
        boolean isFirstRound = false;
        for (Pair<List<String>, List<String>> datum : this.test) {
            if (k >= this.nTest) break;
            if (this.samples.size() <= k) {
                isFirstRound = true;
                this.samples.add(new ArrayList());
            }
            List currentListOfSamples = this.samples.get(k);
            LogInfo.logs("Sequence " + k);
            GMFct<Integer> pots = this.getGraphicalModel(datum.getFirst(), bbParams, obsParams);
            TabularGMFct<Integer> post = null;
            try {
                post = TreeSumProd.computeMoments(pots);
            }
            catch (Exception e) {
                LogInfo.warning("Warning: zero probability sequence!  Backing off to uniform prediction!");
                pots = GMFctUtils.ones(pots);
                post = TreeSumProd.computeMoments(pots);
            }
            Map<Integer, Integer> current = GMFctUtils.sample(post, this.rand);
            if (!isFirstRound) {
                currentListOfSamples.add(current);
            }
            this.addBackBoneStats(this.convert(current), currentBackSS);
            this.addObsStats(Pair.makePair(datum.getFirst(), this.convert(current)), currentObsSS);
            List<Integer> truths = this.truth(datum.getSecond());
            this.evaluate(truths, post, globalAccuracyStats, hist);
            GMFct<Integer> finalPost = GMFctUtils.fromSamples(this.getGraphicalModel(datum.getFirst(), this.backboneMLE, this.observationMLE), currentListOfSamples);
            this.evaluate(this.truth(datum.getSecond()), finalPost, globalAccuracyStats2, hist2);
            ++k;
        }
        LogInfo.end_track();
        if (!isFirstRound) {
            LogInfo.logsForce("Mean accuracy for posterior:" + globalAccuracyStats2.getMean());
        }
        if (isFirstRound) {
            LogInfo.logsForce("Mean accuracy for MLE:" + globalAccuracyStats.getMean());
        }
        return Pair.makePair(currentBackSS, currentObsSS);
    }

    @Override
    public void run() {
        this.loadData();
        Random rand = new Random(1L);
        CounterMap<String, String> currentBackBoneParam = this.backboneMLE;
        CounterMap<String, String> currentObservationParam = this.observationMLE;
        ArrayList samples = CollUtils.list();
        int gibbsIter = 0;
        while (gibbsIter < this.nGibbsIters) {
            LogInfo.track("Gibbs iteration " + gibbsIter);
            Pair<CounterMap<String, String>, CounterMap<String, String>> currentSS = this.doInference(currentBackBoneParam, currentObservationParam);
            CounterMap<String, String> currentBackSS = new CounterMap<String, String>();
            CounterMap<String, String> currentObsSS = new CounterMap<String, String>();
            currentBackSS.incrementAll(this.backboneSuffStats);
            currentObsSS.incrementAll(this.observationSuffStats);
            currentBackSS.incrementAll(currentSS.getFirst());
            currentObsSS.incrementAll(currentSS.getSecond());
            currentBackBoneParam = this.resampleParam(currentBackSS, this.hiddenIdx, this.hiddenIdx);
            currentObservationParam = this.resampleParam(currentObsSS, this.hiddenIdx, this.obsIdx);
            LogInfo.end_track();
            ++this.nGibbsIters;
        }
    }

    public void evaluate(List<Integer> truths, GMFct<Integer> posterior, SummaryStatistics globalAccuracyStats, DynamicBinHistogram hist) {
        LogInfo.track((Object)"Evaluating data point", true);
        LogInfo.logs("Pred : " + StrUtils.join(this.predictionMAP(posterior), "") + "\n" + "Truth: " + StrUtils.join(truths, ""));
        SummaryStatistics accuracy = new SummaryStatistics();
        List<Double> posteriors = this.prediction(posterior);
        for (int i = 0; i < truths.size(); ++i) {
            double curPost = posteriors.get(i);
            int predIsEB = curPost > this.threshold ? 1 : 0;
            boolean correct = predIsEB == truths.get(i);
            accuracy.addValue(correct ? 1.0 : 0.0);
            globalAccuracyStats.addValue(correct ? 1.0 : 0.0);
            hist.add(curPost, truths.get(i).intValue());
        }
        LogInfo.logs("Accuracy: " + accuracy.getMean());
        LogInfo.end_track();
    }

    private CounterMap<String, String> resampleParam(CounterMap<String, String> suffStats, Indexer<String> indexer1, Indexer<String> indexer2) {
        CounterMap<String, String> result = new CounterMap<String, String>();
        for (String s1 : indexer1.objects()) {
            double[] alpha = new double[indexer2.size()];
            for (int i = 0; i < alpha.length; ++i) {
                String s2 = indexer2.i2o(i);
                double currentAlpha = this.alpha0;
                alpha[i] = currentAlpha + suffStats.getCount(s1, s2);
            }
            double[] sample = Dirichlet.sample(this.rand, alpha);
            for (int i = 0; i < alpha.length; ++i) {
                result.setCount(s1, indexer2.i2o(i), sample[i]);
            }
        }
        return result;
    }

    private List<String> convert(Map<Integer, Integer> current) {
        ArrayList<String> result = CollUtils.list();
        for (int i = 0; i < current.size(); ++i) {
            result.add(this.hiddenIdx.i2o(current.get(i)));
        }
        return result;
    }

    private void addObsStats(Pair<List<String>, List<String>> datum, CounterMap<String, String> observationStats) {
        for (int i = 0; i < datum.getFirst().size(); ++i) {
            observationStats.incrementCount(datum.getSecond().get(i), datum.getFirst().get(i), 1.0);
        }
    }

    private void addBackBoneStats(List<String> seq, CounterMap<String, String> backboneStats) {
        backboneStats.incrementCount(BEG, seq.get(0), 1.0);
        for (int i = 1; i < seq.size(); ++i) {
            backboneStats.incrementCount(seq.get(i - 1), seq.get(i), 1.0);
        }
        backboneStats.incrementCount(seq.get(seq.size() - 1), END, 1.0);
    }
}

