/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml.tests;

import conifer.ml.CTMCExpFam;
import conifer.ml.ExpectedStatistics;
import conifer.ml.OptimizationOptions;
import conifer.ml.data.EndPointDataset;
import conifer.ml.extractors.IdentityExtractor;
import fig.basic.LogInfo;
import fig.basic.Pair;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import ma.RateMatrixLoader;
import nuts.math.RateMtxUtils;
import nuts.tui.Table;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import pty.learn.CTMCExpectations;

public class TestGeneratedData {
    public static void main(String[] args) {
        int emIters = 10;
        double bl = 1.5;
        int nSeries = 100000;
        OptimizationOptions optimizationOptions = new OptimizationOptions();
        Random rand = new Random(1L);
        Indexer<Character> indexer = RateMatrixLoader.rnaIndexer();
        CTMCExpFam frrm = null;
        Set c1 = Collections.singleton(new IdentityExtractor());
        frrm.extractUnivariateFeatures(c1);
        frrm.extractReversibleBivariateFeatures(c1);
        double[] trueW = TestGeneratedData.randomWeigths(frrm.nFeatures(), rand);
        double[][] trueRateMtx = frrm.reversibleModelWithParameters(trueW).getRateMatrix();
        LogInfo.logsForce("trueMtx = \n" + Table.toString(trueRateMtx));
        EndPointDataset<Character> data = new EndPointDataset<Character>();
        LogInfo.track("Generating data");
        for (int i = 0; i < nSeries; ++i) {
            LogInfo.logs("datapoint " + data.addGeneratedData(bl, trueRateMtx, rand, indexer));
        }
        LogInfo.end_track();
        double[][] currentMtx = frrm.reversibleModelWithParameters(TestGeneratedData.randomWeigths(frrm.nFeatures(), rand)).getRateMatrix();
        LogInfo.logsForce("initMtx = \n" + Table.toString(currentMtx));
        for (int emIter = 0; emIter < 10; ++emIter) {
            ExpectedStatistics<Character> currentStats = new ExpectedStatistics<Character>(frrm);
            currentStats.addFromMarginalizedData(data, currentMtx);
            CTMCExpFam.LearnedReversibleModel currentModel = frrm.fitReversibleModel(optimizationOptions, currentStats, null);
            currentMtx = currentModel.getRateMatrix();
            TestGeneratedData.checkReversible(currentMtx, currentModel.pi);
            LogInfo.logsForce("mtxAfterIter(" + emIter + ") = \n" + Table.toString(currentMtx));
        }
    }

    private static void checkReversible(double[][] currentMtx, double[] statio) {
        double[][] trans = RateMtxUtils.marginalTransitionMtx(currentMtx, 1.0);
        int nStates = statio.length;
        for (int s1 = 0; s1 < nStates; ++s1) {
            for (int s2 = 0; s2 < nStates; ++s2) {
                double m1 = statio[s1] * trans[s1][s2];
                double m2 = statio[s2] * trans[s2][s1];
                if (MathUtils.close(m1, m2)) continue;
                throw new RuntimeException();
            }
        }
        System.out.println("Reversibility checked");
    }

    private static double[] randomWeigths(int nFeatures, Random rand) {
        double[] result = new double[nFeatures];
        for (int i = 0; i < nFeatures; ++i) {
            result[i] = rand.nextDouble() / 4.0;
        }
        return result;
    }

    public static void easyTest() {
        OptimizationOptions optimizationOptions = new OptimizationOptions();
        Random rand = new Random(1L);
        Indexer<Character> indexer = RateMatrixLoader.rnaIndexer();
        CTMCExpFam frrm = null;
        Set c1 = Collections.singleton(new IdentityExtractor());
        frrm.extractUnivariateFeatures(c1);
        frrm.extractReversibleBivariateFeatures(c1);
        double[] trueW = new double[frrm.nFeatures()];
        for (int i = 0; i < trueW.length; ++i) {
            trueW[i] = rand.nextDouble() / 4.0;
        }
        optimizationOptions.regularizationStrength = 1.0;
        double[][] trueRateMtx = frrm.reversibleModelWithParameters(trueW).getRateMatrix();
        LogInfo.logsForce("trueMtx = \n" + Table.toString(trueRateMtx));
        ExpectedStatistics stat = new ExpectedStatistics(frrm);
        LogInfo.track("Generating data");
        for (int i = 0; i < 1000; ++i) {
            List<Pair<Integer, Double>> datum = CTMCExpectations.simulate(1.5, rand, trueRateMtx);
            LogInfo.logsForce(datum);
            stat.addInitialAndFullyObservedPathStatistics(datum);
        }
        LogInfo.logsForce("nSeries = " + stat.nSeries());
        LogInfo.logsForce("totalTime = " + stat.totalTime());
        LogInfo.logsForce("stats =\n" + stat.toString());
        LogInfo.end_track();
        boolean emIters = true;
        double[] warmStart = trueW;
        double[][] currentMtx = RateMatrixLoader.k2p(2.0);
        LogInfo.logsForce("initMtx = \n" + Table.toString(trueRateMtx));
        for (int emIter = 0; emIter < 1; ++emIter) {
            CTMCExpFam.LearnedReversibleModel currentModel = frrm.fitReversibleModel(optimizationOptions, stat, warmStart);
            currentMtx = currentModel.getRateMatrix();
            LogInfo.logsForce("mtxAfterIter(" + emIter + ") = \n" + Table.toString(currentMtx));
        }
    }
}

