/*
 * Decompiled with CFR 0.152.
 */
package fenchel.factor.multisites;

import fenchel.factor.UnaryFactor;
import fenchel.factor.multisites.MSUnaryFactor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import nuts.math.MeasureZeroException;
import nuts.tui.Table;

public final class MSUnaryScaledFactor
implements MSUnaryFactor {
    private double[][] _values;
    private double _logNormalization;
    private double[][] _delayedConvolution;
    private final boolean isNormalized;

    public MSUnaryScaledFactor delayedMarginalization(double[][] binaryFactor) {
        this.ensureConvolutionsProcessed();
        if (binaryFactor[0].length != this._values[0].length) {
            throw new RuntimeException();
        }
        return new MSUnaryScaledFactor(this._values, this._logNormalization, binaryFactor);
    }

    public static MSUnaryScaledFactor buildFactor(double[][] siteToState) {
        return new MSUnaryScaledFactor(siteToState);
    }

    private MSUnaryScaledFactor(double[][] siteToState) {
        this.isNormalized = false;
        this._logNormalization = 0.0;
        this._delayedConvolution = null;
        this._values = siteToState;
    }

    private MSUnaryScaledFactor(double[][] normalizedValues, double logNormalization, double[][] delayedConvolution) {
        this._values = normalizedValues;
        this._logNormalization = logNormalization;
        this._delayedConvolution = delayedConvolution;
        this.isNormalized = true;
    }

    private MSUnaryScaledFactor(double[][] normalizedValues, double logNormalization) {
        this(normalizedValues, logNormalization, null);
    }

    @Override
    public int nSites() {
        return this._values.length;
    }

    private boolean hasDelayedConvolutions() {
        return this._delayedConvolution != null;
    }

    @Override
    public UnaryFactor multiply(List<UnaryFactor> otherFactors) {
        int nSites = this.nSites();
        int nDestStates = this.nStates();
        int nFactors = otherFactors.size() + 1;
        ArrayList<UnaryFactor> delayeds = new ArrayList<UnaryFactor>();
        ArrayList<UnaryFactor> notDelayeds = new ArrayList<UnaryFactor>();
        if (this._delayedConvolution != null) {
            delayeds.add(this);
        } else {
            notDelayeds.add(this);
        }
        for (UnaryFactor f : otherFactors) {
            if (((MSUnaryScaledFactor)f).hasDelayedConvolutions()) {
                delayeds.add(f);
                continue;
            }
            notDelayeds.add(f);
        }
        int[] nSrcStates = new int[delayeds.size()];
        int nDelayedConvolutions = delayeds.size();
        double[][][] normalizedValues = new double[nFactors][][];
        double[][][] delayedConvolutions = new double[nDelayedConvolutions][][];
        double previousLogNorms = 0.0;
        for (int fIndex = 0; fIndex < nFactors; ++fIndex) {
            if (fIndex < nDelayedConvolutions) {
                MSUnaryScaledFactor f = (MSUnaryScaledFactor)delayeds.get(fIndex);
                normalizedValues[fIndex] = f._values;
                delayedConvolutions[fIndex] = f._delayedConvolution;
                nSrcStates[fIndex] = f._delayedConvolution[0].length;
                if (f._delayedConvolution.length != this.nStates()) {
                    throw new RuntimeException("" + f._delayedConvolution.length + " vs " + this.nStates());
                }
                previousLogNorms += f._logNormalization;
                continue;
            }
            int transformedIndex = fIndex - nDelayedConvolutions;
            MSUnaryScaledFactor f = (MSUnaryScaledFactor)notDelayeds.get(transformedIndex);
            normalizedValues[fIndex] = f._values;
            previousLogNorms += f._logNormalization;
        }
        double[][] result = new double[nSites][nDestStates];
        double logNorm = 0.0;
        double tempNorm = 1.0;
        for (int site = 0; site < nSites; ++site) {
            double[] currentResult = result[site];
            double currentNorm = this.processSite(nSrcStates, nDestStates, nFactors, nDelayedConvolutions, normalizedValues, delayedConvolutions, site, currentResult, true, Double.NaN);
            this.processSite(nSrcStates, nDestStates, nFactors, nDelayedConvolutions, normalizedValues, delayedConvolutions, site, currentResult, false, currentNorm);
            if (!this.isUnderOverFlow(tempNorm *= currentNorm, site, nSites)) continue;
            logNorm += Math.log(tempNorm);
            tempNorm = 1.0;
        }
        double newLogNorm = logNorm + previousLogNorms;
        return new MSUnaryScaledFactor(result, newLogNorm);
    }

    private final boolean isUnderOverFlow(double tempNorm, int site, int nSites) {
        double abs = Math.abs(tempNorm);
        return abs < 1.0E-100 || abs > 1.0E100 || site == nSites - 1;
    }

    private final double processSite(int[] nSrcStates, int nDestStates, int nFactors, int nDelayedConvolutions, double[][][] normalizedValues, double[][][] delayedConvolutions, int site, double[] result, boolean onlyComputeNorm, double norm) {
        double currentNorm = onlyComputeNorm ? 0.0 : Double.NaN;
        for (int destState = 0; destState < nDestStates; ++destState) {
            double prod = 1.0;
            for (int factor = 0; factor < nFactors; ++factor) {
                double currentFactorValue = 0.0;
                if (factor < nDelayedConvolutions) {
                    for (int srcState = 0; srcState < nSrcStates[factor]; ++srcState) {
                        currentFactorValue += delayedConvolutions[factor][destState][srcState] * normalizedValues[factor][site][srcState];
                    }
                } else {
                    currentFactorValue = normalizedValues[factor][site][destState];
                }
                prod *= currentFactorValue;
            }
            if (onlyComputeNorm) {
                currentNorm += prod;
                continue;
            }
            result[destState] = prod / norm;
        }
        if (onlyComputeNorm && !(currentNorm > 0.0)) {
            throw new MeasureZeroException("The normalization of the factor graph is not positive.  Encountered intermediate normalization value of " + currentNorm);
        }
        return currentNorm;
    }

    @Override
    public double logNorm() {
        this.ensureNormalized();
        return this._logNormalization;
    }

    @Override
    public double[][] normalizedValues() {
        this.ensureNormalized();
        return this._values;
    }

    @Override
    public int nStates() {
        if (!this.hasDelayedConvolutions()) {
            return this._values[0].length;
        }
        return this._delayedConvolution.length;
    }

    private void ensureConvolutionsProcessed() {
        if (this.hasDelayedConvolutions()) {
            this._processConvolutionsAndNormalize();
        }
    }

    private void _processConvolutionsAndNormalize() {
        MSUnaryScaledFactor processed = (MSUnaryScaledFactor)this.multiply(Collections.EMPTY_LIST);
        this._values = processed._values;
        this._logNormalization = processed._logNormalization;
        this._delayedConvolution = null;
    }

    private void ensureNormalized() {
        if (!this.isNormalized || this.hasDelayedConvolutions()) {
            this._processConvolutionsAndNormalize();
        }
    }

    @Override
    public double[][] values() {
        this.ensureConvolutionsProcessed();
        if (this._logNormalization != 0.0) {
            throw new RuntimeException("TODO: would have to create an array and rescale");
        }
        return this._values;
    }

    public String toString() {
        this.ensureConvolutionsProcessed();
        return "logNormalization=" + this._logNormalization + ";values:\n" + Table.toString(this._values);
    }
}

