/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.math.MeasureZeroException;
import nuts.maxent.SloppyMath;
import nuts.tui.Table;
import nuts.util.Arbre;

public abstract class DirectedTreeSampler
implements Serializable {
    public static boolean verbose = true;
    protected double annealExp = 1.0;
    private static final SingletonKernel singletonKernel = new SingletonKernel();
    public static final int n = 2;

    public static StdDirectedTreeSampler createSampler(Arbre<? extends TransitionKernel> kernelTree, InitialDistribution initDist, boolean useSandwich, double slack, boolean useSecondOrder) {
        if (kernelTree.getChildren().size() == 0) {
            throw new RuntimeException("The tree should be nontrivial");
        }
        if (kernelTree.getContents() != null) {
            throw new RuntimeException("The contents for the root of a kernel tree has no meaning");
        }
        Arbre<? extends TransitionKernel> transformedTree = DirectedTreeSampler.transformKernel(kernelTree, initDist);
        if (useSandwich) {
            throw new RuntimeException();
        }
        return new StdDirectedTreeSampler(transformedTree);
    }

    private static Arbre<? extends TransitionKernel> transformKernel(Arbre<? extends TransitionKernel> kernelTree, InitialDistribution initDist) {
        Arbre<? extends TransitionKernel> copy = kernelTree.copy();
        copy.setContents(new InitialDistributionKernel(initDist));
        return DirectedTreeSampler.createDummyRoot().addLeaves(copy);
    }

    private static Arbre createDummyRoot() {
        return Arbre.arbre(singletonKernel);
    }

    public double getAnnealExp() {
        return this.annealExp;
    }

    public void setAnnealExp(double annealExp) {
        if (annealExp < 0.0) {
            throw new RuntimeException("Bad anneal");
        }
        this.annealExp = annealExp;
        if (annealExp != 1.0) {
            throw new RuntimeException("xxxxxxxxxxxxxxxxxxxx see below, not sure if it's a bug");
        }
    }

    public abstract double getSumPr();

    public abstract double getSumLogPr();

    public abstract Arbre<Integer> sample(Random var1) throws MeasureZeroException;

    public abstract double q(Arbre<Integer> var1) throws MeasureZeroException;

    public abstract double p(Arbre<Integer> var1);

    public static void main(String[] args) {
        UnNormUniKernel unuk = new UnNormUniKernel();
        UnNormUniInit unui = new UnNormUniInit();
        Arbre<Object> a = Arbre.arbre(unuk, Arbre.arbre(null)).root();
        System.out.println(a.deepToString());
        StdDirectedTreeSampler sampler = DirectedTreeSampler.createSampler(a, unui, false, Double.NaN, false);
        System.out.println(Math.exp(((DirectedTreeSampler)sampler).getSumLogPr()));
    }

    public static class UnNormUniInit
    implements InitialDistribution {
        @Override
        public int nStates() {
            return 2;
        }

        @Override
        public double pr(int state) {
            return 1.0;
        }

        @Override
        public int sample(Random rand) {
            throw new RuntimeException("Not yet sup");
        }
    }

    public static class UnNormUniKernel
    implements TransitionKernel {
        @Override
        public int nBottomStates() {
            return 2;
        }

        @Override
        public int nTopStates() {
            return 2;
        }

        @Override
        public double pr(int topState, int bottomState) {
            return 1.0;
        }

        @Override
        public int sample(int topState, Random rand) {
            throw new RuntimeException("Not yet sup");
        }
    }

    public static class ArrayInitialDistribution
    implements InitialDistribution {
        private static final long serialVersionUID = 1L;
        private final double[] prs;

        public ArrayInitialDistribution(double[] prs) {
            this.prs = prs;
        }

        @Override
        public int nStates() {
            return this.prs.length;
        }

        @Override
        public double pr(int state) {
            return this.prs[state];
        }

        @Override
        public int sample(Random rand) {
            throw new RuntimeException();
        }
    }

    public static interface InitialDistribution
    extends Serializable {
        public double pr(int var1);

        public int sample(Random var1);

        public int nStates();
    }

    public static class InitialDistributionKernel
    implements TransitionKernel {
        private static final long serialVersionUID = 1L;
        private final InitialDistribution initDist;

        public InitialDistributionKernel(InitialDistribution initDist) {
            this.initDist = initDist;
        }

        @Override
        public int nTopStates() {
            return 1;
        }

        @Override
        public int nBottomStates() {
            return this.initDist.nStates();
        }

        @Override
        public double pr(int currentState, int nextState) {
            if (currentState != 0) {
                throw new RuntimeException();
            }
            return this.initDist.pr(nextState);
        }

        @Override
        public int sample(int currentState, Random rand) {
            if (currentState != 0) {
                throw new RuntimeException();
            }
            return this.initDist.sample(rand);
        }
    }

    public static class SingletonKernel
    implements TransitionKernel {
        private static final long serialVersionUID = 1L;

        @Override
        public int nTopStates() {
            throw new RuntimeException();
        }

        @Override
        public int nBottomStates() {
            return 1;
        }

        @Override
        public double pr(int topState, int bottomState) {
            throw new RuntimeException();
        }

        @Override
        public int sample(int topState, Random rand) {
            throw new RuntimeException();
        }
    }

    public static class ArrayTransitionKernel
    implements TransitionKernel {
        private static final long serialVersionUID = 1L;
        private final double[][] prs;

        public ArrayTransitionKernel(double[][] prs) {
            this.prs = prs;
        }

        @Override
        public int nTopStates() {
            return this.prs.length;
        }

        @Override
        public int nBottomStates() {
            return this.prs[0].length;
        }

        @Override
        public double pr(int currentState, int nextState) {
            return this.prs[currentState][nextState];
        }

        @Override
        public int sample(int currentState, Random rand) {
            throw new RuntimeException();
        }

        public String toString() {
            return new Table(new Table.Populator(){

                @Override
                public void populate() {
                    for (int i = 0; i < this.nTopStates(); ++i) {
                        this.set(i + 1, 0, "top #" + i);
                    }
                    for (int j = 0; j < this.nBottomStates(); ++j) {
                        this.set(0, j + 1, "bot #" + j);
                    }
                    for (int top = 0; top < this.nTopStates(); ++top) {
                        for (int bot = 0; bot < this.nBottomStates(); ++bot) {
                            this.set(top + 1, bot + 1, prs[top][bot]);
                        }
                    }
                }
            }).toString();
        }
    }

    public static interface TransitionKernel
    extends Serializable {
        public double pr(int var1, int var2);

        public int sample(int var1, Random var2);

        public int nTopStates();

        public int nBottomStates();
    }

    public static class SandwichSampler
    extends DirectedTreeSampler {
        private static final long serialVersionUID = 2L;
        private final TransitionKernel[] trans;
        private final int length;
        private final int root;
        private final double slack;
        private final boolean useSecondOrderBounds;
        private final List<Integer>[] children;
        private final int[] parents;
        public static final int INIT_APPROX_SIZE = 10;

        private List<Integer> children(int t) {
            return this.children[t];
        }

        private int parent(int t) {
            return this.parents[t];
        }

        public SandwichSampler(Arbre<? extends TransitionKernel> transTree, double slack, boolean useSecondOrderBounds) {
            this.slack = slack;
            this.length = transTree.nodes().size();
            this.trans = new TransitionKernel[this.length];
            this.children = new List[this.length];
            this.parents = new int[transTree.nodes().size()];
            this.root = this.convert(transTree);
            this.useSecondOrderBounds = useSecondOrderBounds;
        }

        private int convert(Arbre transTree) {
            int t;
            List subtrees = transTree.nodes();
            IdentityHashMap subtree2integer = new IdentityHashMap();
            for (t = 0; t < subtrees.size(); ++t) {
                subtree2integer.put(subtrees.get(t), t);
            }
            for (t = 0; t < subtrees.size(); ++t) {
                Arbre node = subtrees.get(t);
                if (node.isRoot()) {
                    this.parents[t] = -1;
                    this.trans[t] = null;
                } else {
                    this.parents[t] = (Integer)subtree2integer.get(node.getParent());
                    this.trans[t] = (TransitionKernel)node.getContents();
                }
                this.children[t] = new ArrayList<Integer>();
                for (Arbre child : node.getChildren()) {
                    this.children[t].add((Integer)subtree2integer.get(child));
                }
            }
            return (Integer)subtree2integer.get(transTree);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Arbre<Integer> sample(Random rand) throws MeasureZeroException {
            double[] uniform = this.uniform(rand);
            Sample sample = new Sample();
            if (verbose) {
                LogInfo.track((Object)("Sampling with Sandwich: " + Arrays.toString(uniform)), true);
            }
            try {
                int cApproxSize = 10;
                while (!sample.isComplete()) {
                    if (verbose) {
                        LogInfo.logs("--- Current approx size: n=" + cApproxSize);
                    }
                    Bounds bounds = this.createNewBounds(cApproxSize, sample.nSampled);
                    bounds.compute();
                    if (verbose) {
                        LogInfo.logs("" + bounds);
                    }
                    this.tryExpand(sample, bounds, uniform);
                    if (verbose) {
                        LogInfo.logs("Sample after trying to expand: \n" + sample);
                    }
                    cApproxSize = this.nextApproxSize(cApproxSize);
                }
                if (verbose) {
                    LogInfo.logs("---");
                }
                if (verbose) {
                    LogInfo.logs("Final approx size: " + cApproxSize);
                }
            }
            finally {
                if (verbose) {
                    LogInfo.end_track();
                }
            }
            return sample.toIntegerTree().getChildren().get(0).copy();
        }

        private void tryExpand(Sample sample, Bounds bounds, double[] uniforms) {
            sample.signalNewBoundComputed();
            while (sample.hasUnattemptedFringeItem()) {
                int t = sample.attemptNextFringeItem();
                SampleExpander expander = new SampleExpander(sample, bounds, t, uniforms[t]);
                expander.tryExpand();
                if (!verbose) continue;
                LogInfo.logs(expander.toString());
            }
        }

        private int nextApproxSize(int cSize) {
            return cSize * 2;
        }

        private int nStates(int t) {
            if (t == this.root) {
                return this.trans[this.children(t).get(0)].nTopStates();
            }
            return this.trans[t].nBottomStates();
        }

        private Bounds createNewBounds(int N2, int nSampled) {
            return this.useSecondOrderBounds ? new SecondOrderBounds(N2, nSampled) : new FirstOrderBounds(N2, nSampled);
        }

        private double[] uniform(Random rand) {
            double[] result = new double[this.length];
            for (int i = 0; i < this.length; ++i) {
                result[i] = rand.nextDouble();
            }
            return result;
        }

        @Override
        public double getSumPr() {
            return Double.NaN;
        }

        @Override
        public double p(Arbre<Integer> sample) {
            return Double.NaN;
        }

        @Override
        public double q(Arbre<Integer> sample) throws MeasureZeroException {
            return Double.NaN;
        }

        @Override
        public double getSumLogPr() {
            return Double.NaN;
        }

        private class FirstOrderBounds
        extends Bounds {
            protected FirstOrderBounds(int N2, int nSampled) {
                super(N2, nSampled);
            }

            @Override
            public void compute() {
                this.U = new double[SandwichSampler.this.length][this.N];
                this.L = new double[SandwichSampler.this.length][this.N];
                this.init();
                for (int t = SandwichSampler.this.length - 2; t >= this.nSampled - 1; --t) {
                    for (int s = 0; s < this.nTruncStates(t); ++s) {
                        this.compute(t, s);
                    }
                }
            }

            private void init() {
                int lastChainIndex = SandwichSampler.this.length - 1;
                for (int s = 0; s < this.nTruncStates(lastChainIndex); ++s) {
                    double cEmi = 1.0;
                    this.U[lastChainIndex][s] = 1.0;
                    this.L[lastChainIndex][s] = 1.0;
                }
            }

            private void compute(int currentNode, int s) {
                this.L[currentNode][s] = 1.0;
                this.U[currentNode][s] = 1.0;
                Iterator iterator = SandwichSampler.this.children(currentNode).iterator();
                while (iterator.hasNext()) {
                    int childNode = (Integer)iterator.next();
                    if (SandwichSampler.this.nStates(childNode) == 1) {
                        this.L[currentNode][s] = SandwichSampler.this.trans[childNode].pr(s, 0);
                        this.U[currentNode][s] = SandwichSampler.this.trans[childNode].pr(s, 0);
                        continue;
                    }
                    double[] dArray = this.L[currentNode];
                    int n = s;
                    dArray[n] = dArray[n] * this.transitionTruncDotProduct(childNode, s, this.L[childNode]);
                    double[] dArray2 = this.U[currentNode];
                    int n2 = s;
                    dArray2[n2] = dArray2[n2] * (this.transitionTruncDotProduct(childNode, s, this.U[childNode]) + this.escapePr(childNode, s));
                }
                assert (this.L[currentNode][s] <= this.U[currentNode][s] + 1.0E-9) : "L=" + this.L[currentNode][s] + ",U=" + this.U[currentNode][s];
            }

            public String toString() {
                return "L table:\n" + this.boundsToString(this.L) + "U table:\n" + this.boundsToString(this.U);
            }

            @Override
            public double postLowerBound(int currentNode, int currentState, int childNode, int childState) {
                if (SandwichSampler.this.nStates(childNode) == 1) {
                    return 1.0;
                }
                return SandwichSampler.this.trans[childNode].pr(currentState, childState) * this.L[childNode][childState] / (this.escapePr(childNode, currentState) + this.transitionTruncDotProduct(childNode, currentState, this.U[childNode]));
            }

            @Override
            public double postUpperBound(int currentNode, int currentState, int childNode, int childState) {
                if (SandwichSampler.this.nStates(childNode) == 1) {
                    return 1.0;
                }
                return SandwichSampler.this.trans[childNode].pr(currentState, childState) * this.U[childNode][childState] / this.transitionTruncDotProduct(childNode, currentState, this.L[childNode]);
            }
        }

        private class SecondOrderBounds
        extends Bounds {
            private double[] Min;
            private double[] Max;

            protected SecondOrderBounds(int N2, int nSampled) {
                super(N2, nSampled);
            }

            @Override
            public void compute() {
                this.L = new double[SandwichSampler.this.length][];
                this.U = new double[SandwichSampler.this.length][];
                this.Min = new double[SandwichSampler.this.length];
                this.Max = new double[SandwichSampler.this.length];
                this.init();
                for (int t = SandwichSampler.this.length - 2; t >= this.nSampled - 1; --t) {
                    this.computeLU(t, this.L, this.Min);
                    this.computeLU(t, this.U, this.Max);
                    this.Min[t] = this.computeExtr(t, this.L, false);
                    this.Max[t] = this.computeExtr(t, this.U, true);
                }
            }

            private void init() {
                for (int t = 0; t < SandwichSampler.this.length; ++t) {
                    this.L[t] = new double[SandwichSampler.this.nStates(t)];
                    this.U[t] = new double[SandwichSampler.this.nStates(t)];
                }
                for (int s = 0; s < SandwichSampler.this.nStates(SandwichSampler.this.length - 1); ++s) {
                    this.L[((SandwichSampler)SandwichSampler.this).length - 1][s] = 1.0;
                    this.U[((SandwichSampler)SandwichSampler.this).length - 1][s] = 1.0;
                }
                this.Min[((SandwichSampler)SandwichSampler.this).length - 1] = this.computeExtr(SandwichSampler.this.length - 1, this.L, false);
                this.Max[((SandwichSampler)SandwichSampler.this).length - 1] = this.computeExtr(SandwichSampler.this.length - 1, this.U, true);
            }

            private void computeLU(int currentNode, double[][] LU, double[] Extr) {
                for (int s = 0; s < SandwichSampler.this.nStates(currentNode); ++s) {
                    double product = 1.0;
                    Iterator iterator = SandwichSampler.this.children(currentNode).iterator();
                    while (iterator.hasNext()) {
                        int childNode = (Integer)iterator.next();
                        if (SandwichSampler.this.nStates(childNode) == 1) {
                            product = SandwichSampler.this.trans[childNode].pr(s, 0);
                            continue;
                        }
                        double sum = this.transitionTruncDotProduct(childNode, s, LU[childNode]);
                        if (s < this.nTruncStates(currentNode)) {
                            for (int y = this.nTruncStates(currentNode); y < SandwichSampler.this.nStates(childNode); ++y) {
                                sum += SandwichSampler.this.trans[childNode].pr(s, y) * LU[childNode][y];
                            }
                        } else {
                            sum += !Double.isInfinite(Extr[childNode]) ? this.escapePr(childNode, s) * Extr[childNode] : 0.0;
                        }
                        product *= sum;
                    }
                    LU[currentNode][s] = product;
                    if (LU[currentNode][s] >= 0.0) continue;
                    throw new RuntimeException("Bad LU: " + LU[currentNode][s]);
                }
            }

            private double computeExtr(int t, double[][] LU, boolean useMax) {
                double extr = useMax ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
                for (int s = this.nTruncStates(t); s < SandwichSampler.this.nStates(t); ++s) {
                    if (!(useMax ? LU[t][s] > extr : LU[t][s] < extr)) continue;
                    extr = LU[t][s];
                }
                return extr;
            }

            public String toString() {
                return "L table:\n" + this.boundsToString(this.L) + "U table:\n" + this.boundsToString(this.U);
            }

            @Override
            public double postUpperBound(int currentNode, int currentState, int childNode, int childState) {
                return this.postUpperBound(currentNode, currentState, childNode, childState, this.U, this.L);
            }

            @Override
            public double postLowerBound(int currentNode, int currentState, int childNode, int childState) {
                return this.postUpperBound(currentNode, currentState, childNode, childState, this.L, this.U);
            }

            public double postUpperBound(int currentNode, int currentState, int childNode, int childState, double[][] U, double[][] L) {
                if (SandwichSampler.this.nStates(childNode) == 1) {
                    return 1.0;
                }
                return SandwichSampler.this.trans[childNode].pr(currentState, childState) * U[childNode][childState] / this.transitionFullDotProduct(childNode, currentState, L[childNode]);
            }

            protected double transitionFullDotProduct(int t, int s1, double[] vector) {
                double sum = 0.0;
                for (int s2 = 0; s2 < SandwichSampler.this.nStates(t); ++s2) {
                    sum += SandwichSampler.this.trans[t].pr(s1, s2) * vector[s2];
                }
                return sum;
            }
        }

        private abstract class Bounds {
            protected final int N;
            protected final int nSampled;
            protected double[][] L;
            protected double[][] U;

            protected Bounds(int N2, int nSampled) {
                this.N = N2;
                this.nSampled = nSampled;
            }

            public abstract void compute();

            public abstract double postLowerBound(int var1, int var2, int var3, int var4);

            public abstract double postUpperBound(int var1, int var2, int var3, int var4);

            protected double escapePr(int t, int s1) {
                double escapePr = 1.0;
                for (int s2 = 0; s2 < this.nTruncStates(t); ++s2) {
                    escapePr -= SandwichSampler.this.trans[t].pr(s1, s2);
                }
                assert (escapePr >= 0.0 && escapePr <= 1.0) : "escape pr=" + escapePr + ", terms=" + this.escapePrTerms(t, s1);
                return escapePr;
            }

            private String escapePrTerms(int t, int s1) {
                String result = "";
                for (int s2 = 0; s2 < this.nTruncStates(t); ++s2) {
                    result = result + SandwichSampler.this.trans[t].pr(s1, s2) + " ";
                }
                return result;
            }

            protected double transitionTruncDotProduct(int t, int s1, double[] vector) {
                double sum = 0.0;
                for (int s2 = 0; s2 < this.nTruncStates(t); ++s2) {
                    sum += SandwichSampler.this.trans[t].pr(s1, s2) * vector[s2];
                }
                return sum;
            }

            protected String boundsToString(final double[][] bounds) {
                Table result = new Table(new Table.Populator(){

                    @Override
                    public void populate() {
                        int t;
                        this.set(0, 0, "s V / t ->");
                        for (t = 0; t < SandwichSampler.this.length; ++t) {
                            this.set(0, t + 1, "" + t);
                        }
                        for (t = 0; t < SandwichSampler.this.length; ++t) {
                            for (int s = 0; s < bounds[t].length; ++s) {
                                this.set(s + 1, 0, "" + s);
                                this.set(s + 1, t + 1, bounds[t][s]);
                            }
                        }
                    }
                });
                result.setBorder(true);
                return result.toString();
            }

            public int nTruncStates(int t) {
                return Math.min(this.N, SandwichSampler.this.nStates(t));
            }
        }

        public class Sample {
            private final int[] samples;
            public static final int UNSAMPLED = -1;
            private int nSampled;
            private Map<Integer, Boolean> fringe;

            private boolean isComplete() {
                return this.nSampled == SandwichSampler.this.length;
            }

            private Arbre<Integer> toIntegerTree(int node) {
                int contents = this.samples[node];
                ArrayList decompiledChildren = new ArrayList();
                Iterator iterator = SandwichSampler.this.children(node).iterator();
                while (iterator.hasNext()) {
                    int childNode = (Integer)iterator.next();
                    decompiledChildren.add(this.toIntegerTree(childNode));
                }
                return Arbre.arbre(Integer.valueOf(contents), decompiledChildren);
            }

            public Arbre<Integer> toIntegerTree() {
                return this.toIntegerTree(SandwichSampler.this.root);
            }

            public Sample() {
                this.samples = new int[SandwichSampler.this.length];
                this.nSampled = 0;
                this.fringe = new HashMap<Integer, Boolean>();
                for (int t = 0; t < SandwichSampler.this.length; ++t) {
                    this.samples[t] = -1;
                }
                this.fringe.put(SandwichSampler.this.root, true);
                this.expand(SandwichSampler.this.root, 0);
            }

            private void expand(int node, int s) {
                this.samples[node] = s;
                ++this.nSampled;
                if (!this.fringe.containsKey(node)) {
                    throw new RuntimeException("Internal error in Sample");
                }
                this.fringe.remove(node);
                Iterator iterator = SandwichSampler.this.children(node).iterator();
                while (iterator.hasNext()) {
                    int childNode = (Integer)iterator.next();
                    this.fringe.put(childNode, false);
                }
            }

            public boolean hasUnattemptedFringeItem() {
                for (boolean wasAttempted : this.fringe.values()) {
                    if (wasAttempted) continue;
                    return true;
                }
                return false;
            }

            public int attemptNextFringeItem() {
                int result = -1;
                for (int node : this.fringe.keySet()) {
                    boolean wasAttempted = this.fringe.get(node);
                    if (wasAttempted) continue;
                    result = node;
                    break;
                }
                if (result == -1) {
                    throw new RuntimeException();
                }
                this.fringe.put(result, true);
                return result;
            }

            public void signalNewBoundComputed() {
                for (int node : this.fringe.keySet()) {
                    this.fringe.put(node, false);
                }
            }

            public int[] getLegacyHMMHiddenStates() {
                if (SandwichSampler.this.length % 2 != 0) {
                    throw new RuntimeException();
                }
                int[] result = new int[SandwichSampler.this.length];
                int currentState = SandwichSampler.this.root;
                for (int i = 0; i < SandwichSampler.this.length / 2; ++i) {
                    result[i] = this.samples[currentState];
                    if (i == SandwichSampler.this.length / 2 - 1) continue;
                    currentState = (Integer)SandwichSampler.this.children(currentState).get(1);
                }
                return result;
            }

            public String toString() {
                return this.toArbre(SandwichSampler.this.root).deepToString().replaceAll("-1", "??");
            }

            public Arbre<Integer> toArbre(int node) {
                int contents = this.samples[node];
                ArrayList children = new ArrayList();
                Iterator iterator = SandwichSampler.this.children(node).iterator();
                while (iterator.hasNext()) {
                    int childNode = (Integer)iterator.next();
                    children.add(this.toArbre(childNode));
                }
                return new Arbre<Integer>(Integer.valueOf(contents), children);
            }
        }

        private class SampleExpander {
            private final Sample sample;
            private final Bounds bounds;
            private final int t;
            private final int parentNode;
            private double uniformNumber;
            private final double[] lowerBoundPartialSums;
            private final double[] upperBoundPartialSums;
            private final int previousState;

            private SampleExpander(Sample sample, Bounds bounds, int t, double uni) {
                this.uniformNumber = uni;
                this.sample = sample;
                this.bounds = bounds;
                this.t = t;
                this.parentNode = SandwichSampler.this.parent(t);
                this.lowerBoundPartialSums = new double[bounds.nTruncStates(t)];
                this.upperBoundPartialSums = new double[bounds.nTruncStates(t)];
                this.previousState = sample.samples[this.parentNode];
            }

            private void computePartialSums() {
                this.lowerBoundPartialSums[0] = this.bounds.postLowerBound(this.parentNode, this.previousState, this.t, 0);
                this.upperBoundPartialSums[0] = this.bounds.postUpperBound(this.parentNode, this.previousState, this.t, 0);
                for (int s = 1; s < this.bounds.nTruncStates(this.t); ++s) {
                    this.lowerBoundPartialSums[s] = this.lowerBoundPartialSums[s - 1] + this.bounds.postLowerBound(this.parentNode, this.previousState, this.t, s);
                    this.upperBoundPartialSums[s] = this.upperBoundPartialSums[s - 1] + this.bounds.postUpperBound(this.parentNode, this.previousState, this.t, s);
                }
            }

            private boolean tryExpand() {
                this.computePartialSums();
                for (int s = 0; s < this.bounds.nTruncStates(this.t); ++s) {
                    if (!this.tryExpand(s)) continue;
                    return true;
                }
                return false;
            }

            private boolean tryExpand(int s) {
                if (this.uniformNumber > this.lowerBoundPartialSums[s] + SandwichSampler.this.slack) {
                    return false;
                }
                for (int otherS = 0; otherS < s; ++otherS) {
                    if (!(this.uniformNumber <= this.upperBoundPartialSums[otherS] - SandwichSampler.this.slack)) continue;
                    return false;
                }
                this.sample.expand(this.t, s);
                return true;
            }

            public String toString() {
                Table result = new Table(new Table.Populator(){

                    @Override
                    public void populate() {
                        this.set(0, 0, "norm L");
                        this.set(0, 1, "norm U");
                        this.set(0, 2, "delta");
                        this.set(0, 3, "sum norm L");
                        this.set(0, 4, "sum norm U");
                        this.set(0, 5, "delta");
                        for (int i = 0; i < SampleExpander.this.bounds.nTruncStates(SampleExpander.this.t); ++i) {
                            double postLow = SampleExpander.this.bounds.postLowerBound(SampleExpander.this.parentNode, SampleExpander.this.previousState, SampleExpander.this.t, i);
                            double postUp = SampleExpander.this.bounds.postUpperBound(SampleExpander.this.parentNode, SampleExpander.this.previousState, SampleExpander.this.t, i);
                            this.set(i + 1, 0, postLow);
                            this.set(i + 1, 1, postUp);
                            this.set(i + 1, 2, postUp - postLow);
                            this.set(i + 1, 3, SampleExpander.this.lowerBoundPartialSums[i]);
                            this.set(i + 1, 4, SampleExpander.this.upperBoundPartialSums[i]);
                            this.set(i + 1, 5, SampleExpander.this.upperBoundPartialSums[i] - SampleExpander.this.lowerBoundPartialSums[i]);
                        }
                    }
                });
                result.setBorder(true);
                return "Expanding t=" + this.t + ", uniform: " + this.uniformNumber + "\n" + result.toString();
            }
        }
    }

    public static class StdDirectedTreeSampler
    extends DirectedTreeSampler {
        private static final long serialVersionUID = 2L;
        private boolean logMode = false;
        private boolean tableInsured = false;
        private final Arbre<TransitionKernel> kernelTree;
        private Arbre<BackwardNode> backwardTree = null;
        private SampleMap sampleMap = null;
        private Arbre<Integer> unprocessedSample;

        public StdDirectedTreeSampler(Arbre transformedKernelTree) {
            this.kernelTree = transformedKernelTree;
        }

        private Arbre<TransitionKernel> getKernelTree() {
            return this.kernelTree;
        }

        @Override
        public double getSumPr() {
            this.init();
            double number = this.backwardTree.getContents().backPrs[0];
            return this.logMode ? Math.exp(number) : number;
        }

        @Override
        public double getSumLogPr() {
            this.init();
            double number = this.backwardTree.getContents().backPrs[0];
            return this.logMode ? number : Math.log(number);
        }

        public void activateLogMode() {
            if (this.tableInsured) {
                throw new RuntimeException("Illegal state in DirectedTreeSampler");
            }
            this.logMode = true;
        }

        @Override
        public void setAnnealExp(double annealingCoef) {
            if (this.getAnnealExp() == annealingCoef) {
                return;
            }
            if (this.tableInsured) {
                throw new RuntimeException("Illegal setAnneal");
            }
            super.setAnnealExp(annealingCoef);
        }

        @Override
        public Arbre<Integer> sample(Random rand) throws MeasureZeroException {
            this.init();
            if (this.getSumLogPr() == Double.NEGATIVE_INFINITY) {
                throw new MeasureZeroException("Directed tree sampler ran on a measure zero");
            }
            this.sampleMap = new SampleMap(rand);
            this.unprocessedSample = this.backwardTree.preOrderMap(this.sampleMap);
            return this.unprocessedSample.getChildren().get(0).copy();
        }

        private void init() {
            if (this.backwardTree == null) {
                this.backwardTree = this.kernelTree.postOrderMap(new BackwardMap());
                if (!this.logMode && this.backwardTree.getContents().backPrs[0] == 0.0) {
                    this.logMode = true;
                    this.backwardTree = this.kernelTree.postOrderMap(new BackwardMap());
                }
                this.tableInsured = true;
            }
        }

        private double anneal(double x) {
            if (this.annealExp == 1.0) {
                return x;
            }
            return Math.pow(x, this.annealExp);
        }

        private double logAnneal(double a) {
            return this.annealExp * Math.log(a);
        }

        @Override
        public double q(Arbre<Integer> sample) throws MeasureZeroException {
            throw new RuntimeException();
        }

        public double logQ(Arbre<Integer> sample) throws MeasureZeroException {
            this.init();
            return this.logP(sample) - this.getSumLogPr();
        }

        private Arbre<Integer> transformSample(Arbre<Integer> sample) {
            ArrayList child = new ArrayList();
            child.add(sample.copy());
            Arbre<Integer> result = Arbre.arbre(Integer.valueOf(0), child);
            return result;
        }

        @Override
        public double p(Arbre<Integer> sample) {
            if (this.kernelTree == null || this.kernelTree.getChildren().size() != 1) {
                throw new RuntimeException("Malformed kernelTree in DirectedTreeSamper.p");
            }
            return this.p(this.kernelTree.getChildren().get(0), this.transformSample(sample).getChildren().get(0), false);
        }

        public double logP(Arbre<Integer> sample) {
            if (this.kernelTree == null || this.kernelTree.getChildren().size() != 1) {
                throw new RuntimeException("Malformed kernelTree in DirectedTreeSampler.logP");
            }
            return this.p(this.kernelTree.getChildren().get(0), this.transformSample(sample).getChildren().get(0), true);
        }

        private double p(Arbre<? extends TransitionKernel> kernelTree, Arbre<Integer> sample, boolean inLogs) {
            int parentSample = sample.getParent().getContents();
            int currentSample = sample.getContents();
            if (kernelTree.getChildren().size() != sample.getChildren().size()) {
                throw new RuntimeException("Internal error in kernelTree in DirectedTreeSampler.p");
            }
            double number = kernelTree.getContents().pr(parentSample, currentSample);
            double result = Double.NaN;
            if (inLogs == this.logMode) {
                result = number;
            } else if (inLogs && !this.logMode) {
                result = Math.log(number);
            } else if (!inLogs && this.logMode) {
                result = Math.exp(number);
            } else {
                throw new RuntimeException();
            }
            for (int i = 0; i < kernelTree.getChildren().size(); ++i) {
                double otherNumber = this.p(kernelTree.getChildren().get(i), sample.getChildren().get(i), inLogs);
                result = inLogs ? result + otherNumber : result * otherNumber;
            }
            return result;
        }

        private class SampleMap
        extends Arbre.ArbreMap<BackwardNode, Integer>
        implements Serializable {
            private static final long serialVersionUID = 1L;
            private final Random rand;

            public SampleMap(Random rand) {
                this.rand = rand;
            }

            @Override
            public Integer map(Arbre<BackwardNode> currentBackwardNode) {
                BackwardNode backwardNode = currentBackwardNode.getContents();
                double[] backPrs = backwardNode.backPrs;
                TransitionKernel kernel = backwardNode.kernel;
                if (currentBackwardNode.isRoot()) {
                    return 0;
                }
                int parentSample = (Integer)this.getCallerImage();
                double[] prs = new double[kernel.nBottomStates()];
                for (int s = 0; s < kernel.nBottomStates(); ++s) {
                    double number = kernel.pr(parentSample, s);
                    prs[s] = StdDirectedTreeSampler.this.logMode ? backPrs[s] + Math.log(number) : backPrs[s] * number;
                }
                if (StdDirectedTreeSampler.this.logMode) {
                    NumUtils.expNormalize(prs);
                } else {
                    NumUtils.normalize(prs);
                }
                try {
                    return SampleUtils.sampleMultinomial(this.rand, prs);
                }
                catch (RuntimeException re) {
                    throw new RuntimeException("Probably some bad prs: " + re.toString() + "\nroottable[0]: " + ((BackwardNode)StdDirectedTreeSampler.this.backwardTree.getContents()).backPrs[0] + "\nPrs: " + Arrays.toString(prs) + "\nBackprs: " + Arrays.toString(backPrs));
                }
            }
        }

        private class BackwardMap
        extends Arbre.ArbreMap<TransitionKernel, BackwardNode>
        implements Serializable {
            private static final long serialVersionUID = 1L;

            private BackwardMap() {
            }

            @Override
            public BackwardNode map(Arbre<TransitionKernel> currentKernelNode) {
                TransitionKernel currentNode = currentKernelNode.getContents();
                double[] backPrs = new double[currentNode.nBottomStates()];
                for (int s = 0; s < currentKernelNode.getContents().nBottomStates(); ++s) {
                    double partial = StdDirectedTreeSampler.this.logMode ? 0.0 : 1.0;
                    for (BackwardNode child : this.getChildImage()) {
                        double subPartial = StdDirectedTreeSampler.this.logMode ? Double.NEGATIVE_INFINITY : 0.0;
                        for (int s2 = 0; s2 < child.kernel.nBottomStates(); ++s2) {
                            subPartial = StdDirectedTreeSampler.this.logMode ? SloppyMath.logAdd(subPartial, StdDirectedTreeSampler.this.logAnneal(child.kernel.pr(s, s2)) + child.backPrs[s2]) : subPartial + StdDirectedTreeSampler.this.anneal(child.kernel.pr(s, s2)) * child.backPrs[s2];
                        }
                        partial = StdDirectedTreeSampler.this.logMode ? partial + subPartial : partial * subPartial;
                    }
                    backPrs[s] = partial;
                }
                return new BackwardNode(currentKernelNode.getContents(), backPrs);
            }
        }

        private static class BackwardNode
        implements Serializable {
            private static final long serialVersionUID = 1L;
            private final TransitionKernel kernel;
            private final double[] backPrs;

            public BackwardNode(TransitionKernel kernel, double[] backPrs) {
                this.kernel = kernel;
                this.backPrs = backPrs;
            }
        }
    }
}

