/*
 * Decompiled with CFR 0.152.
 */
package sand;

import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.tui.FancyTreeRenderer;
import sand.Kernels;
import sand.PoissonKernel;
import sand.TransitionKernel;
import sand.TruncatedKernel;

public class TreeSampler {
    private final TransitionKernel kernel;
    private final List<TreeSampler> children;
    private double[] backPrs = null;

    public TreeSampler(TransitionKernel kernel, List<TreeSampler> children) {
        this.kernel = kernel;
        this.children = new ArrayList<TreeSampler>();
        this.children.addAll(children);
    }

    public TreeSampler(TransitionKernel kernel) {
        this(kernel, new ArrayList<TreeSampler>());
    }

    public TreeSampler(List<TreeSampler> children) {
        this(null, children);
    }

    private void backward() {
        if (this.backPrs != null) {
            return;
        }
        this.backPrs = new double[this.nStates()];
        for (TreeSampler child : this.children) {
            child.backward();
        }
        for (int s = 0; s < this.nStates(); ++s) {
            double product = 1.0;
            for (TreeSampler child : this.children) {
                double sum = 0.0;
                for (int s2 = 0; s2 < child.nStates(); ++s2) {
                    sum += child.kernel.pr(s, s2) * child.backPrs[s2];
                }
                product *= sum;
            }
            this.backPrs[s] = product;
        }
    }

    private TreeSample sample(TreeSampler parentNode, int parentSample, Random rand) {
        this.backward();
        int sample = 0;
        if (this.nStates() > 1) {
            double[] prs = new double[this.nStates()];
            for (int s = 0; s < this.nStates(); ++s) {
                prs[s] = this.backPrs[s] * this.kernel.pr(parentSample, s);
            }
            NumUtils.normalize(prs);
            sample = SampleUtils.sampleMultinomial(rand, prs);
        }
        ArrayList<TreeSample> childrenSamples = new ArrayList<TreeSample>();
        for (TreeSampler child : this.children) {
            childrenSamples.add(child.sample(this, sample, rand));
        }
        return new TreeSample(sample, childrenSamples);
    }

    public TreeSample sample(Random rand) {
        return this.sample(null, -1, rand);
    }

    private int nStates() {
        if (this.kernel == null) {
            return 1;
        }
        return this.kernel.nNextStates();
    }

    public static void main(String[] args) {
        int truncation = (int)Math.pow(8.0, 4.0);
        int[] observations = new int[]{12, 9, 19, 17, 19, 13, 10, 8, 14, 11};
        TreeSampler sampler = null;
        TruncatedKernel kernel = new TruncatedKernel(new PoissonKernel(), truncation);
        for (int t = 9; t >= 1; --t) {
            ArrayList<TreeSampler> children = new ArrayList<TreeSampler>();
            TreeSampler observation = new TreeSampler(Kernels.emissionLikelihood(kernel, observations[t]));
            children.add(observation);
            if (t < 9) {
                children.add(sampler);
            }
            sampler = t == 1 ? new TreeSampler(Kernels.initialDistKernel(kernel), children) : new TreeSampler(kernel, children);
        }
        sampler = new TreeSampler(Collections.singletonList(sampler));
        System.out.println("Starting to sample...");
        TreeSample sample = sampler.sample(new Random());
        System.out.println(sample.deepToString());
    }

    public static class TreeSample {
        private final int sample;
        private final List<TreeSample> children;

        public TreeSample(int sample, List<TreeSample> children) {
            this.sample = sample;
            this.children = children;
        }

        public List<TreeSample> getChildren() {
            return Collections.unmodifiableList(this.children);
        }

        public int getSample() {
            return this.sample;
        }

        public String toString() {
            return "" + this.sample;
        }

        public String deepToString() {
            final TreeSample root = this;
            return new FancyTreeRenderer(new FancyTreeRenderer.Populator(){

                @Override
                public Object populate() {
                    this.add(root);
                    this.populate(root);
                    return root;
                }

                public void populate(TreeSample current) {
                    for (TreeSample child : current.children) {
                        this.add(child, current);
                        this.populate(child);
                    }
                }
            }).toString();
        }
    }
}

