/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.Parallelizer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import nuts.lang.ArrayUtils;
import nuts.math.Sampling;
import nuts.maxent.SloppyMath;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.MathUtils;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;

public final class LazyParticleFilter<T> {
    private final LazyParticleKernel<T> kernel;
    private final ParticleFilterOptions options;
    private final Distribution<T> initialPopulation = new Distribution<T>(){

        @Override
        public List<T> sampleNTimes(Random r, int n) {
            Object initial = LazyParticleFilter.this.kernel.getInitial();
            ArrayList result = CollUtils.list();
            for (int i = 0; i < n; ++i) {
                result.add(initial);
            }
            return result;
        }
    };

    public LazyParticleFilter(LazyParticleKernel<T> kernel, ParticleFilterOptions options) {
        this.kernel = kernel;
        this.options = options;
    }

    private ParticlePopulation pruneIfNeeded(Random rand, ParticlePopulation result, int maxNUniqueParticles) {
        double currentExpectedSupportSize = result.expectedResampledSupportSize(this.options.nParticles);
        if (currentExpectedSupportSize > (double)maxNUniqueParticles) {
            ParticlePopulation resampled = null;
            double nResampled = result.nParticlesTokens();
            while (currentExpectedSupportSize > (double)maxNUniqueParticles) {
                resampled = result.prunePopulation(rand, 1 + (int)nResampled);
                currentExpectedSupportSize = resampled.expectedResampledSupportSize(this.options.nParticles);
                nResampled = this.options.populationShrinkFactor * nResampled;
            }
            this.logs("Current population pruned to " + (1 + (int)nResampled) + " (" + resampled.nParticlesTokens() + " unique particles)");
            result = resampled;
        }
        return result;
    }

