/*
 * Decompiled with CFR 0.152.
 */
package nuts.math;

import fig.basic.NumUtils;
import fig.basic.UnorderedPair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.math.CliqueVisitor;
import nuts.math.EqClasses;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.Graph;
import nuts.math.Graphs;
import nuts.math.HashGraph;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.util.Tree;

public final class StructMeanField<V extends Comparable<V>> {
    private final Graph<V> baseGraph;
    private final Graph<V> subGraph;
    private final GMFct<V> naturalParams;
    private TabularGMFct<V> currentMoments = null;
    private final EqClasses<V> subGraphCCs;

    public StructMeanField(GMFct<V> naturalParameters, Graph<V> subGraph) {
        if (!naturalParameters.graph().vertexSet().equals(subGraph.vertexSet())) {
            throw new RuntimeException();
        }
        this.baseGraph = naturalParameters.graph();
        this.subGraph = subGraph;
        this.naturalParams = naturalParameters;
        this.subGraphCCs = new EqClasses(Graphs.connectedComponents(subGraph));
        this.init(new Random(1L));
    }

    public StructMeanField(GMFct<V> naturalParameters, Graph<V> subGraph, GMFct<V> init) {
        this(naturalParameters, subGraph);
        this.init(init);
    }

    public Graph<V> graph() {
        return this.baseGraph;
    }

    public double logZ() {
        return GMFctUtils.dotProd(this.moments(), this.naturalParams) + GMFctUtils.entropy(this.currentMoments);
    }

    public double vertexPosterior(V n, int s) {
        return this.currentMoments.get(n, s);
    }

    public double edgePosterior(V n1, V n2, int s1, int s2) {
        if (!this.subGraph.hasEdge(n1, n2)) {
            if (this.subGraphCCs.areRelated(n1, n2)) {
                return new GradientCalculator(this, this.currentMoments, this.subGraph, n1, n2, s1, s2).value();
            }
            return this.currentMoments.get(n1, s1) * this.currentMoments.get(n2, s2);
        }
        return this.currentMoments.get(n1, n2, s1, s2);
    }

    public GMFct<V> moments() {
        return new TabularGMFct(new GMFct<V>(){

            @Override
            public double get(V n1, V n2, int s1, int s2) {
                return StructMeanField.this.edgePosterior(n1, n2, s1, s2);
            }

            @Override
            public double get(V n, int s) {
                return StructMeanField.this.vertexPosterior(n, s);
            }

            @Override
            public Graph<V> graph() {
                return StructMeanField.this.naturalParams.graph();
            }

            @Override
            public int nStates(V node) {
                return StructMeanField.this.naturalParams.nStates(node);
            }
        });
    }

    public void iterate(V connCCRep) {
        this._iterate(this.subGraphCCs.eqClass(connCCRep));
    }

    public void iterate() {
        this._iterate(this.graph().vertexSet());
    }

    public void iterateCCs() {
        for (Comparable node : this.subGraphCCs.representatives()) {
            this.iterate(node);
        }
    }

    public void run(int iters) {
        System.out.println("Initial logZ=" + this.logZ());
        ArrayList<V> ccReps = new ArrayList<V>(this.subGraphCCs.representatives());
        for (int i = 0; i < iters; ++i) {
            System.out.println("Iteration " + i);
            this.iterate((Comparable)ccReps.get(i % ccReps.size()));
            System.out.println("LogZ=" + this.logZ());
        }
    }

