/*
 * Decompiled with CFR 0.152.
 */
package smc;

import fig.basic.LogInfo;
import java.io.File;
import java.util.List;
import java.util.TreeSet;
import nuts.io.IO;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.TreeSumProd;
import smc.Model;
import smc.RandomGenerator;
import smc.SMC;

public class TestSMC2
implements Runnable {
    GMFct<Integer> moments;
    int T;
    int N;

    public static void main(String[] args) {
        IO.run(args, new TestSMC2());
    }

    public TestSMC2() {
        File inputPath = new File("data", "question1B.gmf");
        GMFct<Integer> graphicalModel = GMFctUtils.readChain(inputPath);
        TreeSumProd<Integer> treeSumProduct = new TreeSumProd<Integer>(graphicalModel);
        this.moments = treeSumProduct.moments();
        this.T = graphicalModel.graph().vertexSet().size();
        this.N = 1000;
    }

    @Override
    public void run() {
        HMMDiscrete hmm = new HMMDiscrete(this.moments);
        LogInfo.logs("N=" + this.N);
        SMC<Integer> smc = SMC.runSMC(this.T, this.N, hmm);
        for (int t = 0; t < this.T; ++t) {
            double sum = 0.0;
            for (int n = 0; n < this.N; ++n) {
                SMC.Particle<Integer> p = smc.getParticle(t, n);
                sum += (double)((Integer)p.sample).intValue();
            }
            int numStates = this.moments.nStates(t);
            double actual = 0.0;
            for (int i = 0; i < numStates; ++i) {
                actual += (double)i * this.moments.get(t, i);
            }
            LogInfo.logs("est=" + sum / (double)this.N + " actual=" + actual);
        }
        TreeSet<Integer> survivors = new TreeSet<Integer>();
        for (int n = 0; n < this.N; ++n) {
            SMC.Particle<Integer> p = smc.getParticle(this.T - 1, n);
            survivors.add(p.id);
        }
        for (int t = this.T - 1; t >= 1; --t) {
            TreeSet<Integer> temp = new TreeSet<Integer>();
            StringBuffer sb = new StringBuffer();
            for (Integer survivor : survivors) {
                sb.append(survivor + " ");
                SMC.Particle<Integer> p = smc.getParticle(t, survivor);
                temp.add(p.parent.id);
            }
            LogInfo.logs("# of survivors at generation " + t + "=" + survivors.size());
            LogInfo.logs(sb.toString());
            survivors = temp;
            if (survivors.size() == 1) break;
        }
    }

    public static class HMMDiscrete
    implements Model<Integer> {
        GMFct<Integer> moments;

        public HMMDiscrete(GMFct<Integer> moments) {
            this.moments = moments;
        }

        @Override
        public Integer generateSample(long seed, int t, List<Integer> xprev) {
            int numStates = this.moments.nStates(t);
            double[] probs = new double[numStates];
            double sum = 0.0;
            for (int i = 0; i < numStates; ++i) {
                probs[i] = t == 0 ? this.moments.get(0, i) : this.moments.get(t - 1, t, xprev.get(t - 1), i);
                sum += probs[i];
            }
            int sampled = RandomGenerator.discreteMultinomial(seed, probs, sum);
            return sampled;
        }

        @Override
        public double evaluateWeight(int t, List<Integer> xprev, Integer xcurr) {
            double weight = this.moments.get(t, xcurr);
            return weight;
        }
    }
}

