/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.Parallelizer;
import fig.prob.Multinomial;
import goblin.BayesRiskMinimizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import monaco.process.ProcessSchedule;
import monaco.process.ProcessScheduleContext;
import monaco.process.ResampleStatus;
import nuts.lang.ArrayUtils;
import nuts.math.Fct;
import nuts.math.Id;
import nuts.math.MeasureZeroException;
import nuts.math.Sampling;
import nuts.maxent.SloppyMath;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Hasher;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleKernel;

public final class ParticleFilter<S> {
    @Option
    public boolean verbose = false;
    @Option
    public int N = 100;
    @Option
    public Random rand = new Random(1L);
    @Option
    public boolean resampleLastRound = true;
    @Option
    public int nThreads = 1;
    @Option
    public ResamplingStrategy resamplingStrategy = ResamplingStrategy.ALWAYS;
    public static final double essRatioThreshold = 0.5;
    private List<S> conditional = null;
    private double[] conditionalUnnormWeights = null;
    private ProcessSchedule schedule = null;
    private long[] seeds;
    private List<S> samples;
    private double[] logWeights;
    private double lognorm = 0.0;

    public void setProcessSchedule(ProcessSchedule schedule) {
        this.schedule = schedule;
    }

    public void setConditional(List<S> conditional, double[] conditionalUnnormWeights) {
        if (conditional.size() != conditionalUnnormWeights.length) {
            throw new RuntimeException();
        }
        this.conditional = conditional;
        this.conditionalUnnormWeights = conditionalUnnormWeights;
    }

    public void setUnconditional() {
        this.conditional = null;
        this.conditionalUnnormWeights = null;
    }

    public boolean isConditional() {
        return this.conditional != null;
    }

    public static double ess(double[] ws) {
        double sumOfSqr = 0.0;
        for (double w : ws) {
            sumOfSqr += w * w;
        }
        return 1.0 / sumOfSqr;
    }

    public static <S> void bootstrapFilter(ParticleKernel<S> kernel, ParticleProcessor<S> processor, int N, Random rand) {
        ParticleFilter<S> PF = new ParticleFilter<S>();
        PF.N = N;
        PF.rand = rand;
        PF.resampleLastRound = false;
        PF.resamplingStrategy = ResamplingStrategy.ALWAYS;
        PF.sample(kernel, processor);
    }

    public List<S> getSamples() {
        return this.samples;
    }

    public double[] getLogWeights() {
        return this.logWeights;
    }

    private void propagateAndComputeWeights(final ParticleKernel<S> kernel, final int t) {
        if (this.verbose) {
            LogInfo.track((Object)"Processing...", false);
        }
        final double[] normalizedWeights0 = (double[])this.logWeights.clone();
        NumUtils.expNormalize(normalizedWeights0);
        final double[] logWeights2 = new double[this.N];
        this.seeds = Sampling.createSeeds(this.N, this.rand);
        Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(this.nThreads);
        parallelizer.setPrimaryThread();
        parallelizer.process(CollUtils.ints(this.N), new Parallelizer.Processor<Integer>(){

            @Override
            public void process(Integer x, int _i, int _n, boolean log) {
                if (log && (x + 1) % 5 == 0 && ParticleFilter.this.verbose) {
                    LogInfo.logs("Particle " + (x + 1) + "/" + ParticleFilter.this.N);
                }
                Random rand = new Random(ParticleFilter.this.seeds[x]);
                if (x == 0 && ParticleFilter.this.isConditional()) {
                    ParticleFilter.this.samples.set(x, ParticleFilter.this.conditional.get(t));
                    ((ParticleFilter)ParticleFilter.this).logWeights[x.intValue()] = ParticleFilter.this.conditionalUnnormWeights[t];
                } else {
                    Pair current = kernel.next(rand, ParticleFilter.this.samples.get(x));
                    if (current == null) {
                        ParticleFilter.this.samples.set(x, null);
                        ((ParticleFilter)ParticleFilter.this).logWeights[x.intValue()] = Double.NEGATIVE_INFINITY;
                    } else {
                        ParticleFilter.this.samples.set(x, current.getFirst());
                        double[] dArray = ParticleFilter.this.logWeights;
                        int n = x;
                        dArray[n] = dArray[n] + current.getSecond();
                        logWeights2[x.intValue()] = Math.log(normalizedWeights0[x]) + current.getSecond();
                    }
                }
            }
        });
        this.lognorm += SloppyMath.logAdd(logWeights2);
        if (this.verbose) {
            LogInfo.end_track();
        }
    }

    public double estimateNormalizer() {
        return this.lognorm;
    }