    private void _iterate(Set<V> vars) {
        for (Comparable var : vars) {
            if (vars.containsAll(this.subGraphCCs.eqClass(var))) continue;
            throw new RuntimeException();
        }
        TabularGMFct<V> jacobTimeParam = this.jacobTimeParam(this.currentMoments, vars);
        ModifiedPotentials mp = new ModifiedPotentials(jacobTimeParam);
        List ccDecomp = GMFctUtils.split(mp);
        final TabularGMFct<V> result = new TabularGMFct<V>(this.currentMoments.graph(), GMFctUtils.domain(this.currentMoments));
        for (GMFct ccPot : ccDecomp) {
            final boolean shouldUpdate = vars.containsAll(ccPot.graph().vertexSet());
            final TabularGMFct cMoments = shouldUpdate ? new TreeSumProd(ccPot).moments() : null;
            GMFctUtils.visit(ccPot, new CliqueVisitor<V>(){

                @Override
                public void visitEdge(V n1, V n2, int s1, int s2) {
                    result.set(n1, n2, s1, s2, shouldUpdate ? cMoments.get(n1, n2, s1, s2) : StructMeanField.this.currentMoments.get(n1, n2, s1, s2));
                }

                @Override
                public void visitVertex(V n, int s) {
                    result.set(n, s, shouldUpdate ? cMoments.get(n, s) : StructMeanField.this.currentMoments.get(n, s));
                }
            });
        }
        this.currentMoments = result;
    }

    private TabularGMFct<V> jacobTimeParam(GMFct<V> prevIterPost, Set<V> cc) {
        TabularGMFct<Comparable> result = GMFctUtils.zeroes(this.naturalParams);
        for (UnorderedPair<V, V> g : Graphs.edgeSet(this.baseGraph)) {
            if (this.subGraph.hasEdge(g.getFirst(), g.getSecond()) || !cc.contains(g.getFirst()) && !cc.contains(g.getSecond())) continue;
            for (int s1 = 0; s1 < this.naturalParams.nStates(g.getFirst()); ++s1) {
                for (int s2 = 0; s2 < this.naturalParams.nStates(g.getSecond()); ++s2) {
                    double cParam = this.naturalParams.get(g.getFirst(), g.getSecond(), s1, s2);
                    if (this.subGraphCCs.areRelated(g.getFirst(), g.getSecond())) {
                        GradientCalculator gc = new GradientCalculator(this, prevIterPost, this.subGraph, (Comparable)g.getFirst(), (Comparable)g.getSecond(), s1, s2);
                        for (UnorderedPair<V, V> f : Graphs.edgeSet(this.subGraph)) {
                            if (!this.subGraphCCs.areRelated(f.getFirst(), g.getFirst()) || !this.subGraphCCs.areRelated(f.getSecond(), g.getSecond())) continue;
                            for (int t1 = 0; t1 < this.naturalParams.nStates(f.getFirst()); ++t1) {
                                for (int t2 = 0; t2 < this.naturalParams.nStates(f.getSecond()); ++t2) {
                                    result.increment((Comparable)f.getFirst(), (Comparable)f.getSecond(), t1, t2, cParam * gc.gradient((Comparable)f.getFirst(), (Comparable)f.getSecond(), t1, t2));
                                }
                            }
                        }
                        continue;
                    }
                    if (cc.contains(g.getFirst())) {
                        result.increment((Comparable)g.getFirst(), s1, cParam * prevIterPost.get(g.getSecond(), s2));
                    }
                    if (!cc.contains(g.getSecond())) continue;
                    result.increment((Comparable)g.getSecond(), s2, cParam * prevIterPost.get(g.getFirst(), s1));
                }
            }
        }
        return result;
    }

    public void init(final Random rand) {
        this.currentMoments = new TabularGMFct<V>(this.subGraph, GMFctUtils.domain(this.naturalParams));
        GMFctUtils.visit(this.currentMoments, new CliqueVisitor<V>(){

            @Override
            public void visitEdge(V n1, V n2, int s1, int s2) {
            }

            @Override
            public void visitVertex(V n, int s) {
                StructMeanField.this.currentMoments.set(n, s, rand.nextDouble());
            }
        });
        for (Comparable vertex : this.currentMoments.graph().vertexSet()) {
            NumUtils.normalize(this.currentMoments.get(vertex));
        }
        GMFctUtils.visit(this.currentMoments, new CliqueVisitor<V>(){

            @Override
            public void visitEdge(V n1, V n2, int s1, int s2) {
                StructMeanField.this.currentMoments.set(n1, n2, s1, s2, StructMeanField.this.currentMoments.get(n1, s1) * StructMeanField.this.currentMoments.get(n2, s2));
            }

            @Override
            public void visitVertex(V n, int s) {
            }
        });
    }

