/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import dr.math.GammaFunction;
import dr.math.MathUtils;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.MultivariateDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;

public class WishartDistribution
implements MultivariateDistribution {
    public static final String TYPE = "Wishart";
    private double df;
    private int dim;
    private double[][] scaleMatrix;
    private double[] Sinv;
    private Matrix SinvMat;
    private double logNormalizationConstant;

    public WishartDistribution(double df, double[][] scaleMatrix) {
        this.df = df;
        this.scaleMatrix = scaleMatrix;
        this.dim = scaleMatrix.length;
        this.SinvMat = new Matrix(scaleMatrix).inverse();
        double[][] tmp = this.SinvMat.toComponents();
        this.Sinv = new double[this.dim * this.dim];
        for (int i = 0; i < this.dim; ++i) {
            System.arraycopy(tmp[i], 0, this.Sinv, i * this.dim, this.dim);
        }
        this.computeNormalizationConstant();
    }

    public WishartDistribution(int dim) {
        this.df = 0.0;
        this.scaleMatrix = null;
        this.dim = dim;
        this.logNormalizationConstant = 0.0;
    }

    private void computeNormalizationConstant() {
        this.logNormalizationConstant = WishartDistribution.computeNormalizationConstant(new Matrix(this.scaleMatrix), this.df, this.dim);
    }

    public static double computeNormalizationConstant(Matrix Sinv, double df, int dim) {
        if (df == 0.0) {
            return 0.0;
        }
        double logNormalizationConstant = 0.0;
        try {
            logNormalizationConstant = -df / 2.0 * Math.log(Sinv.determinant());
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        logNormalizationConstant -= df * (double)dim / 2.0 * Math.log(2.0);
        logNormalizationConstant -= (double)(dim * (dim - 1)) / 4.0 * Math.log(Math.PI);
        for (int i = 1; i <= dim; ++i) {
            logNormalizationConstant -= GammaFunction.lnGamma((df + 1.0 - (double)i) / 2.0);
        }
        return logNormalizationConstant;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    @Override
    public double[][] getScaleMatrix() {
        return this.scaleMatrix;
    }

    @Override
    public double[] getMean() {
        return null;
    }

    public void testMe() {
        int length = 100000;
        double save1 = 0.0;
        double save2 = 0.0;
        double save3 = 0.0;
        double save4 = 0.0;
        for (int i = 0; i < length; ++i) {
            double[][] draw = this.nextWishart();
            save1 += draw[0][0];
            save2 += draw[0][1];
            save3 += draw[1][0];
            save4 += draw[1][1];
        }
        System.err.println("S1: " + (save1 /= (double)length));
        System.err.println("S2: " + (save2 /= (double)length));
        System.err.println("S3: " + (save3 /= (double)length));
        System.err.println("S4: " + (save4 /= (double)length));
    }

    public double df() {
        return this.df;
    }

    public double[][] scaleMatrix() {
        return this.scaleMatrix;
    }

    public double[][] nextWishart() {
        return WishartDistribution.nextWishart(this.df, this.scaleMatrix);
    }

    public static double[][] nextWishart(double df, double[][] scaleMatrix) {
        int k;
        int j;
        int i;
        int i2;
        int dim = scaleMatrix.length;
        double[][] draw = new double[dim][dim];
        double[][] z = new double[dim][dim];
        for (i2 = 0; i2 < dim; ++i2) {
            for (int j2 = 0; j2 < i2; ++j2) {
                z[i2][j2] = MathUtils.nextGaussian();
            }
        }
        for (i2 = 0; i2 < dim; ++i2) {
            z[i2][i2] = Math.sqrt(MathUtils.nextGamma((df - (double)i2) * 0.5, 0.5));
        }
        double[][] cholesky = new double[dim][dim];
        for (int i3 = 0; i3 < dim; ++i3) {
            for (int j3 = i3; j3 < dim; ++j3) {
                double d = scaleMatrix[i3][j3];
                cholesky[j3][i3] = d;
                cholesky[i3][j3] = d;
            }
        }
        try {
            cholesky = new CholeskyDecomposition(cholesky).getL();
        }
        catch (IllegalDimension illegalDimension) {
            throw new RuntimeException("Numerical exception in WishartDistribution");
        }
        double[][] result = new double[dim][dim];
        for (i = 0; i < dim; ++i) {
            for (j = 0; j < dim; ++j) {
                for (k = 0; k < dim; ++k) {
                    double[] dArray = result[i];
                    int n = j;
                    dArray[n] = dArray[n] + cholesky[i][k] * z[k][j];
                }
            }
        }
        for (i = 0; i < dim; ++i) {
            for (j = 0; j < dim; ++j) {
                for (k = 0; k < dim; ++k) {
                    double[] dArray = draw[i];
                    int n = j;
                    dArray[n] = dArray[n] + result[i][k] * result[j][k];
                }
            }
        }
        return draw;
    }

    @Override
    public double logPdf(double[] x) {
        if (x.length == 4) {
            return WishartDistribution.logPdf2D(x, this.Sinv, this.df, this.dim, this.logNormalizationConstant);
        }
        return this.logPdfSlow(x);
    }

    public double logPdfSlow(double[] x) {
        Matrix W = new Matrix(x, this.dim, this.dim);
        return WishartDistribution.logPdf(W, this.SinvMat, this.df, this.dim, this.logNormalizationConstant);
    }

    public static double logPdf2D(double[] W, double[] Sinv, double df, int dim, double logNormalizationConstant) {
        double det = W[0] * W[3] - W[1] * W[2];
        if (det <= 0.0) {
            return Double.NEGATIVE_INFINITY;
        }
        double logDensity = Math.log(det);
        logDensity *= 0.5 * (df - (double)dim - 1.0);
        double trace = Sinv[0] * W[0] + Sinv[1] * W[2] + Sinv[2] * W[1] + Sinv[3] * W[3];
        logDensity -= 0.5 * trace;
        return logDensity += logNormalizationConstant;
    }

    public static double logPdf(Matrix W, Matrix Sinv, double df, int dim, double logNormalizationConstant) {
        double logDensity = 0.0;
        try {
            if (!W.isPD()) {
                return Double.NEGATIVE_INFINITY;
            }
            double det = W.determinant();
            if (det <= 0.0) {
                return Double.NEGATIVE_INFINITY;
            }
            logDensity = Math.log(det);
            logDensity *= 0.5;
            logDensity *= df - (double)dim - 1.0;
            if (Sinv != null) {
                Matrix product = Sinv.product(W);
                for (int i = 0; i < dim; ++i) {
                    logDensity -= 0.5 * product.component(i, i);
                }
            }
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        return logDensity += logNormalizationConstant;
    }

    public static void testBivariateMethod() {
        System.out.println("Testing new computations ...");
        WishartDistribution wd = new WishartDistribution(5.0, new double[][]{{2.0, -0.5}, {-0.5, 2.0}});
        double[] W = new double[]{4.0, 1.0, 1.0, 3.0};
        System.out.println("Fast logPdf = " + wd.logPdf(W));
        System.out.println("Slow logPdf = " + wd.logPdfSlow(W));
    }

    public static void main(String[] argv) {
        WishartDistribution wd = new WishartDistribution(2.0, new double[][]{{500.0}});
        GammaDistribution gd = new GammaDistribution(0.001, 1000.0);
        double[] x = new double[]{1.0};
        System.out.println("Wishart, df=2, scale = 500, PDF(1.0): " + wd.logPdf(x));
        System.out.println("Gamma, shape = 1/1000, scale = 1000, PDF(1.0): " + gd.logPdf(x[0]));
        wd = new WishartDistribution(4.0, new double[][]{{5.0}});
        gd = new GammaDistribution(2.0, 10.0);
        x = new double[]{1.0};
        System.out.println("Wishart, df=4, scale = 5, PDF(1.0): " + wd.logPdf(x));
        System.out.println("Gamma, shape = 1/1000, scale = 10, PDF(1.0): " + gd.logPdf(x[0]));
        wd = new WishartDistribution(1);
        x = new double[]{0.1};
        System.out.println("Wishart, uninformative, PDF(0.1): " + wd.logPdf(x));
        x = new double[]{1.0};
        System.out.println("Wishart, uninformative, PDF(1.0): " + wd.logPdf(x));
        x = new double[]{10.0};
        System.out.println("Wishart, uninformative, PDF(10.0): " + wd.logPdf(x));
        WishartDistribution.testBivariateMethod();
    }
}

