/*
 * Decompiled with CFR 0.152.
 */
package nuts.maxent;

import fig.basic.LogInfo;
import java.util.LinkedList;
import nuts.maxent.BacktrackingLineSearcher;
import nuts.maxent.DifferentiableFunction;
import nuts.maxent.DoubleArrays;
import nuts.maxent.GradientMinimizer;
import nuts.maxent.SloppyMath;

public class LBFGSMinimizer
implements GradientMinimizer {
    double EPS = 1.0E-10;
    int maxIterations = 20;
    int maxHistorySize = 5;
    LinkedList<double[]> inputDifferenceVectorList = new LinkedList();
    LinkedList<double[]> derivativeDifferenceVectorList = new LinkedList();
    private int nRequiredIterations;

    @Override
    public double[] minimize(DifferentiableFunction function, double[] initial, double tolerance) {
        int iteration;
        BacktrackingLineSearcher lineSearcher = new BacktrackingLineSearcher();
        double[] guess = DoubleArrays.clone(initial);
        LogInfo.track((Object)"LBFGS training:", true);
        for (iteration = 0; iteration < this.maxIterations; ++iteration) {
            double value = function.valueAt(guess);
            double[] derivative = function.derivativeAt(guess);
            double[] initialInverseHessianDiagonal = this.getInitialInverseHessianDiagonal(function);
            double[] direction = this.implicitMultiply(initialInverseHessianDiagonal, derivative);
            DoubleArrays.scale(direction, -1.0);
            lineSearcher.stepSizeMultiplier = iteration == 0 ? 0.01 : 0.5;
            double[] nextGuess = lineSearcher.minimize(function, guess, direction);
            double nextValue = function.valueAt(nextGuess);
            double[] nextDerivative = function.derivativeAt(nextGuess);
            LogInfo.logs("LBFGS iteration %d ended with value %.6f\n", iteration, nextValue);
            if (this.converged(value, nextValue, tolerance)) {
                LogInfo.end_track();
                return nextGuess;
            }
            this.updateHistories(guess, nextGuess, derivative, nextDerivative);
            guess = nextGuess;
            value = nextValue;
            derivative = nextDerivative;
        }
        LogInfo.end_track();
        this.nRequiredIterations = iteration;
        LogInfo.warning("LBFGSMinimizer.minimize: Exceeded maxIterations without converging.");
        return guess;
    }

    private boolean converged(double value, double nextValue, double tolerance) {
        double valueAverage;
        if (value == nextValue) {
            return true;
        }
        double valueChange = SloppyMath.abs(nextValue - value);
        return valueChange / (valueAverage = SloppyMath.abs(nextValue + value + this.EPS) / 2.0) < tolerance;
    }

    private void updateHistories(double[] guess, double[] nextGuess, double[] derivative, double[] nextDerivative) {
        double[] guessChange = DoubleArrays.addMultiples(nextGuess, 1.0, guess, -1.0);
        double[] derivativeChange = DoubleArrays.addMultiples(nextDerivative, 1.0, derivative, -1.0);
        this.pushOntoList(guessChange, this.inputDifferenceVectorList);
        this.pushOntoList(derivativeChange, this.derivativeDifferenceVectorList);
    }

    private void pushOntoList(double[] vector, LinkedList<double[]> vectorList) {
        vectorList.addFirst(vector);
        if (vectorList.size() > this.maxHistorySize) {
            vectorList.removeLast();
        }
    }

    private int historySize() {
        return this.inputDifferenceVectorList.size();
    }

    private double[] getInputDifference(int num) {
        return this.inputDifferenceVectorList.get(num);
    }

    private double[] getDerivativeDifference(int num) {
        return this.derivativeDifferenceVectorList.get(num);
    }

    private double[] getLastDerivativeDifference() {
        return this.derivativeDifferenceVectorList.getFirst();
    }

    private double[] getLastInputDifference() {
        return this.inputDifferenceVectorList.getFirst();
    }

    private double[] implicitMultiply(double[] initialInverseHessianDiagonal, double[] derivative) {
        double[] rho = new double[this.historySize()];
        double[] alpha = new double[this.historySize()];
        double[] right = DoubleArrays.clone(derivative);
        for (int i = this.historySize() - 1; i >= 0; --i) {
            double[] inputDifference = this.getInputDifference(i);
            double[] derivativeDifference = this.getDerivativeDifference(i);
            rho[i] = DoubleArrays.innerProduct(inputDifference, derivativeDifference);
            if (rho[i] == 0.0) {
                throw new RuntimeException("LBFGSMinimizer.implicitMultiply: Curvature problem.");
            }
            alpha[i] = DoubleArrays.innerProduct(inputDifference, right) / rho[i];
            right = DoubleArrays.addMultiples(right, 1.0, derivativeDifference, -1.0 * alpha[i]);
        }
        double[] left = DoubleArrays.pointwiseMultiply(initialInverseHessianDiagonal, right);
        for (int i = 0; i < this.historySize(); ++i) {
            double[] inputDifference = this.getInputDifference(i);
            double[] derivativeDifference = this.getDerivativeDifference(i);
            double beta = DoubleArrays.innerProduct(derivativeDifference, left) / rho[i];
            left = DoubleArrays.addMultiples(left, 1.0, inputDifference, alpha[i] - beta);
        }
        return left;
    }

    private double[] getInitialInverseHessianDiagonal(DifferentiableFunction function) {
        double scale = 1.0;
        if (this.derivativeDifferenceVectorList.size() >= 1) {
            double[] lastDerivativeDifference = this.getLastDerivativeDifference();
            double[] lastInputDifference = this.getLastInputDifference();
            double num = DoubleArrays.innerProduct(lastDerivativeDifference, lastInputDifference);
            double den = DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
            scale = num / den;
        }
        return DoubleArrays.constantArray(scale, function.dimension());
    }

    public LBFGSMinimizer() {
    }

    public LBFGSMinimizer(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getnRequiredIterations() {
        return this.nRequiredIterations;
    }
}

