/*
 * Decompiled with CFR 0.152.
 */
package cs281;

import fig.basic.ListUtils;
import fig.prob.Gaussian;
import fig.prob.SampleUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import nuts.io.IO;

public class GauHMMEM {
    public static Random rand = new Random(1L);
    private Param initialization;
    private Param currentParam;
    private SuffStat suffStat;
    private Observations obs;

    public static void main(String[] args) throws IOException {
        GauHMMEM.testEM();
    }

    public static void testEM() throws IOException {
        Observations obs = Observations.q1Obs(8);
        Param init = Param.symInitParam(0.5, 2.0, 2.0);
        GauHMMEM em = new GauHMMEM(init, obs);
        em.compute(10);
        System.out.println(em.toString());
        State maxConf = em.maxConf();
        System.out.println(maxConf);
    }

    public GauHMMEM(Param initialization, Observations obs) {
        this.initialization = initialization;
        this.currentParam = initialization;
        this.obs = obs;
    }

    public void compute(int maxIters) {
        for (int iter = 0; iter < maxIters; ++iter) {
            System.out.println("Iteration " + iter + " is starting.");
            this.compute();
        }
    }

    public State maxConf() {
        ExactEStep e = new ExactEStep();
        e.init(this.obs, this.currentParam);
        State result = new State(e.computeMaxConfiguration(), this.obs, this.currentParam.getNStates());
        return result;
    }

    private void compute() {
        ExactEStep e = new ExactEStep();
        e.init(this.obs, this.currentParam);
        this.suffStat = e.compute();
        MStep m = new MStep(this.suffStat);
        this.currentParam = m.compute();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("Init: \n" + this.initialization.toString());
        builder.append("Estimate: \n" + this.currentParam.toString());
        return builder.toString();
    }

    public static void testMStep() {
        Param truth = Param.symInitParam(0.05, 10.0, 5.0);
        State sample = State.randomState(truth, 15);
        int NNodes = sample.getNumberOfNodes();
        int NStates = 2;
        double[][][] transSuffStat = new double[NNodes][NStates][NStates];
        double[][] emiSuffStat = new double[NNodes][NStates];
        for (int i = 0; i < NNodes; ++i) {
            int currentState = sample.getHiddenStates(i);
            if (i != 0) {
                int parent = sample.getObservations().parent(i);
                int parentState = sample.getHiddenStates(parent);
                transSuffStat[i][parentState][currentState] = 1.0;
            }
            emiSuffStat[i][currentState] = 1.0;
        }
        SuffStat suffStat = new SuffStat(transSuffStat, emiSuffStat, sample.getObservations());
        MStep m = new MStep(suffStat);
        Param MLE = m.compute();
        System.out.println("Truth was : \n" + truth.toString());
        System.out.println("MLE is : \n" + MLE.toString());
    }

    public static class MStep {
        private SuffStat suffStat;
        private int NStates;

        public MStep(SuffStat suffStat) {
            this.suffStat = suffStat;
            this.NStates = suffStat.nStates();
        }

        public Param compute() {
            double alpha = this.computeAlpha();
            double[] mu = this.computeMu();
            double[] sigma = this.computeSigma();
            return new Param(alpha, mu, sigma);
        }

        private double[] computeSigma() {
            double[] sigma = new double[this.NStates];
            for (int s = 0; s < this.NStates; ++s) {
                sigma[s] = this.suffStat.getSumSqEmissionSuffStat(s) / this.suffStat.getNSuffStat(s) - this.suffStat.getSumEmissionSuffStat(s) / this.suffStat.getNSuffStat(s) * (this.suffStat.getSumEmissionSuffStat(s) / this.suffStat.getNSuffStat(s));
            }
            return sigma;
        }

        private double[] computeMu() {
            double[] mu = new double[this.NStates];
            for (int s = 0; s < this.NStates; ++s) {
                mu[s] = this.suffStat.getSumEmissionSuffStat(s) / this.suffStat.getNSuffStat(s);
            }
            return mu;
        }

        private double computeAlpha() {
            assert (this.NStates == 2);
            double numberEqual = 0.0;
            double numberNotEqual = 0.0;
            for (int s1 = 0; s1 < this.NStates; ++s1) {
                for (int s2 = 0; s2 < this.NStates; ++s2) {
                    if (s1 == s2) {
                        numberEqual += this.suffStat.getTransitionSuffStat(s1, s2);
                        continue;
                    }
                    numberNotEqual += this.suffStat.getTransitionSuffStat(s1, s2);
                }
            }
            return numberNotEqual / (numberNotEqual + numberEqual);
        }
    }

