/*
 * Decompiled with CFR 0.152.
 */
package cs281;

import fig.basic.NumUtils;
import fig.prob.Multinomial;
import fig.prob.SampleUtils;
import java.util.Arrays;
import java.util.Random;
import nuts.lang.ArrayUtils;
import nuts.util.MathUtils;

public class PoiHMMEM {
    public static Random rand = new Random(1L);
    private static int length = 50;
    private static int m;
    private static int k;
    private Param currentParam;
    private SuffStat currentPost;
    private int[] observations;
    public static Param trueParam;
    public static State sample;
    public static State eval;

    public PoiHMMEM(Param init, int[] inObservations) {
        this.currentParam = init;
        this.observations = inObservations;
    }

    public void compute(int iterations) {
        for (int i = 0; i < iterations; ++i) {
            System.out.println("Iteration " + i + " is starting.");
            this.compute();
        }
    }

    private void compute() {
        GibbsEStep ge = new GibbsEStep(this.currentParam, this.observations, sample);
        EStep e = new EStep(this.currentParam, this.observations);
        this.currentPost = e.compute();
        System.out.println("Current LLL: " + e.logll());
        EStep eEval = new EStep(this.currentParam, eval.observations());
        eEval.compute();
        System.out.println("Evaluation LLL: " + eEval.logll());
        MStep m = new MStep(this.currentPost);
        this.currentParam = m.compute();
    }

    public static Param mixInitModel() {
        double[][] trueTau = ArrayUtils.parseMtx("1.0");
        double[][] trueRho = ArrayUtils.parseMtx("0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125");
        double[] truePi = new double[]{1.0};
        double[][] trueLambda = ArrayUtils.parseMtx("1 10 20 30 40 50 60 70");
        return new Param(trueTau, trueRho, truePi, trueLambda);
    }

    public static Param toyModel() {
        double[][] trueTau = ArrayUtils.parseMtx("0.0 1.0;1.0 0.0");
        double[][] trueRho = ArrayUtils.parseMtx("1.0;1.0");
        double[] truePi = new double[]{1.0, 0.0};
        double[][] trueLambda = ArrayUtils.parseMtx("1;10");
        return new Param(trueTau, trueRho, truePi, trueLambda);
    }

    public static Param toyModelInit() {
        double[][] initTau = ArrayUtils.parseMtx("0.6 0.4;0.7 0.3");
        double[][] initLambda = ArrayUtils.parseMtx("2;1");
        double[][] trueRho = ArrayUtils.parseMtx("1.0;1.0");
        double[] truePi = new double[]{1.0, 0.0};
        return new Param(initTau, trueRho, truePi, initLambda);
    }

    public static Param bigModel() {
        double[][] trueTau = ArrayUtils.parseMtx("0.05    0.9     0.05;0.05    0.05     0.9;0.9    0.05     0.05");
        double[][] trueRho = ArrayUtils.parseMtx("0.5    0.5;0.5    0.5;0.5    0.5");
        double[] truePi = ArrayUtils.parseMtx("0.33333 0.33333 0.33333")[0];
        double[][] trueLambda = ArrayUtils.parseMtx(".5      17;35     125;185     255");
        return new Param(trueTau, trueRho, truePi, trueLambda);
    }

    public static Param bigInitModel() {
        double[][] trueTau = ArrayUtils.parseMtx("0.33333  0.33333     0.33333;0.33333  0.33333     0.33333;0.33333  0.33333     0.33333");
        double[][] trueRho = ArrayUtils.parseMtx("0.5    0.5;0.5    0.5;0.5    0.5");
        double[] truePi = ArrayUtils.parseMtx("0.33333 0.33333 0.33333")[0];
        double[][] trueLambda = ArrayUtils.parseMtx("1      5;50     100;200    300");
        return new Param(trueTau, trueRho, truePi, trueLambda);
    }

