/*
 * Decompiled with CFR 0.152.
 */
package sand;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.tui.Table;
import nuts.util.Arbre;
import sand.ObservedEmissionKernel;
import sand.TransitionKernel;

public class SandwitchSampler {
    private final TransitionKernel[] trans;
    private final int length;
    private final int root;
    private final double slack;
    private final boolean useSecondOrderBounds;
    private final List<Integer>[] children;
    private final int[] parents;

    private List<Integer> children(int t) {
        return this.children[t];
    }

    private int parent(int t) {
        return this.parents[t];
    }

    public SandwitchSampler(Arbre<TransitionKernel> transTree, double slack, boolean useSecondOrderBounds) {
        this.slack = slack;
        this.length = transTree.nodes().size();
        this.trans = new TransitionKernel[this.length];
        this.children = new List[this.length];
        this.parents = new int[transTree.nodes().size()];
        this.root = this.convert(transTree);
        this.useSecondOrderBounds = useSecondOrderBounds;
    }

    public SandwitchSampler(TransitionKernel[] transition, ObservedEmissionKernel[] observeds, double slack, boolean useSecondOrder) {
        this(SandwitchSampler.convertLegacy(transition, observeds, 0), slack, useSecondOrder);
    }

    private static Arbre<TransitionKernel> convertLegacy(TransitionKernel[] transition, ObservedEmissionKernel[] observeds, int t) {
        if (observeds.length != transition.length) {
            throw new RuntimeException();
        }
        TransitionKernel currentTrKernel = t == 0 ? null : transition[t];
        ObservedEmissionKernel observedKernel = observeds[t];
        ArrayList children = new ArrayList();
        children.add(new Arbre<TransitionFromObservationKernel>(new TransitionFromObservationKernel(observedKernel)));
        if (t + 1 < transition.length) {
            children.add(SandwitchSampler.convertLegacy(transition, observeds, t + 1));
        }
        return new Arbre<TransitionKernel>(currentTrKernel, children);
    }

    private int convert(Arbre<TransitionKernel> transTree) {
        int t;
        List<Arbre<TransitionKernel>> subtrees = transTree.nodes();
        IdentityHashMap<Arbre<TransitionKernel>, Integer> subtree2integer = new IdentityHashMap<Arbre<TransitionKernel>, Integer>();
        for (t = 0; t < subtrees.size(); ++t) {
            subtree2integer.put(subtrees.get(t), t);
        }
        for (t = 0; t < subtrees.size(); ++t) {
            Arbre<TransitionKernel> node = subtrees.get(t);
            if (node.isRoot()) {
                this.parents[t] = -1;
                this.trans[t] = null;
            } else {
                this.parents[t] = (Integer)subtree2integer.get(node.getParent());
                this.trans[t] = node.getContents();
            }
            this.children[t] = new ArrayList<Integer>();
            for (Arbre<TransitionKernel> child : node.getChildren()) {
                this.children[t].add((Integer)subtree2integer.get(child));
            }
        }
        return (Integer)subtree2integer.get(transTree);
    }

    public Sample sample(Random rand) {
        double[] uniform = this.uniform(rand);
        System.out.println("Samples: " + Arrays.toString(uniform));
        Sample sample = new Sample();
        int cApproxSize = 1;
        while (!sample.isComplete()) {
            System.out.println("--- Current approx size: n=" + cApproxSize);
            Bounds bounds = this.createNewBounds(cApproxSize, sample.nSampled);
            bounds.compute();
            System.out.println("" + bounds);
            this.tryExpand(sample, bounds, uniform);
            System.out.println("Sample after trying to expand: \n" + sample);
            cApproxSize = this.nextApproxSize(cApproxSize);
        }
        System.out.println("---");
        System.out.println("Final approx size: " + cApproxSize);
        return sample;
    }

    private void tryExpand(Sample sample, Bounds bounds, double[] uniforms) {
        sample.signalNewBoundComputed();
        while (sample.hasUnattemptedFringeItem()) {
            int t = sample.attemptNextFringeItem();
            SampleExpander expander = new SampleExpander(sample, bounds, t, uniforms[t]);
            boolean success = expander.tryExpand();
            System.out.println(expander.toString());
        }
    }

    private int nextApproxSize(int cSize) {
        return cSize * 2;
    }

