/*
 * Decompiled with CFR 0.152.
 */
package conifer.tests;

import conifer.ml.CTMCExpFam;
import conifer.ml.main.CTMCExpFamLoader;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.OptionSet;
import fig.prob.SampleUtils;
import java.util.Random;
import nuts.io.IO;
import nuts.math.Sampling;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;

public class GenerateData
implements Runnable {
    @Option(required=true)
    public int numSites;
    @OptionSet(name="model")
    public CTMCExpFamLoader modelLoader = new CTMCExpFamLoader();
    @Option(required=true)
    public Random rand = new Random(1L);
    @Option(required=true)
    public double T;

    public void generateDatum() {
        double time;
        CTMCExpFam<String> model = this.modelLoader.load();
        int numParams = model.nFeatures();
        double[] zero = new double[numParams];
        double[][] I = new double[numParams][numParams];
        for (int i = 0; i < numParams; ++i) {
            I[i][i] = 1.0;
        }
        MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(zero, I);
        double[] w = mvn.sample();
        LogInfo.logs("trueWeights= \n" + this.getContents(w));
        CTMCExpFam.LearnedReversibleModel trueModel = model.reversibleModelWithParameters(w);
        LogInfo.logs(trueModel.rateMatrixString());
        double[][] trueMatrix = trueModel.getRateMatrix();
        double[] pi = trueModel.pi;
        int currState = SampleUtils.sampleMultinomial(this.rand, pi);
        for (double timeElapsed = 0.0; timeElapsed < this.T; timeElapsed += time) {
            String currStateString = (String)model.stateIndexer.i2o(currState);
            LogInfo.logs("current state = " + currStateString);
            double rate = trueMatrix[currState][currState];
            time = Sampling.sampleExponential(this.rand, -1.0 / rate);
            LogInfo.logs("rate=" + rate + ", waitTime=" + time);
            double[] probs = this.normalize(trueMatrix[currState], currState);
            currState = SampleUtils.sampleMultinomial(this.rand, probs);
        }
    }

    public double[] normalize(double[] row, int state) {
        double[] probs = new double[row.length];
        double norm = -row[state];
        probs[state] = 0.0;
        for (int i = 0; i < probs.length; ++i) {
            probs[i] = row[i] / norm;
        }
        probs[state] = 0.0;
        return probs;
    }

    public String getContents(double[] v) {
        StringBuilder sb = new StringBuilder();
        sb.append("[ ");
        for (int i = 0; i < v.length; ++i) {
            sb.append(v[i] + " ");
        }
        sb.append("]");
        return sb.toString();
    }

    public static void main(String[] args) {
        LogInfo.logs("Hello World!");
        IO.run(args, new GenerateData());
    }

    @Override
    public void run() {
        this.generateDatum();
    }
}

