/*
 * Decompiled with CFR 0.152.
 */
package backward;

import fig.basic.NumUtils;
import fig.basic.Option;
import fig.prob.Gaussian;
import fig.prob.SampleUtils;
import hmm.Param;
import hmm.ParamUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.math.TreeSumProd;
import nuts.util.Counter;
import pty.smc.ParticleFilter;
import pty.smc.test.TestParticleNormalization;

public class BackwardPF<S> {
    @Option
    public int Nparticles = 100;
    @Option
    public Random rand = new Random(1L);
    private List<S> conditional = null;
    private BackwardParticleKernel<S> kernel = null;
    private List<ParticlePopulation<S>> populations = null;

    public void compute(BackwardParticleKernel<S> kernel, ParticleFilter.ParticleProcessor<S> processor) {
        this.init(kernel);
        int T = kernel.nIterationsLeft(kernel.getInitial());
        if (T == Integer.MAX_VALUE && this.conditional != null) {
            T = this.conditional.size();
        }
        if (this.isConditional() && this.conditional.size() != T) {
            throw new RuntimeException();
        }
        ParticlePopulation<Object> previous = this.initPopulation(kernel.getInitial());
        for (int t = 0; t < T; ++t) {
            ParticlePopulation<Object> current = this.nextPopulation(this.rand, previous, this.isConditional() ? (S)this.conditional.get(t) : null);
            this.populations.add(current);
            previous = current;
        }
        for (int i = 0; i < this.Nparticles; ++i) {
            processor.process(((ParticlePopulation)previous).particles.get(i), ((ParticlePopulation)previous).normWeights[i]);
        }
    }

    public List<S> stdPGSample() {
        return this._PGSample(false);
    }

    public List<S> backwardPGSample() {
        return this._PGSample(true);
    }

    private List<S> _PGSample(boolean useBack) {
        ArrayList result = new ArrayList();
        int T = this.populations.size();
        if (this.populations.size() != T) {
            throw new RuntimeException();
        }
        int previousIndex = SampleUtils.sampleMultinomial(this.rand, ((ParticlePopulation)this.populations.get(T - 1)).normWeights);
        for (int t = T - 1; t >= 0; --t) {
            result.add(((ParticlePopulation)this.populations.get(t)).particles.get(previousIndex));
            if (t == 0) break;
            previousIndex = useBack ? this.backwardSample(this.rand, this.populations.get(t - 1), this.populations.get(t), previousIndex) : ((ParticlePopulation)this.populations.get(t)).parentPointers[previousIndex];
        }
        Collections.reverse(result);
        return result;
    }

    public boolean isConditional() {
        return this.conditional != null;
    }

    public void setConditional(List<S> conditional) {
        this.conditional = conditional;
    }

    private ParticlePopulation<S> nextPopulation(Random rand, ParticlePopulation<S> p, S currentConditional) {
        Object previousConditional = ((ParticlePopulation)p).particles.get(0);
        int N = ((ParticlePopulation)p).unNormLogWeights.length;
        int[] parentPointers = new int[N];
        ArrayList<S> newParticles = new ArrayList<S>();
        double[] normWeights = new double[N];
        double[] unNormLogWeights = new double[N];
        for (int m = 0; m < N; ++m) {
            S newParticle;
            Object ancestor;
            if (this.isConditional() && m == 0) {
                ancestor = previousConditional;
                newParticle = currentConditional;
                parentPointers[m] = 0;
            } else {
                int ancestorIdx = SampleUtils.sampleMultinomial(rand, ((ParticlePopulation)p).normWeights);
                ancestor = ((ParticlePopulation)p).particles.get(ancestorIdx);
                parentPointers[m] = ancestorIdx;
                newParticle = this.kernel.sampleExtension(ancestor, rand);
            }
            double logDensityRatio = this.kernel.logDensityRatio(newParticle, ancestor);
            double logProp = this.kernel.normalizedExtensionLogDensity(ancestor, newParticle);
            double logWeight = logDensityRatio - logProp;
            newParticles.add(newParticle);
            unNormLogWeights[m] = logWeight;
            normWeights[m] = logWeight;
        }
        NumUtils.expNormalize(normWeights);
        return new ParticlePopulation(normWeights, unNormLogWeights, newParticles, parentPointers);
    }

    private int backwardSample(Random rand, ParticlePopulation<S> p1, ParticlePopulation<S> p2, int p2index) {
        int N = ((ParticlePopulation)p1).unNormLogWeights.length;
        Object s2 = ((ParticlePopulation)p2).particles.get(p2index);
        double[] samplingDistribution = new double[N];
        for (int m = 0; m < N; ++m) {
            samplingDistribution[m] = ((ParticlePopulation)p1).unNormLogWeights[m] + this.kernel.normalizedExtensionLogDensity(((ParticlePopulation)p1).particles.get(m), s2);
        }
        NumUtils.expNormalize(samplingDistribution);
        return SampleUtils.sampleMultinomial(rand, samplingDistribution);
    }

