/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml;

import nuts.math.RateMtxUtils;
import nuts.util.MathUtils;
import pty.learn.CTMCExpectations;

public class RateMtxExpectations {
    public static double[][][][] expectations(double[][] rateMtx, double T) {
        int n = rateMtx.length;
        double[][] simpleExp = RateMtxUtils.marginalTransitionMtx(rateMtx, T);
        double[][][][] result = new double[n][n][n][n];
        double[][] emptyMtx = new double[2 * n][2 * n];
        for (int state1 = 0; state1 < n; ++state1) {
            for (int state2 = 0; state2 < n; ++state2) {
                double[][] current = RateMtxExpectations._expectations(rateMtx, T, state1, state2, simpleExp, emptyMtx);
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j < n; ++j) {
                        result[i][j][state1][state2] = current[i][j];
                    }
                }
            }
        }
        return result;
    }

    public static double[][] expectations(double[][] marginalCounts, double[][] rateMtx, double T) {
        int dim = rateMtx.length;
        double[][] result = new double[dim][dim];
        double[][] auxMtx = new double[2 * dim][2 * dim];
        double[][] simpleExp = RateMtxUtils.marginalTransitionMtx(rateMtx, T);
        for (int state1 = 0; state1 < dim; ++state1) {
            for (int state2 = 0; state2 < dim; ++state2) {
                double[][] current = RateMtxExpectations._expectations(rateMtx, T, state1, state2, simpleExp, auxMtx);
                double sum = 0.0;
                for (int i = 0; i < dim; ++i) {
                    for (int j = 0; j < dim; ++j) {
                        sum += current[i][j] * marginalCounts[i][j];
                    }
                }
                result[state1][state2] = sum;
            }
        }
        return result;
    }

    public static double[][] _expectedWaitingTimes(double[][] rateMtx, double T, int state, double[][] matrixExponential, double[][] emptyMtx) {
        return RateMtxExpectations._expectations(rateMtx, T, state, state, matrixExponential, emptyMtx);
    }

    public static double[][] _expectations(double[][] rateMtx, double T, int state1, int state2, double[][] matrixExponential, double[][] emptyMtx) {
        int n = rateMtx.length;
        if (rateMtx[0].length != n || matrixExponential.length != n) {
            throw new RuntimeException();
        }
        double[][] aux = emptyMtx;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                double d = rateMtx[i][j] * T;
                aux[i + n][j + n] = d;
                aux[i][j] = d;
            }
        }
        aux[state1][state2 + n] = 1.0 * T;
        double[][] exponentiatedAux = RateMtxUtils.marginalTransitionMtx(aux, 1.0);
        double[][] result = new double[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                result[i][j] = exponentiatedAux[i][n + j] / matrixExponential[i][j] * (state1 == state2 ? 1.0 : rateMtx[state1][state2]);
            }
        }
        aux[state1][state2 + n] = 0.0;
        return result;
    }

    public static void main(String[] args) {
        for (int n = 100; n < 1001; n += 10) {
            System.out.println("n = " + n);
            double T = 3.0;
            double[][] rateMtx = new double[n][n];
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < n; ++j) {
                    if (i == j) continue;
                    rateMtx[i][j] = 1.0 / (double)n;
                }
            }
            RateMtxUtils.fillRateMatrixDiagonalEntries(rateMtx);
            long time = System.currentTimeMillis();
            double[][][][] m2 = RateMtxExpectations.expectations(rateMtx, 3.0);
            System.out.println("Time for EXP " + (System.currentTimeMillis() - time));
            time = System.currentTimeMillis();
            double[][][][] m1 = CTMCExpectations.expectations(3.0, rateMtx);
            System.out.println("Time for EVD " + (System.currentTimeMillis() - time));
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < n; ++j) {
                    for (int k = 0; k < n; ++k) {
                        for (int l = 0; l < n; ++l) {
                            if (MathUtils.close(m1[i][j][k][l], m2[i][j][k][l])) continue;
                            System.out.println("" + m1[i][j][k][l] + " vs " + m2[i][j][k][l]);
                        }
                    }
                }
            }
        }
        System.out.println("Done");
    }
}

