/*
 * Decompiled with CFR 0.152.
 */
package hashing;

import fig.basic.LogInfo;
import fig.basic.Option;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.maxent.HomogeneousBaseMeasures;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.maxent.MaxentClassifierTest;
import nuts.util.Counter;
import nuts.util.MathUtils;

public class DataInput
implements Runnable {
    public static final String BEG = "BEG";
    public static final String END = "END";
    @Option(required=true)
    public String trainingFile;
    @Option(required=true)
    public String testFile;
    @Option(required=true)
    public int n;

    public static void main(String[] args) {
        IO.run(args, new DataInput());
    }

    private void guessLastWord() {
        LogInfo.logs("Begin DataInput");
        ArrayList<String> labels = new ArrayList<String>();
        LogInfo.logs("Generate training data");
        ArrayList<LabeledInstance<List<String>, String>> trainingList = new ArrayList<LabeledInstance<List<String>, String>>();
        for (String string : IO.i(this.trainingFile)) {
            String[] stringArray = string.split("\\s+");
            String label = stringArray[stringArray.length - 1];
            trainingList.add(new LabeledInstance<List<String>, String>(label, this.formInput(stringArray)));
            labels.add(label);
        }
        LogInfo.logs("Update training Counter");
        Counter training = new Counter();
        for (LabeledInstance labeledInstance : trainingList) {
            training.incrementCount(labeledInstance, 1.0);
        }
        HomogeneousBaseMeasures homogeneousBaseMeasures = new HomogeneousBaseMeasures(labels);
        MaxentClassifierTest.NGramFE nGramFE = new MaxentClassifierTest.NGramFE();
        MaxentClassifier<List<String>, String, String> maxent = MaxentClassifier.learnMaxentClassifier(homogeneousBaseMeasures, training, nGramFE);
        LogInfo.logs("Extract test data");
        ArrayList<LabeledInstance<List<String>, String>> testList = new ArrayList<LabeledInstance<List<String>, String>>();
        for (String line : IO.i(this.testFile)) {
            String[] tokens = line.split("\\s+");
            String string = tokens[tokens.length - 1];
            testList.add(new LabeledInstance<List<String>, String>(string, this.formInput(tokens)));
        }
        double sum = 0.0;
        for (LabeledInstance labeledInstance : testList) {
            System.out.println(maxent.getLabels((List<String>)labeledInstance.getInput()));
            System.out.println(Arrays.toString(MathUtils.exp(maxent.logProb((List<String>)labeledInstance.getInput()))));
            String trueLabel = (String)labeledInstance.getLabel();
            double[] probs = MathUtils.exp(maxent.logProb((List<String>)labeledInstance.getInput()));
            String[] orderedLabels = new String[probs.length];
            maxent.getLabels((List<String>)labeledInstance.getInput()).toArray(orderedLabels);
            double prob = 0.0;
            double max = 0.0;
            int k = 0;
            for (int i = 0; i < probs.length; ++i) {
                if (orderedLabels[i].equalsIgnoreCase(trueLabel)) {
                    prob = probs[i];
                }
                if (!(probs[i] > max)) continue;
                max = probs[i];
                k = i;
            }
            if (prob == 0.0) continue;
            sum += training.getCount(labeledInstance) * (Math.log(prob) / Math.log(2.0));
            LogInfo.logs(orderedLabels[k]);
        }
        LogInfo.logs("Perplexity: " + Math.pow(2.0, -sum / training.totalCount()));
    }

    private void guessAllWords() {
        LogInfo.logs("Begin DataInput");
        ArrayList<String> labels = new ArrayList<String>();
        LogInfo.logs("Parse training data");
        ArrayList<LabeledInstance<List<String>, String>> trainingList = new ArrayList<LabeledInstance<List<String>, String>>();
        Counter training = new Counter();
        for (String line : IO.i(this.trainingFile)) {
            String[] tokens = line.split("\\s+");
            for (int i = 0; i < tokens.length + 1; ++i) {
                String label = i == tokens.length ? END : tokens[i];
                LabeledInstance<List<String>, String> labeledInstance = new LabeledInstance<List<String>, String>(label.trim(), this.prevTokens(tokens, this.n, i));
                trainingList.add(labeledInstance);
                labels.add(label);
                training.incrementCount(labeledInstance, 1.0);
            }
        }
        LogInfo.logs("Number of training data: " + trainingList.size());
        LogInfo.logs("Number of training total counts: " + training.totalCount());
        HomogeneousBaseMeasures fbm = new HomogeneousBaseMeasures(labels);
        MaxentClassifierTest.NGramFE featureExtractor = new MaxentClassifierTest.NGramFE();
        MaxentClassifier<List<String>, String, String> maxent = MaxentClassifier.learnMaxentClassifier(fbm, training, featureExtractor);
        double[] weights = maxent.rawWeights();
        int sameWeightCounter = 0;
        int sameZeroWeightCounter = 0;
        int m = weights.length;
        for (LabeledInstance labeledInstance : trainingList) {
            Counter features = featureExtractor.extractFeatures(labeledInstance);
            Iterator iterator = features.iterator();
            while (iterator.hasNext()) {
                int n;
                String feature = (String)iterator.next();
                String feature1 = feature + "_h1";
                String feature2 = "h2_" + feature;
                int key1 = Math.abs(feature1.hashCode()) % m;
                if (weights[key1] != weights[n = Math.abs(feature2.hashCode()) % m]) continue;
                ++sameWeightCounter;
                if (weights[key1] != 0.0) continue;
                ++sameZeroWeightCounter;
            }
        }
        LogInfo.logs("Same weights: " + sameWeightCounter);
        LogInfo.logs("Same weights that are zero: " + sameZeroWeightCounter);
        LogInfo.logs("Extract test data");
        ArrayList<LabeledInstance<List<String>, String>> testList = new ArrayList<LabeledInstance<List<String>, String>>();
        Counter<LabeledInstance> counter = new Counter<LabeledInstance>();
        for (String line : IO.i(this.testFile)) {
            String[] tokens = line.split("\\s+");
            for (int i = 0; i < tokens.length; ++i) {
                String label = i == tokens.length ? END : tokens[i].trim();
                LabeledInstance<List<String>, String> labeledInstance = new LabeledInstance<List<String>, String>(label, this.prevTokens(tokens, this.n, i));
                testList.add(labeledInstance);
                counter.incrementCount(labeledInstance, 1.0);
            }
        }
        double sum = 0.0;
        Random rand = new Random();
        String sentences = "";
        String word = "";
        for (LabeledInstance labeledInstance : testList) {
            String trueLabel = (String)labeledInstance.getLabel();
            double[] probs = MathUtils.exp(maxent.logProb((List<String>)labeledInstance.getInput()));
            String[] orderedLabels = new String[probs.length];
            maxent.getLabels((List<String>)labeledInstance.getInput()).toArray(orderedLabels);
            double prob = 0.0;
            boolean k = false;
            for (int i = 0; i < probs.length; ++i) {
                if (!orderedLabels[i].equalsIgnoreCase(trueLabel)) continue;
                prob = probs[i];
            }
            sum += counter.getCount(labeledInstance) * (Math.log(prob) / Math.log(2.0));
            ArrayList<Double> probList = new ArrayList<Double>();
            for (int i = 0; i < probs.length; ++i) {
                Double db = new Double(probs[i]);
                probList.add(db);
            }
            int sample = Sampling.sample(rand, probList);
            word = orderedLabels[sample];
            sentences = sentences + word + " ";
            if (!word.equalsIgnoreCase(END)) continue;
            sentences = sentences + "\n";
        }
        LogInfo.logs("Perplexity: " + Math.pow(2.0, -sum / counter.totalCount()));
        LogInfo.logs("Sentence: ");
        LogInfo.logs(sentences);
    }

    private List<String> formInput(String[] tokens) {
        ArrayList<String> input = new ArrayList<String>();
        for (int i = 0; i < tokens.length - 1; ++i) {
            input.add(tokens[i]);
        }
        return input;
    }

    private List<String> prevTokens(String[] tokens, int n, int i) {
        int j;
        ArrayList<String> prevs = new ArrayList<String>();
        for (j = 0; j < n - i; ++j) {
            prevs.add(BEG);
        }
        for (j = Math.max(0, i - n); j < i; ++j) {
            prevs.add(tokens[j].trim());
        }
        if (prevs.size() != n) {
            throw new RuntimeException();
        }
        return prevs;
    }

    @Override
    public void run() {
        this.guessAllWords();
    }
}

