/*
 * Decompiled with CFR 0.152.
 */
package nuts.math;

import fig.basic.Pair;
import hmm.Param;
import hmm.ParamUtils;
import hmm.RescaledBaumWelch;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.math.CliqueVisitor;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.Graph;
import nuts.math.Graphs;
import nuts.math.HashGraph;
import nuts.math.TabularGMFct;
import nuts.util.CollUtils;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import nuts.util.Tree;

public final class TreeSumProd<V extends Comparable<V>> {
    private final GMFct<V> potentials;
    private final int root;
    private final int nVars;
    private final int[] parentPtrs;
    private final int[][] childrenPtrs;
    private final double[][] inside;
    private final double[][] outside;
    private final double[] insideScale;
    private final double[] outsideScale;
    private final Indexer<V> indexer;
    private final double[] scalingFactor;
    private double _logZ;

    public static <V extends Comparable<V>> TabularGMFct<V> computeMoments(GMFct<V> potentials) {
        List<GMFct<V>> ccDecomp = GMFctUtils.split(potentials);
        final TabularGMFct<V> result = new TabularGMFct<V>(potentials.graph(), GMFctUtils.domain(potentials));
        for (GMFct<V> ccPot : ccDecomp) {
            final TabularGMFct<V> cMoments = new TreeSumProd<V>(ccPot).moments();
            GMFctUtils.visit(cMoments, new CliqueVisitor<V>(){

                @Override
                public void visitEdge(V n1, V n2, int s1, int s2) {
                    result.set(n1, n2, s1, s2, cMoments.get(n1, n2, s1, s2));
                }

                @Override
                public void visitVertex(V n, int s) {
                    result.set(n, s, cMoments.get(n, s));
                }
            });
        }
        return result;
    }

    private int nStates(int node) {
        if (node == this.root) {
            return 1;
        }
        return this.potentials.nStates(this.indexer.i2o(node));
    }

    public int nVars() {
        return this.nVars;
    }

    public TreeSumProd(GMFct<V> potentials) {
        GMFctUtils.checkIsPotential(potentials);
        this.potentials = potentials;
        this.indexer = new Indexer();
        Tree tree = new Tree();
        tree.getChildren().add(Graphs.toTree(potentials.graph()));
        this.nVars = tree.getPostOrderTraversal().size();
        this.parentPtrs = new int[this.nVars()];
        this.childrenPtrs = new int[this.nVars()][];
        this.root = 0;
        this.inside = new double[this.nVars()][];
        this.outside = new double[this.nVars()][];
        this.insideScale = new double[this.nVars()];
        this.outsideScale = new double[this.nVars()];
        this.scalingFactor = new double[this.nVars()];
        this.init(tree);
        this.ensureDPEntries();
    }

    private void init(Tree<V> tree) {
        this.index(tree);
        this.parentPtrs(tree);
        this.childrenPtrs(tree);
    }

    private double edgePosterior(V _i, int x, int z) {
        int i = this.indexer.o2i(_i);
        return this._edgePosterior(i, x, z);
    }

    public TabularGMFct<V> moments() {
        return new TabularGMFct(new GMFct<V>(){

            @Override
            public double get(V n1, V n2, int s1, int s2) {
                int n1Code = TreeSumProd.this.indexer.o2i(n1);
                int n2Code = TreeSumProd.this.indexer.o2i(n2);
                if (n1Code == TreeSumProd.this.root || TreeSumProd.this.parentPtrs[n2Code] == n1Code) {
                    return TreeSumProd.this._edgePosterior(n2Code, s2, s1);
                }
                return TreeSumProd.this._edgePosterior(n1Code, s1, s2);
            }

            @Override
            public double get(V n, int s) {
                return TreeSumProd.this.vertexPosterior(n, s);
            }

            @Override
            public Graph<V> graph() {
                return TreeSumProd.this.potentials.graph();
            }

            @Override
            public int nStates(V node) {
                return TreeSumProd.this.potentials.nStates(node);
            }
        });
    }

    private double _edgePosterior(int i, int x, int z) {
        if (i == this.root) {
            throw new RuntimeException();
        }
        return this.inside(i, x) * this.outside(i, z) * this.pot(i, x, z) * this.scalingFactor(i) * this.insideScale[i];
    }

    private double _vertexPosterior(int i, int x) {
        if (i == this.root) {
            return this.inside(this.root, x);
        }
        double result = 0.0;
        for (int z = 0; z < this.nStates(this.parentPtrs[i]); ++z) {
            result += this._edgePosterior(i, x, z);
        }
        return result;
    }

    public double vertexPosterior(V _i, int x) {
        int i = this.indexer.o2i(_i);
        return this._vertexPosterior(i, x);
    }