    public static void main(String[] args) {
        m = 3;
        k = 2;
        trueParam = PoiHMMEM.bigModel();
        sample = PoiHMMEM.randomState(new Param(trueParam.tau, trueParam.rho, trueParam.pi, trueParam.lambda), length);
        eval = PoiHMMEM.randomState(new Param(trueParam.tau, trueParam.rho, trueParam.pi, trueParam.lambda), length);
        System.out.println(sample.toString());
        System.out.println("---");
        System.out.println(eval.toString());
        System.out.println("---");
        Param init = PoiHMMEM.bigInitModel();
        PoiHMMEM em = new PoiHMMEM(init, sample.observations());
        em.compute(10);
        System.out.println("---");
        System.out.println("Init was:");
        System.out.println(init.toString());
        System.out.println("Truth was:");
        System.out.println(trueParam.toString());
        System.out.println("Final it gave:");
        System.out.println(em.currentParam.toString());
    }

    public static State randomState(Param param, int length) {
        State s = new State();
        s.q = new int[length];
        s.p = new int[length];
        s.y = new int[length];
        for (int t = 0; t < length; ++t) {
            s.q[t] = t == 0 ? param.samplePi() : param.sampleTau(s.q[t - 1]);
            s.p[t] = param.sampleRho(s.q[t]);
            s.y[t] = param.sampleLambda(s.q[t], s.p[t]);
        }
        return s;
    }

    public static void testMStep() {
        int t;
        Param param;
        trueParam = param = PoiHMMEM.bigModel();
        m = 4;
        k = 2;
        length = 10000;
        State sample = PoiHMMEM.randomState(param, length);
        double[][][] transitionPost = new double[length][m][m];
        double[][][] mixingPost = new double[length][m][k];
        for (t = 0; t < length - 1; ++t) {
            transitionPost[t][sample.q[t]][sample.q[t + 1]] = 1.0;
        }
        for (t = 0; t < length; ++t) {
            mixingPost[t][sample.q[t]][sample.p[t]] = 1.0;
        }
        SuffStat post = new SuffStat(transitionPost, mixingPost, sample.observations());
        MStep m = new MStep(post);
        Param estimate = m.compute();
        System.out.println("Truth:");
        System.out.println(trueParam.toString());
        System.out.println("Estimate:");
        System.out.println(estimate.toString());
    }

    public class EStep {
        private Param currentParam;
        private int[] observations;
        private double[][] L = new double[PoiHMMEM.access$400()][PoiHMMEM.access$800()];
        private double[][] M = new double[PoiHMMEM.access$400()][PoiHMMEM.access$800()];
        private double[][] N = new double[PoiHMMEM.access$400()][PoiHMMEM.access$800()];

        public EStep(Param inParams, int[] inObservations) {
            this.currentParam = inParams;
            this.observations = inObservations;
        }

        public SuffStat compute() {
            this.observations();
            this.forward();
            this.backward();
            assert (this.consistent());
            return this.normalization();
        }

        public double logll() {
            double sum = 0.0;
            for (int q = 0; q < m; ++q) {
                sum += this.L[0][q] * this.M[0][q] * this.N[0][q];
            }
            return Math.log(sum);
        }

        private boolean consistent() {
            double norm = -1.0;
            for (int t = 0; t < length; ++t) {
                double sum = 0.0;
                for (int q = 0; q < m; ++q) {
                    sum += this.L[t][q] * this.M[t][q] * this.N[t][q];
                }
                if (norm == -1.0) {
                    norm = sum;
                    continue;
                }
                if (!(Math.abs(norm - sum) > 1.0E-5)) continue;
                return false;
            }
            return true;
        }

