/*
 * Decompiled with CFR 0.152.
 */
package hmm;

import fig.basic.NumUtils;
import fig.basic.Pair;
import hmm.EStep;
import hmm.Param;
import hmm.ParamUtils;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.math.TrMtx;
import nuts.util.MathUtils;

public class RescaledBaumWelch
implements EStep {
    private List<Integer> observations;
    private int length;
    private Param param;
    private double[][] unaryPotentials;
    private boolean useUnaryPot;
    private double[][] back;
    private double[][] forw;
    private double[] resc;

    @Override
    public void compute(List<Integer> observations, Param param) {
        this.init(observations, param);
        this._compute();
    }

    public void compute(double[] init, double[][] trans, double[][] unaryPotentials) {
        this.init(init, trans, unaryPotentials);
        this._compute();
    }

    private void _compute() {
        int t;
        for (t = 0; t < this.length; ++t) {
            this.forw(t);
        }
        for (t = this.length - 1; t >= 0; --t) {
            this.back(t);
        }
    }

    private void init(List<Integer> observations, Param param) {
        this.unaryPotentials = null;
        this.useUnaryPot = false;
        this.observations = observations;
        this.length = observations.size();
        this.param = param;
        this.initTables(this.length, param.nStates());
    }

    private void initTables(int len, int nstates) {
        this.back = new double[this.length][this.param.nStates()];
        this.forw = new double[this.length][this.param.nStates()];
        this.resc = new double[this.length];
    }

    private void init(double[] init, double[][] trans, double[][] unaryPotentials) {
        this.unaryPotentials = unaryPotentials;
        this.useUnaryPot = true;
        this.observations = null;
        this.length = unaryPotentials.length;
        this.param = this.createParam(init, trans);
        this.initTables(this.length, trans.length);
    }

    private Param createParam(double[] initAr, double[][] transAr) {
        TrMtx tr = new TrMtx(transAr);
        double[][] emiAr = new double[transAr.length][1];
        for (int i = 0; i < transAr.length; ++i) {
            emiAr[i][0] = 1.0;
        }
        TrMtx emi = new TrMtx(emiAr);
        TrMtx.PrVec init = new TrMtx.PrVec(initAr);
        return new Param(init, tr, emi);
    }

    @Override
    public double logll() {
        double sum = 0.0;
        for (int t = 0; t < this.length; ++t) {
            sum += Math.log(this.resc[t]);
        }
        return sum;
    }

    public double[][] allOneNodeMoments() {
        double[][] result = new double[this.length][];
        for (int t = 0; t < this.length; ++t) {
            result[t] = this.oneNodePosterior(t);
        }
        return result;
    }

    @Override
    public final double[] oneNodePosterior(int t) {
        return MathUtils.pointwiseMultiply(this.back[t], this.forw[t]);
    }

    @Override
    public double[][] twoNodesPosterior(int t) {
        double[][] result = new double[this.param.nStates()][this.param.nStates()];
        for (int s1 = 0; s1 < this.param.nStates(); ++s1) {
            for (int s2 = 0; s2 < this.param.nStates(); ++s2) {
                result[s1][s2] = this.forw[t][s1] * this.back[t + 1][s2] * this.param.transMtx.p(s1, s2) * this.emissionPotential(t + 1, s2) / this.resc[t + 1];
            }
        }
        return result;
    }

    private final double emissionPotential(int t, int s) {
        return this.useUnaryPot ? this.unaryPotentials[t][s] : this.param.emiMtx.p(s, this.obs(t));
    }

    private void forw(int t) {
        double norm = 0.0;
        for (int s = 0; s < this.param.nStates(); ++s) {
            this.forw[t][s] = this.unscaledForw(t, s);
            norm += this.forw[t][s];
        }
        this.resc[t] = norm;
        NumUtils.normalize(this.forw[t]);
    }

    private final double unscaledForw(int t, int s) {
        double trScore;
        if (t == 0) {
            trScore = this.param.initVec.p(s);
        } else {
            trScore = 0.0;
            for (int ps = 0; ps < this.param.nStates(); ++ps) {
                trScore += this.param.transMtx.p(ps, s) * this.forw[t - 1][ps];
            }
        }
        return trScore * this.emissionPotential(t, s);
    }

    private void back(int t) {
        for (int s = 0; s < this.param.nStates(); ++s) {
            this.back[t][s] = this.back(t, s);
        }
    }

    private double back(int t, int s) {
        if (t == this.length - 1) {
            return 1.0;
        }
        double sum = 0.0;
        for (int ns = 0; ns < this.param.nStates(); ++ns) {
            sum += this.param.transMtx.p(s, ns) * this.emissionPotential(t + 1, ns) * this.back[t + 1][ns];
        }
        return sum / this.resc[t + 1];
    }

    private int obs(int t) {
        return this.observations.get(t);
    }

    @Override
    public int length() {
        return this.length;
    }

    @Override
    public List<Integer> observations() {
        return Collections.unmodifiableList(this.observations);
    }

    public static void main(String[] args) {
        Random rand = new Random(2L);
        Param params = ParamUtils.randomUniParam(rand, 60, 40);
        int length = 100;
        Pair<List<Integer>, List<Integer>> pair = ParamUtils.generateStateObservations(rand, params, length);
        RescaledBaumWelch e = new RescaledBaumWelch();
        double sum = 0.0;
        double N2 = 0.0;
        for (int i = 0; i < 10; ++i) {
            long t1 = System.currentTimeMillis();
            e.compute(pair.getSecond(), params);
            long t2 = System.currentTimeMillis();
            sum += (double)(t2 - t1);
            N2 += 1.0;
        }
        System.out.println("time: " + sum / N2);
        for (int t = 0; t < length; ++t) {
            System.out.println("" + t + ":" + Arrays.toString(e.oneNodePosterior(t)) + " actual:" + pair.getFirst().get(t) + ", sum: " + MathUtils.sum(e.oneNodePosterior(t)));
        }
    }
}