    private void init(BackwardParticleKernel<S> kernel) {
        this.kernel = kernel;
        this.populations = new ArrayList<ParticlePopulation<S>>();
    }

    private static void test1(Param p, List obsList, int np) {
        DiscreteSSM ssm = new DiscreteSSM(p);
        SSMKernel<Integer, Integer> kernel = new SSMKernel<Integer, Integer>(ssm, obsList);
        BackwardPF pf = new BackwardPF();
        pf.Nparticles = np;
        TestProcessor pro = new TestProcessor();
        List cond = null;
        for (int mcmcIter = 0; mcmcIter < 10000; ++mcmcIter) {
            if (cond != null) {
                pf.setConditional(cond);
            }
            pf.compute(kernel, pro);
            cond = pf.backwardPGSample();
            Counter<Integer> copy = new Counter<Integer>(pro.stateDist);
            copy.normalize();
            System.out.println(BackwardPF.toString(copy, p));
        }
    }

    private static String toString(Counter c, Param p) {
        String result = "";
        for (int i = 0; i < p.nStates(); ++i) {
            result = result + c.getCount(i) + " ";
        }
        return result;
    }

    private static void test2(Param p, List obsList, int np) {
        int[] obs = new int[obsList.size()];
        for (int i = 0; i < obsList.size(); ++i) {
            obs[i] = (Integer)obsList.get(i);
        }
        TestParticleNormalization.HMMParticleKernel pk = new TestParticleNormalization.HMMParticleKernel(p, obs);
        ParticleFilter<TestParticleNormalization.HMMPState> pf = new ParticleFilter<TestParticleNormalization.HMMPState>();
        TestProcessor pro = new TestProcessor();
        pf.N = np;
        pf.sample(pk, pro);
        System.out.println(BackwardPF.toString(pro.stateDist, p));
    }

    public static void main(String[] args) {
        boolean length = true;
        Random rand = new Random(1000L);
        Param p = ParamUtils.randomUniParam(rand, 5, 5);
        System.out.println("Param:\n" + p);
        int[] obs = new int[1];
        System.out.println("Observation:" + Arrays.toString(obs));
        ArrayList<Integer> obsList = new ArrayList<Integer>();
        for (int o : obs) {
            obsList.add(o);
        }
        TreeSumProd.HmmAdaptor adapt = new TreeSumProd.HmmAdaptor(p, obsList);
        TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(adapt);
        System.out.println("Exact: " + Arrays.toString(tsp.moments().get(0)));
        BackwardPF.test2(p, obsList, 10);
        System.out.println("----");
        BackwardPF.test1(p, obsList, 10);
    }

    private ParticlePopulation<S> initPopulation(S init) {
        ArrayList<S> newParticles = new ArrayList<S>();
        double[] normWeights = new double[this.Nparticles];
        double[] unNormLogWeights = new double[this.Nparticles];
        for (int i = 0; i < this.Nparticles; ++i) {
            newParticles.add(init);
        }
        NumUtils.expNormalize(normWeights);
        return new ParticlePopulation(normWeights, unNormLogWeights, newParticles, null);
    }

    public static class SSMKernel<S, T>
    implements BackwardParticleKernel<IndexedParticle<S>> {
        private final SSM<IndexedParticle<S>, T> model;
        private final SSMProposal<IndexedParticle<S>> proposal;
        private final List<T> observations;

        public SSMKernel(SSM<IndexedParticle<S>, T> model, List<T> observations) {
            this(model, observations, null);
        }

        public SSMKernel(SSM<IndexedParticle<S>, T> model, List<T> observations, SSMProposal<IndexedParticle<S>> proposal) {
            this.model = model;
            this.proposal = proposal;
            this.observations = observations;
        }

        @Override
        public IndexedParticle<S> sampleExtension(IndexedParticle<S> s, Random rand) {
            return this.proposal == null ? this.model.sampleTransition(s, rand) : this.proposal.propose(s, rand);
        }

        @Override
        public double normalizedExtensionLogDensity(IndexedParticle<S> s1, IndexedParticle<S> s2) {
            return this.proposal == null ? this.model.normalizedLogTransitionDensity(s1, s2) : this.proposal.normalizedProposalLogDensity(s1, s2);
        }

        @Override
        public double logDensityRatio(IndexedParticle<S> num, IndexedParticle<S> denom) {
            T newObs = this.observations.get(num.index);
            double logLike = this.model.normalizedLogLikelihood(num, newObs);
            double trans = this.model.normalizedLogTransitionDensity(denom, num);
            return logLike + trans;
        }

        @Override
        public int nIterationsLeft(IndexedParticle<S> s) {
            if (s == null) {
                return this.observations.size();
            }
            return this.observations.size() - s.index - 1;
        }

        @Override
        public IndexedParticle<S> getInitial() {
            return null;
        }
    }