        private SuffStat normalization() {
            int q;
            double norm;
            int t;
            double[][][] transitionPost = new double[length][m][m];
            double[][][] mixingPost = new double[length][m][k];
            for (t = 0; t < length - 1; ++t) {
                int q2;
                norm = 0.0;
                for (q = 0; q < m; ++q) {
                    for (q2 = 0; q2 < m; ++q2) {
                        transitionPost[t][q][q2] = this.N[t][q] * this.L[t][q] * this.currentParam.tau[q][q2] * this.L[t + 1][q2] * this.M[t + 1][q2];
                        norm += transitionPost[t][q][q2];
                    }
                }
                for (q = 0; q < m; ++q) {
                    for (q2 = 0; q2 < m; ++q2) {
                        transitionPost[t][q][q2] = transitionPost[t][q][q2] / norm;
                    }
                }
            }
            for (t = 0; t < length; ++t) {
                int p;
                norm = 0.0;
                for (q = 0; q < m; ++q) {
                    for (p = 0; p < k; ++p) {
                        mixingPost[t][q][p] = this.N[t][q] * this.M[t][q] * this.currentParam.rho[q][p] * this.currentParam.lambda(q, p, this.observations[t]);
                        norm += mixingPost[t][q][p];
                    }
                }
                for (q = 0; q < m; ++q) {
                    for (p = 0; p < k; ++p) {
                        mixingPost[t][q][p] = mixingPost[t][q][p] / norm;
                    }
                }
            }
            return new SuffStat(transitionPost, mixingPost, this.observations);
        }

        private void backward() {
            for (int t = length - 1; t >= 0; --t) {
                for (int q = 0; q < m; ++q) {
                    if (t == length - 1) {
                        this.M[t][q] = 1.0;
                        continue;
                    }
                    double sum = 0.0;
                    for (int q2 = 0; q2 < m; ++q2) {
                        sum += this.currentParam.tau[q][q2] * this.L[t + 1][q2] * this.M[t + 1][q2];
                    }
                    this.M[t][q] = sum;
                }
            }
        }

        private void forward() {
            for (int t = 0; t < length; ++t) {
                for (int q = 0; q < m; ++q) {
                    if (t == 0) {
                        this.N[t][q] = this.currentParam.pi[q];
                        continue;
                    }
                    double sum = 0.0;
                    for (int q2 = 0; q2 < m; ++q2) {
                        sum += this.currentParam.tau[q2][q] * this.N[t - 1][q2] * this.L[t - 1][q2];
                    }
                    this.N[t][q] = sum;
                }
            }
        }

        private void observations() {
            for (int t = 0; t < length; ++t) {
                for (int q = 0; q < m; ++q) {
                    double sum = 0.0;
                    for (int p = 0; p < k; ++p) {
                        sum += this.currentParam.rho[q][p] * this.currentParam.lambda(q, p, this.observations[t]);
                    }
                    this.L[t][q] = sum;
                }
            }
        }
    }

    public class GibbsEStep {
        private Param currentParam;
        private int[] observations;
        private double[][][] transitionPost = new double[PoiHMMEM.access$400()][PoiHMMEM.access$800()][PoiHMMEM.access$800()];
        private double[][][] mixingPost = new double[PoiHMMEM.access$400()][PoiHMMEM.access$800()][PoiHMMEM.access$900()];
        private State currentState;
        private double N = 0.0;

        public GibbsEStep(Param inParams, int[] inObservations, State init) {
            this.currentParam = inParams;
            this.observations = inObservations;
            this.currentState = init;
        }

        public SuffStat compute() {
            int burnIn = 100;
            for (int iter = 0; iter < 1000; ++iter) {
                this.sampleBackBone();
                this.sampleMixtureIndicators();
                if (iter <= burnIn) continue;
                this.collectStats();
            }
            return this.normalize();
        }

        private SuffStat normalize() {
            for (int t = 0; t < length; ++t) {
                for (int q = 0; q < m; ++q) {
                    for (int q2 = 0; q2 < m; ++q2) {
                        this.transitionPost[t][q][q2] = this.transitionPost[t][q][q2] / this.N;
                    }
                    for (int p = 0; p < k; ++p) {
                        this.mixingPost[t][q][p] = this.mixingPost[t][q][p] / this.N;
                    }
                }
            }
            return new SuffStat(this.transitionPost, this.mixingPost, this.observations);
        }

