/*
 * Decompiled with CFR 0.152.
 */
package smc;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import java.io.File;
import java.util.List;
import java.util.Random;
import nuts.io.IO;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.TreeSumProd;
import nuts.util.CollUtils;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import smc.DistributedSMC;
import smc.RandomGenerator;

public class TestDistributedSMC
implements Runnable {
    @Option(required=true)
    public DistributedSMC.ParticleOptimizationStyle particleOptimizationStyle;
    @Option(required=true)
    public int N;
    @Option(required=true)
    public int numMachines;
    @Option(required=true)
    public int numRuns;
    GMFct<Integer> moments;
    int T;

    public static void main(String[] args) {
        IO.run(args, new TestDistributedSMC());
    }

    @Override
    public void run() {
        File inputPath = new File("data", "question1B.gmf");
        GMFct<Integer> graphicalModel = GMFctUtils.readChain(inputPath);
        TreeSumProd<Integer> treeSumProduct = new TreeSumProd<Integer>(graphicalModel);
        this.moments = treeSumProduct.moments();
        this.T = graphicalModel.graph().vertexSet().size();
        HmmKernel hmmKernel = new HmmKernel(this.moments);
        HMMProcessor processor = new HMMProcessor(this.N);
        int[][] machineConfig = new int[this.numMachines][2];
        int numParticlesPerMachine = this.N / this.numMachines;
        for (int k = 0; k < this.numMachines; ++k) {
            machineConfig[k][0] = numParticlesPerMachine;
            machineConfig[k][1] = numParticlesPerMachine;
        }
        int totalNext = 0;
        double totalMax = 0.0;
        double totalMean = 0.0;
        for (int i = 0; i < this.numRuns; ++i) {
            LogInfo.logs("Iteration=" + i);
            DistributedSMC dSMC = null;
            Pair<Double, Double> stat = dSMC.getStats();
            double max = stat.getFirst();
            double mean = stat.getSecond();
            int nexts = dSMC.getNexts();
            LogInfo.logs("max=" + max + ", mean=" + mean + ", next=" + nexts);
            totalMax += max;
            totalMean += mean;
            totalNext += nexts;
        }
        LogInfo.logs("Next=" + (double)totalNext / (double)this.numRuns);
        LogInfo.logs("Max=" + totalMax / (double)this.numRuns);
        LogInfo.logs("Mean=" + totalMean / (double)this.numRuns);
    }

    public static class HmmKernel
    implements ParticleKernel<HMMState> {
        GMFct<Integer> moments;

        public HmmKernel(GMFct<Integer> moments) {
            this.moments = moments;
        }

        @Override
        public Pair<HMMState, Double> next(Random rand, HMMState current) {
            int t = 0;
            if (current != null) {
                t = current.time;
            }
            try {
                int numStates = this.moments.nStates(t);
                double[] probs = new double[numStates];
                double sum = 0.0;
                for (int i = 0; i < numStates; ++i) {
                    probs[i] = t == 0 ? this.moments.get(0, i) : this.moments.get(t - 1, t, current.sample, i);
                    sum += probs[i];
                }
                int sampled = RandomGenerator.discreteMultinomial(rand, probs, sum);
                HMMState newState = new HMMState(t + 1, sampled);
                double weight = this.moments.get(t, sampled);
                return new Pair<HMMState, Double>(newState, weight);
            }
            catch (Exception ex) {
                LogInfo.logs("time " + t);
                System.exit(0);
                return null;
            }
        }

        @Override
        public int nIterationsLeft(HMMState partialState) {
            return 0;
        }

        @Override
        public HMMState getInitial() {
            return new HMMState(0, null);
        }
    }

    public static class HMMState {
        int time;
        Integer sample;

        public HMMState(int time, Integer sample) {
            this.time = time;
            this.sample = sample;
        }

        public String toString() {
            return this.sample.toString();
        }

        public boolean equals(Object other) {
            HMMState obj = (HMMState)other;
            return obj.sample.intValue() == this.sample.intValue();
        }
    }

    public static class HMMProcessor
    implements ParticleFilter.ParticleProcessor<HMMState> {
        public List<HMMState> states = CollUtils.list();
        public List<Double> estimates = CollUtils.list();
        public int N;

        public HMMProcessor(int N) {
            this.N = N;
        }

        @Override
        public void process(HMMState state, double weight) {
            this.states.add(state);
            if (this.states.size() == this.N) {
                this.computeEstimate();
                this.states = CollUtils.list();
            }
        }

        public double getEstimate(int t) {
            return this.estimates.get(t);
        }

        private void computeEstimate() {
            int total = 0;
            for (HMMState state : this.states) {
                total += state.sample.intValue();
            }
            this.estimates.add(new Double((double)total / (double)this.N));
        }
    }
}

