/*
 * Decompiled with CFR 0.152.
 */
package smcsampler;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.Parallelizer;
import fig.prob.Multinomial;
import goblin.BayesRiskMinimizer;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import monaco.process.ProcessSchedule;
import monaco.process.ProcessScheduleContext;
import monaco.process.ResampleStatus;
import nuts.io.CSV;
import nuts.lang.ArrayUtils;
import nuts.math.Fct;
import nuts.math.Id;
import nuts.math.MeasureZeroException;
import nuts.math.Sampling;
import nuts.maxent.SloppyMath;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Hasher;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.PegasusSolver;
import pty.mcmc.UnrootedTreeState;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleKernel;
import smcsampler.AnnealingKernel;

public final class ParticleFilterSMCSampler<S> {
    @Option
    public boolean verbose = true;
    public int N = 100;
    @Option
    public Random rand = new Random(1L);
    @Option
    public boolean resampleLastRound = true;
    @Option
    public int nThreads = 1;
    @Option
    public ResamplingStrategy resamplingStrategy = ResamplingStrategy.ALWAYS;
    public static double essRatioThreshold = 0.5;
    private List<S> conditional = null;
    private double[] conditionalUnnormWeights = null;
    private ProcessSchedule schedule = null;
    private double ess = this.N;
    private double tempDiff;
    private double oldTempDiff;
    @Option
    public boolean adaptiveTempDiff = false;
    private boolean isLastIter = false;
    public double alpha = 0.9;
    public int adaptiveType = 0;
    public PrintWriter smcSamplerOut = null;
    private long[] seeds;
    private List<S> samples;
    private double[] logWeights;
    private double[] incrementalLogWeights;
    private double varLogZ = 0.0;
    private double lognorm = 0.0;

    public double getEss() {
        return this.ess;
    }

    public void setEss(double ess) {
        this.ess = ess;
    }

    public void setProcessSchedule(ProcessSchedule schedule) {
        this.schedule = schedule;
    }

    public void setConditional(List<S> conditional, double[] conditionalUnnormWeights) {
        if (conditional.size() != conditionalUnnormWeights.length) {
            throw new RuntimeException();
        }
        this.conditional = conditional;
        this.conditionalUnnormWeights = conditionalUnnormWeights;
    }

    public void setUnconditional() {
        this.conditional = null;
        this.conditionalUnnormWeights = null;
    }

    public boolean isConditional() {
        return this.conditional != null;
    }

    public static double ess(double[] ws) {
        double sumOfSqr = 0.0;
        for (double w : ws) {
            sumOfSqr += w * w;
        }
        return 1.0 / sumOfSqr;
    }

    public static <S> void bootstrapFilter(ParticleKernel<S> kernel, ParticleProcessor<S> processor, int N, Random rand) {
        ParticleFilterSMCSampler<S> PF = new ParticleFilterSMCSampler<S>();
        PF.N = N;
        PF.rand = rand;
        PF.resampleLastRound = false;
        PF.resamplingStrategy = ResamplingStrategy.ALWAYS;
        PF.sample(kernel, processor);
    }

    public List<S> getSamples() {
        return this.samples;
    }

    public double[] getLogWeights() {
        return this.logWeights;
    }

