/*
 * Decompiled with CFR 0.152.
 */
package conifer.evol;

import fig.basic.NumUtils;
import java.util.Arrays;
import java.util.Random;
import nuts.math.RateMtxUtils;
import nuts.tui.Table;
import nuts.util.MathUtils;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.junit.Assert;
import org.junit.Test;

public class GTR {
    public static double[][] gtrFromOverParam(double[] stat, double[] subRates, int n) {
        if (stat.length != n || subRates.length != n * (n - 1) / 2) {
            throw new RuntimeException();
        }
        if (!MathUtils.isProb(stat)) {
            throw new RuntimeException();
        }
        double[][] result = new double[n][n];
        int cur = 0;
        for (int col = 0; col < n; ++col) {
            for (int row = 0; row < n; ++row) {
                if (col >= row) continue;
                result[row][col] = subRates[cur++] * stat[col];
            }
        }
        cur = 0;
        for (int row = 0; row < n; ++row) {
            for (int col = 0; col < n; ++col) {
                if (col <= row) continue;
                result[row][col] = subRates[cur++] * stat[col];
            }
        }
        RateMtxUtils.fillRateMatrixDiagonalEntries(result);
        return result;
    }

    public static double[][] scaleGTRrateMat(double[] stat, double[][] Q) {
        return GTR.scaleGTRrateMat(stat, Q, 1.0);
    }

    public static double[][] scaleGTRrateMat(double[] stat, double[][] Q, double mu) {
        int n = Q.length;
        double[][] result = new double[n][n];
        double scale = 0.0;
        for (int i = 0; i < n; ++i) {
            scale += -stat[i] * Q[i][i];
        }
        for (int row = 0; row < n; ++row) {
            for (int col = 0; col < n; ++col) {
                result[row][col] = mu * Q[row][col] / scale;
            }
        }
        return result;
    }

    @Test
    public void testGTR() {
        Random rand = new Random(1L);
        for (int n = 2; n < 10; ++n) {
            for (int i = 0; i < 100; ++i) {
                double[] stat = new double[n];
                for (int j = 0; j < n; ++j) {
                    stat[j] = rand.nextDouble();
                }
                NumUtils.normalize(stat);
                int nRates = n * (n - 1) / 2;
                double[] rates = new double[nRates];
                for (int j = 0; j < nRates; ++j) {
                    rates[j] = rand.nextDouble() * 10.0;
                }
                double[][] rateMtx = GTR.gtrFromOverParam(stat, rates, n);
                double[] trueStat = RateMtxUtils.getStationaryDistribution(rateMtx);
                for (int j = 0; j < n; ++j) {
                    Assert.assertEquals((double)trueStat[j], (double)stat[j], (double)MathUtils.threshold);
                }
                double[][] marg = RateMtxUtils.marginalTransitionMtx(rateMtx, 1.0);
                for (int j = 0; j < n; ++j) {
                    for (int k = 0; k < n; ++k) {
                        Assert.assertEquals((double)(stat[j] * marg[j][k]), (double)(stat[k] * marg[k][j]), (double)MathUtils.threshold);
                    }
                }
            }
        }
    }

    public static void main(String[] args) {
        double[] stat = new double[]{0.25, 0.25, 0.25, 0.25};
        double[] rates = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
        double[][] rateMtx = GTR.gtrFromOverParam(stat, rates, 4);
        System.out.println(Table.toString(rateMtx));
        double[][] rateMtx2 = GTR.scaleGTRrateMat(stat, rateMtx);
        System.out.println(Table.toString(rateMtx2));
        double[] stat2 = RateMtxUtils.getStationaryDistribution(rateMtx);
        System.out.println(Arrays.toString(stat2));
        System.out.println(Table.toString(RateMtxUtils.marginalTransitionMtx(rateMtx, 100.0)));
        System.out.println(MatrixFunctions.expm((DoubleMatrix)new DoubleMatrix(rateMtx).mul(100.0)));
    }
}

