/*
 * Decompiled with CFR 0.152.
 */
package pty.smc.models;

import fig.basic.Option;
import fig.basic.Pair;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import ma.RateMatrixLoader;
import ma.SequenceType;
import nuts.maxent.SloppyMath;
import nuts.util.CollUtils;
import nuts.util.MathUtils;
import pty.io.Dataset;
import pty.smc.models.CTMC;
import pty.smc.models.LikelihoodModelCalculator;

public class DiscreteModelCalculator
implements LikelihoodModelCalculator {
    @Option
    public static boolean allowDiscreteModelCalculator = false;
    private final double[][] cache;
    private final double logLikelihood;
    public final CTMC ctmc;

    public static Map<Taxon, DiscreteModelCalculator> getInit(File alignmentFile, SequenceType sequenceType) {
        throw new RuntimeException("Use Dataset.java instead!");
    }

    @Override
    public LikelihoodModelCalculator combine(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean avoidBuildCache) {
        return (LikelihoodModelCalculator)this.calculate(node1, node2, v1, v2, false);
    }

    public static DiscreteModelCalculator observation(CTMC ctmc, double[][] initCache) {
        initCache = Dataset.DatasetUtils.log(initCache);
        return new DiscreteModelCalculator(initCache, ctmc);
    }

    public double[] getCacheCopy(int site) {
        double[] result = new double[this.ctmc.nCharacter(site)];
        for (int i = 0; i < this.cache[site].length; ++i) {
            result[i] = this.cache[site][i];
        }
        return result;
    }

    public boolean isMissing(int site) {
        return MathUtils.close(this.ctmc.nCharacter(site), SloppyMath.logAdd(this.cache[site]));
    }

    public DiscreteModelCalculator(double[][] cache, CTMC ctmc, double logLikelihood) {
        if (!allowDiscreteModelCalculator) {
            throw new RuntimeException("Should use FastDiscreteModelCalculator");
        }
        this.cache = cache;
        this.ctmc = ctmc;
        this.logLikelihood = logLikelihood;
    }

    private DiscreteModelCalculator(double[][] cache, CTMC ctmc) {
        if (!allowDiscreteModelCalculator) {
            throw new RuntimeException("Should use FastDiscreteModelCalculator");
        }
        this.cache = cache;
        this.ctmc = ctmc;
        this.logLikelihood = this.extendLogLikelihood(0.0);
    }

    public static Pair<Double, Double> k2pDistanceSuffStat(DiscreteModelCalculator dmc1, DiscreteModelCalculator dmc2) {
        return DiscreteModelCalculator.k2pDistanceSuffStat(dmc1.cache, dmc2.cache);
    }

    public static Pair<Double, Double> k2pDistanceSuffStat(double[][] cache1, double[][] cache2) {
        double trans = 0.0;
        double transv = 0.0;
        double count = 0.0;
        if (cache1.length != cache2.length) {
            throw new RuntimeException();
        }
        for (int i = 0; i < cache1.length; ++i) {
            int i2;
            count += 1.0;
            int i1 = DiscreteModelCalculator.getMAP(cache1[i]);
            if (i1 == (i2 = DiscreteModelCalculator.getMAP(cache2[i]))) continue;
            if (RateMatrixLoader.isTransition(i1, i2)) {
                trans += 1.0;
                continue;
            }
            transv += 1.0;
        }
        return Pair.makePair(trans / count, transv / count);
    }

    private static int getMAP(double[] cache) {
        int argmax = -1;
        double max = Double.NEGATIVE_INFINITY;
        for (int c = 0; c < cache.length; ++c) {
            double cval = cache[c];
            if (!(cval > max)) continue;
            max = cval;
            argmax = c;
        }
        return argmax;
    }

    public static double starCombine(List<Double> deltas, List<DiscreteModelCalculator> calcs) {
        double result = 0.0;
        CTMC ctmc = CollUtils.pick(calcs).ctmc;
        if (!ctmc.isSiteTied()) {
            throw new RuntimeException();
        }
        double[] initD = ctmc.getInitialDistribution(0);
        ArrayList prs = CollUtils.list();
        for (double delta : deltas) {
            prs.add(ctmc.getTransitionPr(0, delta));
        }
        for (int s = 0; s < ctmc.nSites(); ++s) {
            double siteResult = Double.NEGATIVE_INFINITY;
            int nChars = ctmc.nCharacter(s);
            for (int x = 0; x < nChars; ++x) {
                double sumOfFs = 0.0;
                for (int i = 0; i < deltas.size(); ++i) {
                    double currentF = Double.NEGATIVE_INFINITY;
                    for (int y = 0; y < nChars; ++y) {
                        currentF = SloppyMath.logAdd(currentF, calcs.get((int)i).cache[s][y] + Math.log(((double[][])prs.get(i))[x][y]));
                    }
                    sumOfFs += currentF;
                }
                siteResult = SloppyMath.logAdd(siteResult, Math.log(initD[x]) + sumOfFs);
            }
            result += siteResult;
        }
        return result;
    }

    @Override
    public double extendLogLikelihood(double delta) {
        double logLikelihood = 0.0;
        double[][] pr = this.ctmc.isSiteTied() ? this.ctmc.getTransitionPr(0, delta) : (double[][])null;
        double[] initD = this.ctmc.isSiteTied() ? this.ctmc.getInitialDistribution(0) : null;
        for (int s = 0; s < this.ctmc.nSites(); ++s) {
            initD = this.ctmc.isSiteTied() ? initD : this.ctmc.getInitialDistribution(s);
            pr = this.ctmc.isSiteTied() ? pr : this.ctmc.getTransitionPr(s, delta);
            double[] array = new double[this.ctmc.nCharacter(s)];
            for (int y = 0; y < this.ctmc.nCharacter(s); ++y) {
                double sum = 0.0;
                for (int x = 0; x < this.ctmc.nCharacter(s); ++x) {
                    sum += pr[x][y] * initD[x];
                }
                array[y] = this.cache[s][y] + Math.log(sum);
            }
            double siteLogLikelihood = SloppyMath.logAdd(array);
            if (!this.isSiteLogLikelihoodValid(siteLogLikelihood)) continue;
            logLikelihood += siteLogLikelihood;
        }
        return logLikelihood;
    }

    public Object calculate(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double v1, double v2, boolean isPeek) {
        double[][] result = isPeek ? (double[][])null : Dataset.DatasetUtils.createObsArray(this.ctmc);
        DiscreteModelCalculator cache1 = (DiscreteModelCalculator)node1;
        DiscreteModelCalculator cache2 = (DiscreteModelCalculator)node2;
        double logLikelihood = 0.0;
        boolean isSiteTied = this.ctmc.isSiteTied();
        double[][] pr1 = isSiteTied ? this.ctmc.getTransitionPr(0, v1) : (double[][])null;
        double[][] pr2 = isSiteTied ? this.ctmc.getTransitionPr(0, v2) : (double[][])null;
        double[] initD = isSiteTied ? this.ctmc.getInitialDistribution(0) : null;
        int ns = this.ctmc.nSites();
        double[] genericWorkingArray = isSiteTied ? new double[this.ctmc.nCharacter(0)] : null;
        for (int s = 0; s < ns; ++s) {
            if (!isSiteTied) {
                pr1 = this.ctmc.getTransitionPr(s, v1);
                pr2 = this.ctmc.getTransitionPr(s, v2);
                initD = this.ctmc.getInitialDistribution(s);
            }
            double[] array = isSiteTied ? genericWorkingArray : new double[this.ctmc.nCharacter(s)];
            double siteLogLikelihood = Double.NEGATIVE_INFINITY;
            int ncs = this.ctmc.nCharacter(s);
            for (int v = 0; v < ncs; ++v) {
                double prod = 0.0;
                prod += this.sum(s, pr1[v], cache1, array);
                prod += this.sum(s, pr2[v], cache2, array);
                if (!isPeek) {
                    result[s][v] = prod;
                }
                siteLogLikelihood = SloppyMath.logAdd(siteLogLikelihood, prod + Math.log(initD[v]));
            }
            if (!this.isSiteLogLikelihoodValid(siteLogLikelihood)) continue;
            logLikelihood += siteLogLikelihood;
        }
        if (isPeek) {
            return logLikelihood;
        }
        return new DiscreteModelCalculator(result, this.ctmc, logLikelihood);
    }

    private boolean isSiteLogLikelihoodValid(double number) {
        return number <= 1.0E-6;
    }

    private double sum(int site, double[] prs, DiscreteModelCalculator cacheAtLeaf, double[] workingArray) {
        double[] cal = cacheAtLeaf.cache[site];
        int ncs = this.ctmc.nCharacter(site);
        for (int w = 0; w < ncs; ++w) {
            workingArray[w] = Math.log(prs[w]) + cal[w];
        }
        return SloppyMath.logAdd(workingArray);
    }

    @Override
    public double logLikelihood() {
        return this.logLikelihood;
    }

    @Override
    public boolean isReversible() {
        return false;
    }

    @Override
    public double peekCoalescedLogLikelihood(LikelihoodModelCalculator node1, LikelihoodModelCalculator node2, double delta1, double delta2) {
        return (Double)this.calculate(node1, node2, delta1, delta2, true);
    }
}

