/*
 * Decompiled with CFR 0.152.
 */
package gep.model;

import fig.prob.SampleUtils;
import gep.model.HyperParamProvider;
import gep.model.Predictives;
import gep.model.SplitContext;
import gep.model.SufficientStatistics;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.util.CollUtils;
import nuts.util.Counter;

public class SplitHyperParams
implements HyperParamProvider<SplitContext> {
    private final List<Double> alphaLevel1 = CollUtils.list();
    private final double alphaLevel2;
    private final double alphaLevel3;
    private final double gamma0;

    public SplitHyperParams(double alphaLevel1, double alphaLevel2, double alphaLevel3, int nSym, double gamma0) {
        for (int i = 0; i < nSym; ++i) {
            this.alphaLevel1.add(alphaLevel1);
        }
        this.alphaLevel2 = alphaLevel2;
        this.alphaLevel3 = alphaLevel3;
        this.gamma0 = gamma0;
    }

    public List<Double> getGammas() {
        return Collections.unmodifiableList(this.alphaLevel1);
    }

    @Override
    public double getHyperParam(SplitContext context) {
        int level = context.level();
        if (level == 0) {
            return this.alphaLevel1.get(context.getX());
        }
        if (level == 1) {
            return this.alphaLevel2;
        }
        if (level == 2) {
            return this.alphaLevel3;
        }
        throw new RuntimeException();
    }

    private double getA(int x, SufficientStatistics<SplitContext, SplitContext> stats) {
        SplitContext xCtxt = new SplitContext(x);
        return this.alphaLevel2 + stats.getStateTransitionStats().getCounter(xCtxt).totalCount();
    }

    private double getB(int x, SufficientStatistics<SplitContext, SplitContext> stats, double beta0) {
        double sum = this.gamma0;
        Counter<SplitContext> jumpStats = stats.getWaitingTimeStats();
        for (SplitContext ctxt : jumpStats.keySet()) {
            if (ctxt.getX() != x) continue;
            sum += Math.log(beta0 + jumpStats.getCount(ctxt)) - Math.log(beta0);
        }
        if (!(sum > 0.0)) {
            throw new RuntimeException("getB() returned " + sum + "\nbeta0=" + beta0 + ",jumpStats=" + jumpStats);
        }
        return sum;
    }

    private void sampleGamma(Predictives<SplitContext, SplitContext> predictiveDistributions, Random rand, int charIndex) {
        double b;
        SufficientStatistics<SplitContext, SplitContext> stats = predictiveDistributions.readOnlySuffStats;
        double beta0 = predictiveDistributions.beta0;
        double a = ((SplitHyperParams)predictiveDistributions.hyperParams).getA(charIndex, stats);
        double newValue = SampleUtils.sampleGamma(rand, a, b = ((SplitHyperParams)predictiveDistributions.hyperParams).getB(charIndex, stats, beta0));
        if (!(newValue > 0.0)) {
            throw new RuntimeException("sampleGamma() returned " + newValue + ", a=" + a + "," + "b=" + b);
        }
        this.alphaLevel1.set(charIndex, newValue);
    }

    public void sample(Predictives<SplitContext, SplitContext> predictiveDistributions, Random rand) {
        for (int charIdx = 0; charIdx < this.alphaLevel1.size(); ++charIdx) {
            this.sampleGamma(predictiveDistributions, rand, charIdx);
        }
    }
}

