/*
 * Decompiled with CFR 0.152.
 */
package fenchel.factor.multihmm;

import fenchel.algo.FactorGraphSumProduct;
import fenchel.factor.FactorUtils;
import fenchel.factor.multihmm.MultiInputHMM;
import fenchel.factor.multisites.MSFactorGraph;
import fenchel.factor.multisites.MSUnaryFactor;
import fenchel.factor.multisites.MSUnaryScaledFactor;
import hmm.RescaledBaumWelch;
import java.util.ArrayList;
import java.util.List;
import nuts.math.TrMtx;

public class MultiInputHMMComputations {
    public static Integer hiddenState(MultiInputHMM hmm) {
        return hmm.nInputs();
    }

    public static double logNorm(List<double[][]> input, MultiInputHMM hmm) {
        return MultiInputHMMComputations.getBaumWelchAlgorithm(input, hmm, true).logll();
    }

    private static RescaledBaumWelch getBaumWelchAlgorithm(List<double[][]> input, MultiInputHMM hmm, boolean normMatters) {
        TrMtx.doChecks = false;
        double[][] trans = hmm.transitions;
        double[] initDist = hmm.initialDistribution;
        double[] endDist = hmm.endingDistribution;
        FactorGraphSumProduct<Integer> tempFactorGraph = MultiInputHMMComputations.createTemporaryFactorGraph(input, null, hmm);
        double[][] processedInput = ((MSUnaryFactor)tempFactorGraph.moment(MultiInputHMMComputations.hiddenState(hmm))).normalizedValues();
        if (normMatters) {
            double norm = Math.exp(tempFactorGraph.moment(MultiInputHMMComputations.hiddenState(hmm)).logNorm());
            for (int i = 0; i < processedInput.length; ++i) {
                int j = 0;
                while (j < processedInput[0].length) {
                    double[] dArray = processedInput[i];
                    int n = j++;
                    dArray[n] = dArray[n] * norm;
                }
            }
        }
        int len = processedInput.length;
        for (int i = 0; i < endDist.length; ++i) {
            double[] dArray = processedInput[len - 1];
            int n = i;
            dArray[n] = dArray[n] * endDist[i];
        }
        RescaledBaumWelch bw = new RescaledBaumWelch();
        bw.compute(initDist, trans, processedInput);
        return bw;
    }

    public static List<MSUnaryScaledFactor> computeMoments(List<double[][]> input, MultiInputHMM hmm) {
        RescaledBaumWelch bw = MultiInputHMMComputations.getBaumWelchAlgorithm(input, hmm, false);
        double[][] moments = bw.allOneNodeMoments();
        FactorGraphSumProduct<Integer> outputFactorGraph = MultiInputHMMComputations.createTemporaryFactorGraph(null, moments, hmm);
        ArrayList<MSUnaryScaledFactor> result = new ArrayList<MSUnaryScaledFactor>();
        for (int i = 0; i < hmm.nInputs(); ++i) {
            result.add((MSUnaryScaledFactor)outputFactorGraph.moment(i));
        }
        return result;
    }

    public static FactorGraphSumProduct<Integer> createTemporaryFactorGraph(List<double[][]> input, double[][] output, MultiInputHMM hmm) {
        MSFactorGraph<Integer> fg = FactorUtils.newFactorGraph();
        for (int i = 0; i < hmm.nInputs(); ++i) {
            if (input != null) {
                fg.addUnary(i, input.get(i));
            }
            fg.addBinary(MultiInputHMMComputations.hiddenState(hmm), i, hmm.emissions[i], hmm.transposedEmissions[i]);
        }
        if (output != null) {
            fg.addUnary(MultiInputHMMComputations.hiddenState(hmm), output);
        }
        FactorGraphSumProduct<Integer> miniSP = new FactorGraphSumProduct<Integer>();
        miniSP.init(fg);
        return miniSP;
    }
}

