/*
 * Decompiled with CFR 0.152.
 */
package nuts.math;

import Jama.Matrix;
import java.util.Random;
import nuts.math.MtxUtils;
import nuts.util.MathUtils;

public final class HMCMC {
    private final Energy e;
    private double[] x;
    private double[] g;
    private double eValue;
    private int tau = 4;
    private double epsilon = 1.0;
    private double proposalStdDev = 1.0;
    private double accepted = 0.0;
    private double n = 0.0;

    public double[] getPosition() {
        return this.x;
    }

    public double getValue() {
        return this.eValue;
    }

    public HMCMC(Energy energyFct, double[] init) {
        this.e = energyFct;
        this.init(init);
    }

    public double acceptRatio() {
        return this.accepted / this.n;
    }

    private void init(double[] x) {
        this.x = x;
        this.g = this.e.gradientAt(x);
        this.eValue = this.e.valueAt(x);
    }

    public void next(Random rand) {
        double[] p = this.sampleIsotropicNormal(rand, this.e.dim());
        double H = MathUtils.dot(p, p) / 2.0 + this.eValue;
        double[] xNew = (double[])this.x.clone();
        double[] gNew = (double[])this.g.clone();
        for (int i = 0; i < this.tau; ++i) {
            MtxUtils.plusEqual(p, gNew, -this.epsilon / 2.0);
            MtxUtils.plusEqual(xNew, p, this.epsilon);
            gNew = this.e.gradientAt(xNew);
            MtxUtils.plusEqual(p, gNew, -this.epsilon / 2.0);
        }
        double Enew = this.e.valueAt(xNew);
        double Hnew = MathUtils.dot(p, p) / 2.0 + Enew;
        double dH = Hnew - H;
        boolean accept = true;
        accept = dH < 0.0 ? true : rand.nextDouble() < Math.exp(-dH);
        this.n += 1.0;
        if (accept) {
            this.accepted += 1.0;
            this.g = gNew;
            this.x = xNew;
            this.eValue = Enew;
        }
    }

    private double[] sampleIsotropicNormal(Random rand, int dim) {
        double[] result = new double[dim];
        for (int i = 0; i < dim; ++i) {
            result[i] = rand.nextGaussian() * this.proposalStdDev;
        }
        return result;
    }

    public void nextMH(Random rand, double pVariance) {
        this.g = null;
        Matrix delta = MathUtils.rVector(this.sampleIsotropicNormal(rand, this.e.dim())).times(Math.sqrt(pVariance));
        Matrix xNew = MathUtils.rVector(this.x).plus(delta);
        double Enew = this.e.valueAt(xNew.getRowPackedCopy());
        double dE = Enew - this.eValue;
        boolean accept = true;
        accept = dE < 0.0 ? true : rand.nextDouble() < Math.exp(-dE);
        this.n += 1.0;
        if (accept) {
            this.accepted += 1.0;
            this.x = xNew.getRowPackedCopy();
            this.eValue = Enew;
        }
    }

    public static void main(String[] args) {
        double var = 2.0;
        double mean = 10.0;
        Energy ef = new Energy(){

            @Override
            public int dim() {
                return 1;
            }

            @Override
            public double[] gradientAt(double[] x) {
                double _x = x[0];
                double[] r = new double[]{(_x - 10.0) / 2.0};
                return r;
            }

            @Override
            public double valueAt(double[] x) {
                if (x.length != 1) {
                    throw new RuntimeException();
                }
                double _x = x[0];
                return (_x - 10.0) * (_x - 10.0) / 2.0 / 2.0;
            }
        };
        HMCMC sampler = new HMCMC(ef, new double[1]);
        double sum = 0.0;
        double n = 0.0;
        Random rand = new Random(1L);
        for (int i = 0; i < 1000000; ++i) {
            sampler.next(rand);
            if (i <= 100) continue;
            sum += sampler.x[0];
            n += 1.0;
        }
        System.out.println(sum / n);
        System.out.println(sampler.accepted / sampler.n);
    }

    public int getTau() {
        return this.tau;
    }

    public void setTau(int tau) {
        this.tau = tau;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    public double getProposalStdDev() {
        return this.proposalStdDev;
    }

    public void setProposalStdDev(double proposalStdDev) {
        this.proposalStdDev = proposalStdDev;
    }

    public static interface Energy {
        public int dim();

        public double valueAt(double[] var1);

        public double[] gradientAt(double[] var1);
    }
}

