/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import Jama.Matrix;
import fig.prob.Distrib;
import fig.prob.NormalInverseWishart;
import fig.prob.SuffStats;
import java.util.Random;
import org.apache.commons.math.special.Gamma;

public class NormalInverseWishartDistrib
implements Distrib<NormalInverseWishart> {
    private double nu;
    private Matrix delta;
    private Matrix scriptV;
    private double kappa;

    public NormalInverseWishartDistrib(double kappa, Matrix scriptV, double nu, Matrix delta) {
        if (kappa <= 0.0) {
            throw new RuntimeException("kappa " + kappa + " should be > 0");
        }
        if (nu <= (double)(delta.getColumnDimension() + 1)) {
            throw new RuntimeException("nu " + nu + " should be > d + 1, d = " + delta.getColumnDimension());
        }
        this.nu = nu;
        this.delta = delta;
        this.scriptV = scriptV;
        this.kappa = kappa;
    }

    public NormalInverseWishartDistrib(double kappa, double[] scriptV, double nu, double[][] delta) {
        this(kappa, new Matrix(scriptV, scriptV.length), nu, new Matrix(delta));
    }

    public double logProbAt01() {
        double result = 0.0;
        Matrix deltaTimeNu = this.delta.times(this.nu);
        result -= 0.5 * deltaTimeNu.trace();
        result -= this.kappa / 2.0 * NormalInverseWishartDistrib.norm(this.scriptV);
        result += this.nu / 2.0 * Math.log(deltaTimeNu.det());
        result -= this.nu * (double)this.dim() * Math.log(2.0);
        result -= this.multivariateLogGamma(this.dim(), this.nu / 2.0);
        return result -= (double)this.dim() / 2.0 * Math.log(Math.PI * 2 / this.kappa);
    }

    private double multivariateLogGamma(int dim, double a) {
        if (dim == 1) {
            return Gamma.logGamma((double)a);
        }
        if (dim == 2) {
            return 0.5 * Math.log(Math.PI) + Gamma.logGamma((double)a) + Gamma.logGamma((double)(a - 0.5));
        }
        double result = (double)dim / 2.0 * Math.log(Math.PI);
        for (int i = 0; i < dim; ++i) {
            result += Gamma.logGamma((double)(a - (double)i / 2.0));
        }
        return result;
    }

    public static double norm(Matrix m) {
        if (m.getColumnDimension() > 1) {
            throw new RuntimeException();
        }
        return m.transpose().times(m).get(0, 0);
    }

    @Override
    public double logProb(SuffStats stats) {
        throw new RuntimeException("Not implemented");
    }

    @Override
    public double logProbObject(NormalInverseWishart x) {
        throw new RuntimeException("Not supported right now");
    }

    @Override
    public NormalInverseWishart sampleObject(Random random) {
        throw new RuntimeException("Not supported right now");
    }

    @Override
    public double crossEntropy(Distrib<NormalInverseWishart> _that) {
        throw new RuntimeException("Not supported");
    }

    private boolean isIdentity(Matrix lambda) {
        for (int i = 0; i < lambda.getRowDimension(); ++i) {
            for (int j = 0; j < lambda.getColumnDimension(); ++j) {
                if (!(i == j ? lambda.get(i, j) != 1.0 : lambda.get(i, j) != 0.0)) continue;
                return false;
            }
        }
        return true;
    }

    public static double norm(Matrix kernel, Matrix vector) {
        assert (kernel.getColumnDimension() == kernel.getRowDimension());
        assert (kernel.getColumnDimension() == vector.getRowDimension());
        assert (vector.getColumnDimension() == 1);
        Matrix result = vector.transpose().times(kernel).times(vector);
        assert (result.getColumnDimension() == 1);
        assert (result.getRowDimension() == 1);
        return result.get(0, 0);
    }

    public Matrix expectedVariance() {
        double coefficient = this.nu / (this.nu - (double)this.dim() - 1.0);
        return this.delta.times(coefficient);
    }

    public int dim() {
        return this.delta.getColumnDimension();
    }

    public Matrix getDelta() {
        return this.delta;
    }

    public double getKappa() {
        return this.kappa;
    }

    public double getNu() {
        return this.nu;
    }

    public Matrix getScriptV() {
        return this.scriptV;
    }

    public double[] getMeanMean() {
        return this.scriptV.getColumnPackedCopy();
    }

    public double[][] getMeanCovar() {
        int d = this.dim();
        double[][] result = new double[d][d];
        for (int i = 0; i < d; ++i) {
            for (int j = 0; j < d; ++j) {
                result[i][j] = this.nu / (this.nu - (double)d - 1.0) * this.delta.get(i, j);
            }
        }
        return result;
    }

    public String toString() {
        return "NIW(nu=" + this.nu + ", kappa=" + this.kappa + ")";
    }
}

