/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.models;

import Jama.Matrix;
import conifer.evol.GTR;
import dr.math.distributions.GammaDistribution;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import ma.RateMatrixLoader;
import ma.SequenceType;
import nuts.math.RateMtxUtils;
import nuts.tui.Table;
import nuts.util.MathUtils;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.MaxIterationsExceededException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.analysis.integration.TrapezoidIntegrator;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import pty.ObservationDimensions;
import pty.smc.models.CachedEigenDecomp;

public interface CTMC
extends Serializable,
ObservationDimensions {
    public double[][] getTransitionPr(int var1, double var2);

    public double[] getInitialDistribution(int var1);

    public CachedEigenDecomp getRateMtx(int var1);

    public boolean isSiteTied();

    public static class Rmean
    implements UnivariateRealFunction {
        private double alpha = 1.0;

        public Rmean(double alpha) {
            this.alpha = alpha;
        }

        public double value(double r) throws FunctionEvaluationException {
            return r * GammaDistribution.pdf(r, this.alpha, 1.0 / this.alpha);
        }
    }

    public static final class GTRIGammaCTMC
    implements CTMC {
        private static final long serialVersionUID = 1L;
        private final CachedEigenDecomp Q;
        private final Matrix originalQ;
        private final double[] statDistn;
        private final int nSites;
        private final int nGammaCat;
        private final double pInv;

        public GTRIGammaCTMC(double[] stat, double[] subRates, int n, int nSites, double alpha, int nGammaCat, double pInv) {
            int c;
            double[] r = GTRIGammaCTMC.calculateCategoryRates(nGammaCat, alpha, pInv);
            int dim = nGammaCat * n;
            if (pInv > 0.0 && pInv < 1.0) {
                dim += n;
            }
            double[][] rate0 = new double[dim][dim];
            double[][] rateMtx = GTR.scaleGTRrateMat(stat, GTR.gtrFromOverParam(stat, subRates, n));
            for (c = 0; c < nGammaCat; ++c) {
                for (int col = 0; col < n; ++col) {
                    for (int row = 0; row < n; ++row) {
                        rate0[c * n + row][c * n + col] = r[c] * rateMtx[row][col];
                    }
                }
            }
            this.statDistn = new double[dim];
            for (c = 0; c < nGammaCat; ++c) {
                for (int i = 0; i < n; ++i) {
                    this.statDistn[c * n + i] = (1.0 - pInv) * stat[i] / (double)nGammaCat;
                }
            }
            if (pInv > 0.0 && pInv < 1.0) {
                for (int i = 0; i < n; ++i) {
                    this.statDistn[n * nGammaCat + i] = pInv * stat[i];
                }
            }
            this.nSites = nSites;
            this.originalQ = new Matrix(rate0);
            this.Q = new CachedEigenDecomp(this.originalQ.eig());
            this.nGammaCat = nGammaCat;
            this.pInv = pInv;
        }

        public GTRIGammaCTMC(double[] stat, double[] subRates, int n, int nSites, int nGammaCat, double pInv) {
            double[][] rateMtx = GTR.scaleGTRrateMat(stat, GTR.gtrFromOverParam(stat, subRates, n));
            this.statDistn = RateMtxUtils.getStationaryDistribution(rateMtx);
            this.nSites = nSites;
            this.originalQ = new Matrix(rateMtx);
            this.Q = new CachedEigenDecomp(this.originalQ.eig());
            MathUtils.checkIsProb(this.statDistn);
            this.nGammaCat = nGammaCat;
            this.pInv = pInv;
        }

        public static double[] calculateCategoryRates(int nGammaCat, double alpha, double pInv) {
            double[] categoryRates = new double[nGammaCat];
            if (nGammaCat == 1) {
                categoryRates[0] = 1.0;
                return categoryRates;
            }
            double propVariable = 1.0 - pInv;
            double[] quantiles = new double[nGammaCat + 1];
            quantiles[0] = 0.0;
            for (int i = 1; i <= nGammaCat; ++i) {
                quantiles[i] = GammaDistribution.quantile(2.0 * (double)i / (2.0 * (double)nGammaCat), alpha, 1.0 / alpha);
            }
            quantiles[nGammaCat] = 200.0;
            Rmean rmean = new Rmean(alpha);
            TrapezoidIntegrator trInt = new TrapezoidIntegrator();
            double sum = 0.0;
            for (int i = 0; i < nGammaCat; ++i) {
                try {
                    categoryRates[i] = trInt.integrate((UnivariateRealFunction)rmean, quantiles[i], quantiles[i + 1]) * (double)nGammaCat;
                    sum += categoryRates[i];
                    continue;
                }
                catch (MaxIterationsExceededException e) {
                    e.printStackTrace();
                    continue;
                }
                catch (FunctionEvaluationException e) {
                    e.printStackTrace();
                    continue;
                }
                catch (IllegalArgumentException e) {
                    e.printStackTrace();
                }
            }
            return categoryRates;
        }

        @Override
        public double[] getInitialDistribution(int site) {
            return this.statDistn;
        }

        @Override
        public double[][] getTransitionPr(int site, double t) {
            int m = this.statDistn.length;
            double[][] result = new double[m][m];
            DoubleMatrix A = DoubleMatrix.zeros((int)m, (int)m);
            for (int i = 0; i < m; ++i) {
                for (int j = 0; j < m; ++j) {
                    A.put(i, j, this.originalQ.get(i, j) * t);
                }
            }
            DoubleMatrix expA = MatrixFunctions.expm((DoubleMatrix)A);
            for (int i = 0; i < m; ++i) {
                for (int j = 0; j < m; ++j) {
                    result[i][j] = expA.get(i, j);
                }
            }
            return result;
        }

        @Override
        public int nCharacter(int site) {
            return this.statDistn.length;
        }

        @Override
        public int nSites() {
            return this.nSites;
        }

        @Override
        public CachedEigenDecomp getRateMtx(int site) {
            return this.Q;
        }

        public String toString() {
            return "GTRIGammaCTMC:\n" + Table.toString(this.originalQ);
        }

        @Override
        public boolean isSiteTied() {
            return true;
        }

        public static void main(String[] args) {
            double[] stat = new double[]{0.3, 0.2, 0.2, 0.3};
            double[] rates = new double[]{0.26, 0.18, 0.17, 0.15, 0.11, 0.13};
            int nCategories = 4;
            int n = 4;
            GTRIGammaCTMC ctmc = new GTRIGammaCTMC(stat, rates, 4, 100, 0.25, nCategories, 0.0);
            double[][] tranMat = ctmc.getTransitionPr(0, 1.0);
            System.out.println(Table.toString(tranMat));
            System.out.println(Arrays.toString(ctmc.getInitialDistribution(0)));
        }
    }

    public static final class SimpleCTMC
    implements CTMC {
        private static final long serialVersionUID = 1L;
        private final CachedEigenDecomp Q;
        private final Matrix originalQ;
        private final double[] statDistn;
        private final int nSites;

        public static SimpleCTMC dnaCTMC(int nSites) {
            return new SimpleCTMC(RateMatrixLoader.k2p(), nSites);
        }

        public static SimpleCTMC dnaCTMC(int nSites, double trans2tranv) {
            return new SimpleCTMC(RateMatrixLoader.k2p(trans2tranv), nSites);
        }

        public static SimpleCTMC proteinCTMC(int nSites) {
            return new SimpleCTMC(RateMatrixLoader.dayhoff(), nSites);
        }

        public static SimpleCTMC fromSequenceType(int nSites, SequenceType st, double scale) {
            if (st == SequenceType.BINARY) {
                double[][] rate = new double[][]{{-scale, scale}, {scale, -scale}};
                return new SimpleCTMC(rate, nSites);
            }
            throw new RuntimeException();
        }

        public SimpleCTMC(double[][] rate, int nSites) {
            this.nSites = nSites;
            this.originalQ = new Matrix(rate);
            this.Q = new CachedEigenDecomp(this.originalQ.eig());
            this.statDistn = RateMtxUtils.getStationaryDistribution(rate);
            MathUtils.checkIsProb(this.statDistn);
        }

        @Override
        public double[] getInitialDistribution(int site) {
            return this.statDistn;
        }

        @Override
        public double[][] getTransitionPr(int site, double t) {
            return RateMtxUtils.marginalTransitionMtx(this.Q.getV(), this.Q.getVinv(), this.Q.getD(), t);
        }

        @Override
        public int nCharacter(int site) {
            return this.statDistn.length;
        }

        @Override
        public int nSites() {
            return this.nSites;
        }

        @Override
        public CachedEigenDecomp getRateMtx(int site) {
            return this.Q;
        }

        public String toString() {
            return "SimpleCTMC:\n" + Table.toString(this.originalQ);
        }

        @Override
        public boolean isSiteTied() {
            return true;
        }
    }

    public static final class GeneralCTMC
    implements CTMC {
        private static final long serialVersionUID = 1L;
        private final List<CachedEigenDecomp> Qs = new ArrayList<CachedEigenDecomp>();
        private final List<Matrix> originalQs = new ArrayList<Matrix>();
        private final List<double[]> statDistn = new ArrayList<double[]>();

        public GeneralCTMC(List<double[][]> Qs) {
            for (double[][] Q : Qs) {
                this.originalQs.add(new Matrix(Q));
                double[] stat = RateMtxUtils.getStationaryDistribution(Q);
                MathUtils.checkIsProb(stat);
                this.statDistn.add(stat);
                this.Qs.add(new CachedEigenDecomp(new Matrix(Q).eig()));
            }
        }

        @Override
        public double[] getInitialDistribution(int site) {
            return this.statDistn.get(site);
        }

        @Override
        public double[][] getTransitionPr(int site, double t) {
            return RateMtxUtils.marginalTransitionMtx(this.Qs.get(site).getV(), this.Qs.get(site).getVinv(), this.Qs.get(site).getD(), t);
        }

        @Override
        public int nCharacter(int site) {
            return this.statDistn.get(site).length;
        }

        @Override
        public int nSites() {
            return this.Qs.size();
        }

        @Override
        public CachedEigenDecomp getRateMtx(int site) {
            return this.Qs.get(site);
        }

        public String toString() {
            StringBuilder result = new StringBuilder();
            result.append("GeneralCTMC:\n");
            for (int i = 0; i < this.nSites(); ++i) {
                result.append("Site " + i + ":\n" + Table.toString(this.originalQs.get(i)) + "\n");
            }
            return result.toString();
        }

        @Override
        public boolean isSiteTied() {
            return false;
        }
    }
}

