/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import Jama.Matrix;
import fig.basic.Exceptions;
import fig.prob.Distrib;
import fig.prob.MargDistrib;
import fig.prob.MultGaussian;
import fig.prob.MultGaussianSuffStats;
import fig.prob.NormalInverseWishart;
import fig.prob.NormalInverseWishartDistrib;
import fig.prob.SuffStats;
import java.util.Random;
import nuts.util.MathUtils;

public class MargMultGaussian
implements MargDistrib<NormalInverseWishart> {
    public NormalInverseWishartDistrib meanVarDistrib;

    public MargMultGaussian(NormalInverseWishartDistrib prior) {
        this.meanVarDistrib = prior;
    }

    @Override
    public double margLogLikelihood(SuffStats stats) {
        double sum = 0.0;
        sum += this.meanVarDistrib.logProbAt01();
        sum += MultGaussian.getStdNormal(this.dim()).logProb((MultGaussianSuffStats)stats);
        return sum -= this.getPosterior((SuffStats)((MultGaussianSuffStats)stats)).meanVarDistrib.logProbAt01();
    }

    @Override
    public double predLogLikelihood(SuffStats condStats, SuffStats predStats) {
        return this.getPosterior((MultGaussianSuffStats)condStats).margLogLikelihood(predStats);
    }

    public MargMultGaussian getPosterior(SuffStats stats) {
        return this.fast_getPosterior(stats);
    }

    public MargMultGaussian slow_getPosterior(SuffStats stats) {
        MultGaussianSuffStats observation = (MultGaussianSuffStats)stats;
        Matrix sumObs = new Matrix(observation.getSum(), observation.dim());
        Matrix outerProducts = new Matrix(observation.getOuterProduct());
        double kappaPrime = this.meanVarDistrib.getKappa() + (double)observation.numPoints();
        double nuPrime = this.meanVarDistrib.getNu() + (double)observation.numPoints();
        Matrix scriptVPrime = this.meanVarDistrib.getScriptV().times(this.meanVarDistrib.getKappa()).plus(sumObs).times(1.0 / kappaPrime);
        Matrix t1 = this.meanVarDistrib.getDelta().times(this.meanVarDistrib.getNu());
        Matrix t2 = this.meanVarDistrib.getScriptV().times(this.meanVarDistrib.getScriptV().transpose()).times(this.meanVarDistrib.getKappa());
        Matrix t3 = scriptVPrime.times(scriptVPrime.transpose()).times(kappaPrime);
        Matrix deltaPrime = t1.plus(t2).minus(t3).plus(outerProducts).times(1.0 / nuPrime);
        return new MargMultGaussian(new NormalInverseWishartDistrib(kappaPrime, scriptVPrime, nuPrime, deltaPrime));
    }

    public static void checkEqual(Matrix m1, Matrix m2) {
        for (int i = 0; i < m1.getRowDimension(); ++i) {
            for (int j = 0; j < m1.getColumnDimension(); ++j) {
                MathUtils.checkClose(m1.get(i, j), m2.get(i, j));
            }
        }
    }

    public MargMultGaussian fast_getPosterior(SuffStats stats) {
        int i;
        MultGaussianSuffStats observation = (MultGaussianSuffStats)stats;
        Matrix sumObs = observation.sum;
        Matrix outerProducts = observation.outerproducts;
        int dim = observation.dim();
        int numPoints = observation.numPoints();
        double nu = this.meanVarDistrib.getNu();
        double kappa = this.meanVarDistrib.getKappa();
        double kappaPrime = kappa + (double)numPoints;
        double nuPrime = nu + (double)numPoints;
        Matrix scriptV = this.meanVarDistrib.getScriptV();
        Matrix delta = this.meanVarDistrib.getDelta();
        Matrix scriptVPrime = new Matrix(dim, 1);
        Matrix deltaPrime = new Matrix(dim, dim);
        for (i = 0; i < dim; ++i) {
            scriptVPrime.set(i, 0, (kappa * scriptV.get(i, 0) + sumObs.get(i, 0)) / kappaPrime);
        }
        for (i = 0; i < dim; ++i) {
            for (int j = 0; j < dim; ++j) {
                deltaPrime.set(i, j, (nu * delta.get(i, j) + outerProducts.get(i, j) + kappa * scriptV.get(i, 0) * scriptV.get(j, 0) - kappaPrime * scriptVPrime.get(i, 0) * scriptVPrime.get(j, 0)) / nuPrime);
            }
        }
        return new MargMultGaussian(new NormalInverseWishartDistrib(kappaPrime, scriptVPrime, nuPrime, deltaPrime));
    }

    @Override
    public double logProb(SuffStats stats) {
        throw Exceptions.unimplemented;
    }

    @Override
    public double logProbObject(NormalInverseWishart distrib) {
        throw Exceptions.unimplemented;
    }

    @Override
    public double crossEntropy(Distrib<NormalInverseWishart> distrib) {
        throw Exceptions.unimplemented;
    }

    @Override
    public double expectedLogLikelihood(SuffStats stats) {
        throw Exceptions.unimplemented;
    }

    @Override
    public NormalInverseWishart sampleObject(Random random) {
        return this.meanVarDistrib.sampleObject(random);
    }

    public int dim() {
        return this.meanVarDistrib.dim();
    }

    public static void main(String[] args) {
        double nu = 4.0;
        Matrix scriptV = new Matrix(1, 1);
        scriptV.set(0, 0, 5.0);
        Matrix delta = new Matrix(1, 1);
        delta.set(0, 0, 1.0);
        double kappa = 1.0;
        NormalInverseWishartDistrib prior = new NormalInverseWishartDistrib(kappa, scriptV, nu, delta);
        double[] mean = new double[]{30.0};
        double[][] variance = new double[][]{{1.0}};
        MultGaussian g = new MultGaussian(mean, variance);
        MultGaussianSuffStats observations = new MultGaussianSuffStats(1);
        Random random = new Random();
        for (int i = 0; i < 10000; ++i) {
            observations.add(g.sample(random));
        }
        MargMultGaussian margGaussian = new MargMultGaussian(prior);
        MargMultGaussian posterior = margGaussian.getPosterior(observations);
        System.out.println(posterior.meanVarDistrib.expectedVariance().get(0, 0));
    }
}