    public static class Param {
        private double alpha;
        private double[] mu;
        private double[] sigma;
        private int nStates = 2;

        public Param(double alpha, double[] mu, double[] sigma) {
            assert (alpha <= 1.0 && alpha >= 0.0);
            assert (mu.length == sigma.length);
            for (int i = 0; i < sigma.length; ++i) {
                assert (sigma[i] > 0.0);
            }
            this.alpha = alpha;
            this.mu = mu;
            this.sigma = sigma;
        }

        public static Param q1Param() {
            double alpha = 0.1;
            double[] mu = new double[]{1.0, -1.0};
            double[] sigma = new double[]{1.0, 1.0};
            return new Param(alpha, mu, sigma);
        }

        public static Param symInitParam(double alpha, double inMu, double inSigma) {
            double[] mu = new double[]{-inMu, inMu};
            double[] sigma = new double[]{inSigma, inSigma};
            return new Param(alpha, mu, sigma);
        }

        public double getAlpha() {
            return this.alpha;
        }

        public double getMu(int state) {
            return this.mu[state];
        }

        public double getSigma(int state) {
            return this.sigma[state];
        }

        public double getPotential(int state1, int state2) {
            if (state1 == state2) {
                return 1.0 - this.alpha;
            }
            return this.alpha;
        }

        public double getLikelihood(int state, double y) {
            return Math.exp(Gaussian.logProb(this.mu[state], this.sigma[state], y));
        }

        public int getNStates() {
            return this.nStates;
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            builder.append("Alpha: " + this.getAlpha() + "\n");
            builder.append("Mu: " + Arrays.toString(this.mu) + "\n");
            builder.append("Sigma: " + Arrays.toString(this.sigma) + "\n");
            return builder.toString();
        }
    }

    public static class SuffStat {
        private double[][] transitionSuffStat;
        private double[] sumEmissionSuffStat;
        private double[] sumSqEmissionSuffStat;
        private double[] NSuffStat;
        private int NStates;
        private int NNodes;

        public SuffStat(double[][][] transPost, double[][] emiPost, Observations obs) {
            this.NStates = emiPost[0].length;
            this.NNodes = obs.numberOfNodes();
            this.sumTransPost(transPost);
            this.sumEmi(emiPost, obs);
        }

        public int nStates() {
            return this.NStates;
        }

        private void sumEmi(double[][] emiPost, Observations obs) {
            this.sumEmissionSuffStat = new double[this.NStates];
            this.sumSqEmissionSuffStat = new double[this.NStates];
            this.NSuffStat = new double[this.NStates];
            for (int state = 0; state < this.NStates; ++state) {
                for (int nodeIndex = 0; nodeIndex < this.NNodes; ++nodeIndex) {
                    this.NSuffStat[state] = this.NSuffStat[state] + emiPost[nodeIndex][state];
                    this.sumEmissionSuffStat[state] = this.sumEmissionSuffStat[state] + emiPost[nodeIndex][state] * obs.getData(nodeIndex);
                    this.sumSqEmissionSuffStat[state] = this.sumSqEmissionSuffStat[state] + emiPost[nodeIndex][state] * obs.getData(nodeIndex) * obs.getData(nodeIndex);
                }
            }
        }

        private void sumTransPost(double[][][] transPost) {
            this.transitionSuffStat = new double[this.NStates][this.NStates];
            for (int parentState = 0; parentState < this.NStates; ++parentState) {
                for (int currentState = 0; currentState < this.NStates; ++currentState) {
                    double sum = 0.0;
                    for (int nodeIndex = 1; nodeIndex < this.NNodes; ++nodeIndex) {
                        sum += transPost[nodeIndex][parentState][currentState];
                    }
                    this.transitionSuffStat[parentState][currentState] = sum;
                }
            }
        }

        public double getNSuffStat(int i) {
            return this.NSuffStat[i];
        }

        public double getSumEmissionSuffStat(int i) {
            return this.sumEmissionSuffStat[i];
        }

        public double getSumSqEmissionSuffStat(int i) {
            return this.sumSqEmissionSuffStat[i];
        }

        public double getTransitionSuffStat(int i, int j) {
            return this.transitionSuffStat[i][j];
        }
    }

