/*
 * Decompiled with CFR 0.152.
 */
package pty.learn;

import Jama.Matrix;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import nuts.math.MtxUtils;
import nuts.math.RateMtxUtils;
import nuts.math.Sampling;
import nuts.tui.Table;
import nuts.util.MathUtils;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.analysis.integration.TrapezoidIntegrator;
import pty.learn.Estimators;
import pty.smc.models.CachedEigenDecomp;

public class CTMCExpectations {
    public static double[][][][] expectations(double T, CachedEigenDecomp ed) {
        double[] imag = ed.getImagEigenvalues();
        for (int i = 0; i < imag.length; ++i) {
            if (imag[i] == 0.0) continue;
            throw new RuntimeException();
        }
        int size = ed.getV().getRowDimension();
        double[][] PT = MtxUtils.exp(ed.getV(), ed.getVinv(), ed.getD().times(T)).getArray();
        double[][] J = CTMCExpectations.computeJ(T, ed);
        double[][] u = ed.getV().getArray();
        double[][] ui = ed.getV().inverse().getArray();
        double[][] q = ed.getV().times(ed.getD()).times(new Matrix(ui)).getArray();
        double[][][][] result = new double[size][size][size][size];
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                for (int a = 0; a < size; ++a) {
                    for (int b = 0; b < size; ++b) {
                        double current;
                        double analyticConvolution;
                        if (a != b) {
                            analyticConvolution = CTMCExpectations.convolution(T, u, ui, i, a, b, j, J);
                            result[i][j][a][b] = current = q[a][b] * analyticConvolution / PT[i][j];
                            if (MathUtils.close(0.0, current) && current < 0.0) {
                                current = 0.0;
                            }
                            if (!(current < 0.0)) continue;
                            throw new RuntimeException();
                        }
                        analyticConvolution = CTMCExpectations.convolution(T, u, ui, i, a, a, j, J);
                        current = analyticConvolution / PT[i][j];
                        if (MathUtils.close(T, current) && current > T) {
                            current = T;
                        }
                        if (MathUtils.close(0.0, current) && current < 0.0) {
                            current = 0.0;
                        }
                        if (current < 0.0 || current > T) {
                            throw new RuntimeException();
                        }
                        result[i][j][a][a] = current;
                    }
                }
            }
        }
        return result;
    }

    public static double[][][][] expectations(double T, double[][] rate) {
        return CTMCExpectations.expectations(T, new CachedEigenDecomp(new Matrix(rate).eig()));
    }

    private static double testConvolution(final double T, final CachedEigenDecomp ed, final int a, final int b, final int c, final int d) {
        double result1 = -1.0;
        UnivariateRealFunction f = new UnivariateRealFunction(){

            public double value(double t) throws FunctionEvaluationException {
                return MtxUtils.exp(ed.getV(), ed.getVinv(), ed.getD().times(t)).getArray()[a][b] * MtxUtils.exp(ed.getV(), ed.getVinv(), ed.getD().times(T - t)).getArray()[c][d];
            }
        };
        TrapezoidIntegrator integrator = new TrapezoidIntegrator(f);
        try {
            result1 = integrator.integrate(0.0, T);
        }
        catch (Exception e) {
            throw new RuntimeException();
        }
        double result2 = -1.0;
        final double[][] u = ed.getV().getArray();
        final double[][] ui = ed.getV().inverse().getArray();
        final double[] ev = ed.getRealEigenvalues();
        UnivariateRealFunction f2 = new UnivariateRealFunction(){

            public double value(double t) throws FunctionEvaluationException {
                double sum1 = 0.0;
                for (int i = 0; i < u.length; ++i) {
                    sum1 += Math.exp(t * ev[i]) * u[a][i] * ui[i][b];
                }
                double otherMethod = MtxUtils.exp(ed.getV(), ed.getVinv(), ed.getD().times(t)).getArray()[a][b];
                if (!MathUtils.close(otherMethod, sum1)) {
                    System.out.println("Divergence:" + otherMethod + "," + sum1);
                }
                double sum2 = 0.0;
                for (int j = 0; j < u.length; ++j) {
                    sum2 += Math.exp((T - t) * ev[j]) * u[c][j] * ui[j][d];
                }
                double otherMethod2 = MtxUtils.exp(ed.getV(), ed.getVinv(), ed.getD().times(T - t)).getArray()[c][d];
                if (!MathUtils.close(otherMethod2, sum2)) {
                    System.out.println("Divergence2:" + otherMethod2 + "," + sum2);
                }
                return sum1 * sum2;
            }
        };
        TrapezoidIntegrator integrator2 = new TrapezoidIntegrator(f2);
        try {
            result2 = integrator2.integrate(0.0, T);
        }
        catch (Exception e) {
            throw new RuntimeException();
        }
        if (!MathUtils.close(result1, result2)) {
            System.out.println("Divergence in result" + result1 + "," + result2);
        }
        return result1;
    }

    public static double convolution(double T, double[][] u, double[][] ui, int a, int b, int c, int d, double[][] J) {
        double sum = 0.0;
        int size = u.length;
        for (int i = 0; i < size; ++i) {
            double subSum = 0.0;
            for (int j = 0; j < size; ++j) {
                subSum += u[c][j] * ui[j][d] * J[i][j];
            }
            sum += subSum * u[a][i] * ui[i][b];
        }
        return sum;
    }

    public static double[][] computeJ(double T, CachedEigenDecomp ed) {
        for (double x : ed.getImagEigenvalues()) {
            if (x == 0.0) continue;
            throw new RuntimeException();
        }
        double[] ev = ed.getRealEigenvalues();
        int size = ev.length;
        double[][] result = new double[size][size];
        for (int j = 0; j < size; ++j) {
            double factor = Math.exp(T * ev[j]);
            if (!(factor > 0.0)) continue;
            for (int i = 0; i < size; ++i) {
                double dLambda = ev[i] - ev[j];
                result[i][j] = MathUtils.close(dLambda, 0.0) ? factor * T : factor * (Math.exp(T * dLambda) - 1.0) / dLambda;
            }
        }
        return result;
    }

    public static Matrix getSufficientStatistics(List<Pair<Integer, Double>> traj, int nChars) {
        Matrix result = new Matrix(nChars, nChars);
        int prev = -1;
        for (Pair<Integer, Double> seg : traj) {
            int cur = seg.getFirst();
            double[] dArray = result.getArray()[cur];
            int n = cur;
            dArray[n] = dArray[n] + seg.getSecond();
            if (prev != -1) {
                double[] dArray2 = result.getArray()[prev];
                int n2 = cur;
                dArray2[n2] = dArray2[n2] + 1.0;
            }
            prev = cur;
        }
        return result;
    }

    public static List<Pair<Integer, Double>> simulate(double T, Random rand, double[][] Q) {
        double[] sd = RateMtxUtils.getStationaryDistribution(Q);
        int initState = SampleUtils.sampleMultinomial(rand, sd);
        return CTMCExpectations.simulate(initState, T, rand, Q);
    }

    public static List<Pair<Integer, Double>> simulate(int startState, double T, Random rand, double[][] Q) {
        double[][] embeddedMarkovChain = RateMtxUtils.getJumpProcess(Q);
        ArrayList<Pair<Integer, Double>> result = new ArrayList<Pair<Integer, Double>>();
        double totalTime = 0.0;
        int state = startState;
        while (totalTime < T) {
            double currentRate = -Q[state][state];
            if (currentRate == 0.0) {
                result.add(Pair.makePair(state, Double.POSITIVE_INFINITY));
                return result;
            }
            double time = Sampling.sampleExponential(rand, 1.0 / currentRate);
            if ((totalTime += time) > T) {
                time -= totalTime - T;
            }
            result.add(Pair.makePair(state, time));
            state = SampleUtils.sampleMultinomial(rand, embeddedMarkovChain[state]);
        }
        return result;
    }

    public static List<List<Pair<Integer, Double>>> simulate(int startState, int endState, int nTrials, double minTime, Random rand, double[][] Q) {
        ArrayList<List<Pair<Integer, Double>>> result = new ArrayList<List<Pair<Integer, Double>>>();
        for (int i = 0; i < nTrials; ++i) {
            List<Pair<Integer, Double>> current = CTMCExpectations.simulate(startState, minTime, rand, Q);
            if (current.get(current.size() - 1).getFirst() != endState) continue;
            result.add(current);
        }
        return result;
    }

    public static int stateAtT(List<Pair<Integer, Double>> seq, double time) {
        double current = 0.0;
        for (Pair<Integer, Double> seg : seq) {
            int state = seg.getFirst();
            if (!(time <= (current += seg.getSecond().doubleValue()) + 1.0E-5)) continue;
            return state;
        }
        return -1;
    }

    public static double visitTime(List<Pair<Integer, Double>> seq, int state) {
        double sum = 0.0;
        for (Pair<Integer, Double> seg : seq) {
            if (seg.getFirst() != state) continue;
            sum += seg.getSecond().doubleValue();
        }
        return sum;
    }

    public static int nTransitions(List<Pair<Integer, Double>> seq, int a, int b) {
        int sum = 0;
        int prev = seq.get(0).getFirst();
        for (int i = 1; i < seq.size(); ++i) {
            int current = seq.get(i).getFirst();
            if (a == prev && b == current) {
                ++sum;
            }
            prev = current;
        }
        return sum;
    }

    public static String toString(List<Pair<Integer, Double>> seq, double timeIncr) {
        int cState;
        StringBuilder result = new StringBuilder();
        double cTime = 0.0;
        while ((cState = CTMCExpectations.stateAtT(seq, cTime)) != -1) {
            for (int i = 0; i < cState; ++i) {
                result.append(" ");
            }
            cTime += timeIncr;
            result.append("|\n");
        }
        return result.toString();
    }

    public static void testEM() {
        double bl = 0.1;
        Random rand = new Random(1L);
        int nObservations = 1000;
        int nEmIters = 100;
        double[] gen_sd = new double[]{0.4, 0.6};
        double[] init_sd = new double[]{0.4, 0.6};
        double gen_rate = 0.1;
        double init_rate = 0.1;
        int nChars = gen_sd.length;
        double[][] gen_q = RateMtxUtils.reversibleRateMtx(gen_rate, gen_sd);
        double[][] init_q = RateMtxUtils.reversibleRateMtx(init_rate, init_sd);
        System.out.println("Gen:\n" + Table.toString(gen_q));
        System.out.println("sd: " + Arrays.toString(RateMtxUtils.getStationaryDistribution(gen_q)) + "\n");
        System.out.println("Init:\n" + Table.toString(init_q));
        System.out.println("sd: " + Arrays.toString(RateMtxUtils.getStationaryDistribution(init_q)) + "\n");
        List<Pair<Double, Pair<Integer, Integer>>> observations = CTMCExpectations.generateObservations(gen_q, 1000, rand, bl);
        double[][] currentParams = init_q;
        for (int emIter = 0; emIter < 100; ++emIter) {
            double[][] p = RateMtxUtils.marginalTransitionMtx(currentParams, bl);
            double[] stat = RateMtxUtils.getStationaryDistribution(currentParams);
            CachedEigenDecomp ed = new CachedEigenDecomp(new Matrix(currentParams).eig());
            Matrix suffStats = new Matrix(nChars, nChars);
            for (Pair<Double, Pair<Integer, Integer>> obs : observations) {
                double[][][][] expss = CTMCExpectations.expectations((double)obs.getFirst(), ed);
                int s1 = obs.getSecond().getFirst();
                int s2 = obs.getSecond().getSecond();
                suffStats.plusEquals(new Matrix(expss[s1][s2]));
            }
            System.out.println("Expected suffstats:\n" + Table.toString(suffStats));
            currentParams = Estimators.getGeneralRateMatrixMLE(suffStats);
            System.out.println("After iter " + (emIter + 1) + "/" + 100 + ":\n" + Table.toString(currentParams));
            System.out.println("sd: " + Arrays.toString(RateMtxUtils.getStationaryDistribution(currentParams)) + "\n");
        }
    }

    private static List<Pair<Double, Pair<Integer, Integer>>> generateObservations(double[][] Q, int observations, Random rand, double bl) {
        ArrayList<Pair<Double, Pair<Integer, Integer>>> result = new ArrayList<Pair<Double, Pair<Integer, Integer>>>();
        int nChars = Q.length;
        Matrix sampledSuffStats = new Matrix(nChars, nChars);
        double[] sd = RateMtxUtils.getStationaryDistribution(Q);
        int prevState = SampleUtils.sampleMultinomial(rand, sd);
        for (int i = 0; i < observations; ++i) {
            double T = bl;
            List<Pair<Integer, Double>> traj = CTMCExpectations.simulate(prevState, T, rand, Q);
            sampledSuffStats.plusEquals(CTMCExpectations.getSufficientStatistics(traj, nChars));
            int initState = traj.get(0).getFirst();
            int lastState = traj.get(traj.size() - 1).getFirst();
            result.add(Pair.makePair(T, Pair.makePair(initState, lastState)));
            prevState = lastState;
        }
        System.out.println("Sampled suffstats:\n" + Table.toString(sampledSuffStats));
        return result;
    }

    public static void main(String[] args) {
        CTMCExpectations.testEM();
    }
}

