/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.models;

import fig.basic.NumUtils;
import java.io.Serializable;
import java.util.ArrayList;
import nuts.maxent.SloppyMath;
import nuts.util.CollUtils;
import nuts.util.MathUtils;
import org.apache.lucene.util.OpenBitSet;
import pty.io.Dataset;
import pty.smc.models.CTMC;
import pty.smc.models.LikelihoodModelCalculator;

public final class GTRIGammaFastDiscreteModelCalculator
implements LikelihoodModelCalculator,
Serializable {
    private static final long serialVersionUID = 1L;
    private final ArrayList<double[][]> cache;
    private final double logLikelihood;
    public final ArrayList<double[]> itemizedLogLikelihoods;
    public final CTMC ctmc;
    private final OpenBitSet[] bitVectorVersion;
    private final double alpha;
    private final int nGammaCat;
    private final double[] rates;

    @Override
    public LikelihoodModelCalculator combine(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean avoidBuildCache) {
        return (LikelihoodModelCalculator)this.calculate(node1, node2, v1, v2, false, avoidBuildCache);
    }

    public static GTRIGammaFastDiscreteModelCalculator observation(CTMC ctmc, double[][] initCache, double alpha, int nGammaCat) {
        return new GTRIGammaFastDiscreteModelCalculator(initCache, ctmc, alpha, nGammaCat, false);
    }

    public static GTRIGammaFastDiscreteModelCalculator observation(CTMC ctmc, double[][] initCache, double alpha, int nGammaCat, boolean resampleRoot) {
        return new GTRIGammaFastDiscreteModelCalculator(initCache, ctmc, alpha, nGammaCat, resampleRoot);
    }

    public ArrayList<double[][]> getCache(double rate) {
        return this.cache;
    }

    private GTRIGammaFastDiscreteModelCalculator(ArrayList<double[][]> cache, CTMC ctmc, double logLikelihood, double alpha, int nGammaCat, boolean resampleRoot, ArrayList<double[]> itemizedLogLikelihoods) {
        this.bitVectorVersion = null;
        this.cache = cache;
        this.ctmc = ctmc;
        this.logLikelihood = logLikelihood;
        this.alpha = alpha;
        this.nGammaCat = nGammaCat;
        this.rates = CTMC.GTRIGammaCTMC.calculateCategoryRates(nGammaCat, alpha, 0.0);
        this.itemizedLogLikelihoods = itemizedLogLikelihoods;
    }

    private GTRIGammaFastDiscreteModelCalculator(double[][] cache, CTMC ctmc, double alpha, int nGammaCat, boolean resampleRoot) {
        if (resampleRoot) {
            throw new RuntimeException("Discontinued feature");
        }
        ArrayList<double[]> itemized = new ArrayList<double[]>();
        for (int i = 0; i < ctmc.nSites(); ++i) {
            itemized.add(new double[nGammaCat]);
        }
        this.ctmc = ctmc;
        this.itemizedLogLikelihoods = itemized;
        this.logLikelihood = GTRIGammaFastDiscreteModelCalculator.initLogLikelihood(cache, ctmc, this.itemizedLogLikelihoods, nGammaCat);
        this.cache = GTRIGammaFastDiscreteModelCalculator.initCache(cache, ctmc, nGammaCat);
        this.bitVectorVersion = GTRIGammaFastDiscreteModelCalculator.observationToBitVector(cache);
        this.alpha = alpha;
        this.rates = CTMC.GTRIGammaCTMC.calculateCategoryRates(nGammaCat, alpha, 0.0);
        this.nGammaCat = nGammaCat;
    }

    private static ArrayList<double[][]> initCache(double[][] cache, CTMC ctmc, int nRepeats) {
        if (!ctmc.isSiteTied()) {
            throw new RuntimeException();
        }
        double[] initD = ctmc.getInitialDistribution(0);
        double[][] result = new double[cache.length][];
        for (int i = 0; i < result.length; ++i) {
            result[i] = new double[cache[i].length];
            for (int j = 0; j < result[i].length; ++j) {
                result[i][j] = initD[j] * cache[i][j];
            }
            NumUtils.normalize(result[i]);
        }
        ArrayList<double[][]> allResult = CollUtils.list();
        for (int i = 0; i < nRepeats; ++i) {
            allResult.add(result);
        }
        return allResult;
    }

    private static double initLogLikelihood(double[][] cache, CTMC ctmc, ArrayList<double[]> itemized, int nGammaCat) {
        if (!ctmc.isSiteTied()) {
            throw new RuntimeException();
        }
        double[] initD = ctmc.getInitialDistribution(0);
        double result = 0.0;
        for (int s = 0; s < cache.length; ++s) {
            for (int c = 0; c < initD.length; ++c) {
                double cur = cache[s][c];
                if (cur != 0.0 && cur != 1.0) {
                    throw new RuntimeException();
                }
                double term = cur * Math.log(initD[c]);
                double[] dArray = itemized.get(s);
                dArray[0] = dArray[0] + term;
                result += term;
            }
            for (int k = 1; k < nGammaCat; ++k) {
                itemized.get((int)s)[k] = itemized.get(s)[0];
            }
        }
        return result;
    }

    @Override
    public double extendLogLikelihood(double delta) {
        throw new RuntimeException();
    }

    public static double quickPeek(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2) {
        System.out.println("Doing quick peek");
        GTRIGammaFastDiscreteModelCalculator cache1 = (GTRIGammaFastDiscreteModelCalculator)node1;
        GTRIGammaFastDiscreteModelCalculator cache2 = (GTRIGammaFastDiscreteModelCalculator)node2;
        double[][] pr = cache1.ctmc.getTransitionPr(0, v1 + v2);
        double[] initD = cache2.ctmc.getInitialDistribution(0);
        double[] logInit = new double[initD.length];
        for (int i = 0; i < logInit.length; ++i) {
            logInit[i] = Math.log(initD[i]);
        }
        int ncs = cache1.ctmc.nCharacter(0);
        double sum = 0.0;
        for (int x = 0; x < ncs; ++x) {
            double cur = logInit[x];
            double[] curAr = pr[x];
            for (int xPrime = 0; xPrime < ncs; ++xPrime) {
                long n = OpenBitSet.intersectionCount((OpenBitSet)cache1.bitVectorVersion[x], (OpenBitSet)cache2.bitVectorVersion[xPrime]);
                sum += (double)n * (cur + Math.log(curAr[xPrime]));
            }
        }
        return sum;
    }

    private final Object calculate(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean isPeek, boolean avoidBuildCache) {
        if (isPeek && !avoidBuildCache) {
            throw new RuntimeException();
        }
        ArrayList<double[][]> result = CollUtils.list();
        for (int c = 0; c < this.nGammaCat; ++c) {
            result.add(avoidBuildCache ? (double[][])null : Dataset.DatasetUtils.createObsArray(this.ctmc));
        }
        GTRIGammaFastDiscreteModelCalculator cache1 = (GTRIGammaFastDiscreteModelCalculator)node1;
        GTRIGammaFastDiscreteModelCalculator cache2 = (GTRIGammaFastDiscreteModelCalculator)node2;
        if (isPeek && cache1.bitVectorVersion != null && cache2.bitVectorVersion != null) {
            return GTRIGammaFastDiscreteModelCalculator.quickPeek(node1, node2, v1, v2);
        }
        ArrayList<double[]> newItemized = CollUtils.list();
        for (int s = 0; s < this.ctmc.nSites(); ++s) {
            newItemized.add(new double[this.nGammaCat]);
        }
        ArrayList pr1array = CollUtils.list();
        ArrayList pr2array = CollUtils.list();
        for (int c = 0; c < this.nGammaCat; ++c) {
            pr1array.add(this.ctmc.getTransitionPr(0, v1 * this.rates[c]));
            pr2array.add(this.ctmc.getTransitionPr(0, v2 * this.rates[c]));
        }
        double[] initD = this.ctmc.getInitialDistribution(0);
        double[] invD = new double[initD.length];
        for (int i = 0; i < invD.length; ++i) {
            initD[i] = initD[i];
            invD[i] = 1.0 / initD[i];
        }
        int ns = this.ctmc.nSites();
        int ncs = this.ctmc.nCharacter(0);
        double logNorm = 0.0;
        double tempNorm = 1.0;
        for (int s = 0; s < ns; ++s) {
            for (int c = 0; c < this.nGammaCat; ++c) {
                double currentNorm = 0.0;
                double[] c1 = cache1.cache.get(c)[s];
                double[] c2 = cache2.cache.get(c)[s];
                for (int x = 0; x < ncs; ++x) {
                    double[] cPr1 = ((double[][])pr1array.get(c))[x];
                    double[] cPr2 = ((double[][])pr2array.get(c))[x];
                    double s1 = 0.0;
                    double s2 = 0.0;
                    for (int xPrime = 0; xPrime < ncs; ++xPrime) {
                        double cc2;
                        double cInvD = invD[xPrime];
                        double cc1 = c1[xPrime];
                        if (cc1 != 0.0) {
                            s1 += cPr1[xPrime] * cInvD * cc1;
                        }
                        if ((cc2 = c2[xPrime]) == 0.0) continue;
                        s2 += cPr2[xPrime] * cInvD * cc2;
                    }
                    double currentProd = initD[x] * s1 * s2;
                    if (!avoidBuildCache) {
                        result.get((int)c)[s][x] = currentProd;
                    }
                    currentNorm += currentProd;
                }
                if (!avoidBuildCache) {
                    MathUtils.normalizeAndGetNorm(result.get(c)[s]);
                }
                newItemized.get((int)s)[c] = Math.log(currentNorm) + cache1.itemizedLogLikelihoods.get(s)[c] + cache2.itemizedLogLikelihoods.get(s)[c];
            }
            logNorm += SloppyMath.logAdd((double[])newItemized.get(s)) - Math.log(this.nGammaCat);
        }
        double newLogLikelihood = logNorm;
        if (isPeek) {
            return newLogLikelihood;
        }
        return new GTRIGammaFastDiscreteModelCalculator(result, this.ctmc, newLogLikelihood, this.alpha, this.nGammaCat, false, newItemized);
    }

    private static OpenBitSet[] observationToBitVector(double[][] message) {
        int nChars = message[0].length;
        int nSites = message.length;
        OpenBitSet[] result = new OpenBitSet[nChars];
        for (int i = 0; i < result.length; ++i) {
            result[i] = new OpenBitSet((long)nSites);
        }
        for (int site = 0; site < message.length; ++site) {
            int idx = GTRIGammaFastDiscreteModelCalculator.findOne(message[site]);
            if (idx == -1) {
                return null;
            }
            result[idx].fastSet(site);
        }
        return result;
    }

    private static int findOne(double[] ds) {
        int result = -1;
        for (int i = 0; i < ds.length; ++i) {
            double d = ds[i];
            if (d != 1.0) continue;
            if (result != -1) {
                return -1;
            }
            result = i;
        }
        return result;
    }

    @Override
    public double logLikelihood() {
        return this.logLikelihood;
    }

    @Override
    public boolean isReversible() {
        throw new RuntimeException();
    }

    @Override
    public double peekCoalescedLogLikelihood(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double delta1, double delta2) {
        return (Double)this.calculate(node1, node2, delta1, delta2, true, true);
    }
}

