/*
 * Decompiled with CFR 0.152.
 */
package unsupalg;

import functional.Fct;
import nuts.math.TrMtx;
import poly.Complex;

public class HMMUnivarPoly
implements Fct<TrMtx, Complex> {
    private final TrMtx obsTr;
    private TrMtx stTr;
    private final int[][] observations;
    private final int nChains;
    private final int cLength;
    private Complex[][][] M;
    private Complex[][][] N;
    private Complex[] psi;
    private Complex[] psiPrime;

    public HMMUnivarPoly(int[][] observations, TrMtx obsTr) {
        this.observations = observations;
        this.obsTr = obsTr;
        this.nChains = observations.length;
        this.cLength = observations[0].length;
        for (int i = 0; i < observations.length; ++i) {
            if (observations[i].length == this.cLength) continue;
            throw new RuntimeException("Assume all chains have same length");
        }
    }

    @Override
    public Complex applyTo(TrMtx stTr) {
        this.stTr = stTr;
        this.M = new Complex[this.nChains][this.cLength][this.nSt()];
        this.N = new Complex[this.nChains][this.cLength][this.nSt()];
        for (int j = 0; j < this.nChains; ++j) {
            for (int y = 0; y < this.nSt(); ++y) {
                this.M[j][this.cLength - 1][this.nSt()] = Complex.plusOne();
                this.N[j][this.cLength - 1][this.nSt()] = Complex.zero();
            }
        }
        this.computeM();
        this.computeN();
        this.computePsis();
        return this.deriv();
    }

    private Complex deriv() {
        Complex sum = Complex.zero();
        Complex prod = Complex.plusOne();
        for (int j = 0; j < this.nChains; ++j) {
            sum = Complex.plus(sum, Complex.over(this.psiPrime[j], this.psi[j]));
            prod = Complex.times(prod, this.psi[j]);
        }
        return Complex.times(sum, prod);
    }

    private void computePsis() {
        for (int j = 0; j < this.nChains; ++j) {
            this.psi[j] = Complex.zero();
            this.psiPrime[j] = Complex.zero();
            for (int y = 0; y < this.nSt(); ++y) {
                this.psi[j] = Complex.plus(this.psi[j], Complex.times(0.5 * this.obsTr.p(y, this.observations[j][0]), this.M[j][0][y]));
                this.psiPrime[j] = Complex.plus(this.psi[j], Complex.times(0.5 * this.obsTr.p(y, this.observations[j][0]), this.N[j][0][y]));
            }
        }
    }

    private void computeM() {
        for (int j = 0; j < this.nChains; ++j) {
            for (int i = this.cLength - 2; i >= 0; --i) {
                for (int y = 0; y < this.nSt(); ++y) {
                    this.computeM(j, i, y);
                }
            }
        }
    }

    private void computeM(int j, int i, int y) {
        double innerSum = 0.0;
        for (int yp = 0; yp < this.nSt() - 1; ++yp) {
            innerSum += this.stTr.p(y, yp);
        }
        Complex constraintTerm = Complex.times((1.0 - innerSum) * this.obsTr.p(this.nSt() - 1, this.observations[j][i]), this.M[j][i][this.nSt() - 1]);
        Complex sum = Complex.zero();
        for (int yn = 0; yn < this.nSt() - 1; ++yn) {
            sum = Complex.plus(sum, Complex.times(this.stTr.p(y, yn) * this.obsTr.p(yn, this.observations[j][i]), this.M[j][i][yn]));
        }
        this.M[j][i - 1][y] = Complex.plus(constraintTerm, sum);
    }

    private void computeN() {
        for (int j = 0; j < this.nChains; ++j) {
            for (int i = this.cLength - 2; i >= 0; --i) {
                this.computeNYZero(j, i);
                for (int y = 1; y < this.nSt(); ++y) {
                    this.computeNYNotZero(j, i, y);
                }
            }
        }
    }

    private void computeNYZero(int j, int i) {
        double x = this.stTr.p(0, 0);
        double innerSum = 0.0;
        for (int yp = 1; yp < this.nSt() - 1; ++yp) {
            innerSum += this.stTr.p(0, yp);
        }
        Complex constraintTerm = Complex.times((1.0 - innerSum) * this.obsTr.p(this.nSt() - 1, this.observations[j][i]), this.N[j][i][this.nSt() - 1]);
        Complex chainConstraintTerm = Complex.plus(Complex.times(this.obsTr.p(this.nSt() - 1, this.observations[j][i]), this.M[j][i][this.nSt() - 1]), Complex.times(this.obsTr.p(this.nSt() - 1, this.observations[j][i]), Complex.times(x, this.N[j][i][this.nSt() - 1])));
        constraintTerm = Complex.minus(constraintTerm, chainConstraintTerm);
        Complex sum = Complex.zero();
        for (int yn = 1; yn < this.nSt() - 1; ++yn) {
            sum = Complex.plus(sum, Complex.times(this.stTr.p(0, yn) * this.obsTr.p(yn, this.observations[j][i]), this.N[j][i][yn]));
        }
        Complex chainMain = Complex.plus(Complex.times(this.obsTr.p(0, this.observations[j][i]), this.M[j][i][0]), Complex.times(this.obsTr.p(0, this.observations[j][i]), Complex.times(x, this.N[j][i][0])));
        sum = Complex.plus(sum, chainMain);
        this.N[j][i - 1][0] = Complex.plus(constraintTerm, sum);
    }

    private void computeNYNotZero(int j, int i, int y) {
        double innerSum = 0.0;
        for (int yp = 0; yp < this.nSt() - 1; ++yp) {
            innerSum += this.stTr.p(y, yp);
        }
        Complex constraintTerm = Complex.times((1.0 - innerSum) * this.obsTr.p(this.nSt() - 1, this.observations[j][i]), this.N[j][i][this.nSt() - 1]);
        Complex sum = Complex.zero();
        for (int yn = 0; yn < this.nSt() - 1; ++yn) {
            sum = Complex.plus(sum, Complex.times(this.stTr.p(y, yn) * this.obsTr.p(yn, this.observations[j][i]), this.N[j][i][yn]));
        }
        this.N[j][i - 1][y] = Complex.plus(constraintTerm, sum);
    }

    private int nSt() {
        return this.stTr.nSrcStates();
    }

    private int nObs() {
        return this.obsTr.nDestStates();
    }
}

