/*
 * Decompiled with CFR 0.152.
 */
package stat560;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.exec.Execution;
import fig.prob.DiagMultGaussian;
import fig.prob.SampleUtils;
import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import nuts.io.CSV;
import nuts.io.IO;
import nuts.math.Plot2D;
import nuts.math.Plot3D;
import nuts.maxent.Function;

public class EM
implements Runnable {
    @Option
    public ArrayList<Double> xs = new ArrayList<Double>(Arrays.asList(0.0, 4.0, -2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0));
    @Option
    public ArrayList<Double> ys = new ArrayList<Double>(Arrays.asList(0.0, 4.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0));
    @Option
    public ArrayList<Double> pis = new ArrayList<Double>(Arrays.asList(0.75, 0.2, 0.05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0));
    @Option
    public Random emInitRandom = new Random(1L);
    @Option
    public Random dataGenRandom = new Random(1L);
    @Option
    public int datasetSize = 10;
    @Option
    public int nEMIters = 30;
    @Option
    public boolean useRealData = false;
    MixtureData currentParam;
    private int[][] perms3 = new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}};
    private int[][] perms2 = new int[][]{{0, 1}, {1, 0}};
    private List<double[]> observedData = new ArrayList<double[]>();
    private List<Integer> trueIndicators = new ArrayList<Integer>();
    private List<double[]> heldoutData = new ArrayList<double[]>();
    private List<Integer> heldoutIndic = new ArrayList<Integer>();
    private List<double[]> posteriors = null;
    double[] mins = new double[]{Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY};
    double[] maxs = new double[]{Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY};

    private int nClusters() {
        return this.pis.size();
    }

    public static void main(String[] args) {
        IO.run(args, new EM());
    }

    private MixtureData truth() {
        MixtureData mx = new MixtureData();
        MixtureData.access$002(mx, new double[this.nClusters()]);
        for (int i = 0; i < this.nClusters(); ++i) {
            ((MixtureData)mx).prs[i] = this.pis.get(i);
            mx.gaussians.add(new DiagMultGaussian(new double[]{this.xs.get(i), this.ys.get(i)}, 1.0));
        }
        return mx;
    }

    private MixtureData init() {
        MixtureData mx = new MixtureData();
        MixtureData.access$002(mx, new double[this.nClusters()]);
        for (int i = 0; i < this.xs.size(); ++i) {
            ((MixtureData)mx).prs[i] = 1.0 / (double)this.nClusters();
            mx.gaussians.add(new DiagMultGaussian(new double[]{this.emInitRandom.nextGaussian(), this.emInitRandom.nextGaussian()}, 1.0));
        }
        return mx;
    }

    @Override
    public void run() {
        if (this.useRealData) {
            this.loadData();
        } else {
            MixtureData truth = this.truth();
            truth.plot("trueModel");
            LogInfo.logsForce("trueParam = " + truth);
            this.generateData(truth);
        }
        this.plotData();
        this.currentParam = this.init();
        LogInfo.logsForce("initParam = " + this.currentParam);
        this.currentParam.plot("initModel");
        ArrayList<Double> logLikelihoods = new ArrayList<Double>();
        ArrayList<Double> expectedCompleteLogLikelihoods = new ArrayList<Double>();
        ArrayList<Double> entropies = new ArrayList<Double>();
        ArrayList<Double> testLL = new ArrayList<Double>();
        ArrayList<Double> crossEntropies = new ArrayList<Double>();
        LogInfo.track("Running EM");
        for (int emIter = 0; emIter < this.nEMIters; ++emIter) {
            this.eStep(this.currentParam);
            double crossEntropy = Double.POSITIVE_INFINITY;
            if (this.nClusters() == 3 || this.nClusters() == 2) {
                for (int[] perm : this.nClusters() == 2 ? this.perms2 : this.perms3) {
                    double cur = this.crossEntropy(perm);
                    if (!(cur < crossEntropy)) continue;
                    crossEntropy = cur;
                }
            }
            crossEntropies.add(crossEntropy);
            double entropy = this.entropy();
            double expectedCompleteLogLikelihood = this.expectedCompleteLogLikelihood(this.currentParam);
            double logLikelihood = expectedCompleteLogLikelihood + entropy;
            logLikelihoods.add(logLikelihood);
            expectedCompleteLogLikelihoods.add(expectedCompleteLogLikelihood);
            entropies.add(entropy);
            LogInfo.logsForce("logLikelihood = " + expectedCompleteLogLikelihood + " + " + entropy + " = " + logLikelihood);
            double testLikelihood = this.testLogLikelihood(this.currentParam);
            testLL.add(testLikelihood);
            this.currentParam = this.mStep();
            LogInfo.logsForce("paramAfterEM-" + emIter + " = " + this.currentParam);
            this.currentParam.plot("emModel-" + emIter);
        }
        LogInfo.end_track();
        Plot2D plot = new Plot2D();
        plot.addTimeSeries(logLikelihoods, "LL");
        plot.addTimeSeries(expectedCompleteLogLikelihoods, "expLL");
        plot.addTimeSeries(entropies, "entro");
        plot.addTimeSeries(testLL, "testLL");
        plot.savePlot(new File(Execution.getFile("logLikelihoods.pdf")));
        Plot2D plot2 = new Plot2D();
        plot2.addTimeSeries(crossEntropies);
        plot2.savePlot(new File(Execution.getFile("crossEntropies.pdf")));
        File csvFile = new File(Execution.getFile("ll.csv"));
        PrintWriter out = IOUtils.openOutHard(csvFile);
        out.println(CSV.header("nClusters", "train", "test"));
        out.println(CSV.body(this.nClusters(), logLikelihoods.get(logLikelihoods.size() - 1), testLL.get(testLL.size() - 1)));
        out.close();
    }

    private double testLogLikelihood(MixtureData currentParam) {
        double sum = 0.0;
        for (double[] testDatum : this.heldoutData) {
            double curL = 0.0;
            for (int c = 0; c < this.nClusters(); ++c) {
                curL += currentParam.prs[c] * Math.exp(((DiagMultGaussian)currentParam.gaussians.get(c)).logProb(testDatum));
            }
            sum += Math.log(curL);
        }
        return sum;
    }

    private double crossEntropy(int[] perm) {
        double sum = 0.0;
        for (int i = 0; i < this.datasetSize; ++i) {
            sum += Math.log(this.posteriors.get(i)[perm[this.trueIndicators.get(i)]]);
        }
        return -sum / (double)this.datasetSize;
    }

    private double expectedCompleteLogLikelihood(MixtureData param) {
        double sum = 0.0;
        for (int i = 0; i < this.datasetSize; ++i) {
            for (int c = 0; c < this.nClusters(); ++c) {
                sum += this.posteriors.get(i)[c] * ((DiagMultGaussian)param.gaussians.get(c)).logProb(this.observedData.get(i));
            }
        }
        return sum;
    }

    private double entropy() {
        double sum = 0.0;
        for (int i = 0; i < this.datasetSize; ++i) {
            for (int c = 0; c < this.nClusters(); ++c) {
                sum += this.posteriors.get(i)[c] * Math.log(this.posteriors.get(i)[c]);
            }
        }
        return -sum;
    }

    private MixtureData mStep() {
        double[][] sums = new double[this.nClusters()][2];
        double[] counts = new double[this.nClusters()];
        for (int i = 0; i < this.datasetSize; ++i) {
            double[] cPost = this.posteriors.get(i);
            double[] cObs = this.observedData.get(i);
            for (int c = 0; c < this.nClusters(); ++c) {
                for (int d = 0; d < 2; ++d) {
                    double[] dArray = sums[c];
                    int n = d;
                    dArray[n] = dArray[n] + cPost[c] * cObs[d];
                }
                int n = c;
                counts[n] = counts[n] + cPost[c];
            }
        }
        for (int c = 0; c < this.nClusters(); ++c) {
            for (int d = 0; d < 2; ++d) {
                sums[c][d] = sums[c][d] / counts[c];
            }
        }
        NumUtils.normalize(counts);
        MixtureData result = new MixtureData();
        MixtureData.access$002(result, counts);
        for (int c = 0; c < this.nClusters(); ++c) {
            result.gaussians.add(new DiagMultGaussian(sums[c], 1.0));
        }
        return result;
    }

    private void eStep(MixtureData currentParam) {
        this.posteriors = new ArrayList<double[]>();
        for (int i = 0; i < this.observedData.size(); ++i) {
            double[] datum = this.observedData.get(i);
            double[] prs = new double[this.nClusters()];
            for (int c = 0; c < this.nClusters(); ++c) {
                prs[c] = ((DiagMultGaussian)currentParam.gaussians.get(c)).logProb(datum);
            }
            NumUtils.expNormalize(prs);
            this.posteriors.add(prs);
        }
    }

    private void plotData() {
        Plot2D plot = new Plot2D();
        for (int i = 0; i < this.nClusters(); ++i) {
            ArrayList<Pair<Double, Double>> series = new ArrayList<Pair<Double, Double>>();
            for (int j = 0; j < this.observedData.size(); ++j) {
                if (this.trueIndicators.get(j) != i) continue;
                series.add(Pair.makePair(this.observedData.get(j)[0], this.observedData.get(j)[1]));
            }
            plot.addSeries(series, false, "cluster-" + i);
        }
        plot.savePlot(new File(Execution.getFile("data.pdf")));
    }

    private void loadData() {
        int read = 0;
        for (String line : IO.i(new File("geyser.txt"))) {
            if (read >= this.datasetSize) break;
            String[] fields = line.split("\\s+");
            this.trueIndicators.add(Integer.parseInt(fields[0]) - 1);
            double[] datum = new double[]{2.0 * Double.parseDouble(fields[1]), Double.parseDouble(fields[2]) / 20.0};
            this.observedData.add(datum);
        }
        this.datasetSize = this.observedData.size();
        this.heldout();
        this.computeBounds();
    }

    private void heldout() {
        int nToH = (int)(0.2 * (double)this.datasetSize);
        for (int i = 0; i < nToH; ++i) {
            this.heldoutData.add(this.observedData.get(this.observedData.size() - 1));
            this.observedData.remove(this.observedData.size() - 1);
            this.heldoutIndic.add(this.trueIndicators.get(this.trueIndicators.size() - 1));
            this.trueIndicators.remove(this.trueIndicators.size() - 1);
        }
        this.datasetSize = this.observedData.size();
    }

    private void computeBounds() {
        for (int i = 0; i < this.datasetSize; ++i) {
            for (int d = 0; d < 2; ++d) {
                double cur = this.observedData.get(i)[d];
                if (cur > this.maxs[d]) {
                    this.maxs[d] = cur;
                }
                if (!(cur < this.mins[d])) continue;
                this.mins[d] = cur;
            }
        }
    }

    private void generateData(MixtureData truth) {
        for (int i = 0; i < this.datasetSize; ++i) {
            int cluster = SampleUtils.sampleMultinomial(this.dataGenRandom, truth.prs);
            double[] datum = ((DiagMultGaussian)truth.gaussians.get(cluster)).sample(this.dataGenRandom);
            this.observedData.add(datum);
            this.trueIndicators.add(cluster);
        }
        this.heldout();
        this.computeBounds();
    }

    public class MixtureData {
        private List<DiagMultGaussian> gaussians = new ArrayList<DiagMultGaussian>();
        private double[] prs;

        public String toString() {
            String result = "";
            for (int c = 0; c < this.prs.length; ++c) {
                result = result + "" + this.prs[c] + " * N(" + this.gaussians.get(c) + ") ";
            }
            return result;
        }

        public void plot(String prefix) {
            Plot3D plot = new Plot3D(new Function(){

                @Override
                public double valueAt(double[] x) {
                    double sum = 0.0;
                    for (int c = 0; c < MixtureData.this.prs.length; ++c) {
                        sum += MixtureData.this.prs[c] * Math.exp(((DiagMultGaussian)MixtureData.this.gaussians.get(c)).logProb(x));
                    }
                    return Math.log(sum);
                }

                @Override
                public int dimension() {
                    return 2;
                }
            });
            plot.setMin_x(EM.this.mins[0]);
            plot.setMax_x(EM.this.maxs[0]);
            plot.setMin_y(EM.this.mins[1]);
            plot.setMax_y(EM.this.maxs[1]);
            plot.savePlot(new File(Execution.getFile("" + prefix + ".pdf")));
        }

        static /* synthetic */ double[] access$002(MixtureData x0, double[] x1) {
            x0.prs = x1;
            return x1;
        }
    }
}