    public static class Observations {
        private double[] data;
        public static String q1ObsPath = "test/hw5-1.data";

        public Observations(double[] inData) {
            this.data = inData;
        }

        public Observations(int size) {
            this.data = new double[size];
        }

        public static Observations q1Obs(int limitDepth) throws IOException {
            int maxNNodes = (int)Math.pow(2.0, limitDepth) - 1;
            ArrayList<Double> obs = new ArrayList<Double>();
            for (String line : IO.i(q1ObsPath)) {
                obs.add(Double.parseDouble(line));
            }
            double[] obsArray = new double[Math.min(obs.size(), maxNNodes)];
            for (int i = 0; i < obsArray.length && i < maxNNodes; ++i) {
                obsArray[i] = (Double)obs.get(i);
            }
            return new Observations(obsArray);
        }

        public static Observations q1Obs() throws IOException {
            return Observations.q1Obs(Integer.MAX_VALUE);
        }

        public void setData(int index, double value) {
            this.data[index] = value;
        }

        public double getData(int index) {
            return this.data[index];
        }

        public int brother(int nodeIndex) {
            int parentIndex = this.parent(nodeIndex);
            int brotherIndex = this.child(parentIndex, 0);
            if (brotherIndex == nodeIndex) {
                brotherIndex = this.child(parentIndex, 1);
            }
            return brotherIndex;
        }

        public boolean isRoot(int node) {
            return node == 0;
        }

        public int parent(int node) {
            return (node + 1) / 2 - 1;
        }

        public boolean hasChild(int node) {
            return this.data.length > 2 * (node + 1) + 1 - 1;
        }

        public int child(int node, int childIndex) {
            assert (childIndex == 0 || childIndex == 1);
            return 2 * (node + 1) + childIndex - 1;
        }

        public int numberOfNodes() {
            return this.data.length;
        }
    }

    public static class State {
        private int[] hiddenStates;
        private Observations obs;
        private int numberOfStates;

        public State(int[] hiddenStates, Observations obs, int numberOfStates) {
            this.hiddenStates = hiddenStates;
            this.obs = obs;
            this.numberOfStates = numberOfStates;
        }

        public static State randomState(Param param, int size) {
            int numberOfNodes = (int)Math.pow(2.0, size) - 1;
            Observations obs = new Observations(numberOfNodes);
            int[] hiddenStates = new int[numberOfNodes];
            for (int nodeIndex = 0; nodeIndex < numberOfNodes; ++nodeIndex) {
                if (nodeIndex == 0) {
                    hiddenStates[0] = 0;
                } else {
                    double[] prs;
                    int oldState = hiddenStates[obs.parent(nodeIndex)];
                    prs = new double[]{param.getAlpha(), 1.0 - prs[0]};
                    int decision = SampleUtils.sampleMultinomial(rand, prs);
                    hiddenStates[nodeIndex] = decision == 0 ? ((2 * oldState - 1) * -1 + 1) / 2 : oldState;
                }
                int currentState = hiddenStates[nodeIndex];
                double newObs = Gaussian.sample(rand, param.getMu(currentState), param.getSigma(currentState));
                obs.setData(nodeIndex, newObs);
            }
            return new State(hiddenStates, obs, 2);
        }

        public int getNumberOfNodes() {
            return this.hiddenStates.length;
        }

        public boolean validState(int index) {
            return index >= 0 && index < this.numberOfStates;
        }

        public int getHiddenStates(int nodeIndex) {
            return this.hiddenStates[nodeIndex];
        }

        public void setHiddenStates(int nodeIndex, int newState) {
            assert (this.validState(newState));
            this.hiddenStates[nodeIndex] = newState;
        }

        public Observations getObservations() {
            return this.obs;
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            assert (this.numberOfStates == 2);
            builder.append(this.aStateToString(0) + "\n");
            builder.append((CharSequence)this.aStateToString(1));
            return builder.toString();
        }

        public StringBuilder aStateToString(int requestedState) {
            StringBuilder builder = new StringBuilder();
            for (int node = 0; node < this.obs.numberOfNodes(); ++node) {
                if (this.hiddenStates[node] != requestedState) continue;
                int i = node + 1;
                double d = Math.floor(Math.log(i) / Math.log(2.0));
                double x = ((double)i - 3.0 * Math.pow(2.0, d - 1.0) + 0.5) * (double)(this.obs.numberOfNodes() + 1) * Math.pow(2.0, 3.0 - d);
                double y = -d;
                builder.append("" + x + ", " + y + "; ");
            }
            return builder;
        }
    }