    private int nStates(int t) {
        if (t == this.root) {
            return this.trans[this.children(t).get(0)].nCurrentStates();
        }
        return this.trans[t].nNextStates();
    }

    private Bounds createNewBounds(int N2, int nSampled) {
        return this.useSecondOrderBounds ? new SecondOrderBounds(N2, nSampled) : new FirstOrderBounds(N2, nSampled);
    }

    private double[] uniform(Random rand) {
        double[] result = new double[this.length];
        for (int i = 0; i < this.length; ++i) {
            result[i] = rand.nextDouble();
        }
        return result;
    }

    private class FirstOrderBounds
    extends Bounds {
        protected FirstOrderBounds(int N2, int nSampled) {
            super(N2, nSampled);
        }

        @Override
        public void compute() {
            this.U = new double[SandwitchSampler.this.length][this.N];
            this.L = new double[SandwitchSampler.this.length][this.N];
            this.init();
            for (int t = SandwitchSampler.this.length - 2; t >= this.nSampled - 1; --t) {
                for (int s = 0; s < this.nTruncStates(t); ++s) {
                    this.compute(t, s);
                }
            }
        }

        private void init() {
            int lastChainIndex = SandwitchSampler.this.length - 1;
            for (int s = 0; s < this.nTruncStates(lastChainIndex); ++s) {
                double cEmi = 1.0;
                this.U[lastChainIndex][s] = 1.0;
                this.L[lastChainIndex][s] = 1.0;
            }
        }

        private void compute(int currentNode, int s) {
            this.L[currentNode][s] = 1.0;
            this.U[currentNode][s] = 1.0;
            Iterator iterator = SandwitchSampler.this.children(currentNode).iterator();
            while (iterator.hasNext()) {
                int childNode = (Integer)iterator.next();
                if (SandwitchSampler.this.nStates(childNode) == 1) {
                    this.L[currentNode][s] = SandwitchSampler.this.trans[childNode].pr(s, 0);
                    this.U[currentNode][s] = SandwitchSampler.this.trans[childNode].pr(s, 0);
                    continue;
                }
                double[] dArray = this.L[currentNode];
                int n = s;
                dArray[n] = dArray[n] * this.transitionTruncDotProduct(childNode, s, this.L[childNode]);
                double[] dArray2 = this.U[currentNode];
                int n2 = s;
                dArray2[n2] = dArray2[n2] * (this.transitionTruncDotProduct(childNode, s, this.U[childNode]) + this.escapePr(childNode, s));
            }
            assert (this.L[currentNode][s] <= this.U[currentNode][s]);
        }

        public String toString() {
            return "L table:\n" + this.boundsToString(this.L) + "U table:\n" + this.boundsToString(this.U);
        }

        @Override
        public double postLowerBound(int currentNode, int currentState, int childNode, int childState) {
            if (SandwitchSampler.this.nStates(childNode) == 1) {
                return 1.0;
            }
            return SandwitchSampler.this.trans[childNode].pr(currentState, childState) * this.L[childNode][childState] / (this.escapePr(childNode, currentState) + this.transitionTruncDotProduct(childNode, currentState, this.U[childNode]));
        }

        @Override
        public double postUpperBound(int currentNode, int currentState, int childNode, int childState) {
            if (SandwitchSampler.this.nStates(childNode) == 1) {
                return 1.0;
            }
            return SandwitchSampler.this.trans[childNode].pr(currentState, childState) * this.U[childNode][childState] / this.transitionTruncDotProduct(childNode, currentState, this.L[childNode]);
        }
    }