    private void init(ParticleKernel<S> kernel) {
        this.lognorm = 0.0;
        this.samples = new ArrayList<S>(this.N);
        S initial = kernel.getInitial();
        for (int n = 0; n < this.N; ++n) {
            this.samples.add(initial);
        }
        this.logWeights = new double[this.N];
    }

    private void newProcess(int t, double[] normalizedWeights, ParticleProcessor<S> processor, int T) {
        if (this.schedule == null && t == T - 1 || this.schedule != null && this.schedule.shouldProcess(new ProcessScheduleContext(t, t == T - 1, ResampleStatus.NA))) {
            if (this.verbose) {
                LogInfo.track("Processing particles");
            }
            for (int n = 0; n < normalizedWeights.length; ++n) {
                if (this.verbose) {
                    LogInfo.logs("Particle " + (n + 1) + "/" + normalizedWeights.length);
                }
                if (this.samples.get(n) == null) continue;
                processor.process(this.samples.get(n), normalizedWeights[n]);
            }
            if (this.verbose) {
                LogInfo.end_track();
            }
        }
    }

    public void sample(ParticleKernel<S> kernel, ParticleProcessor<S> processor) {
        this.init(kernel);
        int T = kernel.nIterationsLeft(kernel.getInitial());
        if (this.isConditional() && this.conditional.size() != T) {
            throw new RuntimeException();
        }
        for (int t = 0; t < T; ++t) {
            if (this.verbose) {
                LogInfo.track((Object)("Particle generation " + (t + 1) + "/" + T), true);
            }
            this.propagateAndComputeWeights(kernel, t);
            double[] normalizedWeights = (double[])this.logWeights.clone();
            NumUtils.expNormalize(normalizedWeights);
            if (this.verbose) {
                LogInfo.logs("LargestNormalizedWeights=" + ArrayUtils.max(normalizedWeights));
            }
            if (this.verbose) {
                LogInfo.logs("RelativeESS=" + ParticleFilter.ess(normalizedWeights) / (double)normalizedWeights.length);
            }
            this.newProcess(t, normalizedWeights, processor, T);
            if (t < T - 1 && (this.hasNulls(this.samples) || this.resamplingStrategy.needResample(normalizedWeights))) {
                this.samples = this.resample(this.samples, normalizedWeights, this.rand);
                this.logWeights = new double[this.N];
            }
            if (this.verbose) {
                LogInfo.end_track();
            }
            if (t == T - 1 && this.schedule != null) {
                if (this.resampleLastRound) {
                    Pair<List<S>, double[]> resampled = this.resampleAndPack(this.samples, normalizedWeights, this.rand);
                    this.samples = resampled.getFirst();
                    normalizedWeights = resampled.getSecond();
                }
                if (this.verbose) {
                    LogInfo.track("Processing particles");
                }
                for (int n = 0; n < normalizedWeights.length; ++n) {
                    if (this.verbose) {
                        LogInfo.logs("Particle " + (n + 1) + "/" + normalizedWeights.length);
                    }
                    if (this.samples.get(n) == null) continue;
                    processor.process(this.samples.get(n), normalizedWeights[n]);
                    this.samples.set(n, null);
                }
                if (this.verbose) {
                    LogInfo.end_track();
                }
            }
            if (this.schedule == null) continue;
            this.schedule.monitor(new ProcessScheduleContext(t, t == T - 1, ResampleStatus.NA));
        }
        this.setUnconditional();
    }

    private boolean hasNulls(List<S> samples) {
        for (S item : samples) {
            if (item != null) continue;
            return true;
        }
        return false;
    }

    private <S> Pair<List<S>, double[]> resampleAndPack(List<S> list, double[] w, Random rand) {
        if (!NumUtils.normalize(w)) {
            throw new MeasureZeroException();
        }
        if (list.size() != w.length) {
            throw new RuntimeException();
        }
        Counter<Integer> packed = Sampling.efficientMultinomialSampling(rand, w, w.length);
        double[] resultWeight = new double[packed.size()];
        ArrayList<S> resultList = new ArrayList<S>(packed.size());
        int i = 0;
        for (int itemIdx : packed.keySet()) {
            S item = list.get(itemIdx);
            resultList.add(item);
            resultWeight[i++] = packed.getCount(itemIdx);
        }
        return Pair.makePair(resultList, resultWeight);
    }

    private List<S> resample(List<S> list, double[] w, Random rand) {
        if (!NumUtils.normalize(w)) {
            throw new MeasureZeroException();
        }
        if (list.size() != w.length) {
            throw new RuntimeException();
        }
        Counter<Integer> packed = Sampling.efficientMultinomialSampling(rand, w, w.length);
        ArrayList<S> result = new ArrayList<S>(list.size());
        for (int itemIdx : packed.keySet()) {
            S item = list.get(itemIdx);
            int cur = 0;
            while ((double)cur < packed.getCount(itemIdx)) {
                result.add(item);
                ++cur;
            }
        }
        return result;
    }