    private void propagateAndComputeWeights(final ParticleKernel<S> kernel, final int t) {
        if (this.verbose) {
            LogInfo.track((Object)"Processing...", (boolean)false);
        }
        final double[] normalizedWeights0 = (double[])this.logWeights.clone();
        NumUtils.expNormalize((double[])normalizedWeights0);
        final double[] logWeights2 = new double[this.N];
        this.seeds = Sampling.createSeeds((int)this.N, (Random)this.rand);
        Parallelizer parallelizer = new Parallelizer(this.nThreads);
        parallelizer.setPrimaryThread();
        parallelizer.process(CollUtils.ints((int)this.N), (Parallelizer.Processor)new Parallelizer.Processor<Integer>(){

            public void process(Integer x, int _i, int _n, boolean log) {
                if (log && (x + 1) % 5 == 0 && ParticleFilterSMCSampler.this.verbose) {
                    LogInfo.logs((Object)("Particle " + (x + 1) + "/" + ParticleFilterSMCSampler.this.N));
                }
                Random rand = new Random(ParticleFilterSMCSampler.this.seeds[x]);
                if (x == 0 && ParticleFilterSMCSampler.this.isConditional()) {
                    ParticleFilterSMCSampler.this.samples.set(x, ParticleFilterSMCSampler.this.conditional.get(t));
                    ((ParticleFilterSMCSampler)ParticleFilterSMCSampler.this).logWeights[x.intValue()] = ParticleFilterSMCSampler.this.conditionalUnnormWeights[t];
                } else {
                    Pair current = kernel.next(rand, ParticleFilterSMCSampler.this.samples.get(x));
                    if (current == null) {
                        ParticleFilterSMCSampler.this.samples.set(x, null);
                        ((ParticleFilterSMCSampler)ParticleFilterSMCSampler.this).logWeights[x.intValue()] = Double.NEGATIVE_INFINITY;
                    } else {
                        ParticleFilterSMCSampler.this.samples.set(x, current.getFirst());
                        ((ParticleFilterSMCSampler)ParticleFilterSMCSampler.this).incrementalLogWeights[x.intValue()] = (Double)current.getSecond();
                        double[] dArray = ParticleFilterSMCSampler.this.logWeights;
                        int n = x;
                        dArray[n] = dArray[n] + (Double)current.getSecond();
                        logWeights2[x.intValue()] = Math.log(normalizedWeights0[x]) + (Double)current.getSecond();
                    }
                }
            }
        });
        double logZRatio = SloppyMath.logAdd((double[])logWeights2);
        double varLogZRatio = 0.0;
        for (int k = 0; k < this.incrementalLogWeights.length; ++k) {
            double tmp = Math.exp(this.incrementalLogWeights[k] - logZRatio) - 1.0;
            varLogZRatio += normalizedWeights0[k] * tmp * tmp;
        }
        this.varLogZ += varLogZRatio;
        this.lognorm += logZRatio;
        if (this.verbose) {
            LogInfo.end_track();
        }
    }

    public double estimateNormalizer() {
        return this.lognorm;
    }

    public double estimateNormalizerVariance() {
        return this.varLogZ;
    }

    private void init(ParticleKernel<S> kernel) {
        this.lognorm = 0.0;
        this.samples = new ArrayList<S>(this.N);
        Object initial = kernel.getInitial();
        for (int n = 0; n < this.N; ++n) {
            this.samples.add(initial);
        }
        this.logWeights = new double[this.N];
        this.incrementalLogWeights = new double[this.N];
    }

    private void newProcess(int t, double[] normalizedWeights, ParticleProcessor<S> processor, int T) {
        if (this.schedule != null && this.schedule.shouldProcess(new ProcessScheduleContext(t, t == T - 1, ResampleStatus.NA))) {
            if (this.verbose) {
                LogInfo.track((Object)"Processing particles");
            }
            for (int n = 0; n < normalizedWeights.length; ++n) {
                if (this.verbose) {
                    LogInfo.logs((Object)("Particle " + (n + 1) + "/" + normalizedWeights.length));
                }
                if (this.samples.get(n) == null) continue;
                processor.process(this.samples.get(n), normalizedWeights[n]);
                this.samples.set(n, null);
            }
            if (this.verbose) {
                LogInfo.end_track();
            }
        }
    }

