/*
 * Decompiled with CFR 0.152.
 */
package unsupalg;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.math.TrMtx;
import unsupalg.hmm.MStep;
import unsupalg.hmm.Param;
import unsupalg.hmm.ParamUtils;
import unsupalg.hmm.RescaledBaumWelch;
import unsupalg.hmm.SuffStat;

public class PlotHMMLineParam {
    public static void main(String[] args) {
        RescaledBaumWelch e;
        Random rand = new Random(2L);
        int nState = 2;
        int nObs = 2;
        Param truth = ParamUtils.randomUniEmiDetTrParam(rand, nState, nObs, 0.0);
        System.out.println("TRUE PARAMS:\n" + truth.toString());
        List<List<Integer>> observations = PlotHMMLineParam.observations(truth, 500, 500, rand);
        Param currentParam = ParamUtils.randomUniParam(rand, nState, nObs);
        System.out.println("INIT PARAMS:\n" + currentParam.toString());
        for (int iter = 0; iter < 20; ++iter) {
            SuffStat suffStat = new SuffStat(nState, nObs);
            double ll = 0.0;
            for (List<Integer> sequence : observations) {
                RescaledBaumWelch e2 = new RescaledBaumWelch();
                e2.compute(sequence, currentParam);
                suffStat.addFromPosterior(e2);
                ll += e2.logll();
            }
            System.out.println("LL: " + ll);
            MStep m = new MStep();
            currentParam = m.compute(suffStat);
        }
        System.out.println("PARAMS FOUND:\n" + currentParam.toString());
        List<Param> xs = PlotHMMLineParam.multiVary(currentParam, 20);
        int i = 0;
        for (Param x : xs) {
            double y = 0.0;
            for (List<Integer> sequence : observations) {
                e = new RescaledBaumWelch();
                e.compute(sequence, x);
                y += e.logll();
            }
            System.out.println(x.emiMtx.p(0, 0) + " " + x.emiMtx.p(1, 0) + " " + y);
            ++i;
        }
        xs = PlotHMMLineParam.multiVaryTrans(currentParam, 20);
        i = 0;
        for (Param x : xs) {
            double y = 0.0;
            for (List<Integer> sequence : observations) {
                e = new RescaledBaumWelch();
                e.compute(sequence, x);
                y += e.logll();
            }
            System.out.println(x.transMtx.p(0, 0) + " " + x.transMtx.p(1, 0) + " " + y);
            ++i;
        }
    }

    public static List<List<Integer>> observations(Param param, int n, int l, Random rand) {
        ArrayList<List<Integer>> result = new ArrayList<List<Integer>>();
        for (int i = 0; i < n; ++i) {
            result.add(ParamUtils.generateObservations(rand, param, l));
        }
        return result;
    }

    public static List<Param> varyEmi(Param base, int numberOfIncr) {
        ArrayList<Param> result = new ArrayList<Param>();
        double total = base.emiMtx.p(0, 0) + base.emiMtx.p(0, 1);
        double incr = total / (double)numberOfIncr;
        for (double value = 0.0; value <= total; value += incr) {
            double[][] baseAr = base.emiMtx.arrayCopy();
            baseAr[0][0] = value;
            baseAr[0][1] = total - value;
            result.add(new Param(base.initVec, base.transMtx, new TrMtx(baseAr)));
        }
        return result;
    }

    public static List<Param> multiVary(Param base, int numberOfIncr) {
        ArrayList<Param> result = new ArrayList<Param>();
        double total1 = base.emiMtx.p(0, 0) + base.emiMtx.p(0, 1);
        double total2 = base.emiMtx.p(1, 0) + base.emiMtx.p(1, 1);
        double incr1 = total1 / (double)numberOfIncr;
        double incr2 = total2 / (double)numberOfIncr;
        for (double value1 = 0.0; value1 <= total1; value1 += incr1) {
            for (double value2 = 0.0; value2 <= total2; value2 += incr2) {
                double[][] baseAr = base.emiMtx.arrayCopy();
                baseAr[0][0] = value1;
                baseAr[0][1] = total1 - value1;
                baseAr[1][0] = value2;
                baseAr[1][1] = total2 - value2;
                result.add(new Param(base.initVec, base.transMtx, new TrMtx(baseAr)));
            }
        }
        return result;
    }

    public static List<Param> multiVaryTrans(Param base, int numberOfIncr) {
        ArrayList<Param> result = new ArrayList<Param>();
        double total1 = base.transMtx.p(0, 0) + base.transMtx.p(0, 1);
        double total2 = base.transMtx.p(1, 0) + base.transMtx.p(1, 1);
        double incr1 = total1 / (double)numberOfIncr;
        double incr2 = total2 / (double)numberOfIncr;
        for (double value1 = 0.0; value1 <= total1; value1 += incr1) {
            for (double value2 = 0.0; value2 <= total2; value2 += incr2) {
                double[][] baseAr = base.transMtx.arrayCopy();
                baseAr[0][0] = value1;
                baseAr[0][1] = total1 - value1;
                baseAr[1][0] = value2;
                baseAr[1][1] = total2 - value2;
                result.add(new Param(base.initVec, new TrMtx(baseAr), base.emiMtx));
            }
        }
        return result;
    }
}