    public void init(final GMFct<V> other) {
        final TabularGMFct<V> copy = new TabularGMFct<V>(this.subGraph, GMFctUtils.domain(this.naturalParams));
        GMFctUtils.visit(copy, new CliqueVisitor<V>(){

            @Override
            public void visitEdge(V n1, V n2, int s1, int s2) {
                copy.set(n1, n2, s1, s2, other.get(n1, n2, s1, s2));
            }

            @Override
            public void visitVertex(V n, int s) {
            }
        });
        this.currentMoments = GMFctUtils.computeMarginals(copy);
    }

    public static <V> Tree<V> trimEnds(Tree<V> chain) {
        Tree<V> result;
        if (chain.getChildren().size() != 1 || chain.getChildren().get(0).getChildren().size() != 1) {
            throw new RuntimeException();
        }
        Tree<V> current = result = chain.getChildren().get(0).deepCopy();
        while (current.getChildren().get(0).getChildren().size() > 0) {
            if (current.getChildren().get(0).getChildren().size() != 1) {
                throw new RuntimeException();
            }
            current = current.getChildren().get(0);
        }
        current.getChildren().remove(0);
        return result;
    }

    public static class Tester {
        public static void main(String[] args) {
            Graphs.Grid g = Graphs.boundedGrid(2, 5);
            Graph<Integer> discon = Graphs.discon(25);
            Graph<Integer> snake = Graphs.snake(5);
            Graph<Integer> comb = Graphs.combDecomposition(5, 2);
            Random rand = new Random(2L);
            double[][] _pot = new double[][]{{0.0, 1.0}, {1.0, 0.0}};
            GMFctUtils.SimpleGraphFct<Integer> logPot = new GMFctUtils.SimpleGraphFct<Integer>(_pot, g);
            System.out.println("Discon");
            StructMeanField<Integer> smf = new StructMeanField<Integer>(logPot, discon);
            smf.init(rand);
            smf.run(20);
            System.out.println("Comb");
            smf = new StructMeanField<Integer>(logPot, comb, smf.moments());
            smf.run(20);
            System.out.println("Snake");
            smf = new StructMeanField<Integer>(logPot, snake, smf.moments());
            smf.run(20);
        }
    }

    public final class GradientCalculator {
        private final V rootQueryNode;
        private final V tailQueryNode;
        private final int rootQueryState;
        private final int tailQueryState;
        private final Map<V, V> parentPtrs;
        private final Graph<V> subGraph;
        private final Graph<V> subsubGraph;
        private final GMFct<V> currentPosteriors;
        private final TreeSumProd<V> tsp;
        private final GMFct<V> auxPosteriors;
        final /* synthetic */ StructMeanField this$0;

        /*
         * WARNING - Possible parameter corruption
         */
        public GradientCalculator(GMFct<V> currentPosteriors, Graph<V> subGraph, V rootQueryNode, V tailQueryNode, int rootQueryState, int tailQueryState) {
            this.this$0 = (StructMeanField)this$0;
            this.currentPosteriors = currentPosteriors;
            this.rootQueryNode = rootQueryNode;
            this.rootQueryState = rootQueryState;
            this.subGraph = subGraph;
            this.tailQueryNode = tailQueryNode;
            this.tailQueryState = tailQueryState;
            Tree subSubTree = Graphs.pathInTree(subGraph, rootQueryNode, tailQueryNode);
            Tree trimmedSubSubTree = StructMeanField.trimEnds(subSubTree);
            this.subsubGraph = new HashGraph(trimmedSubSubTree);
            this.parentPtrs = Graphs.parentPtrs(subSubTree);
            AuxiliaryMRF graphFct = new AuxiliaryMRF();
            this.tsp = new TreeSumProd(graphFct);
            this.auxPosteriors = this.tsp.moments();
        }

        public double value() {
            return Math.exp(this.tsp.logZ());
        }

        public double gradient(V n1, V n2, int s1, int s2) {
            Comparable pOfN1 = (Comparable)this.parentPtrs.get(n1);
            Comparable pOfN2 = (Comparable)this.parentPtrs.get(n2);
            if (n1.equals(pOfN2)) {
                return this.gradient(n2, s2, s1);
            }
            if (n2.equals(pOfN1)) {
                return this.gradient(n1, s1, s2);
            }
            return 0.0;
        }

