/*
 * Decompiled with CFR 0.152.
 */
package fenchel.tests;

import conifer.apps.AbstractPhyloApp;
import conifer.particle.PhyloParticle;
import fenchel.algo.FactorGraphSumProduct;
import fenchel.factor.FactorUtils;
import fenchel.factor.multisites.MSFactorBuilder;
import fenchel.factor.multisites.MSFactorGraph;
import fenchel.factor.multisites.MSScaledFactorBuilder;
import fenchel.factor.multisites.MSUnaryFactor;
import fenchel.tests.utils.NaiveMultiSitesFactorBuilder;
import fig.basic.Option;
import goblin.Taxon;
import hmm.Param;
import hmm.ParamUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import nuts.math.GMFct;
import nuts.math.TreeSumProd;
import nuts.util.MathUtils;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import org.junit.Assert;
import org.junit.Test;
import pty.mcmc.UnrootedTreeState;

public class CompareToGMTest {
    public static <N> void makeMultiSite(GMFct<N> gm, MSFactorGraph<N> factorGraph, int n, Set<N> nodes) {
        FactorUtils.gmEdgePotentials2FactorGraph(gm, factorGraph);
        FactorUtils._gmNodePotentials2FactorGraph(gm, factorGraph, n, nodes);
    }

    public static <N extends Comparable<N>> void excludeRandomEdges(GMFct<N> gm, MSFactorGraph<N> factorGraph, int n, Random rand, double p) {
        HashSet<Comparable> nodes = new HashSet<Comparable>();
        ArrayList<N> list = new ArrayList<N>(gm.graph().vertexSet());
        Collections.sort(list);
        for (Comparable node : list) {
            if (node.equals(new Taxon("leaf_17"))) {
                System.out.println("...");
            }
            if (!(rand.nextDouble() < p) && gm.graph().nbrs(node).size() != 1) continue;
            nodes.add(node);
        }
        FactorUtils.gmEdgePotentials2FactorGraph(gm, factorGraph);
        FactorUtils._gmNodePotentials2FactorGraph(gm, factorGraph, n, nodes);
    }

    public static <S extends Comparable<S>> void test(GMFct<S> graphicalModel, MSFactorBuilder builder) {
        TreeSumProd<S> tsp = new TreeSumProd<S>(graphicalModel);
        double exact = tsp.logZ();
        System.out.println("LogZ, method 1:" + exact);
        MSFactorGraph fg = new MSFactorGraph(builder);
        FactorUtils.gm2FactorGraph(graphicalModel, fg);
        FactorGraphSumProduct fgsp = new FactorGraphSumProduct();
        fgsp.init(fg);
        System.out.println("LogZ, method 2:" + fgsp.logZ());
        for (Comparable node : graphicalModel.graph().vertexSet()) {
            MSUnaryFactor factor = (MSUnaryFactor)fgsp.moment(node);
            double[][] values = factor.normalizedValues();
            Assert.assertEquals((double)exact, (double)factor.logNorm(), (double)MathUtils.threshold);
            for (int s = 0; s < graphicalModel.nStates(node); ++s) {
                Assert.assertEquals((double)tsp.vertexPosterior(node, s), (double)values[0][s], (double)MathUtils.threshold);
            }
        }
    }

    public static <S extends Comparable<S>> double runMultiSites(GMFct<S> graphicalModel, MSFactorBuilder builder, int nSites) {
        MSFactorGraph fg = new MSFactorGraph(builder);
        CompareToGMTest.makeMultiSite(graphicalModel, fg, nSites, graphicalModel.graph().vertexSet());
        FactorGraphSumProduct fgsp = new FactorGraphSumProduct();
        fgsp.init(fg);
        return fgsp.logZ();
    }

    public static <S extends Comparable<S>> void testMultiSites(GMFct<S> graphicalModel, MSFactorBuilder builder, int nSites) {
        TreeSumProd<S> tsp = new TreeSumProd<S>(graphicalModel);
        double exact = tsp.logZ() * (double)nSites;
        System.out.println("LogZ, method 1:" + exact);
        double logZ = CompareToGMTest.runMultiSites(graphicalModel, builder, nSites);
        System.out.println("LogZ, method 2:" + logZ);
        Assert.assertEquals((double)exact, (double)logZ, (double)MathUtils.threshold);
    }

    public void testPhylogeny(MSFactorBuilder builder) {
        for (int s = 0; s < 2; ++s) {
            GMFct<Taxon> graphicalModel = CompareToGMTest.phylogeneticGraphicalModel(20, s);
            CompareToGMTest.test(graphicalModel, builder);
            CompareToGMTest.testMultiSites(graphicalModel, builder, 3);
        }
    }

    @Test
    public void testHMM() {
        this.testHMM(new MSScaledFactorBuilder());
    }

