/*
 * Decompiled with CFR 0.152.
 */
package ev.poi;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;
import ev.poi.QuasiStationaryProcessUtils;
import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import java.util.Arrays;
import java.util.Random;
import ma.RateMatrixLoader;
import nuts.math.MtxUtils;
import nuts.tui.Table;
import nuts.util.MathUtils;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.learn.CTMCExpectations;

public class IntegratedLengthMarginalComputations {
    public static double[] branchWeights(double[] branchLengths, double[][] rateMtx, double[] initDistn) {
        double[] result = new double[branchLengths.length + 1];
        for (int i = 0; i < branchLengths.length; ++i) {
            result[i] = branchLengths[i];
        }
        double alpha1 = 1.0 - IntegratedLengthMarginalComputations.integratedDeathProbability(1.0, rateMtx, initDistn);
        result[branchLengths.length] = alpha1 / (1.0 - alpha1);
        NumUtils.normalize(result);
        for (int i = 0; i < branchLengths.length; ++i) {
            int n = i;
            result[n] = result[n] * (1.0 - IntegratedLengthMarginalComputations.integratedDeathProbability(branchLengths[i], rateMtx, initDistn));
        }
        return result;
    }

    public static double integratedDeathProbability(double branchLength, double[][] rateMtx, double[] initDistn) {
        int deathIndex = rateMtx.length - 1;
        Matrix pi = new Matrix(initDistn, 1);
        Matrix integratedExp = new Matrix(IntegratedLengthMarginalComputations.integratedExp(branchLength, rateMtx));
        Matrix result = pi.times(integratedExp);
        for (int i = 0; i < initDistn.length; ++i) {
            MathUtils.checkIsProb(integratedExp.getArray()[i]);
        }
        return result.get(0, deathIndex);
    }

    public static double[][] integratedExp(double branchLength, double[][] rateMtx) {
        if (Math.abs(branchLength) < 1.0E-6) {
            return MtxUtils.id(rateMtx.length).getArray();
        }
        int N = rateMtx.length;
        EigenvalueDecomposition ed = new Matrix(rateMtx).eig();
        Matrix D = ed.getD();
        Matrix V = ed.getV();
        Matrix Vi = V.inverse();
        Matrix central = new Matrix(N, N);
        for (int i = 0; i < N; ++i) {
            if (MathUtils.close(D.get(i, i), 0.0)) {
                central.set(i, i, 1.0);
                continue;
            }
            central.set(i, i, (Math.exp(branchLength * D.get(i, i)) - 1.0) / D.get(i, i) / branchLength);
        }
        double[][] result = V.times(central.times(Vi)).getArray();
        for (int i = 0; i < result.length; ++i) {
            if (!MathUtils.isProb(result[i])) {
                throw new RuntimeException("Problem in integratedExp(" + branchLength + ",mtx). Possibly a bad rate matrix:\n" + Table.toString(rateMtx) + "Root problem: this should be a pr:" + Arrays.toString(result[i]));
            }
            NumUtils.normalize(result[i]);
        }
        return result;
    }

    public static void main(String[] args) {
        double bl = 10.5;
        double[][] rm = RateMatrixLoader.hky85();
        double[] deathRates = new double[]{1.0, 2.0, 3.0, 4.0};
        rm = QuasiStationaryProcessUtils.formQMtx(rm, deathRates);
        double[] initD = new double[]{0.25, 0.25, 0.25, 0.25, 0.0};
        int END_STATE = initD.length - 1;
        System.out.println("DP-integrated=" + IntegratedLengthMarginalComputations.integratedDeathProbability(bl, rm, initD));
        Random rand = new Random(1L);
        SummaryStatistics ss = new SummaryStatistics();
        for (int i = 0; i < 1000000; ++i) {
            double curBL = rand.nextDouble() * bl;
            int startState = SampleUtils.sampleMultinomial(rand, initD);
            int endState = CTMCExpectations.stateAtT(CTMCExpectations.simulate(startState, curBL, rand, rm), curBL);
            ss.addValue(endState == END_STATE ? 1.0 : 0.0);
        }
        System.out.println("DP-simulated=" + ss.getMean());
    }
}

