/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml.tests;

import conifer.ml.CTMCExpFam;
import conifer.ml.ExpectedStatistics;
import conifer.ml.extractors.IdentityExtractor;
import fig.basic.LogInfo;
import fig.basic.Pair;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import ma.RateMatrixLoader;
import nuts.math.HashGraph;
import nuts.math.RateMtxUtils;
import nuts.math.SemiGraph;
import nuts.util.CollUtils;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import org.junit.Assert;
import org.junit.Test;
import pty.learn.CTMCExpectations;

public class TestGradient {
    @Test
    public void testAnalyticGradient() {
        for (boolean value : new boolean[]{true, false}) {
            Random rand = new Random(1L);
            final Indexer<Character> indexer = RateMatrixLoader.rnaIndexer();
            HashSet<Character> states = CollUtils.set(RateMatrixLoader.rnaIndexer().objects());
            SemiGraph<Character> sg = new SemiGraph<Character>(){

                @Override
                public boolean hasSemiEdge(Character _one, Character _two) {
                    char two;
                    char one = Character.toUpperCase(_one.charValue());
                    if (one == (two = Character.toUpperCase(_two.charValue()))) {
                        return false;
                    }
                    return (one != 'A' || two != 'C') && (one != 'C' || two != 'A');
                }

                @Override
                public Set<Character> vertexSet() {
                    return CollUtils.set(indexer.objects());
                }
            };
            CTMCExpFam<Character> frrm = new CTMCExpFam<Character>(new HashGraph<Character>(sg), indexer, value);
            Set c1 = Collections.singleton(new IdentityExtractor());
            frrm.extractUnivariateFeatures(c1);
            frrm.extractReversibleBivariateFeatures(c1);
            double[][] trueQ = RateMatrixLoader.k2p();
            trueQ[indexer.o2i((Character)Character.valueOf((char)'C'))][indexer.o2i((Character)Character.valueOf((char)'A'))] = 0.0;
            trueQ[indexer.o2i((Character)Character.valueOf((char)'A'))][indexer.o2i((Character)Character.valueOf((char)'C'))] = 0.0;
            for (int i = 0; i < 4; ++i) {
                trueQ[i][i] = 0.0;
            }
            RateMtxUtils.fillRateMatrixDiagonalEntries(trueQ);
            ExpectedStatistics<Character> stat = new ExpectedStatistics<Character>(frrm);
            LogInfo.track("Generating data");
            for (int i = 0; i < 10; ++i) {
                List<Pair<Integer, Double>> datum = CTMCExpectations.simulate(10.0, rand, trueQ);
                LogInfo.logs(datum);
                stat.addInitialAndFullyObservedPathStatistics(datum);
            }
            LogInfo.logsForce("nSeries = " + stat.nSeries());
            LogInfo.logsForce("totalTime = " + stat.totalTime());
            LogInfo.logsForce("stats =\n" + stat.toString());
            LogInfo.end_track();
            final CTMCExpFam.ExpectedCompleteReversibleObjective obj = frrm.getExpectedCompleteReversibleObjective(1.0, stat);
            final int dim = obj.dimension();
            final double[] direction = new double[dim];
            direction[0] = 1.0;
            for (int i = 0; i < dim; ++i) {
                direction[i] = rand.nextDouble();
            }
            final double[] point = new double[dim];
            double[] gradient = obj.derivativeAt(point);
            double analytic = MathUtils.dot(gradient, direction);
            LogInfo.logsForce("analytic = " + analytic);
            MathUtils.FPlusDelta fpd = new MathUtils.FPlusDelta(){

                @Override
                public double logfd(double delta) {
                    double[] fpd = new double[dim];
                    for (int i = 0; i < dim; ++i) {
                        fpd[i] = point[i] + delta * direction[i];
                    }
                    return obj.valueAt(fpd);
                }
            };
            double valueAtPoint = obj.valueAt(point);
            double approx = 0.0;
            for (double h = 0.1; h > 1.0E-6; h /= 10.0) {
                approx = (fpd.logfd(h) - valueAtPoint) / h;
                LogInfo.logsForce("approx(h=" + h + ") = " + approx);
            }
            LogInfo.logs("approx=" + approx);
            LogInfo.logs("analytic=" + analytic);
            Assert.assertEquals((double)approx, (double)analytic, (double)MathUtils.threshold);
        }
    }
}