        public double gradient(V node, int stateAtNode, int stateAtParent) {
            if (this.parentPtrs.get(node) == null) {
                throw new RuntimeException();
            }
            Comparable parent = (Comparable)this.parentPtrs.get(node);
            if (this.rootQueryNode.equals(parent)) {
                return this._gradientAtRoot(node, parent, stateAtNode, stateAtParent);
            }
            if (this.tailQueryNode.equals(node)) {
                return this._gradientAtTail(node, parent, stateAtNode, stateAtParent);
            }
            return this._gradient(node, parent, stateAtNode, stateAtParent);
        }

        private double _gradient(V node, V parent, int stateAtNode, int stateAtParent) {
            return this.value() * (this.auxPosteriors.get(node, parent, stateAtNode, stateAtParent) / this.currentPosteriors.get(node, parent, stateAtNode, stateAtParent) - this.auxPosteriors.get(parent, stateAtParent) / this.currentPosteriors.get(parent, stateAtParent));
        }

        private double _gradientAtTail(V node, V parent, int stateAtNode, int stateAtParent) {
            return this.value() * this.auxPosteriors.get(parent, stateAtParent) * ((stateAtNode == this.tailQueryState ? 1.0 / this.currentPosteriors.get(node, parent, stateAtNode, stateAtParent) : 0.0) - 1.0 / this.currentPosteriors.get(parent, stateAtParent));
        }

        private double _gradientAtRoot(V node, V parent, int stateAtNode, int stateAtParent) {
            if (stateAtParent != this.rootQueryState) {
                return 0.0;
            }
            return this.value() * this.auxPosteriors.get(node, stateAtNode) / this.currentPosteriors.get(node, parent, stateAtNode, stateAtParent);
        }

        public final class AuxiliaryMRF
        implements GMFct<V> {
            @Override
            public int nStates(V node) {
                return GradientCalculator.this.currentPosteriors.nStates(node);
            }

            @Override
            public Graph<V> graph() {
                return GradientCalculator.this.subsubGraph;
            }

            @Override
            public double get(V n1, V n2, int s1, int s2) {
                return GradientCalculator.this.currentPosteriors.get(n1, n2, s1, s2);
            }

            @Override
            public double get(V n, int s) {
                double result = 1.0 / GradientCalculator.this.currentPosteriors.get(n, s);
                if (GradientCalculator.this.subGraph.nbrs(n).contains(GradientCalculator.this.tailQueryNode)) {
                    result *= GradientCalculator.this.currentPosteriors.get(GradientCalculator.this.tailQueryNode, n, GradientCalculator.this.tailQueryState, s);
                }
                if (GradientCalculator.this.subGraph.nbrs(n).contains(GradientCalculator.this.rootQueryNode)) {
                    result *= GradientCalculator.this.currentPosteriors.get(GradientCalculator.this.rootQueryNode, n, GradientCalculator.this.rootQueryState, s);
                }
                return result;
            }
        }
    }

    private class ModifiedPotentials
    implements GMFct<V> {
        private final GMFct<V> jacobTimeParam;

        public ModifiedPotentials(GMFct<V> jacobTimeParam) {
            this.jacobTimeParam = jacobTimeParam;
        }

        @Override
        public double get(V node, V parent, int stateAtNode, int stateAtParentOfNode) {
            return Math.exp(StructMeanField.this.naturalParams.get(node, parent, stateAtNode, stateAtParentOfNode) + this.jacobTimeParam.get(node, parent, stateAtNode, stateAtParentOfNode));
        }

        @Override
        public int nStates(V node) {
            return StructMeanField.this.naturalParams.nStates(node);
        }

        @Override
        public Graph<V> graph() {
            return StructMeanField.this.subGraph;
        }

        @Override
        public double get(V n, int s) {
            return Math.exp(StructMeanField.this.naturalParams.get(n, s) + this.jacobTimeParam.get(n, s));
        }
    }
}

