/*
 * Decompiled with CFR 0.152.
 */
package nuts.math;

import Jama.Matrix;
import conifer.evol.GTR;
import fig.basic.NumUtils;
import java.util.Random;
import nuts.io.IO;
import nuts.math.MtxUtils;
import nuts.tui.Table;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.junit.Assert;
import org.junit.Test;

public class RateMtxUtils {
    public static MatrixExponentialAlgorithm defaultMatrixExponentialAlgorithm = MatrixExponentialAlgorithm.BLAS;

    public static String toString(double[][] rate, Indexer indexer) {
        int i;
        Table table = new Table();
        if (indexer.size() != rate.length) {
            throw new RuntimeException();
        }
        for (i = 0; i < rate.length; ++i) {
            table.set(0, i + 1, indexer.i2o(i).toString());
            table.set(i + 1, 0, indexer.i2o(i).toString());
        }
        for (i = 0; i < rate.length; ++i) {
            for (int j = 0; j < rate.length; ++j) {
                table.set(i + 1, j + 1, rate[i][j]);
            }
        }
        return table.toString();
    }

    public static double[][] reversibleRateMtx(double rate, double[] stat) {
        int size = stat.length;
        if (!MathUtils.isProb(stat)) {
            throw new RuntimeException();
        }
        double[][] result = new double[size][size];
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                if (i == j) continue;
                result[i][j] = stat[j] * rate;
            }
        }
        RateMtxUtils.fillRateMatrixDiagonalEntries(result);
        return result;
    }

    public static void main(String[] args) {
        double[][] Q = new double[][]{{0.0, 1.0, 2.0}, {1.0, 0.0, 1.0}, {3.0, 1.0, 0.0}};
        RateMtxUtils.fillRateMatrixDiagonalEntries(Q);
        for (double t = 1.0; t != 0.0; t /= 10.0) {
            System.out.println("Current t:" + t);
            IO.so(Table.toString(RateMtxUtils.marginalTransitionMtx(Q, t)));
        }
    }

    public static double[][] marginalTransitionMtx(double[][] rate, double t) {
        return defaultMatrixExponentialAlgorithm.marginalTransitionMtx(rate, t);
    }

    @Test
    public void testMatrixExponentialAlgorithms() {
        Random rand = new Random(1L);
        for (int l = 0; l < 100; ++l) {
            int i;
            double[] stat = new double[4];
            double[] rates = new double[6];
            for (i = 0; i < 4; ++i) {
                stat[i] = rand.nextDouble() * 2.0;
            }
            NumUtils.normalize(stat);
            for (i = 0; i < 6; ++i) {
                rates[i] = rand.nextDouble() * 2.0;
            }
            double[][] mtx = GTR.gtrFromOverParam(stat, rates, 4);
            double t = Math.abs(rand.nextGaussian());
            double[][] correct = null;
            for (MatrixExponentialAlgorithm algo : MatrixExponentialAlgorithm.values()) {
                double[][] current = algo.marginalTransitionMtx(mtx, t);
                if (correct == null) {
                    correct = current;
                }
                for (int m = 0; m < 4; ++m) {
                    for (int n = 0; n < 4; ++n) {
                        Assert.assertEquals((double)current[m][n], (double)correct[m][n], (double)MathUtils.threshold);
                    }
                }
            }
        }
    }

    public static double[][] marginalTransitionMtx(double[][] rate, double t, MatrixExponentialAlgorithm method) {
        if (t < 0.0) {
            throw new RuntimeException("T most be non-negative");
        }
        return method.marginalTransitionMtx(rate, t);
    }

    public static double[][] marginalTransitionMtx(Matrix V, Matrix Vinv, Matrix D, double t) {
        if (Math.abs(t) < 1.0E-10) {
            return MtxUtils.id(D.getColumnDimension()).getArray();
        }
        double[][] p = MtxUtils.exp(V, Vinv, D.times(t)).getArray();
        return p;
    }

    public static double[] getStationaryDistribution(double[][] rate) {
        double[][] marginal = RateMtxUtils.marginalTransitionMtx(rate, 1.0);
        double[] initD = MtxUtils.topEigenvector(new Matrix(marginal).transpose());
        NumUtils.normalize(initD);
        return initD;
    }

    public static double[][] getJumpProcess(double[][] rate) {
        int size = rate.length;
        double[][] result = new double[size][size];
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                if (i == j) continue;
                result[i][j] = rate[i][j];
            }
            NumUtils.normalize(result[i]);
        }
        return result;
    }

    public static void checkReversibleRateMtx(double[][] rate) {
        RateMtxUtils.checkRateMtx(rate);
        double[] sd = RateMtxUtils.getStationaryDistribution(rate);
        double[][] p = RateMtxUtils.marginalTransitionMtx(rate, 1.0);
        for (int i = 0; i < sd.length; ++i) {
            for (int j = 0; j < sd.length; ++j) {
                MathUtils.checkClose(sd[i] * p[i][j], sd[j] * p[j][i]);
            }
        }
    }

    public static void checkRateMtx(double[][] rate) {
        int size = rate.length;
        for (int i = 0; i < size; ++i) {
            if (rate[i].length != size) {
                throw new RuntimeException();
            }
            double sum = 0.0;
            for (int j = 0; j < size; ++j) {
                if (i == j) continue;
                if (rate[i][j] < 0.0) {
                    throw new RuntimeException();
                }
                sum += rate[i][j];
            }
            MathUtils.checkClose(rate[i][i], -sum);
        }
    }

    public static void fillRateMatrixDiagonalEntries(double[][] rate) {
        int size = rate.length;
        for (int i = 0; i < size; ++i) {
            double sum = 0.0;
            for (int j = 0; j < size; ++j) {
                if (i == j) continue;
                sum += rate[i][j];
            }
            if (rate[i][i] != 0.0) {
                throw new RuntimeException();
            }
            rate[i][i] = -sum;
        }
    }

    public static enum MatrixExponentialAlgorithm {
        DIAGONALIZATION{

            @Override
            public double[][] marginalTransitionMtx(double[][] rate, double t) {
                if (Math.abs(t) < 1.0E-10) {
                    return MtxUtils.id(rate.length).getArray();
                }
                Matrix M = new Matrix(rate).times(t);
                double[][] p = MtxUtils.exp(M).getArray();
                for (int i = 0; i < p.length; ++i) {
                    if (!MathUtils.isProb(p[i])) {
                        throw new RuntimeException("Problem in marginalTransitionMtx(mtx," + t + "). " + "Possibly a bad rate matrix:\n" + Table.toString(rate));
                    }
                    NumUtils.normalize(p[i]);
                }
                return p;
            }
        }
        ,
        BLAS{

            @Override
            public double[][] marginalTransitionMtx(double[][] rate, double t) {
                DoubleMatrix rateMtx = new DoubleMatrix(rate);
                rateMtx.muli(t);
                return MatrixFunctions.expm((DoubleMatrix)rateMtx).toArray2();
            }
        };


        public abstract double[][] marginalTransitionMtx(double[][] var1, double var2);
    }
}