    public static class MAPDecoder<D>
    implements ParticleProcessor<D> {
        private D argmax = null;
        private double max = Double.NEGATIVE_INFINITY;

        @Override
        public void process(D state, double weight) {
            if (weight > this.max) {
                this.argmax = state;
                this.max = weight;
            }
        }

        public D map() {
            return this.argmax;
        }
    }

    public static class PCSHash
    implements ParticleProcessor<PartialCoalescentState> {
        private Hasher hasher = new Hasher();

        @Override
        public void process(PartialCoalescentState state, double weight) {
            this.hasher.add(weight).add(state.logLikelihood()).add(state.topHeight()).add(state.getUnlabeledArbre().deepToLispString());
        }

        public int getHash() {
            return this.hasher.hashCode();
        }
    }

    public static class ParticleMapperProcessor<D, I>
    implements ParticleProcessor<D> {
        private final Fct<D, I> prj;
        private final Counter<I> counter = new Counter();

        public static <S> ParticleMapperProcessor<S, S> saveParticlesProcessor() {
            return new ParticleMapperProcessor(new Id());
        }

        public static ParticleMapperProcessor<PartialCoalescentState, PartialCoalescentState> saveCoalescentParticlesProcessor() {
            return ParticleMapperProcessor.saveParticlesProcessor();
        }

        public ParticleMapperProcessor(Fct<D, I> prj) {
            this.prj = prj;
        }

        @Override
        public void process(D state, double weight) {
            this.counter.incrementCount(this.prj.evalAt(state), weight);
        }

        public I centroid(BayesRiskMinimizer.LossFct<I> loss) {
            return new BayesRiskMinimizer<I>(loss).findMin(this.counter);
        }

        public String printweights() {
            String s = "";
            for (I key : this.counter) {
                double value = this.counter.getCount(key);
                s = s + value + ",";
            }
            return s;
        }

        public I map() {
            return this.counter.argMax();
        }

        public I sample(Random rand) {
            double[] probs = new double[this.counter.size()];
            ArrayList<I> states = new ArrayList<I>();
            int i = 0;
            for (I key : this.counter.keySet()) {
                states.add(key);
                probs[i++] = this.counter.getCount(key);
            }
            int j = Multinomial.sample(rand, probs);
            return (I)states.get(j);
        }

        public Counter<I> getCounter() {
            return this.counter;
        }
    }

    public static class ForkedProcessor<S>
    implements ParticleProcessor<S> {
        public List<ParticleProcessor<S>> processors = new ArrayList<ParticleProcessor<S>>();

        public ForkedProcessor(ParticleProcessor<S> ... items) {
            this.processors = new ArrayList<ParticleProcessor<S>>(Arrays.asList(items));
        }

        public ForkedProcessor(Collection<ParticleProcessor<S>> items) {
            this.processors = CollUtils.list(items);
        }

        @Override
        public void process(S state, double weight) {
            for (ParticleProcessor<S> processor : this.processors) {
                processor.process(state, weight);
            }
        }
    }

    public static class DoNothingProcessor<S>
    implements ParticleProcessor<S> {
        @Override
        public void process(S state, double w) {
        }
    }

    public static class StoreProcessor<S>
    implements ParticleProcessor<S> {
        public List<S> particles = CollUtils.list();
        public List<Double> ws = CollUtils.list();

        @Override
        public void process(S state, double weight) {
            this.particles.add(state);
            this.ws.add(weight);
        }

        public S sample(Random rand) {
            int idx = Sampling.sample(rand, this.ws);
            return this.particles.get(idx);
        }

        public S argmax() {
            int argmax = -1;
            double value = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.ws.size(); ++i) {
                if (!(this.ws.get(i) > value)) continue;
                value = this.ws.get(i);
                argmax = i;
            }
            return this.particles.get(argmax);
        }
    }

    public static interface ParticleProcessor<S> {
        public void process(S var1, double var2);
    }

    public static enum ResamplingStrategy {
        ALWAYS{

            @Override
            public boolean needResample(double[] w) {
                return true;
            }
        }
        ,
        NEVER{

            @Override
            public boolean needResample(double[] w) {
                return false;
            }
        }
        ,
        ESS{

            @Override
            public boolean needResample(double[] weights) {
                double threshold;
                double ess = ParticleFilter.ess(weights);
                return ess < (threshold = 0.5 * (double)weights.length);
            }
        };


        abstract boolean needResample(double[] var1);
    }
}

