/*
 * Decompiled with CFR 0.152.
 */
package ctmc;

import conifer.ml.CTMCExpFam;
import conifer.ml.ExpectedStatistics;
import conifer.ml.OptimizationOptions;
import conifer.ml.main.CTMCExpFamLoader;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.OptionSet;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.tui.Table;
import nuts.util.CollUtils;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.Precision;

public class TestModel
implements Runnable {
    @OptionSet(name="model")
    public CTMCExpFamLoader modelLoader = new CTMCExpFamLoader();
    @OptionSet(name="optimization")
    public OptimizationOptions optimizationOptions = new OptimizationOptions();
    @Option(required=true)
    public Random rand = new Random(1L);
    @Option(required=true)
    public int numSites;
    @Option(required=true)
    public double T;

    @Override
    public void run() {
        CTMCExpFam<String> model = this.modelLoader.load();
        int p = model.nFeatures();
        double[] mu = new double[p];
        double[][] I = new double[p][p];
        for (int i = 0; i < p; ++i) {
            I[i][i] = 1.0;
        }
        MultivariateNormalDistribution mvn = new MultivariateNormalDistribution((RandomGenerator)new CustomRandomGenerator(this.rand.nextLong()), mu, I);
        double[] w = mvn.sample();
        CTMCExpFam.LearnedReversibleModel learnedModel = model.reversibleModelWithParameters(w);
        ExpectedStatistics<String> stat = new ExpectedStatistics<String>(model);
        double[][] trueMatrix = learnedModel.getRateMatrix();
        double[][] transProbs = this.getTransitionProbs(trueMatrix);
        LogInfo.track("Generating data");
        for (int n = 0; n < this.numSites; ++n) {
            List<Pair<Integer, Double>> datum = this.generateData(trueMatrix, transProbs, learnedModel.pi);
            stat.addInitialAndFullyObservedPathStatistics(datum);
        }
        LogInfo.logsForce("nSeries = " + stat.nSeries());
        LogInfo.logsForce("totalTime = " + stat.totalTime());
        LogInfo.logsForce("stats =\n" + stat.toString());
        LogInfo.end_track();
        CTMCExpFam.LearnedReversibleModel fitReversibleModel = model.fitReversibleModel(this.optimizationOptions, stat, null);
        LogInfo.logs("# of features=" + p);
        LogInfo.logsForce("trueMtx = \n" + Table.toString(learnedModel.getRateMatrix()));
        LogInfo.logsForce("fittedMtx = \n" + Table.toString(fitReversibleModel.getRateMatrix()));
        LogInfo.logs("trueStationary = \n" + TestModel.vectorToString(learnedModel.pi));
        LogInfo.logs("fittedStationary = \n" + TestModel.vectorToString(fitReversibleModel.pi));
        LogInfo.logs("trueWeights = \n" + TestModel.vectorToString(w));
        LogInfo.logs("fittedWeights = \n" + TestModel.vectorToString(fitReversibleModel.weights));
    }

    public static String vectorToString(double[] v) {
        StringBuilder sb = new StringBuilder();
        sb.append("[ ");
        for (int i = 0; i < v.length; ++i) {
            sb.append(Precision.round((double)v[i], (int)5) + " ");
        }
        sb.append("]");
        return sb.toString();
    }

    private List<Pair<Integer, Double>> generateData(double[][] trueMatrix, double[][] transProbs, double[] pi) {
        int currState = SampleUtils.sampleMultinomial(this.rand, pi);
        double t = 0.0;
        ArrayList<Pair<Integer, Double>> datum = CollUtils.list();
        while (t < this.T) {
            double holdingTime = Sampling.sampleExponential(this.rand, -1.0 / trueMatrix[currState][currState]);
            if ((t += holdingTime) > this.T) {
                double temp = t - holdingTime;
                holdingTime = this.T - temp;
            }
            datum.add(Pair.makePair(currState, holdingTime));
            double[] transitionProbs = transProbs[currState];
            currState = SampleUtils.sampleMultinomial(this.rand, transitionProbs);
        }
        return datum;
    }

    private double[][] getTransitionProbs(double[][] trueMatrix) {
        int numStates = trueMatrix[0].length;
        double[][] transProbs = new double[numStates][numStates];
        for (int i = 0; i < numStates; ++i) {
            double norm = -trueMatrix[i][i];
            for (int j = 0; j < numStates; ++j) {
                if (i == j) continue;
                transProbs[i][j] = trueMatrix[i][j] / norm;
            }
        }
        return transProbs;
    }

    public static void main(String[] args) {
        IO.run(args, new TestModel());
    }

    public static class CustomRandomGenerator
    implements RandomGenerator {
        private Random random;

        public CustomRandomGenerator(long seed) {
            this.random = new Random(seed);
        }

        public boolean nextBoolean() {
            return this.random.nextBoolean();
        }

        public void nextBytes(byte[] bytes) {
            this.random.nextBytes(bytes);
        }

        public double nextDouble() {
            return this.random.nextDouble();
        }

        public float nextFloat() {
            return this.random.nextFloat();
        }

        public double nextGaussian() {
            return this.random.nextGaussian();
        }

        public int nextInt() {
            return this.random.nextInt();
        }

        public int nextInt(int n) {
            return this.random.nextInt(n);
        }

        public long nextLong() {
            return this.random.nextLong();
        }

        public void setSeed(int seed) {
            this.random.setSeed(seed);
        }

        public void setSeed(int[] arg0) {
        }

        public void setSeed(long seed) {
            this.random.setSeed(seed);
        }
    }
}

