/*
 * Decompiled with CFR 0.152.
 */
package slice.likelihood;

import fig.prob.Gaussian;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.util.MathUtils;
import slice.likelihood.NormalLocation;
import slice.stickrep.LocationDistribution;

public class NormalMeanFixedVarLocationDist
implements LocationDistribution<NormalLocation, List<Double>> {
    private List<Double> meanOfTheNormalMeanDist = new ArrayList<Double>();
    private List<Double> varOfTheNormalMeanDist = new ArrayList<Double>();
    private List<Double> spikeOfTheDiracDeltaVarDist = new ArrayList<Double>();

    public NormalMeanFixedVarLocationDist(List<Double> meanOfTheNormalMeanDist, List<Double> varOfTheNormalMeanDist, List<Double> spikeOfTheDiracDeltaVarDist) {
        this.meanOfTheNormalMeanDist.addAll(meanOfTheNormalMeanDist);
        this.varOfTheNormalMeanDist.addAll(varOfTheNormalMeanDist);
        this.spikeOfTheDiracDeltaVarDist.addAll(spikeOfTheDiracDeltaVarDist);
        assert (this.invar());
    }

    public int dim() {
        return this.meanOfTheNormalMeanDist.size();
    }

    @Override
    public NormalLocation sample(Random rand) {
        ArrayList<Double> sampledMean = new ArrayList<Double>();
        for (int d = 0; d < this.dim(); ++d) {
            double sampled = Gaussian.sample(rand, this.meanOfTheNormalMeanDist.get(d), this.varOfTheNormalMeanDist.get(d));
            sampledMean.add(sampled);
        }
        return new NormalLocation(sampledMean, this.spikeOfTheDiracDeltaVarDist);
    }

    @Override
    public LocationDistribution<NormalLocation, List<Double>> posterior(List<List<Double>> data) {
        if (data == null || data.size() == 0) {
            return this;
        }
        assert (data.iterator().next().size() == this.dim());
        double n = data.size();
        double[] sums = MathUtils.addVectors(data);
        ArrayList<Double> postMeanOfTheNormalMeanDist = new ArrayList<Double>();
        ArrayList<Double> postVarOfTheNormalMeanDist = new ArrayList<Double>();
        for (int i = 0; i < this.dim(); ++i) {
            double prM = this.meanOfTheNormalMeanDist.get(i);
            double prV = this.varOfTheNormalMeanDist.get(i);
            double mx = sums[i] / n;
            double sd = this.spikeOfTheDiracDeltaVarDist.get(i);
            double denum = prV + sd / n;
            double poM = (prV * mx + sd / n * prM) / denum;
            double poV = prV * (sd / n) / denum;
            postMeanOfTheNormalMeanDist.add(poM);
            postVarOfTheNormalMeanDist.add(poV);
        }
        return new NormalMeanFixedVarLocationDist(postMeanOfTheNormalMeanDist, postVarOfTheNormalMeanDist, this.spikeOfTheDiracDeltaVarDist);
    }

    public boolean invar() {
        assert (this.meanOfTheNormalMeanDist.size() == this.varOfTheNormalMeanDist.size());
        assert (this.varOfTheNormalMeanDist.size() == this.spikeOfTheDiracDeltaVarDist.size());
        return true;
    }

    @Override
    public NormalLocation samplePosterior(Random rand, List<List<Double>> data) {
        return this.posterior(data).sample(rand);
    }
}