    private class SecondOrderBounds
    extends Bounds {
        private double[] Min;
        private double[] Max;

        protected SecondOrderBounds(int N2, int nSampled) {
            super(N2, nSampled);
        }

        @Override
        public void compute() {
            this.L = new double[SandwitchSampler.this.length][];
            this.U = new double[SandwitchSampler.this.length][];
            this.Min = new double[SandwitchSampler.this.length];
            this.Max = new double[SandwitchSampler.this.length];
            this.init();
            for (int t = SandwitchSampler.this.length - 2; t >= this.nSampled - 1; --t) {
                this.computeLU(t, this.L, this.Min);
                this.computeLU(t, this.U, this.Max);
                this.Min[t] = this.computeExtr(t, this.L, false);
                this.Max[t] = this.computeExtr(t, this.U, true);
            }
        }

        private void init() {
            for (int t = 0; t < SandwitchSampler.this.length; ++t) {
                this.L[t] = new double[SandwitchSampler.this.nStates(t)];
                this.U[t] = new double[SandwitchSampler.this.nStates(t)];
            }
            for (int s = 0; s < SandwitchSampler.this.nStates(SandwitchSampler.this.length - 1); ++s) {
                this.L[((SandwitchSampler)SandwitchSampler.this).length - 1][s] = 1.0;
                this.U[((SandwitchSampler)SandwitchSampler.this).length - 1][s] = 1.0;
            }
            this.Min[((SandwitchSampler)SandwitchSampler.this).length - 1] = this.computeExtr(SandwitchSampler.this.length - 1, this.L, false);
            this.Max[((SandwitchSampler)SandwitchSampler.this).length - 1] = this.computeExtr(SandwitchSampler.this.length - 1, this.U, true);
        }

        private void computeLU(int currentNode, double[][] LU, double[] Extr) {
            for (int s = 0; s < SandwitchSampler.this.nStates(currentNode); ++s) {
                double product = 1.0;
                Iterator iterator = SandwitchSampler.this.children(currentNode).iterator();
                while (iterator.hasNext()) {
                    int childNode = (Integer)iterator.next();
                    if (SandwitchSampler.this.nStates(childNode) == 1) {
                        product = SandwitchSampler.this.trans[childNode].pr(s, 0);
                        continue;
                    }
                    double sum = this.transitionTruncDotProduct(childNode, s, LU[childNode]);
                    if (s < this.nTruncStates(currentNode)) {
                        for (int y = this.nTruncStates(currentNode); y < SandwitchSampler.this.nStates(childNode); ++y) {
                            sum += SandwitchSampler.this.trans[childNode].pr(s, y) * LU[childNode][y];
                        }
                    } else {
                        sum += !Double.isInfinite(Extr[childNode]) ? this.escapePr(childNode, s) * Extr[childNode] : 0.0;
                    }
                    product *= sum;
                }
                LU[currentNode][s] = product;
                if (LU[currentNode][s] >= 0.0) continue;
                throw new RuntimeException("Bad LU: " + LU[currentNode][s]);
            }
        }

        private double computeExtr(int t, double[][] LU, boolean useMax) {
            double extr = useMax ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
            for (int s = this.nTruncStates(t); s < SandwitchSampler.this.nStates(t); ++s) {
                if (!(useMax ? LU[t][s] > extr : LU[t][s] < extr)) continue;
                extr = LU[t][s];
            }
            return extr;
        }

        public String toString() {
            return "L table:\n" + this.boundsToString(this.L) + "U table:\n" + this.boundsToString(this.U);
        }

        @Override
        public double postUpperBound(int currentNode, int currentState, int childNode, int childState) {
            return this.postUpperBound(currentNode, currentState, childNode, childState, this.U, this.L);
        }

        @Override
        public double postLowerBound(int currentNode, int currentState, int childNode, int childState) {
            return this.postUpperBound(currentNode, currentState, childNode, childState, this.L, this.U);
        }

        public double postUpperBound(int currentNode, int currentState, int childNode, int childState, double[][] U, double[][] L) {
            if (SandwitchSampler.this.nStates(childNode) == 1) {
                return 1.0;
            }
            return SandwitchSampler.this.trans[childNode].pr(currentState, childState) * U[childNode][childState] / this.transitionFullDotProduct(childNode, currentState, L[childNode]);
        }

        protected double transitionFullDotProduct(int t, int s1, double[] vector) {
            double sum = 0.0;
            for (int s2 = 0; s2 < SandwitchSampler.this.nStates(t); ++s2) {
                sum += SandwitchSampler.this.trans[t].pr(s1, s2) * vector[s2];
            }
            return sum;
        }
    }

    private abstract class Bounds {
        protected final int N;
        protected final int nSampled;
        protected double[][] L;
        protected double[][] U;

        protected Bounds(int N2, int nSampled) {
            this.N = N2;
            this.nSampled = nSampled;
        }

        public abstract void compute();

        public abstract double postLowerBound(int var1, int var2, int var3, int var4);

        public abstract double postUpperBound(int var1, int var2, int var3, int var4);

