/*
 * Decompiled with CFR 0.152.
 */
package unsupalg.hmm;

import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.util.MathUtils;
import unsupalg.hmm.EStep;
import unsupalg.hmm.Param;

public class GibbsEStep
implements EStep {
    private Random rand;
    private List<Integer> observations;
    private int length;
    private Param param;
    private double[][] counts;
    private double[][][] twoNodesTotalCount;
    private int[] sample;
    private int currentIter = 0;
    private int maxIter = 10000000;
    private int burnin = 1000000;

    public GibbsEStep(Random rand) {
        this.rand = rand;
    }

    @Override
    public void compute(List<Integer> observations, Param param) {
        this.init(observations, param);
        this.currentIter = 0;
        while (this.currentIter < this.maxIter) {
            this.resample(this.rand.nextInt(this.length));
            ++this.currentIter;
        }
    }

    @Override
    public double[] oneNodePosterior(int t) {
        double[] result = (double[])this.counts[t].clone();
        NumUtils.normalize(result);
        return result;
    }

    @Override
    public double[][] twoNodesPosterior(int t) {
        double[][] result = new double[this.param.nStates()][this.param.nStates()];
        for (int i = 0; i < this.param.nStates(); ++i) {
            for (int j = 0; j < this.param.nStates(); ++j) {
                result[i][j] = this.twoNodesTotalCount[t][i][j];
            }
        }
        MathUtils.normalize(result);
        return result;
    }

    private void init(List<Integer> observations, Param param) {
        this.observations = observations;
        this.length = observations.size();
        this.param = param;
        this.counts = new double[this.length][param.nStates()];
        this.twoNodesTotalCount = new double[this.length - 1][param.nStates()][param.nStates()];
        this.sample = new int[this.length];
    }

    public void resample(int t) {
        double[] prs = new double[this.param.nStates()];
        for (int s = 0; s < this.param.nStates(); ++s) {
            prs[s] = this.param.emiMtx.p(s, this.obs(t));
            if (t != 0) {
                int n = s;
                prs[n] = prs[n] * this.param.transMtx.p(this.sample[t - 1], s);
            } else {
                int n = s;
                prs[n] = prs[n] * this.param.initVec.p(s);
            }
            if (t == this.length - 1) continue;
            int n = s;
            prs[n] = prs[n] * this.param.transMtx.p(s, this.sample[t + 1]);
        }
        NumUtils.normalize(prs);
        this.sample[t] = SampleUtils.sampleMultinomial(this.rand, prs);
        if (this.currentIter > this.burnin) {
            double[] dArray = this.counts[t];
            int n = this.sample[t];
            dArray[n] = dArray[n] + 1.0;
            if (t != 0) {
                double[] dArray2 = this.twoNodesTotalCount[t - 1][this.sample[t - 1]];
                int n2 = this.sample[t];
                dArray2[n2] = dArray2[n2] + 1.0;
            }
            if (t != this.length - 1) {
                double[] dArray3 = this.twoNodesTotalCount[t][this.sample[t]];
                int n3 = this.sample[t + 1];
                dArray3[n3] = dArray3[n3] + 1.0;
            }
        }
    }

    @Override
    public double logll() {
        return 0.0;
    }

    private int obs(int t) {
        return this.observations.get(t);
    }

    @Override
    public int length() {
        return this.length;
    }

    @Override
    public List<Integer> observations() {
        return Collections.unmodifiableList(this.observations);
    }
}

