/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.test;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import gep.util.OutputManager;
import java.io.File;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import nuts.io.IO;
import nuts.tui.Table;
import nuts.util.CollUtils;
import nuts.util.Counter;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;

public class SmallExample
implements Runnable {
    @Option
    public File ranks;
    @Option
    public File qFwd;
    @Option
    public File qBwd;
    @Option
    public File gammaFile;
    @Option
    public double p = 0.5;
    @Option
    public int nReplicates = 1000;
    @Option
    public boolean noBack = false;
    private static ParticleFilter<Integer> pf = new ParticleFilter();
    private int nRanks;
    private int nPStates;
    private double[][] fwd;
    private double[][] bwd;
    private Map<Integer, Integer> i2r;
    private Map<Integer, Integer> r2i;
    private double[] gamma;
    private double[] trueDist;
    private Counter<Integer> trueDistC;

    @Override
    public void run() {
        this.load();
        LogInfo.logs("Data loaded");
        OutputManager outMan = new OutputManager();
        for (int n = 0; n < this.nReplicates; ++n) {
            ParticleFilter.ParticleMapperProcessor pp = ParticleFilter.ParticleMapperProcessor.saveParticlesProcessor();
            pf.sample(new SmallKernel(), pp);
            Counter<Integer> approx = pp.getCounter();
            approx.normalize();
            double tvd = SmallExample.totalVariationDist(this.trueDistC, approx);
            double iwv = this.impWeightVariance(approx);
            outMan.printWrite("variationalError", "replicate", n, "variationalError", tvd);
            outMan.printWrite("impWeightVariance", "replicate", n, "impWeightVariance", iwv);
        }
    }

    private double impWeightVariance(Counter<Integer> approx) {
        double mean = approx.totalCount() / (double)approx.size();
        double sum = 0.0;
        for (double w : approx.entries.values()) {
            sum += (w - mean) * (w - mean);
        }
        return sum / (double)approx.size();
    }

    public static <T> double totalVariationDist(Counter<T> c1, Counter<T> c2) {
        double total = 0.0;
        for (Object key : CollUtils.union(c1.keySet(), c2.keySet())) {
            total += Math.abs(c1.getCount(key) - c2.getCount(key));
        }
        return total * 0.5;
    }

    public void load() {
        this.i2r = CollUtils.map();
        this.r2i = CollUtils.map();
        int r = 0;
        for (String line : IO.i(this.ranks)) {
            String[] fields;
            if (line.isEmpty()) continue;
            for (String _i : fields = line.split("[,]")) {
                ++this.nPStates;
                int i = Integer.parseInt(_i);
                this.i2r.put(i, r);
                this.r2i.put(r, i);
            }
            this.nRanks = ++r;
        }
        this.fwd = this.loadTrans(this.qFwd, true);
        this.bwd = this.loadTrans(this.qBwd, false);
        this.gamma = new double[this.nPStates];
        this.trueDist = new double[this.nPStates];
        this.trueDistC = new Counter();
        for (String line : IO.i(this.gammaFile)) {
            double value;
            if (line.isEmpty()) continue;
            String[] split = line.split("\\t");
            int index = Integer.parseInt(split[0]);
            this.gamma[index] = value = Double.parseDouble(split[1]);
            if (this.i2r.get(index) != this.nRanks - 1) continue;
            this.trueDistC.setCount(index, value);
            this.trueDist[index] = value;
        }
        NumUtils.normalize(this.trueDist);
        this.trueDistC.normalize();
        LogInfo.logs("Gamma:" + Arrays.toString(this.gamma));
        LogInfo.logs("True dist:" + Arrays.toString(this.trueDist));
    }

    private double[][] loadTrans(File qFwd, boolean isFwd) {
        double[][] trans = new double[this.nPStates][this.nPStates];
        for (String line : IO.i(qFwd)) {
            if (line.isEmpty()) continue;
            String[] fields = line.split("\\t");
            int f1 = Integer.parseInt(fields[0]);
            int f2 = Integer.parseInt(fields[1]);
            double value = this.parse(fields[2]);
            trans[isFwd ? f1 : f2][isFwd ? f2 : f1] = value;
        }
        NumUtils.normalizeEachRow(trans);
        LogInfo.logs((isFwd ? "Fwd" : "Bwd") + ":\n" + Table.fromMatrix(trans, true, true, false));
        return trans;
    }

    private double parse(String string) {
        if (string.equals("p")) {
            return this.p;
        }
        if (string.equals("q")) {
            return 1.0 - this.p;
        }
        return Double.parseDouble(string);
    }

    public static void main(String[] args) {
        IO.run(args, new SmallExample(), "pf", pf);
    }

    public class SmallKernel
    implements ParticleKernel<Integer> {
        @Override
        public Pair<Integer, Double> next(Random rand, Integer current) {
            int next = SampleUtils.sampleMultinomial(rand, SmallExample.this.fwd[current]);
            double qPlus = SmallExample.this.fwd[current][next];
            double qMinus = SmallExample.this.noBack ? 1.0 : SmallExample.this.bwd[next][current];
            double newGamma = SmallExample.this.gamma[next];
            double oldGamma = SmallExample.this.gamma[current];
            double ratio = newGamma * qMinus / oldGamma / qPlus;
            return Pair.makePair(next, Math.log(ratio));
        }

        @Override
        public int nIterationsLeft(Integer partialState) {
            int rank = (Integer)SmallExample.this.i2r.get(partialState);
            return SmallExample.this.nRanks - rank - 1;
        }

        @Override
        public Integer getInitial() {
            return 0;
        }
    }
}

