/*
 * Decompiled with CFR 0.152.
 */
package gep.pmcmc;

import fig.basic.LogInfo;
import fig.basic.Option;
import gep.GEPMain;
import gep.comparisons.IncrementalReconstructionMethod;
import gep.model.Predictives;
import gep.model.SplitContext;
import gep.model.SplitHyperParams;
import gep.model.SufficientStatisticsImpl;
import gep.pmcmc.Generator;
import gep.pmcmc.PartialHiddenState;
import gep.pmcmc.SMCKernel;
import gep.timeseries.Series;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.math.MeasureZeroException;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.CounterMap;
import nuts.util.MathUtils;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.smc.ParticleFilter;

public class PMCMC
implements IncrementalReconstructionMethod {
    @Option
    public double initBaseMeasureAlpha0 = 1.0;
    @Option
    public double initAlphaL1 = 1.0;
    @Option
    public double initAlphaL2 = 1.0;
    @Option
    public double initAlphaL3 = 1.0;
    @Option
    public double initBeta0 = 1.0;
    @Option
    public double initGamma0 = 1.0;
    @Option
    public boolean generateHittingTimes = false;
    @Option
    public boolean sampleGamma = false;
    @Option
    public boolean useIncrementalInit = true;
    @Option
    public int nInitRetries = 10;
    public static ParticleFilter<PartialHiddenState<SplitContext>> pf = null;
    private List<Series> series;
    private Predictives<SplitContext, SplitContext> predictiveDistributions;
    public SummaryStatistics acceptanceRate = null;
    public SummaryStatistics totalPropFailure = null;

    @Override
    public int reconstruct(int sIdx, int qIdx) {
        CounterMap<Integer, Integer> post = this.series.get(sIdx).getQueryPosteriors();
        Counter<Integer> curPost = post.getCounter(qIdx);
        MathUtils.checkClose(curPost.totalCount(), 1.0);
        return curPost.argMax();
    }

    @Override
    public void init(Random samplingRand) {
        this.initSequence(samplingRand, this.nInitRetries, this.useIncrementalInit);
    }

    @Override
    public void iterate(Random samplingRand, int iterIndex) {
        this.resetAcceptRateStats();
        List<Integer> indices = CollUtils.ints(this.nSeries());
        Collections.shuffle(indices, samplingRand);
        for (int sIdx : indices) {
            this.sampleSequence(samplingRand, sIdx);
            if (this.generateHittingTimes) {
                Generator.generateHitting(this.predictiveDistributions.readOnlySuffStats, this.predictiveDistributions, samplingRand, 5000.0, 10000, SplitContext.BEG);
            }
            if (iterIndex <= 0 || !this.sampleGamma) continue;
            this.sampleGammas(samplingRand);
        }
        LogInfo.logsForce("gammas:" + this.getGammas());
        GEPMain.outputManager.printWrite("totalFailRate", "Iter", iterIndex, "FailRate", this.totalPropFailure.getMean());
        GEPMain.outputManager.printWrite("acceptRate", "Iter", iterIndex, "AcceptRate", this.acceptanceRate.getMean());
        GEPMain.outputManager.printWrite("nHiddenStates", "Iter", iterIndex, "NumHiddenStates", this.nHiddenStates());
        this.updateAllQueryStats();
    }

    @Override
    public void loadData(List<Series> series, int nChars) {
        this.predictiveDistributions = Predictives.splitPredictives(new SufficientStatisticsImpl<SplitContext, SplitContext>(), this.initBaseMeasureAlpha0, this.initAlphaL1, this.initAlphaL2, this.initAlphaL3, nChars, this.initBeta0, this.initGamma0);
        this.series = series;
    }

    public int nHiddenStates() {
        return SufficientStatisticsImpl.nHiddenStates(this.predictiveDistributions.readOnlySuffStats);
    }

    public void updateAllQueryStats() {
        for (Series s : this.series) {
            s.updateQueryStats();
        }
    }

    public int nSeries() {
        return this.series.size();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void initSequence(Random rand, int nInitRetries, boolean useIncrementalInit) {
        for (int seqnIndex = 0; seqnIndex < this.series.size(); ++seqnIndex) {
            boolean accepted = false;
            for (int retry = 0; retry < nInitRetries + 1 && !accepted; ++retry) {
                int multiple = (int)Math.pow(2.0, retry);
                PMCMC.pf.N *= multiple;
                try {
                    accepted = this.sampleSequence(rand, seqnIndex, useIncrementalInit);
                    continue;
                }
                finally {
                    PMCMC.pf.N /= multiple;
                }
            }
            if (accepted) continue;
            throw new RuntimeException("Failed to initialize sequence");
        }
        if (!useIncrementalInit) {
            for (int i = 0; i < this.series.size(); ++i) {
                SufficientStatisticsImpl.plusEqual(this.predictiveDistributions.readOnlySuffStats, this.series.get((int)i).suffStats);
            }
        }
    }

    public boolean sampleSequence(Random rand, int seqnIndex) {
        return this.sampleSequence(rand, seqnIndex, true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public boolean sampleSequence(Random rand, int seqnIndex, boolean commitSuffStats) {
        boolean accept = false;
        Series current = this.series.get(seqnIndex);
        if (commitSuffStats) {
            SufficientStatisticsImpl.minusEqualCheckNonNeg(this.predictiveDistributions.readOnlySuffStats, current.suffStats);
        }
        SMCKernel<SplitContext> kernel = new SMCKernel<SplitContext>(this.predictiveDistributions, current.observations, SplitContext.BEG);
        ParticleFilter.StoreProcessor processor = new ParticleFilter.StoreProcessor();
        double ratio = 0.0;
        try {
            pf.sample(kernel, processor);
            PartialHiddenState sampled = (PartialHiddenState)processor.sample(rand);
            double newLogLL = pf.estimateNormalizer();
            ratio = Math.min(1.0, Math.exp(newLogLL - current.previousLogLL));
            boolean bl = accept = rand.nextDouble() < ratio;
            if (accept) {
                current.suffStats = sampled.suffStats();
                current.previousLogLL = newLogLL;
                current.currentHiddenVariables = sampled.eventsList();
                current.initialized = true;
            }
            if (this.totalPropFailure != null) {
                this.totalPropFailure.addValue(0.0);
            }
        }
        catch (MeasureZeroException mze) {
            if (this.totalPropFailure != null) {
                this.totalPropFailure.addValue(1.0);
            }
        }
        finally {
            if (this.acceptanceRate != null) {
                this.acceptanceRate.addValue(ratio);
            }
            if (commitSuffStats) {
                SufficientStatisticsImpl.plusEqual(this.predictiveDistributions.readOnlySuffStats, current.suffStats);
            }
        }
        return accept;
    }

    public void sampleGammas(Random rand) {
        ((SplitHyperParams)this.predictiveDistributions.hyperParams).sample(this.predictiveDistributions, rand);
    }

    public List<Double> getGammas() {
        return ((SplitHyperParams)this.predictiveDistributions.hyperParams).getGammas();
    }

    public void resetAcceptRateStats() {
        this.totalPropFailure = new SummaryStatistics();
        this.acceptanceRate = new SummaryStatistics();
    }
}