    private void index(Tree<V> tree) {
        for (Tree<V> node : tree.getPreOrderTraversal()) {
            this.indexer.addToIndex((Comparable)node.getLabel());
        }
    }

    private double pot(int node, int stateAtNode, int stateAtParentOfNode) {
        if (node == this.root) {
            throw new RuntimeException();
        }
        double result = this.potentials.get(this.indexer.i2o(node), stateAtNode);
        int parent = this.parentPtrs[node];
        if (parent != this.root) {
            result *= this.potentials.get(this.indexer.i2o(node), this.indexer.i2o(parent), stateAtNode, stateAtParentOfNode);
        }
        return result;
    }

    private void childrenPtrs(Tree<V> tree) {
        int[] current = new int[tree.getChildren().size()];
        this.childrenPtrs[this.indexer.o2i(tree.getLabel())] = current;
        for (int i = 0; i < tree.getChildren().size(); ++i) {
            Tree<V> child = tree.getChildren().get(i);
            current[i] = this.indexer.o2i(child.getLabel());
            this.childrenPtrs(child);
        }
    }

    private void parentPtrs(Tree<V> root) {
        this.parentPtrs[this.indexer.o2i(root.getLabel())] = -1000;
        this._parentPtrs(root, null);
    }

    private void _parentPtrs(Tree<V> tree, V parent) {
        Comparable current = (Comparable)tree.getLabel();
        if (parent != null) {
            this.parentPtrs[this.indexer.o2i(current)] = this.indexer.o2i(parent);
        }
        for (Tree<V> child : tree.getChildren()) {
            this._parentPtrs(child, current);
        }
    }

    private double scalingFactor(int p) {
        return this.scalingFactor[p];
    }

    private void ensureScalingFactors() {
        this.scalingFactor[this.root] = 1.0 / this.insideScale[this.root];
        for (int child : this.childrenPtrs[this.root]) {
            this.ensureScalingFactors(child);
        }
    }

    private void ensureScalingFactors(int current) {
        this.scalingFactor[current] = this.scalingFactor[this.parentPtrs[current]] * this.outsideScale[current] / this.insideScale[current];
        for (int child : this.childrenPtrs[current]) {
            this.ensureScalingFactors(child);
        }
    }

    public double logZ() {
        return this._logZ;
    }

    private double computeLogZ() {
        double result = 0.0;
        for (int v = 0; v < this.nVars(); ++v) {
            result += Math.log(this.insideScale[v]);
        }
        return result;
    }

    private double inside(int v, int s) {
        return this.inside[v][s];
    }

    private double outside(int v, int s) {
        if (s >= this.outside[v].length) {
            throw new RuntimeException();
        }
        return this.outside[v][s];
    }

    public double logMessage(int state, V subtreeRoot, V nbhrOutsideTree) {
        int nbhrOutsideTreeIdx;
        int subtreeRootIdx = this.indexer.o2i(subtreeRoot);
        if (this.parentPtrs[subtreeRootIdx] == (nbhrOutsideTreeIdx = this.indexer.o2i(nbhrOutsideTree))) {
            return this._insideSubTreeLogZ(subtreeRoot, subtreeRootIdx, state);
        }
        return this.logZ() + Math.log(this._edgePosterior(nbhrOutsideTreeIdx, 0, state)) - this._insideSubTreeLogZ(nbhrOutsideTree, nbhrOutsideTreeIdx, 0) - Math.log(this.potentials.get(subtreeRoot, nbhrOutsideTree, state, 0));
    }

    public double logZ(V node, int state) {
        return this.logZ() + Math.log(this.vertexPosterior(node, state));
    }

    private double _insideSubTreeLogZ(V subtreeRoot, int subtreeRootIdx, int state) {
        return Math.log(this.inside(subtreeRootIdx, state)) + this._insideSubTreeLogZ(subtreeRootIdx) + Math.log(this.potentials.get(subtreeRoot, state));
    }

    private double _insideSubTreeLogZ(int subtreeRootIdx) {
        double result = Math.log(this.insideScale[subtreeRootIdx]);
        for (int child : this.childrenPtrs[subtreeRootIdx]) {
            result += this._insideSubTreeLogZ(child);
        }
        return result;
    }

    private void ensureDPEntries() {
        this.ensureInsideEntries(this.root);
        this.ensureOutsideEntries(this.root);
        this.ensureScalingFactors();
        this._logZ = this.computeLogZ();
    }

    private void ensureInsideEntries(int node) {
        for (int child : this.childrenPtrs[node]) {
            this.ensureInsideEntries(child);
        }
        this.computeDPEntry(node, true);
    }

    private void ensureOutsideEntries(int node) {
        this.computeDPEntry(node, false);
        for (int child : this.childrenPtrs[node]) {
            this.ensureOutsideEntries(child);
        }
    }

