/*
 * Decompiled with CFR 0.152.
 */
package gep.model;

import fig.basic.Pair;
import gep.model.BaseMeasure;
import gep.model.ContextHierarchy;
import gep.model.HyperParamProvider;
import gep.model.SplitBaseMeasure;
import gep.model.SplitContext;
import gep.model.SplitHierarchy;
import gep.model.SplitHyperParams;
import gep.model.SufficientStatistics;
import gep.util.TP;
import java.util.Random;
import nuts.util.Counter;
import nuts.util.CounterMap;

public class Predictives<C, S> {
    public final SufficientStatistics<C, S> readOnlySuffStats;
    public final ContextHierarchy<C> contextHierarchy;
    public final BaseMeasure<C, S> baseMeasure;
    public final HyperParamProvider<C> hyperParams;
    public final double beta0;
    private static final Counter DUMMY = new Counter();

    private double alpha(C context) {
        return this.hyperParams.getHyperParam(context);
    }

    public static Predictives<SplitContext, SplitContext> splitPredictives(SufficientStatistics<SplitContext, SplitContext> readOnlySuffStats, double bmAlpha0, double alphaLevel1, double alphaLevel2, double alphaLevel3, int nSym, double beta0, double gamma0) {
        return new Predictives<SplitContext, SplitContext>(readOnlySuffStats, new SplitHierarchy(), new SplitBaseMeasure(bmAlpha0, nSym), new SplitHyperParams(alphaLevel1, alphaLevel2, alphaLevel3, nSym, gamma0), beta0);
    }

    public Predictives(SufficientStatistics<C, S> readOnlySuffStats, ContextHierarchy<C> contextHierarchy, BaseMeasure<C, S> baseMeasure, HyperParamProvider<C> hyperParmams, double beta0) {
        this.readOnlySuffStats = readOnlySuffStats;
        this.contextHierarchy = contextHierarchy;
        this.baseMeasure = baseMeasure;
        this.hyperParams = hyperParmams;
        this.beta0 = beta0;
    }

    public Pair<S, Double> sample(Random rand, C context, SufficientStatistics<C, S> readWriteSuffStat) {
        double time = this.sampleTime(rand, context, readWriteSuffStat);
        S state = this.sampleState(rand, context, readWriteSuffStat);
        return Pair.makePair(state, time);
    }

    public static <C, S> Counter<S> getReadOnly(CounterMap<C, S> map, C key) {
        return map.hasCounter(key) ? map.getCounter(key) : DUMMY;
    }

    private double sampleTime(Random rand, C context, SufficientStatistics<C, S> readWriteSuffStat) {
        Counter<S> currentCacheRW = readWriteSuffStat.getStateTransitionStats().getCounter(context);
        Counter<S> currentCacheR = Predictives.getReadOnly(this.readOnlySuffStats.getStateTransitionStats(), context);
        double alpha = currentCacheRW.totalCount() + currentCacheR.totalCount() + this.alpha(context);
        Counter<C> waitingTimeStatsRW = readWriteSuffStat.getWaitingTimeStats();
        Counter<C> waitingTimeStatsR = this.readOnlySuffStats.getWaitingTimeStats();
        double beta = waitingTimeStatsRW.getCount(context) + waitingTimeStatsR.getCount(context) + this.beta0;
        double sample = TP.sample(rand, alpha, beta);
        if (!(sample > 0.0)) {
            throw new RuntimeException();
        }
        waitingTimeStatsRW.incrementCount(context, sample);
        return sample;
    }

    public S sampleState(Random rand, C context, SufficientStatistics<C, S> readWriteSuffStat) {
        Counter<S> currentCacheRW = readWriteSuffStat.getStateTransitionStats().getCounter(context);
        Counter<S> currentCacheR = Predictives.getReadOnly(this.readOnlySuffStats.getStateTransitionStats(), context);
        double nObsInCurContxt = currentCacheRW.totalCount() + currentCacheR.totalCount();
        double alpha = this.alpha(context);
        S result = rand.nextDouble() < nObsInCurContxt / (nObsInCurContxt + alpha) ? Predictives.sampleCache(rand, currentCacheRW, currentCacheR) : this.recurse(rand, context, readWriteSuffStat);
        currentCacheRW.incrementCount(result, 1.0);
        return result;
    }

    public static <S> S sampleCache(Random rand, Counter<S> currentCacheRW, Counter<S> currentCacheR) {
        double totalMass = currentCacheR.totalCount() + currentCacheRW.totalCount();
        double u = rand.nextDouble() * totalMass;
        double sum = 0.0;
        for (S item : currentCacheRW) {
            if (!((sum += currentCacheRW.getCount(item)) >= u)) continue;
            return item;
        }
        for (S item : currentCacheR) {
            if (!((sum += currentCacheR.getCount(item)) >= u)) continue;
            return item;
        }
        throw new RuntimeException();
    }

    private S recurse(Random rand, C context, SufficientStatistics<C, S> readWriteSuffStat) {
        C backoff = this.contextHierarchy.backoff(context);
        return backoff == null ? this.baseMeasure.sampleState(rand, readWriteSuffStat, this.readOnlySuffStats) : this.sampleState(rand, backoff, readWriteSuffStat);
    }
}