    private static class IndexedParticle<S> {
        public final int index;
        public final S contents;

        public IndexedParticle(int index, S contents) {
            this.index = index;
            this.contents = contents;
        }
    }

    public static interface SSMProposal<S> {
        public S propose(S var1, Random var2);

        public double normalizedProposalLogDensity(S var1, S var2);
    }

    public static class SimpleSSMExample
    implements SSM<IndexedParticle<Double>, Double> {
        public final Gaussian v;
        public final Gaussian e;

        public SimpleSSMExample(double sigmaV, double sigmaE) {
            this.v = new Gaussian(0.0, sigmaV);
            this.e = new Gaussian(0.0, sigmaE);
        }

        public static double f(double x, double t) {
            return 0.5 * x + 25.0 * x / (1.0 + x * x) + 8.0 * Math.cos(1.2 * (t + 1.0));
        }

        public static double g(double x) {
            return 0.05 * x * x;
        }

        @Override
        public IndexedParticle<Double> sampleTransition(IndexedParticle<Double> s, Random rand) {
            if (s == null) {
                return new IndexedParticle<Double>(0, Gaussian.sample(rand, 0.0, 5.0));
            }
            int t = s.index;
            return new IndexedParticle<Double>(t + 1, SimpleSSMExample.f((Double)s.contents, t) + this.v.sample(rand));
        }

        @Override
        public IndexedParticle<Double> sampleObservation(IndexedParticle<Double> s, Random rand) {
            return new IndexedParticle<Double>(s.index + 1, SimpleSSMExample.g((Double)s.contents) + this.e.sample(rand));
        }

        @Override
        public double normalizedLogLikelihood(IndexedParticle<Double> hidden, Double obs) {
            double value = obs - SimpleSSMExample.g((Double)hidden.contents);
            return this.e.logProb(value);
        }

        @Override
        public double normalizedLogTransitionDensity(IndexedParticle<Double> t1, IndexedParticle<Double> t2) {
            if (t1 == null) {
                return Gaussian.logProb(0.0, 5.0, (Double)t2.contents);
            }
            double value = (Double)t2.contents - SimpleSSMExample.f((Double)t1.contents, t1.index);
            return this.v.logProb(value);
        }
    }

    public static class DiscreteSSM
    implements SSM<IndexedParticle<Integer>, Integer> {
        public final Param p;

        public DiscreteSSM(Param p) {
            this.p = p;
        }

        @Override
        public IndexedParticle<Integer> sampleTransition(IndexedParticle<Integer> s, Random rand) {
            if (s == null) {
                return new IndexedParticle<Integer>(0, this.p.initVec.nextState(rand));
            }
            return new IndexedParticle<Integer>(s.index + 1, this.p.transMtx.nextState((Integer)s.contents, rand));
        }

        @Override
        public IndexedParticle<Integer> sampleObservation(IndexedParticle<Integer> s, Random rand) {
            return null;
        }

        @Override
        public double normalizedLogLikelihood(IndexedParticle<Integer> hidden, Integer obs) {
            return Math.log(this.p.emiMtx.p((Integer)hidden.contents, obs));
        }

        @Override
        public double normalizedLogTransitionDensity(IndexedParticle<Integer> t1, IndexedParticle<Integer> t2) {
            if (t1 == null) {
                return Math.log(this.p.initVec.p((Integer)t2.contents));
            }
            return Math.log(this.p.transMtx.p((Integer)t1.contents, (Integer)t2.contents));
        }
    }

    public static interface SSM<S, T> {
        public S sampleTransition(S var1, Random var2);

        public S sampleObservation(S var1, Random var2);

        public double normalizedLogLikelihood(S var1, T var2);

        public double normalizedLogTransitionDensity(S var1, S var2);
    }

    public static class TestProcessor
    implements ParticleFilter.ParticleProcessor {
        Counter<Integer> stateDist = new Counter();

        public void process(Object state, double weight) {
            Integer value = state instanceof IndexedParticle ? (Integer)((IndexedParticle)state).contents : Integer.valueOf(((TestParticleNormalization.HMMPState)state).state);
            this.stateDist.incrementCount(value, weight);
        }
    }

    public static interface BackwardParticleKernel<S> {
        public S sampleExtension(S var1, Random var2);

        public double normalizedExtensionLogDensity(S var1, S var2);

        public double logDensityRatio(S var1, S var2);

        public int nIterationsLeft(S var1);

        public S getInitial();
    }

    public static class ParticlePopulation<S> {
        private final double[] normWeights;
        private final double[] unNormLogWeights;
        private final List<S> particles;
        private final int[] parentPointers;

        public ParticlePopulation(double[] normWeights, double[] unNormLogWeights, List<S> particles, int[] parentPointers) {
            this.normWeights = normWeights;
            this.unNormLogWeights = unNormLogWeights;
            this.particles = particles;
            this.parentPointers = parentPointers;
        }
    }
}

