/*
 * Decompiled with CFR 0.152.
 */
package unsupalg.hmm;

import java.util.List;
import nuts.util.ApproxConsistencyCheck;
import nuts.util.MathUtils;
import unsupalg.hmm.EStep;

public class SuffStat
implements ApproxConsistencyCheck.MetricElt {
    private double[][] sumOfTwoNodesPost;
    private double[] sumOfInitPost;
    private double[][] sumOfEmiPost;

    public SuffStat(int nStates, int nObs) {
        this.sumOfTwoNodesPost = new double[nStates][nStates];
        this.sumOfInitPost = new double[nStates];
        this.sumOfEmiPost = new double[nStates][nObs];
    }

    public void addFromFullObservation(List<Integer> state, List<Integer> obs) {
        if (state.size() != obs.size()) {
            throw new RuntimeException();
        }
        int n = state.get(0);
        this.sumOfInitPost[n] = this.sumOfInitPost[n] + 1.0;
        for (int t = 0; t < state.size(); ++t) {
            double[] dArray = this.sumOfEmiPost[state.get(t)];
            int n2 = obs.get(t);
            dArray[n2] = dArray[n2] + 1.0;
            if (t == state.size() - 1) continue;
            double[] dArray2 = this.sumOfTwoNodesPost[state.get(t)];
            int n3 = state.get(t + 1);
            dArray2[n3] = dArray2[n3] + 1.0;
        }
    }

    public void addFromPosterior(EStep eStep) {
        this.addInitial(eStep.oneNodePosterior(0));
        for (int t = 0; t < eStep.length() - 1; ++t) {
            this.add(eStep.oneNodePosterior(t), eStep.observations().get(t), eStep.twoNodesPosterior(t));
        }
        this.add(eStep.oneNodePosterior(eStep.length() - 1), eStep.observations().get(eStep.length() - 1));
    }

    private void addInitial(double[] firstNodePost) {
        for (int s = 0; s < firstNodePost.length; ++s) {
            int n = s;
            this.sumOfInitPost[n] = this.sumOfInitPost[n] + firstNodePost[s];
        }
    }

    private void add(double[] oneNodePost, int obs) {
        for (int s = 0; s < oneNodePost.length; ++s) {
            double[] dArray = this.sumOfEmiPost[s];
            int n = obs;
            dArray[n] = dArray[n] + oneNodePost[s];
        }
    }

    private void add(double[] oneNodePost, int obs, double[][] twoNodePost) {
        this.add(oneNodePost, obs);
        for (int s1 = 0; s1 < oneNodePost.length; ++s1) {
            for (int s2 = 0; s2 < oneNodePost.length; ++s2) {
                double[] dArray = this.sumOfTwoNodesPost[s1];
                int n = s2;
                dArray[n] = dArray[n] + twoNodePost[s1][s2];
            }
        }
    }

    public int nState() {
        return this.sumOfTwoNodesPost.length;
    }

    public int nObs() {
        return this.sumOfEmiPost[0].length;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("Init. node posterior:\n" + MathUtils.toString(this.sumOfInitPost));
        builder.append("Two nodes posterior:\n" + MathUtils.toString(this.sumOfTwoNodesPost));
        builder.append("Emission posterior:\n" + MathUtils.toString(this.sumOfEmiPost));
        return builder.toString();
    }

    @Override
    public double d(Object other) {
        if (other == this) {
            return 0.0;
        }
        if (!(other instanceof SuffStat)) {
            return Double.POSITIVE_INFINITY;
        }
        SuffStat cast = (SuffStat)other;
        if (this.nState() != cast.nState() || this.nObs() != cast.nObs()) {
            return Double.POSITIVE_INFINITY;
        }
        return MathUtils.frobeniusInfty(this.sumOfTwoNodesPost, cast.sumOfTwoNodesPost) + MathUtils.frobeniusInfty(this.sumOfEmiPost, cast.sumOfEmiPost) + MathUtils.frobeniusInfty(this.sumOfInitPost, cast.sumOfInitPost);
    }

    public double[][] getSumOfEmiPost() {
        return MathUtils.clone(this.sumOfEmiPost);
    }

    public double[] getSumOfInitPost() {
        return (double[])this.sumOfInitPost.clone();
    }

    public double[][] getSumOfTwoNodesPost() {
        return MathUtils.clone(this.sumOfTwoNodesPost);
    }
}

