/*
 * Decompiled with CFR 0.152.
 */
package ev.ex;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import goblin.CognateId;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.SortedSet;
import java.util.TreeSet;
import ma.BalibaseCorpus;
import nuts.io.IO;
import nuts.maxent.BaseMeasures;
import nuts.maxent.FeatureExtractor;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.CounterMap;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pepper.Encodings;

public class Shannon
implements Runnable {
    public static BalibaseCorpus.BalibaseCorpusOptions options = new BalibaseCorpus.BalibaseCorpusOptions();
    @Option
    public Encodings.EncodingType encodingType = Encodings.EncodingType.PROTEIN;
    @Option
    public double heldoutProp = 0.2;
    @Option
    public String fastaTrain = null;
    @Option
    public String fastaTest = null;
    @Option
    public boolean useBali = true;
    @Option
    public int maxTrain = Integer.MAX_VALUE;
    @Option
    public int maxTest = Integer.MAX_VALUE;
    public static Character pad = Character.valueOf('#');
    private SortedSet<Character> bm = null;

    public static void main(String[] args) {
        IO.run(args, new Shannon(), "bali", options);
    }

    private static <T> Pair<List<T>, List<T>> split(List<T> list, double heldoutProp) {
        ArrayList<T> train = new ArrayList<T>();
        ArrayList<T> test = new ArrayList<T>();
        for (int i = 0; i < list.size(); ++i) {
            if ((double)i < heldoutProp * (double)list.size()) {
                test.add(list.get(i));
                continue;
            }
            train.add(list.get(i));
        }
        return Pair.makePair(train, test);
    }

    public static List<String> readFasta(File f) {
        ArrayList<String> result = CollUtils.list();
        StringBuilder cString = new StringBuilder();
        for (String line : IO.i(f)) {
            if (line.length() > 0 && (line.charAt(0) == ';' || line.charAt(0) == '>')) {
                if (cString.length() > 0) {
                    result.add(cString.toString());
                }
                cString = new StringBuilder();
                continue;
            }
            cString.append(line);
        }
        if (cString.length() > 0) {
            result.add(cString.toString());
        }
        return result;
    }

    @Override
    public void run() {
        int n;
        this.initBM();
        Data data = this.useBali ? new DataFromBalibase() : new DataFromFasta();
        LogInfo.logsForce("Train:" + data.trainData.size());
        LogInfo.logsForce("Test:" + data.testData.size());
        ArrayList<Predictor> preds = new ArrayList<Predictor>();
        preds.add(new RandomPredictor());
        for (n = 1; n <= 10; ++n) {
            preds.add(new NGramPredictor(n));
        }
        for (n = 1; n <= 10; ++n) {
            for (double var = 0.01; var <= 0.01; var *= 10.0) {
                preds.add(new MaxentPredictor(var, n));
            }
        }
        for (Predictor predictor : preds) {
            LogInfo.track(predictor, true);
            for (Pair<PunctSeq, Character> pair : data.trainData) {
                predictor.addLearn(pair.getFirst(), pair.getSecond().charValue());
            }
            predictor.finishLearn();
            SummaryStatistics heldoutAcc = new SummaryStatistics();
            double perp = 0.0;
            for (Pair<PunctSeq, Character> pair : data.testData) {
                Counter<Character> prediction = predictor.predict(pair.getFirst());
                Character pred = prediction.argMax();
                perp += Math.log(prediction.getCount(pair.getSecond()));
                if (pair.getSecond().equals(pred)) {
                    heldoutAcc.addValue(1.0);
                    continue;
                }
                heldoutAcc.addValue(0.0);
            }
            LogInfo.logsForce(heldoutAcc);
            LogInfo.logsForce("Perplexity:" + perp);
            LogInfo.end_track();
        }
    }

    private void initBM() {
        this.bm = new TreeSet<Character>();
        if (this.encodingType.equals((Object)Encodings.EncodingType.PROTEIN)) {
            for (Character c : Encodings.AMINO_ACIDS) {
                this.bm.add(Character.valueOf(Character.toLowerCase(c.charValue())));
            }
        } else if (this.encodingType.equals((Object)Encodings.EncodingType.RNA)) {
            for (Character c : Encodings.RNA) {
                this.bm.add(Character.valueOf(Character.toLowerCase(c.charValue())));
            }
        } else if (this.encodingType.equals((Object)Encodings.EncodingType.DNA)) {
            for (Character c : Encodings.DNA) {
                this.bm.add(Character.valueOf(Character.toLowerCase(c.charValue())));
            }
        } else {
            throw new RuntimeException();
        }
    }

    private List<Pair<PunctSeq, Character>> iterate(Collection<String> sequences, int max) {
        ArrayList<Pair<PunctSeq, Character>> result = new ArrayList<Pair<PunctSeq, Character>>();
        int i = 0;
        for (String _cur : sequences) {
            String current = _cur.toLowerCase();
            for (int pos = 0; pos < current.length(); ++pos) {
                char curChar = current.charAt(pos);
                if (!this.bm.contains(Character.valueOf(curChar))) continue;
                result.add(Pair.makePair(new PunctSeq(current, pos), Character.valueOf(curChar)));
                if (i++ <= max) continue;
                return result;
            }
        }
        return result;
    }

    public static String suffix(String left, int n) {
        while (left.length() < n) {
            left = pad + left;
        }
        return left.substring(left.length() - n, left.length());
    }

    public static class PunctSeq {
        public final String left;

        public PunctSeq(PunctSeq model, int n) {
            this.left = Shannon.suffix(model.left, n);
        }

        public PunctSeq(String fullSeq, int pos) {
            this.left = fullSeq.substring(0, pos);
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.left == null ? 0 : this.left.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            PunctSeq other = (PunctSeq)obj;
            return !(this.left == null ? other.left != null : !this.left.equals(other.left));
        }

        public String toString() {
            return this.left;
        }
    }

    public static class NGramPredictor
    implements Predictor {
        public final int n;
        private CounterMap<String, Character> stats = new CounterMap();

        public NGramPredictor(int n) {
            if (n < 1) {
                throw new RuntimeException();
            }
            this.n = n;
        }

        @Override
        public void addLearn(PunctSeq seq, char pred) {
            String suffix = Shannon.suffix(seq.left, this.n - 1);
            this.stats.incrementCount(suffix, Character.valueOf(pred), 1.0);
        }

        @Override
        public void finishLearn() {
            this.stats.normalize();
        }

        public String toString() {
            return "NGramPredictor(" + this.n + ")";
        }

        @Override
        public Counter<Character> predict(PunctSeq seq) {
            return this.stats.getCounter(Shannon.suffix(seq.left, this.n - 1));
        }
    }

    public class MaxentPredictor
    implements Predictor {
        private final int maxN;
        private final double var;
        private Counter<LabeledInstance<PunctSeq, Character>> train = new Counter();
        private ShannonExtractor ext = null;
        private ShannonBaseMeasure bm = new ShannonBaseMeasure();
        private MaxentClassifier<PunctSeq, Character, String> maxent = null;

        public MaxentPredictor(double var, int maxN) {
            this.var = var;
            this.maxN = maxN;
            this.ext = new ShannonExtractor(maxN);
        }

        @Override
        public void addLearn(PunctSeq seq, char pred) {
            this.train.incrementCount(new LabeledInstance<PunctSeq, Character>(Character.valueOf(pred), new PunctSeq(seq, this.maxN)), 1.0);
        }

        @Override
        public void finishLearn() {
            MaxentClassifier.MaxentOptions options = new MaxentClassifier.MaxentOptions();
            options.sigma = this.var;
            this.maxent = MaxentClassifier.learnMaxentClassifier(this.bm, this.train, this.ext, options);
        }

        public String toString() {
            return "Maxent(" + this.var + ", " + this.maxN + ")";
        }

        @Override
        public Counter<Character> predict(PunctSeq seq) {
            Counter<Character> result = new Counter<Character>();
            SortedSet<Character> labels = this.maxent.getLabels(seq);
            double[] prs = this.maxent.logProb(seq);
            int i = 0;
            for (Character c : labels) {
                result.incrementCount(c, Math.exp(prs[i]));
                ++i;
            }
            return result;
        }
    }

    public class RandomPredictor
    implements Predictor {
        private Random rand = new Random(1L);

        @Override
        public void addLearn(PunctSeq seq, char pred) {
        }

        @Override
        public void finishLearn() {
        }

        @Override
        public Counter<Character> predict(PunctSeq seq) {
            Counter<Character> c = new Counter<Character>();
            for (Character cc : Shannon.this.bm) {
                c.incrementCount(cc, this.rand.nextDouble());
            }
            c.normalize();
            return c;
        }
    }

    public class ShannonExtractor
    implements FeatureExtractor<LabeledInstance<PunctSeq, Character>, String> {
        private final int maxN;

        public ShannonExtractor(int maxN) {
            this.maxN = maxN;
        }

        @Override
        public Counter<String> extractFeatures(LabeledInstance<PunctSeq, Character> instance) {
            Counter<String> result = new Counter<String>();
            for (int i = 1; i <= this.maxN; ++i) {
                result.incrementCount("" + Shannon.suffix(instance.getInput().left, i).toLowerCase() + " -> " + Character.toLowerCase(instance.getLabel().charValue()), 1.0);
            }
            return result;
        }

        @Override
        public double regularizationFactor(String feature) {
            return 1.0;
        }
    }

    public class ShannonBaseMeasure
    implements BaseMeasures<PunctSeq, Character> {
        @Override
        public SortedSet<Character> support(PunctSeq input) {
            return Shannon.this.bm;
        }
    }

    public static interface Predictor {
        public void addLearn(PunctSeq var1, char var2);

        public void finishLearn();

        public Counter<Character> predict(PunctSeq var1);
    }

    private class DataFromFasta
    extends Data {
        public DataFromFasta() {
            for (File f : IO.locate(new File(Shannon.this.fastaTrain), IO.suffixFilter("gz"))) {
                for (Pair pair : Shannon.this.iterate(Shannon.readFasta(f), Shannon.this.maxTrain)) {
                    this.trainData.add(pair);
                }
            }
            for (File f : IO.locate(new File(Shannon.this.fastaTest), IO.suffixFilter("gz"))) {
                for (Pair pair : Shannon.this.iterate(Shannon.readFasta(f), Shannon.this.maxTest)) {
                    this.testData.add(pair);
                }
            }
        }
    }

    private class DataFromBalibase
    extends Data {
        public DataFromBalibase() {
            Encodings.EncodingType cfr_ignored_0 = Shannon.this.encodingType;
            if (Shannon.this.encodingType == Encodings.EncodingType.RNA) {
                Shannon.options.ignoreAnn = true;
            }
            BalibaseCorpus bc = new BalibaseCorpus(options);
            ArrayList<CognateId> allIds = new ArrayList<CognateId>(bc.getMultiAlignments().keySet());
            Collections.sort(allIds, new Comparator<CognateId>(){

                @Override
                public int compare(CognateId arg0, CognateId arg1) {
                    return arg0.toString().compareTo(arg1.toString());
                }
            });
            Pair p = Shannon.split(allIds, Shannon.this.heldoutProp);
            List train = (List)p.getFirst();
            List test = (List)p.getSecond();
            for (CognateId id : train) {
                for (Pair pair : Shannon.this.iterate(bc.getMultiAlignment(id).getSequences().values(), Shannon.this.maxTrain)) {
                    this.trainData.add(pair);
                }
            }
            for (CognateId id : test) {
                for (Pair pair : Shannon.this.iterate(bc.getMultiAlignment(id).getSequences().values(), Shannon.this.maxTest)) {
                    this.testData.add(pair);
                }
            }
        }
    }

    private static class Data {
        public List<Pair<PunctSeq, Character>> trainData = CollUtils.list();
        public List<Pair<PunctSeq, Character>> testData = CollUtils.list();

        private Data() {
        }
    }
}

