/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import dr.math.MathUtils;
import dr.math.distributions.MultivariateDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;

public class MultivariateNormalDistribution
implements MultivariateDistribution {
    public static final String TYPE = "MultivariateNormal";
    private final double[] mean;
    private final double[][] precision;
    private double[][] variance = null;
    private double[][] cholesky = null;
    private Double logDet = null;
    public static final double logNormalize = -0.5 * Math.log(Math.PI * 2);

    public MultivariateNormalDistribution(double[] mean, double[][] precision) {
        this.mean = mean;
        this.precision = precision;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    public double[][] getVariance() {
        if (this.variance == null) {
            this.variance = new SymmetricMatrix(this.precision).inverse().toComponents();
        }
        return this.variance;
    }

    public double[][] getCholeskyDecomposition() {
        if (this.cholesky == null) {
            this.cholesky = MultivariateNormalDistribution.getCholeskyDecomposition(this.getVariance());
        }
        return this.cholesky;
    }

    public double getLogDet() {
        if (this.logDet == null) {
            this.logDet = Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(this.precision));
        }
        return this.logDet;
    }

    @Override
    public double[][] getScaleMatrix() {
        return this.precision;
    }

    @Override
    public double[] getMean() {
        return this.mean;
    }

    public double[] nextMultivariateNormal() {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(this.mean, this.getCholeskyDecomposition(), 1.0);
    }

    public double[] nextMultivariateNormal(double[] x) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(x, this.getCholeskyDecomposition(), 1.0);
    }

    public double[] nextScaledMultivariateNormal(double[] mean, double scale) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, this.getCholeskyDecomposition(), Math.sqrt(scale));
    }

    public void nextScaledMultivariateNormal(double[] mean, double scale, double[] result) {
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, this.getCholeskyDecomposition(), Math.sqrt(scale), result);
    }

    public static double calculatePrecisionMatrixDeterminate(double[][] precision) {
        try {
            return new Matrix(precision).determinant();
        }
        catch (IllegalDimension e) {
            throw new RuntimeException(e.getMessage());
        }
    }

    @Override
    public double logPdf(double[] x) {
        return MultivariateNormalDistribution.logPdf(x, this.mean, this.precision, this.getLogDet(), 1.0);
    }

    public static double logPdf(double[] x, double[] mean, double[][] precision, double logDet, double scale) {
        int i;
        if (logDet == Double.NEGATIVE_INFINITY) {
            return logDet;
        }
        int dim = x.length;
        double[] delta = new double[dim];
        double[] tmp = new double[dim];
        for (i = 0; i < dim; ++i) {
            delta[i] = x[i] - mean[i];
        }
        for (i = 0; i < dim; ++i) {
            for (int j = 0; j < dim; ++j) {
                int n = i;
                tmp[n] = tmp[n] + delta[j] * precision[j][i];
            }
        }
        double SSE = 0.0;
        for (int i2 = 0; i2 < dim; ++i2) {
            SSE += tmp[i2] * delta[i2];
        }
        return (double)dim * logNormalize + 0.5 * (logDet - (double)dim * Math.log(scale) - SSE / scale);
    }

    public static double logPdf(double[] x, double[] mean, double precision, double scale) {
        int dim = x.length;
        double SSE = 0.0;
        for (int i = 0; i < dim; ++i) {
            double delta = x[i] - mean[i];
            SSE += delta * delta;
        }
        return (double)dim * logNormalize + 0.5 * ((double)dim * (Math.log(precision) - Math.log(scale)) - SSE * precision / scale);
    }

    private static double[][] getInverse(double[][] x) {
        return new SymmetricMatrix(x).inverse().toComponents();
    }

    private static double[][] getCholeskyDecomposition(double[][] variance) {
        double[][] cholesky;
        try {
            cholesky = new CholeskyDecomposition(variance).getL();
        }
        catch (IllegalDimension illegalDimension) {
            throw new RuntimeException("Attempted Cholesky decomposition on non-square matrix");
        }
        return cholesky;
    }

    public static double[] nextMultivariateNormalPrecision(double[] mean, double[][] precision) {
        return MultivariateNormalDistribution.nextMultivariateNormalVariance(mean, MultivariateNormalDistribution.getInverse(precision));
    }

    public static double[] nextMultivariateNormalVariance(double[] mean, double[][] variance) {
        return MultivariateNormalDistribution.nextMultivariateNormalVariance(mean, variance, 1.0);
    }

    public static double[] nextMultivariateNormalVariance(double[] mean, double[][] variance, double scale) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, MultivariateNormalDistribution.getCholeskyDecomposition(variance), Math.sqrt(scale));
    }

    public static double[] nextMultivariateNormalCholesky(double[] mean, double[][] cholesky) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, cholesky, 1.0);
    }

    public static double[] nextMultivariateNormalCholesky(double[] mean, double[][] cholesky, double sqrtScale) {
        double[] result = new double[mean.length];
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, cholesky, sqrtScale, result);
        return result;
    }

    public static void nextMultivariateNormalCholesky(double[] mean, double[][] cholesky, double sqrtScale, double[] result) {
        int i;
        int dim = mean.length;
        System.arraycopy(mean, 0, result, 0, dim);
        double[] epsilon = new double[dim];
        for (i = 0; i < dim; ++i) {
            epsilon[i] = MathUtils.nextGaussian() * sqrtScale;
        }
        for (i = 0; i < dim; ++i) {
            for (int j = 0; j <= i; ++j) {
                int n = i;
                result[n] = result[n] + cholesky[i][j] * epsilon[j];
            }
        }
    }

    public static void main(String[] args) {
        MultivariateNormalDistribution.testPdf();
        MultivariateNormalDistribution.testRandomDraws();
    }

    public static void testPdf() {
        double[] start = new double[]{1.0, 2.0};
        double[] stop = new double[]{0.0, 0.0};
        double[][] precision = new double[][]{{2.0, 0.5}, {0.5, 1.0}};
        double scale = 0.2;
        System.err.println("logPDF = " + MultivariateNormalDistribution.logPdf(start, stop, precision, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(precision)), scale));
        System.err.println("Should = -19.94863\n");
        System.err.println("logPDF = " + MultivariateNormalDistribution.logPdf(start, stop, 2.0, 0.2));
        System.err.println("Should = -24.53529\n");
    }

    public static void testRandomDraws() {
        double[] start = new double[]{1.0, 2.0};
        double[][] precision = new double[][]{{2.0, 0.5}, {0.5, 1.0}};
        int length = 100000;
        System.err.println("Random draws (via precision) ...");
        double[] mean = new double[2];
        double[] SS = new double[2];
        double[] var = new double[2];
        double ZZ = 0.0;
        for (int i = 0; i < length; ++i) {
            double[] draw = MultivariateNormalDistribution.nextMultivariateNormalPrecision(start, precision);
            for (int j = 0; j < 2; ++j) {
                int n = j;
                mean[n] = mean[n] + draw[j];
                int n2 = j;
                SS[n2] = SS[n2] + draw[j] * draw[j];
            }
            ZZ += draw[0] * draw[1];
        }
        for (int j = 0; j < 2; ++j) {
            int n = j;
            mean[n] = mean[n] / (double)length;
            int n3 = j;
            SS[n3] = SS[n3] / (double)length;
            var[j] = SS[j] - mean[j] * mean[j];
        }
        ZZ /= (double)length;
        System.err.println("Mean: " + new Vector(mean));
        System.err.println("TRUE: [ 1 2 ]\n");
        System.err.println("MVar: " + new Vector(var));
        System.err.println("TRUE: [ 0.571 1.14 ]\n");
        System.err.println("Covv: " + (ZZ -= mean[0] * mean[1]));
        System.err.println("TRUE: -0.286");
    }
}

