/*
 * Decompiled with CFR 0.152.
 */
package ev.poi;

import fig.basic.NumUtils;
import java.util.Arrays;
import nuts.math.RateMtxUtils;
import nuts.util.MathUtils;

public class QuasiStationaryProcessUtils {
    public static double[] quasiStationaryDistribution(double[][] reversibleConditionalProbabilities, double[] absorptionsProbabilities) {
        return MathUtils.findStatDistn(QuasiStationaryProcessUtils.reversibleSubProbabilityChain(reversibleConditionalProbabilities, absorptionsProbabilities));
    }

    public static double[] quasiStationaryDistributionFromRates(double[][] subRates, double[] absRates) {
        if (subRates.length != absRates.length) {
            throw new RuntimeException();
        }
        double[][] Q = QuasiStationaryProcessUtils.formQMtx(subRates, absRates);
        double[][] marg = RateMtxUtils.marginalTransitionMtx(Q, 1.0);
        double[][] reversibleCondPrMtx = QuasiStationaryProcessUtils.extractReversibleCondPrMtx(marg);
        double[] sd = MathUtils.findStatDistn(reversibleCondPrMtx);
        double[] result = new double[Q.length];
        for (int i = 0; i < sd.length; ++i) {
            result[i] = sd[i];
        }
        if (!MathUtils.isProb(result)) {
            throw new RuntimeException();
        }
        return result;
    }

    public static double[][] extractReversibleCondPrMtx(double[][] fullIrreversiblePrMtx) {
        int N = fullIrreversiblePrMtx.length - 1;
        double[][] result = new double[N][N];
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                result[i][j] = fullIrreversiblePrMtx[i][j];
                if (!(result[i][j] < 0.0)) continue;
                throw new RuntimeException();
            }
            NumUtils.normalize(result[i]);
        }
        return result;
    }

    public static double[][] reversibleSubProbabilityChain(double[][] reversibleConditionalChain, double[] absorptionsProbabilities) {
        int N = reversibleConditionalChain.length;
        if (absorptionsProbabilities.length != N) {
            throw new RuntimeException();
        }
        double[][] result = new double[N][N];
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                result[i][j] = reversibleConditionalChain[i][j] * (1.0 - absorptionsProbabilities[i]);
            }
        }
        return result;
    }

    public static double[][] formQMtx(double[][] reversibleConditionalSubstitutionRates, double[] absorptionsRates) {
        int i;
        int N = reversibleConditionalSubstitutionRates.length;
        double[][] result = new double[N + 1][N + 1];
        for (i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                if (i == j) continue;
                result[i][j] = reversibleConditionalSubstitutionRates[i][j];
            }
        }
        for (i = 0; i < N; ++i) {
            result[i][N] = absorptionsRates[i];
        }
        RateMtxUtils.fillRateMatrixDiagonalEntries(result);
        return result;
    }

    public static void main(String[] args) {
        double[][] revCondChain = new double[][]{{0.2, 0.8}, {0.4, 0.6}};
        double[] revCondStatDist = MathUtils.findStatDistn(revCondChain);
        System.out.println("Reversible cond chain stat dist:" + Arrays.toString(revCondStatDist));
        double[] absPrs = new double[]{0.2, 0.2};
        double[] quasi = QuasiStationaryProcessUtils.quasiStationaryDistribution(revCondChain, absPrs);
        System.out.println("quasi:" + Arrays.toString(quasi));
        QuasiStationaryProcessUtils.test(revCondChain, absPrs, quasi);
    }

    public static void test(double[][] reversibleConditionalChain, double[] absorptionsProbabilities, double[] quasiStatDistn) {
        int i;
        double[][] subTrans = QuasiStationaryProcessUtils.reversibleSubProbabilityChain(reversibleConditionalChain, absorptionsProbabilities);
        int N = quasiStatDistn.length;
        double[] convolution = new double[N];
        for (i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                int n = i;
                convolution[n] = convolution[n] + quasiStatDistn[j] * subTrans[j][i];
            }
        }
        NumUtils.normalize(convolution);
        System.out.println("convolved:" + Arrays.toString(convolution));
        for (i = 0; i < N; ++i) {
            MathUtils.checkClose(quasiStatDistn[i], convolution[i]);
        }
    }
}

