/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.Exceptions;
import fig.basic.Fmt;
import fig.basic.ListUtils;
import fig.basic.NumUtils;

public class DirichletUtils {
    private static double fastExpMaxRange = 100.0;
    private static double[] fastExpDigammaBuckets;

    public static double expectedLog(double count, double totalCount) {
        double x = NumUtils.digamma(count) - NumUtils.digamma(totalCount);
        if (!NumUtils.isFinite(x)) {
            throw Exceptions.bad("count=%f, totalCount=%f", count, totalCount);
        }
        return x;
    }

    public static double thatTotalCountContrib(double thatTotalCount) {
        return NumUtils.logGamma(thatTotalCount);
    }

    public static double elementContrib(double thisCount, double thatCount, double thisTotalCount) {
        return (thatCount - 1.0) * DirichletUtils.expectedLog(thisCount, thisTotalCount) - NumUtils.logGamma(thatCount);
    }

    public static double logGammaRatio(double a, double n) {
        if (n == 1.0) {
            return Math.log(a);
        }
        return NumUtils.logGamma(a + n) - NumUtils.logGamma(a);
    }

    public static double fastExpDigamma(double count) {
        int i;
        assert (count >= 0.0) : count;
        if (count >= fastExpMaxRange) {
            return count - 0.5;
        }
        if (fastExpDigammaBuckets == null) {
            fastExpDigammaBuckets = new double[1000000];
            for (i = 1; i < fastExpDigammaBuckets.length; ++i) {
                double icount = fastExpMaxRange * (double)i / (double)fastExpDigammaBuckets.length;
                DirichletUtils.fastExpDigammaBuckets[i] = Math.exp(NumUtils.digamma(icount));
            }
        }
        if ((i = (int)((double)fastExpDigammaBuckets.length * count / fastExpMaxRange + 0.5)) >= fastExpDigammaBuckets.length) {
            i = fastExpDigammaBuckets.length - 1;
        }
        return fastExpDigammaBuckets[i];
    }

    public static double[] fastExpExpectedLog(double[] counts) {
        int n = counts.length;
        double[] scores = new double[n];
        double normalizer = DirichletUtils.fastExpDigamma(ListUtils.sum(counts));
        for (int i = 0; i < n; ++i) {
            scores[i] = DirichletUtils.fastExpDigamma(counts[i]) / normalizer;
        }
        return scores;
    }

    public static boolean fastExpExpectedLogMut(double[] counts) {
        int n = counts.length;
        double normalizer = DirichletUtils.fastExpDigamma(ListUtils.sum(counts));
        if (normalizer == 0.0) {
            return false;
        }
        for (int i = 0; i < n; ++i) {
            counts[i] = DirichletUtils.fastExpDigamma(counts[i]) / normalizer;
        }
        return true;
    }

    public static double[] expExpectedLog(double[] counts) {
        int n = counts.length;
        double[] scores = new double[n];
        double normalizer = Math.exp(NumUtils.digamma(ListUtils.sum(counts)));
        for (int i = 0; i < n; ++i) {
            scores[i] = Math.exp(NumUtils.digamma(counts[i])) / normalizer;
        }
        return scores;
    }

    public static void main(String[] args) {
        double[] counts = new double[]{3.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0};
        double alpha = 1.0;
        double[] prior = ListUtils.newDouble(counts.length, alpha);
        System.out.println(counts.length);
        double[] hack = DirichletUtils.expExpectedLog(ListUtils.add(counts, prior));
        System.out.println(Fmt.D(hack));
        System.out.println("sum = " + ListUtils.sum(hack));
        System.out.println("norm = " + Fmt.D(DirichletUtils.norm(hack)));
        System.out.println("MLE = " + Fmt.D(DirichletUtils.norm(counts)));
    }

    static double[] norm(double[] x) {
        x = (double[])x.clone();
        NumUtils.normalize(x);
        return x;
    }
}