    private void computeDPEntry(int v, boolean isInside) {
        if (!isInside && v == this.root) {
            return;
        }
        int nStates = isInside ? this.nStates(v) : this.nStates(this.parentPtrs[v]);
        double[] result = new double[nStates];
        for (int s = 0; s < nStates; ++s) {
            result[s] = isInside ? this.computeInside(v, s) : this.computeOutside(v, s);
        }
        double norm = TreeSumProd.normalize(result);
        if (isInside) {
            this.inside[v] = result;
            this.insideScale[v] = norm;
        } else {
            this.outside[v] = result;
            this.outsideScale[v] = norm;
        }
    }

    public static double normalize(double[] data) {
        double sum = 0.0;
        for (double x : data) {
            sum += x;
        }
        if (sum == 0.0) {
            throw new RuntimeException();
        }
        int i = 0;
        while (i < data.length) {
            int n = i++;
            data[n] = data[n] / sum;
        }
        return sum;
    }

    private double computeInside(int i, int x) {
        double prod = 1.0;
        for (int j : this.childrenPtrs[i]) {
            double sum = 0.0;
            for (int y = 0; y < this.nStates(j); ++y) {
                sum += this.pot(j, y, x) * this.inside(j, y);
            }
            prod *= sum;
        }
        return prod;
    }

    private double computeOutside(int i, int x) {
        if (i == this.root) {
            throw new RuntimeException();
        }
        int p = this.parentPtrs[i];
        double prod = 1.0;
        if (p != this.root) {
            int g = this.parentPtrs[p];
            double sum = 0.0;
            for (int z = 0; z < this.nStates(g); ++z) {
                sum += this.pot(p, x, z) * this.outside(p, z);
            }
            prod *= sum;
        }
        for (int j : this.childrenPtrs[p]) {
            if (j == i) continue;
            double sum = 0.0;
            for (int y = 0; y < this.nStates(j); ++y) {
                sum += this.pot(j, y, x) * this.inside(j, y);
            }
            prod *= sum;
        }
        return prod;
    }

    public static <V extends Comparable<V>> void testLogSubTree(TreeSumProd<V> tsp) {
        double Z = Math.exp(tsp.logZ());
        GMFct<Comparable> pots = tsp.potentials;
        Graph<Comparable> g = pots.graph();
        for (Comparable vertex : g.vertexSet()) {
            for (Comparable other : g.nbrs(vertex)) {
                double sum = 0.0;
                for (int sver = 0; sver < pots.nStates(vertex); ++sver) {
                    for (int soth = 0; soth < pots.nStates(other); ++soth) {
                        sum += pots.get(vertex, other, sver, soth) * Math.exp(tsp.logMessage(sver, vertex, other) + tsp.logMessage(soth, other, vertex));
                    }
                }
                MathUtils.checkClose(sum, Z);
            }
        }
        System.out.println("Tested log sub tree successfully");
    }

    public static void main(String[] args) {
        double[][] _pot = new double[][]{{1.0, 1.0}, {1.0, 1.0}};
        Graph<Integer> chain = Graphs.chainGraph(4);
        Map<Integer, Integer> sizes = CollUtils.cnstMap(chain.vertexSet(), 2);
        GMFctUtils.SimpleGraphFct<Integer> pot = new GMFctUtils.SimpleGraphFct<Integer>(_pot, chain);
        TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(pot);
        System.out.println("log z=" + tsp.logZ());
        System.out.println("subt log z(1,0)=" + tsp.logMessage(0, 2, 1));
        System.out.println("subt log z(0,1)=" + tsp.logMessage(0, 1, 2));
        TreeSumProd.testLogSubTree(tsp);
        TreeSumProd.testCorresp();
        _pot = new double[][]{{2.0, 1.0}, {1.0, 1.0}};
        chain = Graphs.chainGraph(3);
        GMFctUtils.SimpleGraphFct<Integer> pot2 = new GMFctUtils.SimpleGraphFct<Integer>(_pot, chain);
        TreeSumProd<Integer> tsp2 = new TreeSumProd<Integer>(pot2);
        TreeSumProd.test(tsp2);
        TreeSumProd.testLogSubTree(tsp2);
        System.out.println("Testing HMM...");
        Random rand = new Random(2L);
        int stateSize = 2;
        Param params = ParamUtils.randomUniParam(rand, 2, 2);
        int length = 3;
        Pair<List<Integer>, List<Integer>> pair = ParamUtils.generateStateObservations(rand, params, length);
        HmmAdaptor hmmA = new HmmAdaptor(params, pair.getSecond());
        TreeSumProd<Integer> tsp3 = new TreeSumProd<Integer>(hmmA);
        TreeSumProd.testLogSubTree(tsp3);
        RescaledBaumWelch e = new RescaledBaumWelch();
        e.compute(pair.getSecond(), params);
        System.out.println("NEW:" + Math.exp(tsp3.logZ()));
        System.out.println("OLD:" + Math.exp(e.logll()));
        TabularGMFct<Integer> newPost = tsp3.moments();
        for (int i = 0; i < length - 1; ++i) {
            for (int j = 0; j < 2; ++j) {
                for (int k = 0; k < 2; ++k) {
                    double oldV = e.twoNodesPosterior(i)[j][k];
                    double newV1 = newPost.get(i, i + 1, k, j);
                    double newV2 = newPost.get(i, i + 1, j, k);
                    if (MathUtils.close(oldV, newV2, 1.0E-6)) continue;
                    System.err.println("" + oldV + " not close to " + newV1 + " (other is " + newV2 + ")");
                }
            }
            System.err.println("---");
        }
    }