        private void collectStats() {
            this.N += 1.0;
            for (int t = 0; t < length; ++t) {
                double[] dArray = this.mixingPost[t][this.currentState.q[t]];
                int n = this.currentState.p[t];
                dArray[n] = dArray[n] + 1.0;
                if (t == length - 1) continue;
                double[] dArray2 = this.transitionPost[t][this.currentState.q[t]];
                int n2 = this.currentState.q[t + 1];
                dArray2[n2] = dArray2[n2] + 1.0;
            }
        }

        private void sampleMixtureIndicators() {
            for (int t : SampleUtils.samplePermutation(rand, length)) {
                double[] prs = new double[k];
                for (int p = 0; p < k; ++p) {
                    prs[p] = this.currentParam.rho[this.currentState.q[t]][p] * this.currentParam.lambda(this.currentState.q[t], p, this.observations[t]);
                }
                NumUtils.normalize(prs);
                this.currentState.p[t] = SampleUtils.sampleMultinomial(rand, prs);
            }
        }

        private void sampleBackBone() {
            for (int t : SampleUtils.samplePermutation(rand, length)) {
                double[] prs = new double[m];
                for (int q = 0; q < m; ++q) {
                    prs[q] = this.currentParam.rho[q][this.currentState.p[t]] * this.currentParam.lambda(q, this.currentState.p[t], this.observations[t]);
                    if (t != 0) {
                        prs[q] = prs[q] * this.currentParam.tau[this.currentState.q[t - 1]][q];
                    }
                    if (t == length - 1) continue;
                    prs[q] = prs[q] * this.currentParam.tau[q][this.currentState.q[t + 1]];
                }
                NumUtils.normalize(prs);
                this.currentState.q[t] = SampleUtils.sampleMultinomial(rand, prs);
            }
        }
    }

    public static class MStep {
        private SuffStat post;

        public MStep(SuffStat inPost) {
            this.post = inPost;
        }

        public Param compute() {
            double[][] inTau = this.updateTransitions(this.post.transitionSuffStat);
            double[][] inRho = this.updateTransitions(this.post.mixingSuffStat);
            double[][] inLambda = this.updateLambda();
            double[] inPi = trueParam.pi;
            return new Param(inTau, inRho, inPi, inLambda);
        }

        private double[] updatePi() {
            return this.post.initSuffStat;
        }

        private double[][] updateLambda() {
            double[][] result = new double[m][k];
            for (int q = 0; q < m; ++q) {
                for (int p = 0; p < k; ++p) {
                    result[q][p] = this.post.emissionSuffStat[q][p] / this.post.mixingSuffStat[q][p];
                }
            }
            return result;
        }

        private double[][] updateTransitions(double[][] expectations) {
            int srcSize = expectations.length;
            int destSize = expectations[0].length;
            double[] summedExpectations = new double[srcSize];
            for (int src = 0; src < srcSize; ++src) {
                double sum = 0.0;
                for (int dest = 0; dest < destSize; ++dest) {
                    sum += expectations[src][dest];
                }
                summedExpectations[src] = sum;
            }
            double[][] result = new double[srcSize][destSize];
            for (int src = 0; src < srcSize; ++src) {
                for (int dest = 0; dest < destSize; ++dest) {
                    result[src][dest] = expectations[src][dest] / summedExpectations[src];
                }
            }
            return result;
        }
    }

    public static class Param {
        private double[][] tau;
        private double[][] rho;
        private double[] pi;
        private double[][] lambda;

        public Param(double[][] inTau, double[][] inRho, double[] inPi, double[][] inLambda) {
            this.tau = inTau;
            this.rho = inRho;
            this.pi = inPi;
            this.lambda = inLambda;
            assert (ArrayUtils.condSumsToOne(this.tau));
            assert (ArrayUtils.condSumsToOne(this.rho));
            assert (ArrayUtils.sumsToOne(this.pi));
        }

        public double lambda(int q, int p, int y) {
            double lambda = this.lambda[q][p];
            return MathUtils.poisson(lambda, y);
        }

        public int samplePi() {
            return Multinomial.sample(rand, this.pi);
        }

