/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.Fmt;
import fig.basic.ListUtils;
import fig.basic.NumUtils;
import fig.prob.DegenerateDirichlet;
import fig.prob.DirichletInterface;
import fig.prob.DirichletUtils;
import fig.prob.Distrib;
import fig.prob.Gamma;
import fig.prob.SuffStats;
import java.io.Serializable;
import java.util.Random;

public class Dirichlet
implements DirichletInterface,
Serializable {
    private static long serialVersionUID = 42L;
    private double[] alpha;
    private double totalCount;

    public Dirichlet(int numDim, double alpha) {
        this.alpha = ListUtils.newDouble(numDim, alpha);
        this.totalCount = alpha * (double)numDim;
    }

    public Dirichlet(double[] alpha) {
        this.alpha = alpha;
        this.totalCount = ListUtils.sum(alpha);
    }

    public double logProb(double[] x) {
        return Dirichlet.logProb(this.alpha, this.totalCount, x);
    }

    public static double logProb(double[] alpha, double totalCount, double[] x) {
        if (NumUtils.equals(totalCount, alpha.length)) {
            return 0.0;
        }
        double sum = NumUtils.logGamma(totalCount);
        for (int i = 0; i < alpha.length; ++i) {
            sum -= NumUtils.logGamma(alpha[i]);
            sum += (alpha[i] - 1.0) * Math.log(x[i]);
        }
        return sum;
    }

    @Override
    public double logProb(SuffStats stats) {
        throw new RuntimeException("Not implemented");
    }

    @Override
    public double logProbObject(double[] x) {
        return this.logProb(x);
    }

    public double[] sample(Random random) {
        return Dirichlet.sample(random, this.alpha);
    }

    public static double[] sample(Random random, double[] alpha) {
        double[] x = new double[alpha.length];
        for (int i = 0; i < alpha.length; ++i) {
            x[i] = Gamma.sample(random, alpha[i], 1.0);
        }
        NumUtils.normalize(x);
        return x;
    }

    @Override
    public double[] sampleObject(Random random) {
        return this.sample(random);
    }

    @Override
    public double crossEntropy(Distrib<double[]> _that) {
        Dirichlet that = (Dirichlet)_that;
        double sum = 0.0;
        sum += DirichletUtils.thatTotalCountContrib(that.totalCount());
        for (int i = 0; i < this.alpha.length; ++i) {
            sum += DirichletUtils.elementContrib(this.alpha[i], that.alpha[i], this.totalCount());
        }
        return sum;
    }

    @Override
    public double expectedLog(int i) {
        return DirichletUtils.expectedLog(this.alpha[i], this.totalCount);
    }

    @Override
    public double[] expectedLog() {
        double[] result = new double[this.dim()];
        for (int i = 0; i < this.dim(); ++i) {
            result[i] = this.expectedLog(i);
        }
        return result;
    }

    @Override
    public DirichletInterface modeSpike() {
        return new DegenerateDirichlet(this.getMode());
    }

    public Dirichlet perturb(Random random) {
        return new Dirichlet(ListUtils.mult(this.totalCount, this.sample(random)));
    }

    @Override
    public double[] getMean() {
        double[] mean = (double[])this.alpha.clone();
        NumUtils.normalize(mean);
        return mean;
    }

    @Override
    public double[] getMode() {
        double[] mode = ListUtils.add(this.alpha, -1.0);
        for (int i = 0; i < mode.length; ++i) {
            mode[i] = Math.max(mode[i], 1.0E-8);
        }
        NumUtils.normalize(mode);
        return mode;
    }

    public double[] getAlpha() {
        return this.alpha;
    }

    @Override
    public double getAlpha(int i) {
        return this.alpha[i];
    }

    @Override
    public double totalCount() {
        return this.totalCount;
    }

    @Override
    public int dim() {
        return this.alpha.length;
    }

    public String toString() {
        return String.format("Dirichlet(%s)", Fmt.D(this.alpha));
    }
}

