/*
 * Decompiled with CFR 0.152.
 */
package nuts.maxent;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.io.Serializable;
import java.util.Random;
import java.util.Set;
import java.util.SortedSet;
import nuts.math.HMCMC;
import nuts.maxent.BaseMeasures;
import nuts.maxent.DifferentiableFunction;
import nuts.maxent.FeatureExtractor;
import nuts.maxent.FeatureVectors;
import nuts.maxent.FeatureVectorsInterface;
import nuts.maxent.HashedFeatureVectors;
import nuts.maxent.LBFGSMinimizer;
import nuts.maxent.LabeledInstance;
import nuts.maxent.SloppyMath;
import nuts.maxent.SparseVector;
import nuts.util.Counter;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;

public final class MaxentClassifier<I, L, F>
implements Serializable {
    private static final long serialVersionUID = 1L;
    private final FeatureVectorsInterface<I, L> featureVectors;
    private final BaseMeasures<I, L> baseMeasures;
    private double[] weights;
    private final double[] regularizationCenters;

    public static void main(String[] args) {
    }

    private MaxentClassifier(BaseMeasures<I, L> baseMeasures, Set<LabeledInstance<I, L>> data, FeatureExtractor<LabeledInstance<I, L>, F> extractor, Counter<F> regularizerCenters, boolean useHash) {
        this.baseMeasures = baseMeasures;
        this.featureVectors = useHash ? HashedFeatureVectors.createFeatureVectors(extractor) : FeatureVectors.createFeatureVectors(baseMeasures, data, extractor);
        this.regularizationCenters = this.featureVectors.createInitialWeight(regularizerCenters);
    }

    private MaxentClassifier(BaseMeasures<I, L> baseMeasures, Counter<F> weights, FeatureExtractor<LabeledInstance<I, L>, F> extractor, boolean useHash) {
        this.baseMeasures = baseMeasures;
        this.featureVectors = useHash ? HashedFeatureVectors.createFeatureVectors(extractor) : FeatureVectors.createFeatureVectorsFromSet(weights.keySet(), extractor);
        this.weights = this.featureVectors.createInitialWeight(weights);
        this.regularizationCenters = this.featureVectors.createInitialWeight();
    }

    public static <I, L, F> MaxentClassifier<I, L, F> learnMaxentClassifier(BaseMeasures<I, L> baseMeasures, Counter<LabeledInstance<I, L>> trainingData, FeatureExtractor<LabeledInstance<I, L>, F> extractor, MaxentOptions<F> learningOptions) {
        return MaxentClassifier.learnMaxentClassifier(baseMeasures, trainingData, extractor, learningOptions, new Counter());
    }

    public static <I, L, F> MaxentClassifier<I, L, F> learnMaxentClassifier(BaseMeasures<I, L> baseMeasures, Counter<LabeledInstance<I, L>> trainingData, FeatureExtractor<LabeledInstance<I, L>, F> extractor, MaxentOptions<F> learningOptions, Counter<F> regularizerCenters) {
        MaxentClassifier<I, L, F> result = new MaxentClassifier<I, L, F>(baseMeasures, trainingData.keySet(), extractor, regularizerCenters, learningOptions.useHash);
        super.learn(trainingData, learningOptions);
        return result;
    }

    public static <I, L, F> MaxentClassifier<I, L, F> learnMaxentClassifier(BaseMeasures<I, L> baseMeasures, Counter<LabeledInstance<I, L>> trainingData, FeatureExtractor<LabeledInstance<I, L>, F> extractor) {
        return MaxentClassifier.learnMaxentClassifier(baseMeasures, trainingData, extractor, new MaxentOptions(), new Counter());
    }

    public static <I, L, F> MaxentClassifier<I, L, F> createMaxentClassifierFromHashWeights(BaseMeasures<I, L> baseMeasures, Counter<F> weights, FeatureExtractor<LabeledInstance<I, L>, F> extractor) {
        return new MaxentClassifier<I, L, F>(baseMeasures, weights, extractor, true);
    }

    public static <I, L, F> MaxentClassifier<I, L, F> createMaxentClassifierFromWeights(BaseMeasures<I, L> baseMeasures, Counter<F> weights, FeatureExtractor<LabeledInstance<I, L>, F> extractor) {
        return new MaxentClassifier<I, L, F>(baseMeasures, weights, extractor, false);
    }

    public double[] logProb(I input) {
        return this.logProb(input, this.weights, false);
    }

    public Counter<L> probabilitiesCounter(I input) {
        Counter result = new Counter();
        double[] prs = this.logProb(input);
        NumUtils.expNormalize(prs);
        int i = 0;
        for (Object label : this.getLabels(input)) {
            result.setCount(label, prs[i++]);
        }
        return result;
    }

    public double[] logProb(I input, boolean cache) {
        return this.logProb(input, this.weights, cache);
    }

    public double localLogNormalization(I input) {
        return this.localLogNormalization(input, false);
    }

    public double localLogNormalization(I input, boolean cache) {
        return SloppyMath.logAdd(this.unNormLogProb(input, this.weights, cache));
    }

    public SortedSet<L> getLabels(I input) {
        return this.baseMeasures.support(input);
    }

    public String toString() {
        Counter namedCounter = this.featureVectors.namedCounter(this.weights);
        StringBuilder builder = new StringBuilder();
        for (Object key : namedCounter) {
            builder.append(key.toString() + "\t" + namedCounter.getCount(key) + "\n");
        }
        return builder.toString();
    }

    public int numberOfActiveFeatures() {
        return this.featureVectors.dim();
    }

    public Counter weights() {
        return this.featureVectors.namedCounter(this.weights);
    }

    public double[] rawWeights() {
        return this.weights;
    }

    private double[] logProb(I instance, double[] w, boolean cache) {
        double[] result = this.unNormLogProb(instance, w, cache);
        double norm = SloppyMath.logAdd(result);
        for (int l = 0; l < result.length; ++l) {
            result[l] = result[l] - norm;
        }
        return result;
    }

    private double[] unNormLogProb(I instance, double[] w, boolean cache) {
        SortedSet<L> labels = this.baseMeasures.support(instance);
        double[] result = new double[labels.size()];
        int l = 0;
        for (Object label : labels) {
            SparseVector featureVector = this.featureVectors.getFeatureVector(new LabeledInstance(label, instance), cache);
            result[l] = featureVector.dotProduct(w);
            ++l;
        }
        return result;
    }

    private void learn(Counter<LabeledInstance<I, L>> suffStats, MaxentOptions<F> options) {
        double[] init = this.featureVectors.createInitialWeight(options.initialWeights);
        ObjectiveFunction objective = this.objectiveFunction(suffStats, options.sigma, options.useL1);
        LogInfo.logs("featureVectors.dim()=" + this.featureVectors.dim());
        if (options.learningAlgo == LearningAlgorithm.OW) {
            throw new RuntimeException();
        }
        if (options.learningAlgo == LearningAlgorithm.LBFGS) {
            LogInfo.logs("Optimization using LBFGS");
            if (options.useL1) {
                throw new RuntimeException();
            }
            LBFGSMinimizer minimizer = new LBFGSMinimizer(options.iterations);
            this.weights = minimizer.minimize(objective, init, options.tolerance);
        } else if (options.learningAlgo == LearningAlgorithm.MH) {
            double[] prev = init;
            double[] cur = new double[init.length];
            double prevLogPr = -objective.valueAt(prev);
            LogInfo.track((Object)"MH sampling", true);
            LogInfo.logs("Initial log pr:" + prevLogPr + ", norm:" + NumUtils.l2Norm(prev));
            SummaryStatistics acceptRatioStats = new SummaryStatistics();
            long initTime = System.currentTimeMillis();
            for (int iter = 0; iter < options.iterations; ++iter) {
                for (int i = 0; i < cur.length; ++i) {
                    cur[i] = prev[i] + options.proposalStdDev * SampleUtils.sampleGaussian(options.rand);
                }
                double curLogPr = -objective.valueAt(cur);
                double ratio = Math.min(1.0, Math.exp(curLogPr - prevLogPr));
                acceptRatioStats.addValue(ratio);
                if (options.rand.nextDouble() < ratio) {
                    for (int i = 0; i < cur.length; ++i) {
                        prev[i] = cur[i];
                    }
                    prevLogPr = curLogPr;
                }
                LogInfo.logs("MH-iteration " + (iter + 1) + "/" + options.iterations + " " + (System.currentTimeMillis() - initTime) + " " + curLogPr);
            }
            LogInfo.end_track();
            LogInfo.logs("Mean acceptance ratio:" + acceptRatioStats.getMean());
            this.weights = prev;
        } else if (options.learningAlgo == LearningAlgorithm.HMC) {
            HMCMC sampler = new HMCMC(objective, init);
            sampler.setEpsilon(options.hmcEpsilon);
            sampler.setTau(options.hmcTau);
            sampler.setProposalStdDev(options.proposalStdDev);
            long initTime = System.currentTimeMillis();
            LogInfo.track((Object)"HMC sampling", true);
            for (int iter = 0; iter < options.iterations; ++iter) {
                sampler.next(options.rand);
                LogInfo.logs("HMC-iteration " + (iter + 1) + "/" + options.iterations + " " + (System.currentTimeMillis() - initTime) + "ms " + -sampler.getValue());
            }
            LogInfo.end_track();
            this.weights = sampler.getPosition();
            LogInfo.logs("Mean acceptance ratio:" + sampler.acceptRatio());
        } else {
            throw new RuntimeException();
        }
    }

    public ObjectiveFunction objectiveFunction(Counter<LabeledInstance<I, L>> suffStats, double sigma, boolean useL1) {
        return new ObjectiveFunction(suffStats, sigma, useL1);
    }

    private static <I, L, F> SparseVector computeActualExpectedFeatureVector(FeatureVectorsInterface<I, L> featureVectors, Counter<LabeledInstance<I, L>> trainingCounts) {
        double[] result = new double[featureVectors.dim()];
        for (LabeledInstance<I, L> labeledInstance : trainingCounts.keySet()) {
            double currentCount = trainingCounts.getCount(labeledInstance);
            featureVectors.getFeatureVector(labeledInstance, true).linearIncrement(currentCount, result);
        }
        return new SparseVector(result);
    }

    private static <I, L> Counter<I> computeMarginalExpectedCounts(Counter<LabeledInstance<I, L>> trainingCount) {
        Counter<I> result = new Counter<I>();
        for (LabeledInstance<I, L> labeledInstance : trainingCount) {
            double currentCount = trainingCount.getCount(labeledInstance);
            result.incrementCount(labeledInstance.getInput(), currentCount);
        }
        return result;
    }

    private static <I, L> Counter<LabeledInstance<I, L>> restrictToValidTrainingData(Counter<LabeledInstance<I, L>> data, BaseMeasures<I, L> baseMeasures) {
        Counter<LabeledInstance<I, L>> result = new Counter<LabeledInstance<I, L>>();
        for (LabeledInstance<I, L> trainingDatum : data.keySet()) {
            L label = trainingDatum.getLabel();
            I input = trainingDatum.getInput();
            if (!baseMeasures.support(input).contains(label)) continue;
            double count = data.getCount(trainingDatum);
            result.setCount(trainingDatum, count);
        }
        return result;
    }

    public final class ObjectiveFunction
    implements DifferentiableFunction,
    HMCMC.Energy {
        private final Counter<LabeledInstance<I, L>> trainingCounts;
        private final SparseVector actualExpectedFeatureVector;
        private final Counter<I> marginalExpectedCounts;
        private final boolean useL1;
        private double sigma;
        double lastValue;
        double[] lastDerivative;
        double[] lastUnregDeriv;
        double[] lastX = null;

        public ObjectiveFunction(Counter<LabeledInstance<I, L>> expectedCounts, double sigma, boolean useL1) {
            this.useL1 = useL1;
            this.sigma = sigma;
            this.trainingCounts = expectedCounts = MaxentClassifier.restrictToValidTrainingData(expectedCounts, MaxentClassifier.this.baseMeasures);
            this.marginalExpectedCounts = MaxentClassifier.computeMarginalExpectedCounts(this.trainingCounts);
            this.actualExpectedFeatureVector = MaxentClassifier.computeActualExpectedFeatureVector(MaxentClassifier.this.featureVectors, this.trainingCounts);
        }

        public void setSigma(double s) {
            if (s <= 0.0) {
                throw new RuntimeException();
            }
            this.lastX = null;
            this.sigma = s;
        }

        @Override
        public int dimension() {
            return MaxentClassifier.this.featureVectors.dim();
        }

        @Override
        public double valueAt(double[] x) {
            this.ensureCache(x);
            return this.lastValue;
        }

        @Override
        public double[] derivativeAt(double[] x) {
            this.ensureCache(x);
            return this.lastDerivative;
        }

        public double[] unregularizedDerivativeAt(double[] x) {
            this.ensureCache(x);
            return this.lastUnregDeriv;
        }

        private void ensureCache(double[] x) {
            if (this.requiresUpdate(this.lastX, x)) {
                Pair<Double, double[]> currentValueAndDerivative = this.calculate(x);
                this.lastValue = currentValueAndDerivative.getFirst();
                this.lastDerivative = currentValueAndDerivative.getSecond();
                this.lastUnregDeriv = (double[])this.lastDerivative.clone();
                this.lastValue += this.regularize(x, this.lastDerivative);
                if (this.lastX == null) {
                    this.lastX = new double[x.length];
                }
                for (int i = 0; i < x.length; ++i) {
                    this.lastX[i] = x[i];
                }
            }
        }

        private double regularize(double[] x, double[] gradient) {
            double result = 0.0;
            for (int j = 0; j < x.length; ++j) {
                double diff = x[j] - MaxentClassifier.this.regularizationCenters[j];
                double sigma = this.sigma * MaxentClassifier.this.featureVectors.getRegularizationFactor(j);
                if (this.useL1) {
                    if (diff == 0.0) {
                        if (gradient[j] < -0.5 / sigma) {
                            int n = j;
                            gradient[n] = gradient[n] - 0.5 / sigma;
                        } else if (gradient[j] > 0.5 / sigma) {
                            int n = j;
                            gradient[n] = gradient[n] - -0.5 / sigma;
                        } else {
                            gradient[j] = 0.0;
                        }
                    } else {
                        int n = j;
                        gradient[n] = gradient[n] + (double)((diff < 0.0 ? -1 : 1) / 2) / sigma;
                    }
                    result += Math.abs(diff) / 2.0 / sigma;
                    continue;
                }
                int n = j;
                gradient[n] = gradient[n] + diff / sigma;
                result += diff * diff / 2.0 / sigma;
            }
            return result;
        }

        private boolean requiresUpdate(double[] lastX, double[] x) {
            if (lastX == null) {
                return true;
            }
            for (int i = 0; i < x.length; ++i) {
                if (lastX[i] == x[i]) continue;
                return true;
            }
            return false;
        }

        private Pair<Double, double[]> calculate(double[] x) {
            double fValue = 0.0;
            double[] gradient = new double[x.length];
            this.actualExpectedFeatureVector.linearIncrement(-1.0, gradient);
            for (Object input : this.marginalExpectedCounts.keySet()) {
                SortedSet labels = MaxentClassifier.this.baseMeasures.support(input);
                double[] logProb = MaxentClassifier.this.logProb(input, x, true);
                double currentMarginalCount = this.marginalExpectedCounts.getCount(input);
                int l = 0;
                for (Object label : labels) {
                    LabeledInstance key = new LabeledInstance(label, input);
                    double currentCount = this.trainingCounts.getCount(key);
                    fValue -= currentCount * logProb[l];
                    SparseVector currentVector = MaxentClassifier.this.featureVectors.getFeatureVector(key, true);
                    double coef = Math.exp(logProb[l]) * currentMarginalCount;
                    currentVector.linearIncrement(coef, gradient);
                    ++l;
                }
            }
            return new Pair<Double, double[]>(fValue, gradient);
        }

        public int gradientDim() {
            return MaxentClassifier.this.featureVectors.dim();
        }

        @Override
        public int dim() {
            return this.gradientDim();
        }

        @Override
        public double[] gradientAt(double[] x) {
            return this.derivativeAt(x);
        }
    }

    public static class MaxentOptions<F>
    implements Cloneable {
        @Option
        public boolean useHash = false;
        @Option
        public double regResamplingPriorShape = 2.0;
        @Option
        public double regResamplingPriorRate = 2.0;
        @Option
        public int regResamplingFreq = 50;
        @Option
        public int hmcTau = 100;
        @Option
        public double hmcEpsilon = 0.005;
        @Option
        public Random rand = new Random(1L);
        @Option
        public LearningAlgorithm learningAlgo = LearningAlgorithm.LBFGS;
        @Option
        public double proposalStdDev = 0.05;
        @Option
        public boolean useL1 = false;
        @Option(gloss="Default regularization factor")
        public double sigma = 1.0;
        @Option(gloss="Max number of learning iterations")
        public int iterations = 100;
        @Option(gloss="Under this delta, stop LBFGS optimization")
        public double tolerance = 1.0E-8;
        public Counter<F> initialWeights = new Counter();

        public static <F> MaxentOptions<F> cloneWithWeights(MaxentOptions<F> model, Counter<F> initWeights) {
            try {
                MaxentOptions result = (MaxentOptions)model.clone();
                result.initialWeights = initWeights;
                return result;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public Object clone() throws CloneNotSupportedException {
            return super.clone();
        }
    }

    public static enum LearningAlgorithm {
        OW(false),
        LBFGS(false),
        MH(true),
        HMC(true);

        public final boolean isSampling;

        private LearningAlgorithm(boolean isSampling) {
            this.isSampling = isSampling;
        }
    }
}

