/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.test;

import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import hmm.Param;
import hmm.ParamUtils;
import java.util.ArrayList;
import java.util.Random;
import nuts.math.Sampling;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.util.CollUtils;
import nuts.util.Counter;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import pty.smc.test.TestParticleNormalization;

public class TestConditionalSMC {
    public static void main(String[] args) {
        Random rand = new Random(1L);
        Param p = ParamUtils.randomUniParam(rand, 2, 2);
        int[] obs = new int[20];
        ArrayList<Integer> obsList = new ArrayList<Integer>();
        for (int o : obs) {
            obsList.add(o);
        }
        TreeSumProd.HmmAdaptor adapt = new TreeSumProd.HmmAdaptor(p, obsList);
        TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(adapt);
        TabularGMFct<Integer> exactMoments = tsp.moments();
        for (int i = 10; i < 1000000000; i *= 10) {
            System.out.println("I=" + i);
            TestConditionalSMC.smcApprox(i, p, obs, rand, exactMoments, true);
            TestConditionalSMC.smcApprox(i, p, obs, rand, exactMoments, false);
            System.out.println();
        }
    }

    private static void smcApprox(int maxIter, Param p, int[] obs, Random rand, TabularGMFct<Integer> exactMoments, boolean breakIt) {
        int[] cStates = new int[obs.length];
        for (int i = 0; i < obs.length; ++i) {
            cStates[i] = rand.nextInt(p.nObs());
        }
        ParticleFilter<Object> pf = new ParticleFilter<Object>();
        pf.N = 10;
        int winSize = 2;
        Counter<String> stats = new Counter<String>();
        HMMConditionalParticleKernel kernel = new HMMConditionalParticleKernel(p, obs, cStates);
        for (int iter = 0; iter < maxIter; ++iter) {
            kernel.left = rand.nextInt(obs.length - 2 + 1);
            kernel.right = kernel.left + 2;
            ArrayList currentCStates = CollUtils.list();
            double[] uw = new double[2];
            int _cur = 0;
            TestParticleNormalization.HMMPState prev = new TestParticleNormalization.HMMPState(kernel.left - 1, -1, null);
            for (int t = kernel.left; t < kernel.right; ++t) {
                int curCState = cStates[t];
                TestParticleNormalization.HMMPState curHMMPS = new TestParticleNormalization.HMMPState(t, curCState, prev);
                currentCStates.add(curHMMPS);
                uw[_cur] = p.emiMtx.p(curCState, obs[t]);
                ++_cur;
                prev = curHMMPS;
            }
            if (!breakIt) {
                pf.setConditional(currentCStates, uw);
            }
            ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
            pf.sample(kernel, pro);
            int index = Sampling.sample(rand, pro.ws);
            TestParticleNormalization.HMMPState sampled = (TestParticleNormalization.HMMPState)pro.particles.get(index);
            while (sampled.ancestor != null) {
                cStates[sampled.t] = sampled.state;
                sampled = sampled.ancestor;
            }
            for (int t = 0; t < obs.length; ++t) {
                stats.incrementCount("Z_" + t + "=" + cStates[t], 1.0);
            }
        }
        double error = 0.0;
        for (int t = 0; t < obs.length; ++t) {
            for (int s = 0; s < p.nStates(); ++s) {
                double delta = Math.abs(exactMoments.get(t, s) - stats.getCount("Z_" + t + "=" + s) / (double)maxIter);
                error += delta;
            }
        }
        System.out.println("Break=" + breakIt + ",error=" + error);
    }

    public static class HMMConditionalParticleKernel
    implements ParticleKernel<TestParticleNormalization.HMMPState> {
        public final Param params;
        public final int[] obs;
        public int left;
        public int right;
        public final int[] cStates;

        public HMMConditionalParticleKernel(Param params, int[] obs, int[] cStates) {
            this.params = params;
            this.obs = obs;
            this.cStates = cStates;
        }

        @Override
        public TestParticleNormalization.HMMPState getInitial() {
            return new TestParticleNormalization.HMMPState(this.left - 1, this.right == 0 ? -1 : this.cStates[this.right - 1], null);
        }

        @Override
        public int nIterationsLeft(TestParticleNormalization.HMMPState partialState) {
            return this.right - partialState.t - 1;
        }

        @Override
        public Pair<TestParticleNormalization.HMMPState, Double> next(Random rand, TestParticleNormalization.HMMPState current) {
            int nxt;
            if (current.t >= this.right) {
                throw new RuntimeException();
            }
            if (current.t == this.right - 1 && this.right < this.cStates.length) {
                int rightSym = this.cStates[this.right];
                double[] unormPrs = new double[this.params.nStates()];
                for (int s = 0; s < this.params.nStates(); ++s) {
                    unormPrs[s] = this.params.transMtx.p(current.state, s) * this.params.transMtx.p(s, rightSym);
                }
                NumUtils.normalize(unormPrs);
                nxt = SampleUtils.sampleMultinomial(rand, unormPrs);
            } else {
                nxt = current.t >= 0 ? this.params.transMtx.nextState(current.state, rand) : this.params.initVec.nextState(rand);
            }
            double w = Math.log(this.params.emiMtx.p(nxt, this.obs[current.t + 1]));
            return Pair.makePair(new TestParticleNormalization.HMMPState(current.t + 1, nxt, current), w);
        }
    }
}

