/*
 * Decompiled with CFR 0.152.
 */
package nuts.math;

import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import java.util.Random;
import nuts.util.MathUtils;

public class TrMtx {
    public static boolean doChecks = true;
    private final double[][] trans;
    private final int nSrcStates;
    private final int nDestState;

    public TrMtx(double[][] trans) {
        this.nSrcStates = trans.length;
        this.nDestState = trans[0].length;
        this.trans = new double[this.nSrcStates][this.nDestState];
        for (int s = 0; s < this.nSrcStates; ++s) {
            this.trans[s] = (double[])trans[s].clone();
        }
        if (doChecks && !this.valid(trans)) {
            throw new RuntimeException("Invalid trans mtx:\n" + this.toString());
        }
    }

    public double p(int src, int dest) {
        return this.trans[src][dest];
    }

    public int nextState(int currentState, Random rand) {
        return SampleUtils.sampleMultinomial(rand, this.trans[currentState]);
    }

    public double[][] arrayCopy() {
        return MathUtils.clone(this.trans);
    }

    public int nSrcStates() {
        return this.nSrcStates;
    }

    public int nDestStates() {
        return this.nDestState;
    }

    private boolean valid(double[][] trans2) {
        for (int s = 0; s < this.nSrcStates; ++s) {
            if (this.trans[s].length != this.nDestState) {
                throw new RuntimeException("Trans mtx should be rectangular");
            }
            if (MathUtils.isProb(this.trans[s])) continue;
            return false;
        }
        return true;
    }

    public String toString() {
        return MathUtils.toString(this.trans);
    }

    public static TrMtx uniRandTrMtx(Random rand, int nSrc, int nDest) {
        double[][] prs = new double[nSrc][nDest];
        for (int s = 0; s < nSrc; ++s) {
            for (int d = 0; d < nDest; ++d) {
                prs[s][d] = rand.nextDouble();
            }
            NumUtils.normalize(prs[s]);
        }
        return new TrMtx(prs);
    }

    public static TrMtx uniTrMtx(int nSrc, int nDest) {
        double[][] prs = new double[nSrc][nDest];
        double pr = 1.0 / (double)nDest;
        for (int s = 0; s < nSrc; ++s) {
            for (int d = 0; d < nDest; ++d) {
                prs[s][d] = pr;
            }
        }
        return new TrMtx(prs);
    }

    public static TrMtx cyclicTrMtx(int nStates, double epsilon) {
        double[][] prs = new double[nStates][nStates];
        for (int s = 0; s < nStates; ++s) {
            for (int d = 0; d < nStates; ++d) {
                prs[s][d] = (s + 1) % nStates == d ? 1.0 - epsilon : epsilon / ((double)nStates - 1.0);
            }
        }
        return new TrMtx(prs);
    }

    public static PrVec uniRandPrVec(Random rand, int nDest) {
        return new PrVec(TrMtx.uniRandTrMtx((Random)rand, (int)1, (int)nDest).trans);
    }

    public static PrVec uniPrVec(int nDest) {
        return new PrVec(TrMtx.uniTrMtx((int)1, (int)nDest).trans);
    }

    public static void main(String[] args) {
        double[][] mtx = new double[][]{{0.1, 0.9}, {0.5, 0.5}};
        TrMtx trMtx = new TrMtx(mtx);
        System.out.println(trMtx.toString());
        mtx[0][0] = 666.0;
        System.out.println(trMtx.toString());
    }

    public static class PrVec
    extends TrMtx {
        public PrVec(double[][] trans) {
            super(trans);
        }

        public PrVec(double[] trans) {
            super(MathUtils.convert(trans));
        }

        public double p(int s) {
            return super.p(0, s);
        }

        public int nextState(Random rand) {
            return this.nextState(0, rand);
        }
    }
}