        public int sampleTau(int conditionOnPrevState) {
            return Multinomial.sample(rand, this.tau[conditionOnPrevState]);
        }

        public int sampleRho(int conditionOnPrevState) {
            return Multinomial.sample(rand, this.rho[conditionOnPrevState]);
        }

        public int sampleLambda(int conditionOnState, int conditionOnMixtureComp) {
            return (int)SampleUtils.samplePoisson(rand, this.lambda[conditionOnState][conditionOnMixtureComp]);
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            builder.append("Lambda:\n");
            builder.append(ArrayUtils.printMtx(this.lambda));
            builder.append("Pi:\n");
            builder.append(Arrays.toString(this.pi) + "\n");
            builder.append("Rho:\n");
            builder.append(ArrayUtils.printMtx(this.rho));
            builder.append("Tau:\n");
            builder.append(ArrayUtils.printMtx(this.tau));
            return builder.toString();
        }
    }

    public static class SuffStat {
        private double[][] transitionSuffStat;
        private double[][] mixingSuffStat;
        private double[] initSuffStat;
        private double[][] emissionSuffStat;
        private int m;
        private int k;

        public SuffStat(double[][][] transitionPost, double[][][] mixingPost, int[] observations) {
            this.m = transitionPost[0].length;
            this.k = mixingPost[0][0].length;
            this.transitionSuffStat = this.sumTransitionPosteriors(transitionPost);
            this.mixingSuffStat = this.sumTransitionPosteriors(mixingPost);
            this.emissionSuffStat = this.sumEmissionPosteriors(mixingPost, observations);
            this.initSuffStat = this.sumInitPosterior(transitionPost[0]);
        }

        private double[] sumInitPosterior(double[][] firstTransitionPost) {
            double[] singleNodePosteriors = new double[this.m];
            for (int q = 0; q < this.m; ++q) {
                double sum = 0.0;
                for (int q2 = 0; q2 < this.m; ++q2) {
                    sum += firstTransitionPost[q][q2];
                }
                singleNodePosteriors[q] = sum;
            }
            return singleNodePosteriors;
        }

        private double[][] sumEmissionPosteriors(double[][][] mixingPost, int[] observations) {
            double[][] result = new double[this.m][this.k];
            for (int q = 0; q < this.m; ++q) {
                for (int p = 0; p < this.k; ++p) {
                    double sum = 0.0;
                    for (int t = 0; t < length; ++t) {
                        sum += (double)observations[t] * mixingPost[t][q][p];
                    }
                    result[q][p] = sum;
                }
            }
            return result;
        }

        private double[][] sumTransitionPosteriors(double[][][] transitionPosteriors) {
            int seqnLength = transitionPosteriors.length;
            int srcSize = transitionPosteriors[0].length;
            int destSize = transitionPosteriors[0][0].length;
            double[][] expectations = new double[srcSize][destSize];
            for (int src = 0; src < srcSize; ++src) {
                for (int dest = 0; dest < destSize; ++dest) {
                    double sum = 0.0;
                    for (int t = 0; t < seqnLength; ++t) {
                        sum += transitionPosteriors[t][src][dest];
                    }
                    expectations[src][dest] = sum;
                }
            }
            return expectations;
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            builder.append("Trans:\n" + ArrayUtils.printMtx(this.transitionSuffStat));
            builder.append("Mixing:\n" + ArrayUtils.printMtx(this.mixingSuffStat));
            builder.append("Init:\n" + Arrays.toString(this.initSuffStat));
            builder.append("Emission:\n" + ArrayUtils.printMtx(this.emissionSuffStat));
            return builder.toString();
        }
    }

    public static class State {
        public int[] q;
        public int[] p;
        public int[] y;

        public int[] observations() {
            return this.y;
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            for (int t = 0; t < length; ++t) {
                builder.append("" + this.y[t] + "\t" + this.p[t] + "\t" + this.q[t] + "\n");
            }
            return builder.toString();
        }
    }
}

