/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import ev.io.PreprocessGutellData;
import ev.poi.processors.TreeDistancesProcessor;
import fig.basic.UnorderedPair;
import goblin.Taxon;
import hmm.Param;
import hmm.ParamUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import ma.MSAPoset;
import ma.SequenceType;
import nuts.math.TreeSumProd;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.MathUtils;
import org.junit.Test;
import pty.io.Dataset;
import pty.smc.LazyPCS;
import pty.smc.LazyParticleFilter;
import pty.smc.LazyPriorPrior;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.PriorPriorKernel;
import pty.smc.models.CTMC;
import pty.smc.test.TestParticleNormalization;

public class LazyParticleFilterTest {
    public static void main(String[] args) {
        new LazyParticleFilterTest().testTree();
    }

    @Test
    public void testTree() {
        ParticleFilter<LazyPCS> pc = new ParticleFilter<LazyPCS>();
        pc.N = 1000;
        pc.nThreads = 1;
        pc.resampleLastRound = true;
        pc.verbose = false;
        Random rand = new Random(1L);
        MSAPoset msa = PreprocessGutellData.randomDataSet(new File("/Users/bouchard/Documents/data/gutell/16S.3.alnfasta"), 1, 10, rand).get(0);
        Dataset dataset = Dataset.DatasetUtils.fromAlignment(msa, SequenceType.RNA);
        CTMC.SimpleCTMC ctmc = CTMC.SimpleCTMC.dnaCTMC(dataset.nSites());
        PartialCoalescentState init = PartialCoalescentState.initFastState(false, dataset, ctmc);
        LazyPriorPrior ppk = new LazyPriorPrior(init);
        ParticleFilter.DoNothingProcessor voidPro = new ParticleFilter.DoNothingProcessor();
        pc.sample(ppk, voidPro);
        System.out.println(pc.estimateNormalizer());
        Double newApprox = null;
        Counter<UnorderedPair<Taxon, Taxon>> ref = null;
        for (int i = 0; i < 4; ++i) {
            PriorPriorKernel pk2 = new PriorPriorKernel(init);
            LazyParticleFilter.ParticleFilterOptions options = new LazyParticleFilter.ParticleFilterOptions();
            options.nParticles = pc.N;
            options.rand = new Random(1L);
            options.resampleLastRound = true;
            options.parallelizeFinalParticleProcessing = true;
            options.nThreads = 10 * i + 1;
            options.verbose = false;
            options.check();
            LazyParticleFilter<PartialCoalescentState> lpf = new LazyParticleFilter<PartialCoalescentState>(pk2, options);
            TreeDistancesProcessor tdp = new TreeDistancesProcessor();
            double zHat = lpf.sample(tdp);
            System.out.println("Approx2=" + zHat);
            if (ref == null) {
                ref = tdp.getMeanDistances();
            }
            Counter<UnorderedPair<Taxon, Taxon>> cur = tdp.getMeanDistances();
            for (UnorderedPair p : CollUtils.union(ref.keySet(), cur.keySet())) {
                MathUtils.checkClose(cur.getCount(p), ref.getCount(p));
            }
            if (newApprox == null) {
                newApprox = zHat;
            }
            if (newApprox == zHat) continue;
            throw new RuntimeException();
        }
    }

    @Test
    public void testHMM() {
        Random rand = new Random(1L);
        Param p = ParamUtils.randomUniParam(rand, 10, 10);
        System.out.println("Param:\n" + p);
        int[] obs = new int[20];
        System.out.println("Observation:" + Arrays.toString(obs));
        TestParticleNormalization.HMMParticleKernel pk = new TestParticleNormalization.HMMParticleKernel(p, obs);
        ParticleFilter.DoNothingProcessor voidPro = new ParticleFilter.DoNothingProcessor();
        ParticleFilter<TestParticleNormalization.HMMPState> pf = new ParticleFilter<TestParticleNormalization.HMMPState>();
        pf.N = 100;
        pf.resamplingStrategy = ParticleFilter.ResamplingStrategy.ALWAYS;
        pf.resampleLastRound = false;
        pf.sample(pk, voidPro);
        double approxRef = pf.estimateNormalizer();
        System.out.println("Approx=" + approxRef);
        ArrayList<Integer> obsList = new ArrayList<Integer>();
        for (int o : obs) {
            obsList.add(o);
        }
        TreeSumProd.HmmAdaptor adapt = new TreeSumProd.HmmAdaptor(p, obsList);
        TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(adapt);
        double exact = tsp.logZ();
        System.out.println("Exact=" + exact);
        Double newApprox = null;
        for (int i = 0; i < 10; ++i) {
            LazyParticleFilter.Eager2LazyAdaptor<TestParticleNormalization.HMMPState> adaptor = new LazyParticleFilter.Eager2LazyAdaptor<TestParticleNormalization.HMMPState>(pk);
            LazyParticleFilter.ParticleFilterOptions options = new LazyParticleFilter.ParticleFilterOptions();
            options.nParticles = pf.N;
            options.resampleLastRound = false;
            options.nThreads = 10 * i + 1;
            options.verbose = true;
            options.check();
            LazyParticleFilter<TestParticleNormalization.HMMPState> lpf = new LazyParticleFilter<TestParticleNormalization.HMMPState>(adaptor, options);
            double zHat = lpf.sample(new ParticleFilter.ParticleProcessor[0]);
            System.out.println("Approx2=" + zHat);
            if (newApprox == null) {
                newApprox = zHat;
            }
            if (newApprox != zHat) {
                throw new RuntimeException();
            }
            MathUtils.threshold = 1.0;
            MathUtils.checkClose(newApprox, approxRef);
            MathUtils.checkClose(newApprox, exact);
        }
    }
}

