/*
 * Decompiled with CFR 0.152.
 */
package gep.comparisons;

import Jama.Matrix;
import fig.basic.Option;
import gep.GEPMain;
import gep.comparisons.IncrementalReconstructionMethod;
import gep.timeseries.Series;
import java.util.List;
import java.util.Random;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.Graphs;
import nuts.math.RateMtxUtils;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.util.All2OneMap;
import nuts.util.CounterMap;
import nuts.util.MathUtils;
import pty.learn.CTMCExpectations;
import pty.learn.Estimators;

public class EMForCTMC
implements IncrementalReconstructionMethod {
    @Option
    public int nSplits = 1;
    private List<Series> data;
    private int nChars;
    private double[][] currentRateParam;
    private double[] currentStatParam;
    private double[][] _cachedPr = null;
    private double tpr = -1.0;
    private double[][][][] _cachedExp = null;
    private double texp = -1.0;

    @Override
    public void init(Random samplingRand) {
        this.currentRateParam = new double[this.nHidden()][this.nHidden()];
        for (int i = 0; i < this.nHidden(); ++i) {
            for (int j = i + 1; j < this.nHidden(); ++j) {
                double d = samplingRand.nextDouble();
                this.currentRateParam[j][i] = d;
                this.currentRateParam[i][j] = d;
            }
        }
        RateMtxUtils.fillRateMatrixDiagonalEntries(this.currentRateParam);
        this.currentStatParam = RateMtxUtils.getStationaryDistribution(this.currentRateParam);
    }

    @Override
    public void iterate(Random samplingRand, int iterIndex) {
        Matrix suffStats = Matrix.random((int)this.nHidden(), (int)this.nHidden()).times(1.0);
        double logll = 0.0;
        for (int sIdx = 0; sIdx < this.data.size(); ++sIdx) {
            Series curSeries = this.data.get(sIdx);
            GMFct<Integer> potentials = this.createPotentials(this.data.get(sIdx));
            TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(potentials);
            TabularGMFct<Integer> post = tsp.moments();
            this.addSuffStat(suffStats, curSeries, post);
            this.updateReconstructions(curSeries, post);
            logll += tsp.logZ();
        }
        GEPMain.outputManager.printWrite("logll", "Iter", iterIndex, "LogLL", logll);
        this.currentRateParam = Estimators.getGeneralRateMatrixMLE(suffStats, 100.0);
        this.currentStatParam = RateMtxUtils.getStationaryDistribution(this.currentRateParam);
        this._cachedPr = null;
        this._cachedExp = null;
    }

    private void updateReconstructions(Series curSeries, GMFct<Integer> post) {
        List<Double> allTimes = curSeries.allTimes();
        curSeries.queryStats = new CounterMap();
        int index = 0;
        for (int i = 0; i < allTimes.size(); ++i) {
            if (curSeries.isObserved(allTimes.get(i))) continue;
            for (int hidIdx = 0; hidIdx < this.nHidden(); ++hidIdx) {
                curSeries.queryStats.setCount(index, hidIdx, post.get(i, hidIdx));
            }
            ++index;
        }
    }

    private GMFct<Integer> createPotentials(Series series) {
        int i;
        List<Double> allTimes = series.allTimes();
        TabularGMFct<Integer> result = GMFctUtils.ones(new TabularGMFct<Integer>(Graphs.chainGraph(allTimes.size()), new All2OneMap(this.nHidden())));
        for (i = 0; i < allTimes.size(); ++i) {
            if (!series.isObserved(allTimes.get(i))) continue;
            int observedIndex = series.observationAt(allTimes.get(i));
            for (int hidIdx = 0; hidIdx < this.nHidden(); ++hidIdx) {
                if (this.hidden2char(hidIdx) == observedIndex) continue;
                result.set(i, hidIdx, 0.0);
            }
        }
        for (int hidIdx = 0; hidIdx < this.nHidden(); ++hidIdx) {
            result.set(0, hidIdx, this.currentStatParam[hidIdx]);
        }
        for (i = 0; i < allTimes.size() - 1; ++i) {
            double T = allTimes.get(i + 1) - allTimes.get(i);
            double[][] prs = this.getPr(T);
            for (int hidIdx = 0; hidIdx < this.nHidden(); ++hidIdx) {
                for (int hidIdx2 = 0; hidIdx2 < this.nHidden(); ++hidIdx2) {
                    result.set(i, i + 1, hidIdx, hidIdx2, prs[hidIdx][hidIdx2]);
                }
            }
        }
        return result;
    }

    private void addSuffStat(Matrix suffStats, Series curSeries, GMFct<Integer> post) {
        List<Double> allTimes = curSeries.allTimes();
        for (int i = 0; i < post.graph().vertexSet().size() - 1; ++i) {
            double T = allTimes.get(i + 1) - allTimes.get(i);
            double[][][][] expss = this.getExp(T);
            for (int x0 = 0; x0 < this.nHidden(); ++x0) {
                for (int xT = 0; xT < this.nHidden(); ++xT) {
                    double curPost = post.get(i, i + 1, x0, xT);
                    double[][] currentStat = expss[x0][xT];
                    suffStats.plusEquals(new Matrix(currentStat).timesEquals(curPost));
                }
            }
        }
    }

    @Override
    public void loadData(List<Series> data, int nChars) {
        this.data = data;
        this.nChars = nChars;
    }

    @Override
    public int reconstruct(int sIdx, int qIdx) {
        return this.data.get((int)sIdx).queryStats.getCounter(qIdx).argMax();
    }

    private double[][] getPr(double t) {
        if (MathUtils.close(t, this.tpr) && this._cachedPr != null) {
            return this._cachedPr;
        }
        this._cachedPr = RateMtxUtils.marginalTransitionMtx(this.currentRateParam, t);
        this.tpr = t;
        return this._cachedPr;
    }

    private double[][][][] getExp(double t) {
        if (MathUtils.close(t, this.texp) && this._cachedExp != null) {
            return this._cachedExp;
        }
        this._cachedExp = CTMCExpectations.expectations(t, this.currentRateParam);
        this.tpr = t;
        return this._cachedExp;
    }

    private int nHidden() {
        return this.nChars * this.nSplits;
    }

    private int hidden2char(int hidden) {
        return hidden / this.nSplits;
    }
}

