/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.models;

import fig.basic.NumUtils;
import java.io.Serializable;
import nuts.util.MathUtils;
import org.apache.lucene.util.OpenBitSet;
import pty.io.Dataset;
import pty.smc.models.CTMC;
import pty.smc.models.LikelihoodModelCalculator;

public final class FastDiscreteModelCalculator
implements LikelihoodModelCalculator,
Serializable {
    private static final long serialVersionUID = 1L;
    public final double[][] cache;
    private final double logLikelihood;
    public final CTMC ctmc;
    private final OpenBitSet[] bitVectorVersion;

    @Override
    public LikelihoodModelCalculator combine(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean avoidBuildCache) {
        return (LikelihoodModelCalculator)this.calculate(node1, node2, v1, v2, false, avoidBuildCache);
    }

    public static FastDiscreteModelCalculator observation(CTMC ctmc, double[][] initCache) {
        return new FastDiscreteModelCalculator(initCache, ctmc, false);
    }

    public static FastDiscreteModelCalculator observation(CTMC ctmc, double[][] initCache, boolean resampleRoot) {
        return new FastDiscreteModelCalculator(initCache, ctmc, resampleRoot);
    }

    public double[][] getCache() {
        return this.cache;
    }

    private FastDiscreteModelCalculator(double[][] cache, CTMC ctmc, double logLikelihood, boolean resampleRoot) {
        this.bitVectorVersion = null;
        this.cache = cache;
        this.ctmc = ctmc;
        this.logLikelihood = logLikelihood;
    }

    private FastDiscreteModelCalculator(double[][] cache, CTMC ctmc, boolean resampleRoot) {
        if (resampleRoot) {
            throw new RuntimeException("Discontinued feature");
        }
        this.ctmc = ctmc;
        this.logLikelihood = FastDiscreteModelCalculator.initLogLikelihood(cache, ctmc);
        this.cache = FastDiscreteModelCalculator.initCache(cache, ctmc);
        this.bitVectorVersion = FastDiscreteModelCalculator.observationToBitVector(cache);
    }

    private static double[][] initCache(double[][] cache, CTMC ctmc) {
        if (!ctmc.isSiteTied()) {
            throw new RuntimeException();
        }
        double[] initD = ctmc.getInitialDistribution(0);
        double[][] result = new double[cache.length][];
        for (int i = 0; i < result.length; ++i) {
            result[i] = new double[cache[i].length];
            for (int j = 0; j < result[i].length; ++j) {
                result[i][j] = initD[j] * cache[i][j];
            }
            NumUtils.normalize(result[i]);
        }
        return result;
    }

    private static double initLogLikelihood(double[][] cache, CTMC ctmc) {
        if (!ctmc.isSiteTied()) {
            throw new RuntimeException();
        }
        double[] initD = ctmc.getInitialDistribution(0);
        double result = 0.0;
        for (int s = 0; s < cache.length; ++s) {
            double likeli = 0.0;
            for (int c = 0; c < initD.length; ++c) {
                double cur = cache[s][c];
                if (cur != 0.0 && cur != 1.0) {
                    throw new RuntimeException();
                }
                double term = cur * initD[c];
                likeli += term;
            }
            result += Math.log(likeli);
        }
        return result;
    }

    @Override
    public double extendLogLikelihood(double delta) {
        throw new RuntimeException();
    }

    public static double quickPeek(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2) {
        FastDiscreteModelCalculator cache1 = (FastDiscreteModelCalculator)node1;
        FastDiscreteModelCalculator cache2 = (FastDiscreteModelCalculator)node2;
        double[][] pr = cache1.ctmc.getTransitionPr(0, v1 + v2);
        double[] initD = cache2.ctmc.getInitialDistribution(0);
        double[] logInit = new double[initD.length];
        for (int i = 0; i < logInit.length; ++i) {
            logInit[i] = Math.log(initD[i]);
        }
        int ncs = cache1.ctmc.nCharacter(0);
        double sum = 0.0;
        for (int x = 0; x < ncs; ++x) {
            double cur = logInit[x];
            double[] curAr = pr[x];
            for (int xPrime = 0; xPrime < ncs; ++xPrime) {
                long n = OpenBitSet.intersectionCount((OpenBitSet)cache1.bitVectorVersion[x], (OpenBitSet)cache2.bitVectorVersion[xPrime]);
                sum += (double)n * (cur + Math.log(curAr[xPrime]));
            }
        }
        return sum;
    }

    private final Object calculate(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean isPeek, boolean avoidBuildCache) {
        if (isPeek && !avoidBuildCache) {
            throw new RuntimeException();
        }
        double[][] result = avoidBuildCache ? (double[][])null : Dataset.DatasetUtils.createObsArray(this.ctmc);
        FastDiscreteModelCalculator cache1 = (FastDiscreteModelCalculator)node1;
        FastDiscreteModelCalculator cache2 = (FastDiscreteModelCalculator)node2;
        if (isPeek && cache1.bitVectorVersion != null && cache2.bitVectorVersion != null) {
            return FastDiscreteModelCalculator.quickPeek(node1, node2, v1, v2);
        }
        double[][] pr1 = this.ctmc.getTransitionPr(0, v1);
        double[][] pr2 = this.ctmc.getTransitionPr(0, v2);
        double[] initD = this.ctmc.getInitialDistribution(0);
        double[] invD = new double[initD.length];
        for (int i = 0; i < invD.length; ++i) {
            invD[i] = 1.0 / initD[i];
        }
        int ns = this.ctmc.nSites();
        int ncs = this.ctmc.nCharacter(0);
        double logNorm = 0.0;
        double tempNorm = 1.0;
        for (int s = 0; s < ns; ++s) {
            double abs;
            double currentNorm = 0.0;
            double[] c1 = cache1.cache[s];
            double[] c2 = cache2.cache[s];
            for (int x = 0; x < ncs; ++x) {
                double[] cPr1 = pr1[x];
                double[] cPr2 = pr2[x];
                double s1 = 0.0;
                double s2 = 0.0;
                for (int xPrime = 0; xPrime < ncs; ++xPrime) {
                    double cc2;
                    double cInvD = invD[xPrime];
                    double cc1 = c1[xPrime];
                    if (cc1 != 0.0) {
                        s1 += cPr1[xPrime] * cInvD * cc1;
                    }
                    if ((cc2 = c2[xPrime]) == 0.0) continue;
                    s2 += cPr2[xPrime] * cInvD * cc2;
                }
                double currentProd = initD[x] * s1 * s2;
                if (!avoidBuildCache) {
                    result[s][x] = currentProd;
                }
                currentNorm += currentProd;
            }
            if (!avoidBuildCache) {
                MathUtils.normalizeAndGetNorm(result[s]);
            }
            if (!((abs = Math.abs(tempNorm *= currentNorm)) < 1.0E-100) && !(abs > 1.0E100) && s != ns - 1) continue;
            logNorm += Math.log(tempNorm);
            tempNorm = 1.0;
        }
        double newLogLikelihood = logNorm + cache1.logLikelihood + cache2.logLikelihood;
        if (isPeek) {
            return newLogLikelihood;
        }
        return new FastDiscreteModelCalculator(result, this.ctmc, newLogLikelihood, false);
    }

    private static OpenBitSet[] observationToBitVector(double[][] message) {
        int nChars = message[0].length;
        int nSites = message.length;
        OpenBitSet[] result = new OpenBitSet[nChars];
        for (int i = 0; i < result.length; ++i) {
            result[i] = new OpenBitSet((long)nSites);
        }
        for (int site = 0; site < message.length; ++site) {
            int idx = FastDiscreteModelCalculator.findOne(message[site]);
            if (idx == -1) {
                return null;
            }
            result[idx].fastSet(site);
        }
        return result;
    }

    private static int findOne(double[] ds) {
        int result = -1;
        for (int i = 0; i < ds.length; ++i) {
            double d = ds[i];
            if (d != 1.0) continue;
            if (result != -1) {
                return -1;
            }
            result = i;
        }
        return result;
    }

    @Override
    public double logLikelihood() {
        return this.logLikelihood;
    }

    @Override
    public boolean isReversible() {
        throw new RuntimeException();
    }

    @Override
    public double peekCoalescedLogLikelihood(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double delta1, double delta2) {
        return (Double)this.calculate(node1, node2, delta1, delta2, true, true);
    }
}