    @Test
    public void testPhylogeny() {
        this.testPhylogeny(new MSScaledFactorBuilder());
    }

    @Test
    public void testSparseEdges() {
        GMFct<Taxon> graphicalModel = CompareToGMTest.phylogeneticGraphicalModel(20, 1);
        MSScaledFactorBuilder builder1 = new MSScaledFactorBuilder();
        NaiveMultiSitesFactorBuilder builder2 = new NaiveMultiSitesFactorBuilder();
        for (int i = 0; i < 10; ++i) {
            MSFactorGraph fg1 = new MSFactorGraph(builder1);
            MSFactorGraph fg2 = new MSFactorGraph(builder2);
            Random rand = new Random(100 * (i + 1));
            CompareToGMTest.excludeRandomEdges(graphicalModel, fg1, 1, rand, 0.5);
            rand = new Random(100 * (i + 1));
            CompareToGMTest.excludeRandomEdges(graphicalModel, fg2, 1, rand, 0.5);
            FactorGraphSumProduct fgsp1 = new FactorGraphSumProduct();
            FactorGraphSumProduct fgsp2 = new FactorGraphSumProduct();
            fgsp1.init(fg1);
            fgsp2.init(fg2);
            System.out.println("LogZ, method 2:" + fgsp1.logZ());
            Assert.assertEquals((double)fgsp1.logZ(), (double)fgsp2.logZ(), (double)MathUtils.threshold);
        }
    }

    public void testHMM(MSFactorBuilder builder) {
        Random rand = new Random(1L);
        for (int rep = 0; rep < 2; ++rep) {
            Param p = ParamUtils.randomUniParam(rand, 2, 2);
            System.out.println("Param:\n" + p);
            for (int len = 1; len < 10; ++len) {
                int[] obs = new int[len];
                System.out.println("Observation:" + Arrays.toString(obs));
                ArrayList<Integer> obsList = new ArrayList<Integer>();
                for (int o : obs) {
                    obsList.add(o);
                }
                TreeSumProd.HmmAdaptor adapt = new TreeSumProd.HmmAdaptor(p, obsList);
                CompareToGMTest.test(adapt, builder);
                CompareToGMTest.testMultiSites(adapt, builder, 3);
            }
        }
    }

    public static long timeOldPhyloLikelihood(final int nTaxa, final int nSites) {
        final long[] o = new long[1];
        AbstractPhyloApp phyloApp = new AbstractPhyloApp(){

            @Override
            public void run() {
                long delta;
                this.allOptions.dataOptions.generatingTreeOptions.nTaxa = nTaxa;
                this.allOptions.dataOptions.generatingEvolutionaryOptions.nSites = nSites;
                this.allOptions.inferenceEvolutionaryOptions.nSites = nSites;
                PhyloParticle pk = this.getKernel().getInitial();
                long start = System.nanoTime();
                UnrootedTreeState.computeLogLikelihood(pk.getPhylogeny().getUnrooted(), ((UnrootedTreeState)pk).getLikelihoodModels());
                o[0] = delta = System.nanoTime() - start;
            }
        };
        phyloApp.run();
        return o[0];
    }

    public static GMFct<Taxon> phylogeneticGraphicalModel(int nTaxa, int siteIndex) {
        return null;
    }

    public static <N> void testRuntime(int nRepeats, Runnable run) {
        System.out.println("Burnin..");
        run.run();
        SummaryStatistics stat = new SummaryStatistics();
        for (int i = 0; i < nRepeats; ++i) {
            System.out.println("Instrumented run " + i);
            long before = System.nanoTime();
            run.run();
            long elapsed = System.nanoTime() - before;
            stat.addValue((double)elapsed);
            System.out.println("Current mean=" + stat.getMean() + ", SD=" + stat.getStandardDeviation());
        }
    }

    public static void main(String[] args) {
        int nTaxa = 100;
        int nSites = 50000;
        final GMFct<Taxon> graphicalModel = CompareToGMTest.phylogeneticGraphicalModel(100, 0);
        CompareToGMTest.testRuntime(10, new Runnable(){

            @Override
            public void run() {
                for (int i = 0; i < 50000; ++i) {
                    TreeSumProd tsp = new TreeSumProd(graphicalModel);
                    tsp.logZ();
                }
            }
        });
        System.out.println("New method");
        CompareToGMTest.testRuntime(10, new Runnable(){

            @Override
            public void run() {
                CompareToGMTest.runMultiSites(graphicalModel, new MSScaledFactorBuilder(), 50000);
            }
        });
    }

    public static class Main
    implements Runnable {
        @Option
        public ArrayList<String> anOption;
        @Option
        public Random rand = new Random(1L);

        @Override
        public void run() {
            System.out.println("option" + this.anOption.get(1));
        }
    }
}

