/*
 * Decompiled with CFR 0.152.
 */
package fenchel.factor.multisitecat;

public class LowLevelOperations {
    public static double marginalize(double[] binary, double[] unary, double unaryLogNorm, int nSites, int nCat, int nChars, double[] result) {
        if (unary.length != nSites * nCat * nChars) {
            throw new RuntimeException();
        }
        if (binary.length != nCat * nChars * nChars) {
            throw new RuntimeException();
        }
        if (result.length != unary.length) {
            throw new RuntimeException();
        }
        double logNorm = 0.0;
        double tempNorm = 1.0;
        for (int site = 0; site < nSites; ++site) {
            double outerSum = 0.0;
            for (int c = 0; c < nCat; ++c) {
                for (int y = 0; y < nChars; ++y) {
                    double sum = 0.0;
                    for (int x = 0; x < nChars; ++x) {
                        sum += unary[LowLevelOperations.unaryIdx(site, c, x, nCat, nChars)] * binary[LowLevelOperations.binaryIdx(c, x, y, nChars)];
                    }
                    result[LowLevelOperations.unaryIdx((int)site, (int)c, (int)y, (int)nCat, (int)nChars)] = sum;
                    outerSum += sum;
                }
            }
            double inverseOuter = 1.0 / outerSum;
            for (int c = 0; c < nCat; ++c) {
                for (int y = 0; y < nChars; ++y) {
                    int n = LowLevelOperations.unaryIdx(site, c, y, nCat, nChars);
                    result[n] = result[n] * inverseOuter;
                }
            }
            double abs = Math.abs(tempNorm *= outerSum);
            if (!(abs < 1.0E-100) && !(abs > 1.0E100) && site != nSites - 1) continue;
            logNorm += Math.log(tempNorm);
            tempNorm = 1.0;
        }
        return logNorm + unaryLogNorm;
    }

    public static double pointwiseMultiply(double[][] unaries, double[] unaryNorms, int nSites, int nCat, int nChars, double[] result) {
        int nFactors = unaries.length;
        if (nFactors != unaries.length) {
            throw new RuntimeException();
        }
        for (int f = 0; f < nFactors; ++f) {
            if (unaries[f].length == nSites * nCat * nChars) continue;
            throw new RuntimeException();
        }
        if (result.length != unaries[0].length) {
            throw new RuntimeException();
        }
        double logNorm = 0.0;
        double tempNorm = 1.0;
        for (int site = 0; site < nSites; ++site) {
            double outerSum = 0.0;
            for (int c = 0; c < nCat; ++c) {
                for (int y = 0; y < nChars; ++y) {
                    double prod = 1.0;
                    for (int f = 0; f < nFactors; ++f) {
                        prod *= unaries[f][LowLevelOperations.unaryIdx(site, c, y, nCat, nChars)];
                    }
                    result[LowLevelOperations.unaryIdx((int)site, (int)c, (int)y, (int)nCat, (int)nChars)] = prod;
                    outerSum += prod;
                }
            }
            double inverseOuter = 1.0 / outerSum;
            for (int c = 0; c < nCat; ++c) {
                for (int y = 0; y < nChars; ++y) {
                    int n = LowLevelOperations.unaryIdx(site, c, y, nCat, nChars);
                    result[n] = result[n] * inverseOuter;
                }
            }
            double abs = Math.abs(tempNorm *= outerSum);
            if (!(abs < 1.0E-100) && !(abs > 1.0E100) && site != nSites - 1) continue;
            logNorm += Math.log(tempNorm);
            tempNorm = 1.0;
        }
        for (double subNorm : unaryNorms) {
            logNorm += subNorm;
        }
        return logNorm;
    }

    public static double[][][] pairwiseExpectations(double[] binary, double[] unarySrc, double[] unaryDest, int nSites, int nCat, int nChars) {
        if (unarySrc.length != nSites * nCat * nChars) {
            throw new RuntimeException();
        }
        if (binary.length != nCat * nChars * nChars) {
            throw new RuntimeException();
        }
        if (unarySrc.length != unaryDest.length) {
            throw new RuntimeException();
        }
        int totLen = nCat * nChars * nChars;
        double[] temp = new double[totLen];
        double[] temp2 = new double[totLen];
        for (int site = 0; site < nSites; ++site) {
            double norm = 0.0;
            for (int c = 0; c < nCat; ++c) {
                for (int y = 0; y < nChars; ++y) {
                    double currentDestUnaryPot = unaryDest[LowLevelOperations.unaryIdx(site, c, y, nCat, nChars)];
                    for (int x = 0; x < nChars; ++x) {
                        double current;
                        temp[LowLevelOperations.binaryIdx((int)c, (int)x, (int)y, (int)nChars)] = current = unarySrc[LowLevelOperations.unaryIdx(site, c, x, nCat, nChars)] * currentDestUnaryPot * binary[LowLevelOperations.binaryIdx(c, x, y, nChars)];
                        norm += current;
                    }
                }
            }
            double inverseNorm = 1.0 / norm;
            for (int i = 0; i < totLen; ++i) {
                int n = i;
                temp2[n] = temp2[n] + temp[i] * inverseNorm;
            }
        }
        double[][][] result = new double[nCat][nChars][nChars];
        for (int c = 0; c < nCat; ++c) {
            for (int x = 0; x < nChars; ++x) {
                for (int y = 0; y < nChars; ++y) {
                    result[c][x][y] = temp2[LowLevelOperations.binaryIdx(c, x, y, nChars)];
                }
            }
        }
        return result;
    }

    public static int binaryIdx(int c, int x, int y, int nChars) {
        return c * nChars * nChars + y * nChars + x;
    }

    public static int unaryIdx(int s, int c, int x, int nCat, int nChars) {
        return s * nCat * nChars + c * nChars + x;
    }
}