    public double temperatureDifference(final double alpha, double absoluteAccuracy, double min, double max) {
        final double[] logLikePrior = new double[this.samples.size()];
        for (int n = 0; n < this.samples.size(); ++n) {
            UnrootedTreeState urt = (UnrootedTreeState)this.samples.get(n);
            logLikePrior[n] = urt.getLogLikelihood() + urt.getLogPrior();
        }
        int maxEval = 100;
        UnivariateFunction f = new UnivariateFunction(){

            public double value(double x) {
                double[] logWeightLikePriorVec = new double[ParticleFilterSMCSampler.this.samples.size()];
                for (int n = 0; n < ParticleFilterSMCSampler.this.samples.size(); ++n) {
                    logWeightLikePriorVec[n] = ParticleFilterSMCSampler.this.logWeights[n] + logLikePrior[n] * x;
                }
                NumUtils.expNormalize((double[])logWeightLikePriorVec);
                return ParticleFilterSMCSampler.ess(logWeightLikePriorVec) / (double)logWeightLikePriorVec.length - alpha;
            }
        };
        double result = 0.0;
        try {
            double relativeAccuracy = absoluteAccuracy * 1.0E-4;
            PegasusSolver solver = new PegasusSolver(relativeAccuracy, absoluteAccuracy);
            result = solver.solve(maxEval, f, min, max);
        }
        catch (RuntimeException e) {
            LogInfo.logsForce((Object)"Solver Fail!");
            result = 0.0;
        }
        return result;
    }

    public void sample(ParticleKernel<S> kernel, ParticleProcessor<S> tdp) {
        this.init(kernel);
        int T = kernel.nIterationsLeft(kernel.getInitial());
        if (this.isConditional() && this.conditional.size() != T) {
            throw new RuntimeException();
        }
        if (this.smcSamplerOut != null) {
            this.smcSamplerOut.println(CSV.header((Object[])new Object[]{"t", "ESS", "tempDiff"}));
        }
        double alpha0 = this.alpha;
        for (int t = 0; t < T && !this.isLastIter; ++t) {
            if (kernel instanceof AnnealingKernel) {
                AnnealingKernel currentKernel = (AnnealingKernel)kernel;
                currentKernel.setCurrentIter(t + 1);
                if (this.adaptiveTempDiff) {
                    this.oldTempDiff = this.tempDiff;
                    this.tempDiff = 0.0;
                    if (this.adaptiveType == 1) {
                        alpha0 += 0.5 * Math.pow(1.0 - this.alpha, 2.0) * Math.pow(this.alpha, t);
                    }
                    if (this.adaptiveType == 2) {
                        alpha0 += 0.5 * Math.pow(1.0 - this.alpha, 3.0) * (double)(t + 1) * Math.pow(this.alpha, t);
                    }
                    if (t > 0) {
                        this.tempDiff = this.temperatureDifference(alpha0 * this.ess / (double)this.samples.size(), 1.0E-7, 0.0, 0.1);
                        if (this.tempDiff == 0.0) {
                            this.tempDiff = 0.0;
                        }
                    }
                } else {
                    this.tempDiff = 1.0 / (double)T;
                }
                currentKernel.setTemperatureDifference(this.tempDiff);
                if (currentKernel.isLastIter()) {
                    this.isLastIter = true;
                }
            }
            if (this.verbose) {
                LogInfo.track((Object)("Particle generation " + (t + 1) + "/" + T), (boolean)true);
            }
            this.propagateAndComputeWeights(kernel, t);
            double[] normalizedWeights = (double[])this.logWeights.clone();
            NumUtils.expNormalize((double[])normalizedWeights);
            if (this.verbose) {
                LogInfo.logs((Object)("LargestNormalizedWeights=" + ArrayUtils.max((double[])normalizedWeights)));
            }
            this.ess = ParticleFilterSMCSampler.ess(normalizedWeights);
            if (this.smcSamplerOut != null) {
                this.smcSamplerOut.println(CSV.body((Object[])new Object[]{t, this.ess, this.tempDiff}));
            }
            if (this.verbose) {
                LogInfo.logs((Object)("RelativeESS=" + this.ess / (double)normalizedWeights.length));
            }
            this.newProcess(t, normalizedWeights, tdp, T);
            if (t < T - 1 && (this.hasNulls(this.samples) || this.resamplingStrategy.needResample(normalizedWeights))) {
                this.samples = this.resample(this.samples, normalizedWeights, this.rand);
                this.logWeights = new double[this.N];
                this.ess = this.N;
            }
            if (this.verbose) {
                LogInfo.end_track();
            }
            if ((t == T - 1 || this.isLastIter) && this.schedule != null) {
                if (this.resampleLastRound) {
                    Pair<List<S>, double[]> resampled = this.resampleAndPack(this.samples, normalizedWeights, this.rand);
                    this.samples = (List)resampled.getFirst();
                    normalizedWeights = (double[])resampled.getSecond();
                }
                if (this.verbose) {
                    LogInfo.track((Object)"Processing particles");
                }
                for (int n = 0; n < normalizedWeights.length; ++n) {
                    if (this.verbose) {
                        LogInfo.logs((Object)("Particle " + (n + 1) + "/" + normalizedWeights.length));
                    }
                    if (this.samples.get(n) == null) continue;
                    tdp.process(this.samples.get(n), normalizedWeights[n]);
                    this.samples.set(n, null);
                }
                if (this.verbose) {
                    LogInfo.end_track();
                }
            }
            if (this.schedule != null) {
                this.schedule.monitor(new ProcessScheduleContext(t, t == T - 1, ResampleStatus.NA));
            }
            if (this.smcSamplerOut == null) continue;
            this.smcSamplerOut.flush();
        }
        this.setUnconditional();
        if (this.smcSamplerOut != null) {
            this.smcSamplerOut.close();
        }
    }

