/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import JSci.maths.statistics.NormalDistribution;
import Jama.CholeskyDecomposition;
import Jama.Matrix;
import fig.prob.Distrib;
import fig.prob.Gaussian;
import fig.prob.MultGaussianSuffStats;
import fig.prob.SampleUtils;
import fig.prob.SuffStats;
import java.util.Arrays;
import java.util.Random;

public class MultGaussian
implements Distrib<double[]> {
    private Matrix mean;
    private Matrix covar;
    private CholeskyDecomposition chol = null;
    private static MultGaussian stdNormal = null;
    private static double[] zeroVector;
    private static double[][] identityMtx;

    public MultGaussian(double[] mean, double[][] covar) {
        this.mean = new Matrix(mean, mean.length);
        this.covar = new Matrix(covar);
    }

    @Override
    public double logProb(SuffStats _stats) {
        MultGaussianSuffStats stats = (MultGaussianSuffStats)_stats;
        if (MultGaussian.isDiag(this.covar)) {
            double sum = 0.0;
            for (int i = 0; i < this.dim(); ++i) {
                sum += Gaussian.logProb(this.mean.get(i, 0), this.covar.get(i, i), stats.getSum(i), stats.getOuterProduct(i, i), stats.numPoints());
            }
            return sum;
        }
        throw new RuntimeException();
    }

    public static boolean isDiag(Matrix covar2) {
        for (int i = 0; i < covar2.getRowDimension(); ++i) {
            for (int j = 0; j < covar2.getColumnDimension(); ++j) {
                if (i == j || covar2.get(i, j) == 0.0) continue;
                return false;
            }
        }
        return true;
    }

    @Override
    public double logProbObject(double[] x) {
        return this.logProb(new MultGaussianSuffStats(x));
    }

    private CholeskyDecomposition getChol() {
        if (this.chol != null) {
            return this.chol;
        }
        this.chol = this.covar.chol();
        return this.chol;
    }

    public double[] sample(Random random) {
        Matrix L = this.getChol().getL();
        Matrix stdNormal = new Matrix(this.dim(), 1);
        for (int i = 0; i < this.dim(); ++i) {
            stdNormal.set(i, 0, SampleUtils.sampleGaussian(random));
        }
        Matrix result = L.times(stdNormal);
        result.plusEquals(this.mean);
        return result.getColumnPackedCopy();
    }

    @Override
    public double[] sampleObject(Random random) {
        return this.sample(random);
    }

    @Override
    public double crossEntropy(Distrib<double[]> _that) {
        throw new RuntimeException("unsupported");
    }

    public static void main(String[] args) {
        double x1 = 1.2;
        double y1 = 4.5;
        MultGaussianSuffStats s1 = new MultGaussianSuffStats(new double[]{x1, y1});
        MultGaussian mg = MultGaussian.getStdNormal(2);
        System.out.println(mg.logProb(s1));
        System.out.println("---");
        System.out.println(Gaussian.logProb(0.0, 1.0, x1) + Gaussian.logProb(0.0, 1.0, y1));
        NormalDistribution nd = new NormalDistribution();
        System.out.println(Math.log(nd.probability(x1)) + Math.log(nd.probability(y1)));
        double[] mean = new double[]{1.0, 2.0};
        double[][] covar = new double[2][2];
        covar[0][0] = 1.0;
        covar[1][1] = 4.0;
        covar[0][1] = 1.0;
        covar[1][0] = 1.0;
        MultGaussian g = new MultGaussian(mean, covar);
        Random random = new Random();
        for (int i = 0; i < 10000; ++i) {
            System.out.println(Arrays.toString(g.sample(random)));
        }
    }

    public static double aggregatePtwiseProduct(Matrix m1, Matrix m2) {
        assert (m1.getRowDimension() == m2.getRowDimension());
        assert (m1.getColumnDimension() == m2.getColumnDimension());
        double sum = 0.0;
        for (int i = 0; i < m1.getRowDimension(); ++i) {
            for (int j = 0; j < m1.getColumnDimension(); ++j) {
                sum += m1.get(i, j) * m2.get(i, j);
            }
        }
        return sum;
    }

    public int dim() {
        return this.covar.getRowDimension();
    }

    public static MultGaussian getStdNormal(int n) {
        if (stdNormal != null && stdNormal.dim() == n) {
            return stdNormal;
        }
        stdNormal = new MultGaussian(MultGaussian.getZeroVector(n), MultGaussian.getIdentityMtx(n));
        return stdNormal;
    }

    public static double[] getZeroVector(int n) {
        if (zeroVector != null && zeroVector.length == n) {
            return zeroVector;
        }
        zeroVector = new double[n];
        for (int i = 0; i < n; ++i) {
            MultGaussian.zeroVector[i] = 0.0;
        }
        return zeroVector;
    }

    public static double[][] getIdentityMtx(int n) {
        if (identityMtx != null && identityMtx.length == n) {
            return identityMtx;
        }
        identityMtx = new double[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                MultGaussian.identityMtx[i][j] = i != j ? 0.0 : 1.0;
            }
        }
        return identityMtx;
    }

    public double[] getMean() {
        return this.mean.getArray()[0];
    }

    public double[][] getCovar() {
        return this.covar.getArray();
    }

    public Matrix getCovarMatrix() {
        return this.covar;
    }

    public String toString() {
        return "N(mean = " + Arrays.toString(this.mean.getColumnPackedCopy()) + ", covar = " + Arrays.deepToString((Object[])this.covar.getArray()) + ")";
    }
}

