/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.test;

import fig.basic.Pair;
import hmm.Param;
import hmm.ParamUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import nuts.math.TreeSumProd;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;

public class TestParticleNormalization {
    public static void main(String[] args) {
        Random rand = new Random(1L);
        Param p = ParamUtils.randomUniParam(rand, 10, 10);
        System.out.println("Param:\n" + p);
        int[] obs = new int[20];
        System.out.println("Observation:" + Arrays.toString(obs));
        HMMParticleKernel pk = new HMMParticleKernel(p, obs);
        ParticleFilter.DoNothingProcessor voidPro = new ParticleFilter.DoNothingProcessor();
        ParticleFilter<HMMPState> pf = new ParticleFilter<HMMPState>();
        pf.N = 100;
        pf.resamplingStrategy = ParticleFilter.ResamplingStrategy.ALWAYS;
        pf.resampleLastRound = false;
        pf.sample(pk, voidPro);
        System.out.println("Approx=" + pf.estimateNormalizer());
        ArrayList<Integer> obsList = new ArrayList<Integer>();
        for (int o : obs) {
            obsList.add(o);
        }
        TreeSumProd.HmmAdaptor adapt = new TreeSumProd.HmmAdaptor(p, obsList);
        TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(adapt);
        System.out.println("Exact=" + tsp.logZ());
    }

    public static class HMMParticleKernel
    implements ParticleKernel<HMMPState> {
        public final Param params;
        public final int[] obs;
        public final int T;

        public HMMParticleKernel(Param params, int[] obs) {
            this.params = params;
            this.obs = obs;
            this.T = obs.length;
        }

        @Override
        public HMMPState getInitial() {
            return new HMMPState(-1, 0, null);
        }

        @Override
        public int nIterationsLeft(HMMPState partialState) {
            return this.T - partialState.t - 1;
        }

        @Override
        public Pair<HMMPState, Double> next(Random rand, HMMPState current) {
            if (current.t >= this.T) {
                throw new RuntimeException();
            }
            int nxt = current.t >= 0 ? this.params.transMtx.nextState(current.state, rand) : this.params.initVec.nextState(rand);
            double w = Math.log(this.params.emiMtx.p(nxt, this.obs[current.t + 1]));
            return Pair.makePair(new HMMPState(current.t + 1, nxt, current), w);
        }
    }

    public static class HMMPState {
        public final int t;
        public final int state;
        public final HMMPState ancestor;

        public HMMPState(int t, int state, HMMPState ancestor) {
            this.t = t;
            this.state = state;
            this.ancestor = ancestor;
        }
    }
}