    private boolean hasNulls(List<S> samples) {
        for (S item : samples) {
            if (item != null) continue;
            return true;
        }
        return false;
    }

    private <S> Pair<List<S>, double[]> resampleAndPack(List<S> list, double[] w, Random rand) {
        if (!NumUtils.normalize((double[])w)) {
            throw new MeasureZeroException();
        }
        if (list.size() != w.length) {
            throw new RuntimeException();
        }
        Counter packed = Sampling.efficientMultinomialSampling((Random)rand, (double[])w, (int)w.length);
        double[] resultWeight = new double[packed.size()];
        ArrayList<S> resultList = new ArrayList<S>(packed.size());
        int i = 0;
        Iterator iterator = packed.keySet().iterator();
        while (iterator.hasNext()) {
            int itemIdx = (Integer)iterator.next();
            S item = list.get(itemIdx);
            resultList.add(item);
            resultWeight[i++] = packed.getCount((Object)itemIdx);
        }
        return Pair.makePair(resultList, (Object)resultWeight);
    }

    private List<S> resample(List<S> list, double[] w, Random rand) {
        if (!NumUtils.normalize((double[])w)) {
            throw new MeasureZeroException();
        }
        if (list.size() != w.length) {
            throw new RuntimeException();
        }
        Counter packed = Sampling.efficientMultinomialSampling((Random)rand, (double[])w, (int)w.length);
        ArrayList<S> result = new ArrayList<S>(list.size());
        Iterator iterator = packed.keySet().iterator();
        while (iterator.hasNext()) {
            int itemIdx = (Integer)iterator.next();
            S item = list.get(itemIdx);
            int cur = 0;
            while ((double)cur < packed.getCount((Object)itemIdx)) {
                result.add(item);
                ++cur;
            }
        }
        return result;
    }

    public static class MAPDecoder<D>
    implements ParticleProcessor<D> {
        private D argmax = null;
        private double max = Double.NEGATIVE_INFINITY;

        @Override
        public void process(D state, double weight) {
            if (weight > this.max) {
                this.argmax = state;
                this.max = weight;
            }
        }

        public D map() {
            return this.argmax;
        }
    }