        protected double escapePr(int t, int s1) {
            double escapePr = 1.0;
            for (int s2 = 0; s2 < this.nTruncStates(t); ++s2) {
                escapePr -= SandwitchSampler.this.trans[t].pr(s1, s2);
            }
            return escapePr;
        }

        protected double transitionTruncDotProduct(int t, int s1, double[] vector) {
            double sum = 0.0;
            for (int s2 = 0; s2 < this.nTruncStates(t); ++s2) {
                sum += SandwitchSampler.this.trans[t].pr(s1, s2) * vector[s2];
            }
            return sum;
        }

        protected String boundsToString(final double[][] bounds) {
            Table result = new Table(new Table.Populator(){

                @Override
                public void populate() {
                    int t;
                    this.set(0, 0, "s V / t ->");
                    for (t = 0; t < SandwitchSampler.this.length; ++t) {
                        this.set(0, t + 1, "" + t);
                    }
                    for (t = 0; t < SandwitchSampler.this.length; ++t) {
                        for (int s = 0; s < bounds[t].length; ++s) {
                            this.set(s + 1, 0, "" + s);
                            this.set(s + 1, t + 1, bounds[t][s]);
                        }
                    }
                }
            });
            result.setBorder(true);
            return result.toString();
        }

        public int nTruncStates(int t) {
            return Math.min(this.N, SandwitchSampler.this.nStates(t));
        }
    }

    public class Sample {
        private final int[] samples;
        public static final int UNSAMPLED = -1;
        private int nSampled;
        private Map<Integer, Boolean> fringe;

        private boolean isComplete() {
            return this.nSampled == SandwitchSampler.this.length;
        }

        public Sample() {
            this.samples = new int[SandwitchSampler.this.length];
            this.nSampled = 0;
            this.fringe = new HashMap<Integer, Boolean>();
            for (int t = 0; t < SandwitchSampler.this.length; ++t) {
                this.samples[t] = -1;
            }
            this.fringe.put(SandwitchSampler.this.root, true);
            this.expand(SandwitchSampler.this.root, 0);
        }

        private void expand(int node, int s) {
            this.samples[node] = s;
            ++this.nSampled;
            if (!this.fringe.containsKey(node)) {
                throw new RuntimeException();
            }
            this.fringe.remove(node);
            Iterator iterator = SandwitchSampler.this.children(node).iterator();
            while (iterator.hasNext()) {
                int childNode = (Integer)iterator.next();
                this.fringe.put(childNode, false);
            }
        }

        public boolean hasUnattemptedFringeItem() {
            for (boolean wasAttempted : this.fringe.values()) {
                if (wasAttempted) continue;
                return true;
            }
            return false;
        }

        public int attemptNextFringeItem() {
            int result = -1;
            for (int node : this.fringe.keySet()) {
                boolean wasAttempted = this.fringe.get(node);
                if (wasAttempted) continue;
                result = node;
                break;
            }
            if (result == -1) {
                throw new RuntimeException();
            }
            this.fringe.put(result, true);
            return result;
        }

        public void signalNewBoundComputed() {
            for (int node : this.fringe.keySet()) {
                this.fringe.put(node, false);
            }
        }

        public int[] getLegacyHMMHiddenStates() {
            if (SandwitchSampler.this.length % 2 != 0) {
                throw new RuntimeException();
            }
            int[] result = new int[SandwitchSampler.this.length];
            int currentState = SandwitchSampler.this.root;
            for (int i = 0; i < SandwitchSampler.this.length / 2; ++i) {
                result[i] = this.samples[currentState];
                if (i == SandwitchSampler.this.length / 2 - 1) continue;
                currentState = (Integer)SandwitchSampler.this.children(currentState).get(1);
            }
            return result;
        }

        public String toString() {
            return this.toArbre(SandwitchSampler.this.root).deepToString().replaceAll("-1", "??");
        }

        public Arbre<Integer> toArbre(int node) {
            int contents = this.samples[node];
            ArrayList children = new ArrayList();
            Iterator iterator = SandwitchSampler.this.children(node).iterator();
            while (iterator.hasNext()) {
                int childNode = (Integer)iterator.next();
                children.add(this.toArbre(childNode));
            }
            return new Arbre<Integer>(Integer.valueOf(contents), children);
        }
    }

