/*
 * Decompiled with CFR 0.152.
 */
package unsupalg.hmm;

import fig.basic.NumUtils;
import fig.basic.Pair;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.util.MathUtils;
import unsupalg.hmm.EStep;
import unsupalg.hmm.Param;
import unsupalg.hmm.ParamUtils;

public class RescaledBaumWelch
implements EStep {
    private List<Integer> observations;
    private int length;
    private Param param;
    private double[][] back;
    private double[][] forw;
    private double[] resc;

    @Override
    public void compute(List<Integer> observations, Param param) {
        int t;
        this.init(observations, param);
        for (t = 0; t < this.length; ++t) {
            this.forw(t);
        }
        for (t = this.length - 1; t >= 0; --t) {
            this.back(t);
        }
    }

    private void init(List<Integer> observations, Param param) {
        this.observations = observations;
        this.length = observations.size();
        this.param = param;
        this.back = new double[this.length][param.nStates()];
        this.forw = new double[this.length][param.nStates()];
        this.resc = new double[this.length];
    }

    @Override
    public double logll() {
        double sum = 0.0;
        for (int t = 0; t < this.length; ++t) {
            sum += Math.log(this.resc[t]);
        }
        return sum;
    }

    @Override
    public double[] oneNodePosterior(int t) {
        return MathUtils.pointwiseMultiply(this.back[t], this.forw[t]);
    }

    @Override
    public double[][] twoNodesPosterior(int t) {
        double[][] result = new double[this.param.nStates()][this.param.nStates()];
        for (int s1 = 0; s1 < this.param.nStates(); ++s1) {
            for (int s2 = 0; s2 < this.param.nStates(); ++s2) {
                result[s1][s2] = this.forw[t][s1] * this.back[t + 1][s2] * this.param.transMtx.p(s1, s2) * this.param.emiMtx.p(s2, this.obs(t + 1)) / this.resc[t + 1];
            }
        }
        return result;
    }

    private void forw(int t) {
        double norm = 0.0;
        for (int s = 0; s < this.param.nStates(); ++s) {
            this.forw[t][s] = this.unscaledForw(t, s);
            norm += this.forw[t][s];
        }
        this.resc[t] = norm;
        NumUtils.normalize(this.forw[t]);
    }

    private final double unscaledForw(int t, int s) {
        double trScore;
        if (t == 0) {
            trScore = this.param.initVec.p(s);
        } else {
            trScore = 0.0;
            for (int ps = 0; ps < this.param.nStates(); ++ps) {
                trScore += this.param.transMtx.p(ps, s) * this.forw[t - 1][ps];
            }
        }
        return trScore * this.param.emiMtx.p(s, this.obs(t));
    }

    private void back(int t) {
        for (int s = 0; s < this.param.nStates(); ++s) {
            this.back[t][s] = this.back(t, s);
        }
    }

    private double back(int t, int s) {
        if (t == this.length - 1) {
            return 1.0;
        }
        double sum = 0.0;
        for (int ns = 0; ns < this.param.nStates(); ++ns) {
            sum += this.param.transMtx.p(s, ns) * this.param.emiMtx.p(ns, this.obs(t + 1)) * this.back[t + 1][ns];
        }
        return sum / this.resc[t + 1];
    }

    private int obs(int t) {
        return this.observations.get(t);
    }

    @Override
    public int length() {
        return this.length;
    }

    @Override
    public List<Integer> observations() {
        return Collections.unmodifiableList(this.observations);
    }

    public static void main(String[] args) {
        Random rand = new Random(2L);
        Param params = ParamUtils.randomUniParam(rand, 60, 40);
        int length = 20000;
        Pair<List<Integer>, List<Integer>> pair = ParamUtils.generateStateObservations(rand, params, length);
        RescaledBaumWelch e = new RescaledBaumWelch();
        double sum = 0.0;
        double N = 0.0;
        for (int i = 0; i < 10; ++i) {
            long t1 = System.currentTimeMillis();
            e.compute(pair.getSecond(), params);
            long t2 = System.currentTimeMillis();
            sum += (double)(t2 - t1);
            N += 1.0;
        }
        System.out.println("time: " + sum / N);
    }
}