    public Pair<ParticlePopulation, Double> extend(Random rand, final List<T> parentParticles) {
        int size = parentParticles.size();
        final long[] seeds = new long[size];
        for (int i = 0; i < seeds.length; ++i) {
            seeds[i] = rand.nextLong();
        }
        final double[] logWeights = new double[size];
        this.track("Extension.. (" + size + " particles over " + this.options.nThreads + " threads)");
        Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(this.options.nThreads);
        parallelizer.setPrimaryThread();
        parallelizer.process(CollUtils.ints(size), new Parallelizer.Processor<Integer>(){

            @Override
            public void process(Integer x, int _i, int _n, boolean log) {
                double currentW;
                Object parent = parentParticles.get(x);
                Random curRand = new Random(seeds[x]);
                logWeights[x.intValue()] = currentW = LazyParticleFilter.this.kernel.peekNext(curRand, parent);
            }
        });
        this.end_track();
        double logSum = Double.NEGATIVE_INFINITY;
        for (double n : logWeights) {
            logSum = SloppyMath.logAdd(logSum, n);
        }
        double[] unNorm = (double[])logWeights.clone();
        NumUtils.expNormalize(logWeights);
        Object[] currents = new Object[logWeights.length];
        ParticlePopulation newPop = new ParticlePopulation(logWeights, parentParticles.toArray(), seeds, currents, unNorm);
        return Pair.makePair(newPop, logSum);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public double sample(ParticleFilter.ParticleProcessor<T> ... particleProcessors) {
        this.track("Sampling PF");
        Random curRandom = null;
        LazyParticleFilter lazyParticleFilter = this;
        synchronized (lazyParticleFilter) {
            curRandom = new Random(this.options.rand.nextLong());
        }
        ParticlePopulation previousPopulation = this.initialPopulation;
        double zHat = 0.0;
        int T = this.kernel.nIterationsLeft(this.kernel.getInitial());
        ParticlePopulation finalPopulation = null;
        for (int t = 0; t < T; ++t) {
            this.track("Generation " + t + "/" + T);
            Pair<ParticlePopulation, Double> resultPair = this.extend(curRandom, previousPopulation.sampleNTimes(curRandom, this.options.nParticles));
            ParticlePopulation newPopulation = resultPair.getFirst();
            this.logs("maxNormalizedWeight=" + newPopulation.maxNormalizedWeight());
            newPopulation = this.pruneIfNeeded(curRandom, newPopulation, this.options.maxNUniqueParticles);
            this.logs("expectedResampledSupportSize=" + newPopulation.expectedResampledSupportSize(this.options.nParticles));
            zHat += resultPair.getSecond() - Math.log(this.options.nParticles);
            previousPopulation = newPopulation;
            if (t == T - 1) {
                finalPopulation = newPopulation;
            }
            this.end_track();
        }
        this.end_track();
        if (this.options.resampleLastRound) {
            finalPopulation = this.pruneIfNeeded(curRandom, finalPopulation, this.options.finalMaxNUniqueParticles);
        }
        finalPopulation.process(particleProcessors);
        return zHat;
    }

    private void logs(Object s) {
        if (this.options.verbose) {
            LogInfo.logs(s);
        }
    }

    private void track(Object s) {
        if (this.options.verbose) {
            LogInfo.track(s);
        }
    }

    private void end_track() {
        if (this.options.verbose) {
            LogInfo.end_track();
        }
    }

    public class ParticlePopulation
    implements Distribution<T> {
        private final double[] normalizedWeights;
        private final Object[] parents;
        private final Object[] currents;
        private final long[] randomSeeds;
        private final double[] checks;

        private ParticlePopulation(double[] normalizedWeights, Object[] parents, long[] randomSeeds, Object[] currents, double[] checks) {
            int size = normalizedWeights.length;
            if (parents.length != size || randomSeeds.length != size || currents.length != size) {
                throw new RuntimeException();
            }
            this.normalizedWeights = normalizedWeights;
            MathUtils.checkIsProb(normalizedWeights);
            this.checks = checks;
            this.parents = parents;
            this.currents = currents;
            this.randomSeeds = randomSeeds;
        }

        public double maxNormalizedWeight() {
            return ArrayUtils.max(this.normalizedWeights);
        }

        public ParticlePopulation prunePopulation(Random r, int nResampling) {
            Counter<Integer> indices = Sampling.efficientMultinomialSampling(r, this.normalizedWeights, nResampling);
            double[] normalizedWeights = new double[indices.keySet().size()];
            Object[] parents = new Object[indices.keySet().size()];
            double[] checks = new double[indices.keySet().size()];
            long[] randomSeeds = new long[indices.keySet().size()];
            Object[] currents = new Object[indices.keySet().size()];
            int i = 0;
            for (int index : indices) {
                normalizedWeights[i] = indices.getCount(index);
                parents[i] = this.parents[index];
                randomSeeds[i] = this.randomSeeds[index];
                currents[i] = this.currents[index];
                checks[i] = this.checks[index];
                if (parents[i] == null && currents[i] == null) {
                    throw new RuntimeException();
                }
                ++i;
            }
            NumUtils.normalize(normalizedWeights);
            return new ParticlePopulation(normalizedWeights, parents, randomSeeds, currents, checks);
        }

        public void process(final ParticleFilter.ParticleProcessor<T>[] particleProcessors) {
            if (particleProcessors.length == 0) {
                return;
            }
            HashSet<Integer> allIndices = CollUtils.set(CollUtils.ints(this.nParticlesTokens()));
            LazyParticleFilter.this.track("Processing particles");
            while (!allIndices.isEmpty()) {
                HashSet<Integer> subset = CollUtils.set();
                int i = 0;
                Iterator iterator = allIndices.iterator();
                while (iterator.hasNext()) {
                    int index = (Integer)iterator.next();
                    subset.add(index);
                    if (i++ <= ((LazyParticleFilter)LazyParticleFilter.this).options.processBatchSize) continue;
                    break;
                }
                allIndices.removeAll(subset);
                this.unLazyParticles(subset);
                int nThreads = ((LazyParticleFilter)LazyParticleFilter.this).options.parallelizeFinalParticleProcessing ? ((LazyParticleFilter)LazyParticleFilter.this).options.nThreads : 1;
                Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(nThreads);
                parallelizer.setPrimaryThread();
                parallelizer.process(CollUtils.list(subset), new Parallelizer.Processor<Integer>(){

                    @Override
                    public void process(Integer x, int _i, int _n, boolean log) {
                        Object state = ParticlePopulation.this.currents[x];
                        double w = ParticlePopulation.this.normalizedWeights[x];
                        if (log) {
                            LazyParticleFilter.this.logs("" + _i + "/" + _n);
                        }
                        for (ParticleFilter.ParticleProcessor pro : particleProcessors) {
                            pro.process(state, w);
                        }
                        ((ParticlePopulation)ParticlePopulation.this).currents[x.intValue()] = null;
                    }
                });
            }
            LazyParticleFilter.this.end_track();
        }

        private void unLazyParticles(Set<Integer> indices) {
            LazyParticleFilter.this.track("Unlazy...");
            ArrayList listOfUniqueIndices = CollUtils.list();
            for (Integer index : indices) {
                if (this.currents[index] != null) continue;
                listOfUniqueIndices.add(index);
            }
            Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(((LazyParticleFilter)LazyParticleFilter.this).options.nThreads);
            parallelizer.setPrimaryThread();
            parallelizer.process(listOfUniqueIndices, new Parallelizer.Processor<Integer>(){

                @Override
                public void process(Integer x, int _i, int _n, boolean log) {
                    Random rand = new Random(ParticlePopulation.this.randomSeeds[x]);
                    Object parent = ParticlePopulation.this.parents[x];
                    Pair<Object, Double> pair = LazyParticleFilter.this.kernel.next(rand, parent);
                    if (!MathUtils.close(pair.getSecond(), ParticlePopulation.this.checks[x])) {
                        throw new RuntimeException("" + pair.getSecond() + "!=" + ParticlePopulation.this.checks[x] + "\n" + pair.getFirst());
                    }
                    Object current = pair.getFirst();
                    ((ParticlePopulation)ParticlePopulation.this).parents[x.intValue()] = null;
                    ((ParticlePopulation)ParticlePopulation.this).randomSeeds[x.intValue()] = -1L;
                    ((ParticlePopulation)ParticlePopulation.this).currents[x.intValue()] = current;
                }
            });
            LazyParticleFilter.this.end_track();
        }

        public int nParticlesTokens() {
            return this.normalizedWeights.length;
        }

        public double expectedResampledSupportSize(int nResampling) {
            return Sampling.expectedNumberDistinctParticles(this.normalizedWeights, nResampling);
        }

        @Override
        public List<T> sampleNTimes(Random r, int n) {
            Counter<Integer> indices = Sampling.efficientMultinomialSampling(r, this.normalizedWeights, n);
            this.unLazyParticles(indices.keySet());
            ArrayList result = CollUtils.list();
            for (Integer index : indices.keySet()) {
                Object current = this.currents[index];
                int i = 0;
                while ((double)i < indices.getCount(index)) {
                    result.add(current);
                    ++i;
                }
            }
            return result;
        }
    }

    public static interface Distribution<T> {
        public List<T> sampleNTimes(Random var1, int var2);
    }

    public static interface LazyParticleKernel<S>
    extends ParticleKernel<S> {
        public double peekNext(Random var1, S var2);
    }

    public static class Eager2LazyAdaptor<S>
    implements LazyParticleKernel<S> {
        private final ParticleKernel<S> kernel;

        public Eager2LazyAdaptor(ParticleKernel<S> kernel) {
            this.kernel = kernel;
        }

        @Override
        public Pair<S, Double> next(Random rand, S current) {
            return this.kernel.next(rand, current);
        }

        @Override
        public int nIterationsLeft(S partialState) {
            return this.kernel.nIterationsLeft(partialState);
        }

        @Override
        public S getInitial() {
            return this.kernel.getInitial();
        }

        @Override
        public double peekNext(Random rand, S current) {
            return this.kernel.next(rand, current).getSecond();
        }
    }

    public static class ParticleFilterOptions {
        @Option
        public boolean verbose = false;
        @Option
        public int nParticles = 100000;
        @Option
        public int maxNUniqueParticles = 1000;
        @Option
        public Random rand = new Random(1L);
        @Option
        public boolean resampleLastRound = true;
        @Option
        public int nThreads = 8;
        @Option
        public int processBatchSize = 1000;
        @Option
        public int finalMaxNUniqueParticles = Integer.MAX_VALUE;
        @Option
        public boolean parallelizeFinalParticleProcessing = false;
        @Option
        public double populationShrinkFactor = 0.9;

        public void check() {
            if (this.populationShrinkFactor >= 1.0) {
                throw new RuntimeException();
            }
            if (this.maxNUniqueParticles <= 1) {
                throw new RuntimeException();
            }
        }
    }
}