    private class SampleExpander {
        private final Sample sample;
        private final Bounds bounds;
        private final int t;
        private final int parentNode;
        private double uniformNumber;
        private final double[] lowerBoundPartialSums;
        private final double[] upperBoundPartialSums;
        private final int previousState;

        private SampleExpander(Sample sample, Bounds bounds, int t, double uni) {
            this.uniformNumber = uni;
            this.sample = sample;
            this.bounds = bounds;
            this.t = t;
            this.parentNode = SandwitchSampler.this.parent(t);
            this.lowerBoundPartialSums = new double[bounds.nTruncStates(t)];
            this.upperBoundPartialSums = new double[bounds.nTruncStates(t)];
            this.previousState = sample.samples[this.parentNode];
        }

        private void computePartialSums() {
            this.lowerBoundPartialSums[0] = this.bounds.postLowerBound(this.parentNode, this.previousState, this.t, 0);
            this.upperBoundPartialSums[0] = this.bounds.postUpperBound(this.parentNode, this.previousState, this.t, 0);
            for (int s = 1; s < this.bounds.nTruncStates(this.t); ++s) {
                this.lowerBoundPartialSums[s] = this.lowerBoundPartialSums[s - 1] + this.bounds.postLowerBound(this.parentNode, this.previousState, this.t, s);
                this.upperBoundPartialSums[s] = this.upperBoundPartialSums[s - 1] + this.bounds.postUpperBound(this.parentNode, this.previousState, this.t, s);
            }
        }

        private boolean tryExpand() {
            this.computePartialSums();
            for (int s = 0; s < this.bounds.nTruncStates(this.t); ++s) {
                if (!this.tryExpand(s)) continue;
                return true;
            }
            return false;
        }

        private boolean tryExpand(int s) {
            if (this.uniformNumber > this.lowerBoundPartialSums[s] + SandwitchSampler.this.slack) {
                return false;
            }
            for (int otherS = 0; otherS < s; ++otherS) {
                if (!(this.uniformNumber <= this.upperBoundPartialSums[otherS] - SandwitchSampler.this.slack)) continue;
                return false;
            }
            this.sample.expand(this.t, s);
            return true;
        }

        public String toString() {
            Table result = new Table(new Table.Populator(){

                @Override
                public void populate() {
                    this.set(0, 0, "norm L");
                    this.set(0, 1, "norm U");
                    this.set(0, 2, "delta");
                    this.set(0, 3, "sum norm L");
                    this.set(0, 4, "sum norm U");
                    this.set(0, 5, "delta");
                    for (int i = 0; i < SampleExpander.this.bounds.nTruncStates(SampleExpander.this.t); ++i) {
                        double postLow = SampleExpander.this.bounds.postLowerBound(SampleExpander.this.parentNode, SampleExpander.this.previousState, SampleExpander.this.t, i);
                        double postUp = SampleExpander.this.bounds.postUpperBound(SampleExpander.this.parentNode, SampleExpander.this.previousState, SampleExpander.this.t, i);
                        this.set(i + 1, 0, postLow);
                        this.set(i + 1, 1, postUp);
                        this.set(i + 1, 2, postUp - postLow);
                        this.set(i + 1, 3, SampleExpander.this.lowerBoundPartialSums[i]);
                        this.set(i + 1, 4, SampleExpander.this.upperBoundPartialSums[i]);
                        this.set(i + 1, 5, SampleExpander.this.upperBoundPartialSums[i] - SampleExpander.this.lowerBoundPartialSums[i]);
                    }
                }
            });
            result.setBorder(true);
            return "Expanding t=" + this.t + ", uniform: " + this.uniformNumber + "\n" + result.toString();
        }
    }

    private static class TransitionFromObservationKernel
    implements TransitionKernel {
        private final ObservedEmissionKernel oek;

        private TransitionFromObservationKernel(ObservedEmissionKernel oek) {
            this.oek = oek;
        }

        @Override
        public int nCurrentStates() {
            return this.oek.nStates();
        }

        @Override
        public int nNextStates() {
            return 1;
        }

        @Override
        public double pr(int currentState, int nextState) {
            if (nextState > 0) {
                throw new RuntimeException();
            }
            return this.oek.pr(currentState);
        }

        @Override
        public int sample(int currentState, Random rand) {
            throw new RuntimeException();
        }
    }
}