    private static void testCorresp() {
        Integer queryNode = 2;
        double[][] mu = new double[][]{{1.0, 2.0}, {2.0, 1.0}};
        double[][] direction = new double[][]{{1.0, 0.0}, {0.0, 0.0}};
        Graph<Integer> chain = Graphs.chainGraph(10);
        TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(new GMFctUtils.SimpleGraphFct<Integer>(mu, chain));
        System.out.println("---");
    }

    public static <V extends Comparable<V>> void test(TreeSumProd<V> tsp) {
        int node;
        for (node = 0; node < tsp.nVars(); ++node) {
            if (node == tsp.root) continue;
            Comparable v = (Comparable)tsp.indexer.i2o(node);
            int p = tsp.parentPtrs[node];
            double sum = 0.0;
            for (int x = 0; x < super.nStates(node); ++x) {
                for (int y = 0; y < super.nStates(p); ++y) {
                    sum += super.edgePosterior(v, x, y);
                }
            }
            if (MathUtils.close(1.0, sum)) continue;
            System.err.println("Bad norm: " + sum);
        }
        System.out.println("Normalization of post... PASSED");
        for (node = 0; node < tsp.nVars(); ++node) {
            int y;
            int x;
            if (node == tsp.root) continue;
            int p = tsp.parentPtrs[node];
            Comparable v = (Comparable)tsp.indexer.i2o(node);
            Comparable w = (Comparable)tsp.indexer.i2o(p);
            for (x = 0; x < super.nStates(node); ++x) {
                double sum = 0.0;
                for (y = 0; y < super.nStates(p); ++y) {
                    sum += super.edgePosterior(v, x, y);
                }
                if (MathUtils.close(tsp.vertexPosterior(v, x), sum)) continue;
                System.err.println("" + sum + "," + tsp.vertexPosterior(v, x));
            }
            for (x = 0; x < super.nStates(p); ++x) {
                double sum = 0.0;
                for (y = 0; y < super.nStates(node); ++y) {
                    sum += super.edgePosterior(v, y, x);
                }
                if (MathUtils.close(tsp.vertexPosterior(w, x), sum)) continue;
                System.err.println("" + sum + "," + tsp.vertexPosterior(w, x));
            }
        }
        System.out.println("Local consistency... PASSED");
    }

    public static final class HmmAdaptor
    implements GMFct<Integer> {
        private final Param hmmParam;
        private final List<Integer> observations;
        private final Integer root;

        public HmmAdaptor(Param hmmParam, List<Integer> observations) {
            this.hmmParam = hmmParam;
            this.observations = observations;
            this.root = 0;
        }

        public Tree<Integer> getChain() {
            Tree<Integer> result;
            Tree<Integer> current = result = new Tree<Integer>(this.root);
            while (current.getLabel() < this.observations.size() - 1) {
                Tree<Integer> child = new Tree<Integer>(current.getLabel() + 1);
                current.getChildren().add(child);
                current = current.getChildren().get(0);
            }
            return result;
        }

        @Override
        public int nStates(Integer v) {
            return this.hmmParam.nStates();
        }

        @Override
        public double get(Integer n1, Integer n2, int s1, int s2) {
            int stateAtParentOfNode;
            int stateAtNode;
            if (n1 < n2) {
                stateAtNode = s2;
                stateAtParentOfNode = s1;
            } else if (n1 > n2) {
                stateAtNode = s1;
                stateAtParentOfNode = s2;
            } else {
                throw new RuntimeException();
            }
            return this.hmmParam.transMtx.p(stateAtParentOfNode, stateAtNode);
        }

        @Override
        public double get(Integer n, int s) {
            double result = this.hmmParam.emiMtx.p(s, this.observations.get(n));
            if (n.equals(this.root)) {
                result *= this.hmmParam.initVec.p(s);
            }
            return result;
        }

        @Override
        public Graph<Integer> graph() {
            return new HashGraph<Integer>(this.getChain());
        }
    }
}