    public static class PCSHash
    implements ParticleProcessor<PartialCoalescentState> {
        private Hasher hasher = new Hasher();

        @Override
        public void process(PartialCoalescentState state, double weight) {
            this.hasher.add(weight).add(state.logLikelihood()).add(state.topHeight()).add((Object)state.getUnlabeledArbre().deepToLispString());
        }

        public int getHash() {
            return this.hasher.hashCode();
        }
    }

    public static class ParticleMapperProcessor<D, I>
    implements ParticleProcessor<D> {
        private final Fct<D, I> prj;
        private final Counter<I> counter = new Counter();

        public static <S> ParticleMapperProcessor<S, S> saveParticlesProcessor() {
            return new ParticleMapperProcessor(new Id());
        }

        public static ParticleMapperProcessor<PartialCoalescentState, PartialCoalescentState> saveCoalescentParticlesProcessor() {
            return ParticleMapperProcessor.saveParticlesProcessor();
        }

        public ParticleMapperProcessor(Fct<D, I> prj) {
            this.prj = prj;
        }

        @Override
        public void process(D state, double weight) {
            this.counter.incrementCount(this.prj.evalAt(state), weight);
        }

        public I centroid(BayesRiskMinimizer.LossFct<I> loss) {
            return (I)new BayesRiskMinimizer(loss).findMin(this.counter);
        }

        public String printweights() {
            String s = "";
            for (Object key : this.counter) {
                double value = this.counter.getCount(key);
                s = s + value + ",";
            }
            return s;
        }

        public I map() {
            return (I)this.counter.argMax();
        }

        public I sample(Random rand) {
            double[] probs = new double[this.counter.size()];
            ArrayList states = new ArrayList();
            int i = 0;
            for (Object key : this.counter.keySet()) {
                states.add(key);
                probs[i++] = this.counter.getCount(key);
            }
            int j = Multinomial.sample((Random)rand, (double[])probs);
            return (I)states.get(j);
        }

        public Counter<I> getCounter() {
            return this.counter;
        }
    }

    public static class ForkedProcessor<S>
    implements ParticleProcessor<S> {
        public List<ParticleProcessor<S>> processors = new ArrayList<ParticleProcessor<S>>();

        public ForkedProcessor(ParticleProcessor<S> ... items) {
            this.processors = new ArrayList<ParticleProcessor<S>>(Arrays.asList(items));
        }

        public ForkedProcessor(Collection<ParticleProcessor<S>> items) {
            this.processors = CollUtils.list(items);
        }

        @Override
        public void process(S state, double weight) {
            for (ParticleProcessor<S> processor : this.processors) {
                processor.process(state, weight);
            }
        }
    }

    public static class DoNothingProcessor<S>
    implements ParticleProcessor<S> {
        @Override
        public void process(S state, double w) {
        }
    }

    public static class StoreProcessor<S>
    implements ParticleProcessor<S> {
        public List<S> particles = CollUtils.list();
        public List<Double> ws = CollUtils.list();

        @Override
        public void process(S state, double weight) {
            this.particles.add(state);
            this.ws.add(weight);
        }

        public S sample(Random rand) {
            int idx = Sampling.sample((Random)rand, this.ws);
            return this.particles.get(idx);
        }
    }

    public static interface ParticleProcessor<S> {
        public void process(S var1, double var2);
    }

    public static enum ResamplingStrategy {
        ALWAYS{

            @Override
            public boolean needResample(double[] w) {
                return true;
            }
        }
        ,
        NEVER{

            @Override
            public boolean needResample(double[] w) {
                return false;
            }
        }
        ,
        ESS{

            @Override
            public boolean needResample(double[] weights) {
                double threshold;
                double ess = ParticleFilterSMCSampler.ess(weights);
                return ess < (threshold = essRatioThreshold * (double)weights.length);
            }
        };


        abstract boolean needResample(double[] var1);
    }
}