    public static class ExactEStep
    implements EStep {
        private Observations observations;
        private Param param;
        private int NStates;
        private int NNodes;
        private double[][] I;
        private double[][] M;
        private int[][] CAM0;
        private int[][] CAM1;
        private double[][] O;
        private int[] maxConfiguration;

        @Override
        public SuffStat compute() {
            this.inside(0, false);
            this.outside(0);
            assert (this.consistent());
            return this.normalize();
        }

        public int[] computeMaxConfiguration() {
            this.inside(0, true);
            this.maxConf(0, -1);
            return this.maxConfiguration;
        }

        private double likelihood(int state, int nodeIndex) {
            return this.param.getLikelihood(state, this.observations.getData(nodeIndex));
        }

        private void maxConf(int nodeIndex, int currentMaximizingState) {
            if (this.observations.isRoot(nodeIndex)) {
                currentMaximizingState = ListUtils.maxIndex(this.M[nodeIndex]);
            }
            this.maxConfiguration[nodeIndex] = currentMaximizingState;
            if (this.observations.hasChild(nodeIndex)) {
                int child0 = this.observations.child(nodeIndex, 0);
                int child1 = this.observations.child(nodeIndex, 1);
                int bp0 = this.CAM0[nodeIndex][currentMaximizingState];
                int pb1 = this.CAM1[nodeIndex][currentMaximizingState];
                this.maxConf(child0, bp0);
                this.maxConf(child1, pb1);
            }
        }

        private void inside(int nodeIndex, boolean viterbi) {
            if (!this.observations.hasChild(nodeIndex)) {
                for (int currentState = 0; currentState < this.NStates; ++currentState) {
                    if (viterbi) {
                        this.M[nodeIndex][currentState] = 1.0;
                        this.CAM0[nodeIndex][currentState] = -1;
                        this.CAM1[nodeIndex][currentState] = -1;
                        continue;
                    }
                    this.I[nodeIndex][currentState] = 1.0;
                }
                return;
            }
            int childIndex0 = this.observations.child(nodeIndex, 0);
            int childIndex1 = this.observations.child(nodeIndex, 1);
            this.inside(childIndex0, viterbi);
            this.inside(childIndex1, viterbi);
            for (int currentState = 0; currentState < this.NStates; ++currentState) {
                double sum = 0.0;
                if (viterbi) {
                    this.M[nodeIndex][currentState] = Double.NEGATIVE_INFINITY;
                }
                for (int childState0 = 0; childState0 < this.NStates; ++childState0) {
                    for (int childState1 = 0; childState1 < this.NStates; ++childState1) {
                        double prod0 = this.param.getPotential(currentState, childState0) * this.likelihood(childState0, childIndex0);
                        double prod1 = this.param.getPotential(currentState, childState1) * this.likelihood(childState1, childIndex1);
                        if (viterbi) {
                            double prod = (prod0 *= this.M[childIndex0][childState0]) * (prod1 *= this.M[childIndex1][childState1]);
                            if (!(prod > this.M[nodeIndex][currentState])) continue;
                            this.M[nodeIndex][currentState] = prod;
                            this.CAM0[nodeIndex][currentState] = childState0;
                            this.CAM1[nodeIndex][currentState] = childState1;
                            continue;
                        }
                        assert ((prod0 *= this.I[childIndex0][childState0]) * (prod1 *= this.I[childIndex1][childState1]) > 0.0);
                        sum += prod0 * prod1;
                    }
                }
                if (viterbi) continue;
                this.I[nodeIndex][currentState] = sum;
            }
        }

        private void outside(int nodeIndex) {
            if (this.observations.isRoot(nodeIndex)) {
                for (int currentState = 0; currentState < this.NStates; ++currentState) {
                    this.O[nodeIndex][currentState] = 1.0;
                }
            } else {
                int parentIndex = this.observations.parent(nodeIndex);
                int brotherIndex = this.observations.brother(nodeIndex);
                for (int currentState = 0; currentState < this.NStates; ++currentState) {
                    double sum = 0.0;
                    for (int parentState = 0; parentState < this.NStates; ++parentState) {
                        for (int brotherState = 0; brotherState < this.NStates; ++brotherState) {
                            double prod0 = this.param.getPotential(currentState, parentState) * this.likelihood(parentState, parentIndex) * this.O[parentIndex][parentState];
                            double prod1 = this.param.getPotential(parentState, brotherState) * this.likelihood(brotherState, brotherIndex) * this.I[brotherIndex][brotherState];
                            assert (prod0 * prod1 > 0.0);
                            sum += prod0 * prod1;
                        }
                    }
                    this.O[nodeIndex][currentState] = sum;
                }
            }
            if (this.observations.hasChild(nodeIndex)) {
                int childIndex0 = this.observations.child(nodeIndex, 0);
                int childIndex1 = this.observations.child(nodeIndex, 1);
                this.outside(childIndex0);
                this.outside(childIndex1);
            }
        }

        private boolean consistent() {
            double norm = -1.0;
            for (int nodeIndex = 0; nodeIndex < this.NNodes; ++nodeIndex) {
                double sum = 0.0;
                for (int state = 0; state < this.NStates; ++state) {
                    sum += this.I[nodeIndex][state] * this.O[nodeIndex][state] * this.likelihood(state, nodeIndex);
                }
                if (norm == -1.0) {
                    norm = sum;
                    continue;
                }
                if (!(Math.abs(norm - sum) > 1.0E-5)) continue;
                return false;
            }
            System.out.println(norm);
            return true;
        }

        private SuffStat normalize() {
            double[][][] transPost = new double[this.NNodes][this.NStates][this.NStates];
            for (int currentNodeIndex = 1; currentNodeIndex < this.NNodes; ++currentNodeIndex) {
                int parentState;
                int currentState;
                int parentIndex = this.observations.parent(currentNodeIndex);
                int brotherIndex = this.observations.brother(currentNodeIndex);
                double norm = 0.0;
                for (currentState = 0; currentState < this.NStates; ++currentState) {
                    for (parentState = 0; parentState < this.NStates; ++parentState) {
                        double sum = 0.0;
                        for (int brotherState = 0; brotherState < this.NStates; ++brotherState) {
                            sum += this.param.getPotential(parentState, brotherState) * this.likelihood(brotherState, brotherIndex) * this.I[brotherIndex][brotherState];
                        }
                        transPost[currentNodeIndex][parentState][currentState] = sum * this.O[parentIndex][parentState] * this.likelihood(parentState, parentIndex) * this.param.getPotential(currentState, parentState) * this.likelihood(currentState, currentNodeIndex) * this.I[currentNodeIndex][currentState];
                        norm += transPost[currentNodeIndex][parentState][currentState];
                    }
                }
                for (currentState = 0; currentState < this.NStates; ++currentState) {
                    for (parentState = 0; parentState < this.NStates; ++parentState) {
                        transPost[currentNodeIndex][parentState][currentState] = transPost[currentNodeIndex][parentState][currentState] / norm;
                    }
                }
            }
            double[][] emiPost = new double[this.NNodes][this.NStates];
            for (int currentNodeIndex = 0; currentNodeIndex < this.NNodes; ++currentNodeIndex) {
                int currentState;
                double norm = 0.0;
                for (currentState = 0; currentState < this.NStates; ++currentState) {
                    emiPost[currentNodeIndex][currentState] = this.O[currentNodeIndex][currentState] * this.I[currentNodeIndex][currentState] * this.likelihood(currentState, currentNodeIndex);
                    norm += emiPost[currentNodeIndex][currentState];
                }
                for (currentState = 0; currentState < this.NStates; ++currentState) {
                    emiPost[currentNodeIndex][currentState] = emiPost[currentNodeIndex][currentState] / norm;
                }
            }
            return new SuffStat(transPost, emiPost, this.observations);
        }

        @Override
        public void init(Observations observations, Param param) {
            this.observations = observations;
            this.param = param;
            this.I = new double[observations.numberOfNodes()][param.getNStates()];
            this.O = new double[observations.numberOfNodes()][param.getNStates()];
            this.M = new double[observations.numberOfNodes()][param.getNStates()];
            this.CAM0 = new int[observations.numberOfNodes()][param.getNStates()];
            this.CAM1 = new int[observations.numberOfNodes()][param.getNStates()];
            this.maxConfiguration = new int[observations.numberOfNodes()];
            this.NStates = param.getNStates();
            this.NNodes = observations.numberOfNodes();
        }

        @Override
        public double logll() {
            return 0.0;
        }
    }

    public static interface EStep {
        public SuffStat compute();

        public double logll();

        public void init(Observations var1, Param var2);
    }
}

