/*
 * Decompiled with CFR 0.152.
 */
package facto;

import facto.Factor;
import facto.MFBP;
import fig.basic.UnorderedPair;
import infer.Exact2DIsing;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.math.CliqueVisitor;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.Graph;
import nuts.math.Graphs;
import nuts.math.HashGraph;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.util.Indexer;

public class IsingMeasureFactorization {
    public static final int N = 5;
    public static Indexer<Object> indexer;
    public static Map<Integer, Integer> map;
    public static Random rand;

    public static List<Graph<Integer>> oneEdgeAwayAcyclicSuccessors(Graph<Integer> superGraph, Graph<Integer> subgraph) {
        ArrayList<Graph<Integer>> result = new ArrayList<Graph<Integer>>();
        Set<UnorderedPair<Integer, Integer>> edgeSet = Graphs.edgeSet(subgraph);
        for (UnorderedPair<Integer, Integer> edge : Graphs.edgeSet(superGraph)) {
            if (subgraph.hasEdge(edge.getFirst(), edge.getSecond())) continue;
            HashSet edgeCopy = new HashSet(edgeSet);
            edgeCopy.add(edge);
            Graph oneEdgeAwayGraph = Graphs.union(subgraph, new HashGraph<Integer>(superGraph.vertexSet(), edgeCopy));
            if (!Graphs.isForest(oneEdgeAwayGraph)) continue;
            result.add(oneEdgeAwayGraph);
        }
        return result;
    }

    public static Graph<Integer> randomForestCover(Graph<Integer> superGraph, Random rand) {
        List<Graph<Integer>> oneEdgeAwayAcyclicSuccessors;
        Graph<Integer> result = Graphs.discon(superGraph.vertexSet().size());
        while ((oneEdgeAwayAcyclicSuccessors = IsingMeasureFactorization.oneEdgeAwayAcyclicSuccessors(superGraph, result)).size() > 0) {
            result = oneEdgeAwayAcyclicSuccessors.get(rand.nextInt(oneEdgeAwayAcyclicSuccessors.size()));
        }
        return result;
    }

    public static Graph<Integer> nextDistinctRandomForestCover(Graph<Integer> superGraph, List<Graph<Integer>> currentList, Random rand, int trials) {
        block0: for (int i = 0; i < trials; ++i) {
            Graph<Integer> candidate = IsingMeasureFactorization.randomForestCover(superGraph, rand);
            for (Graph<Integer> current : currentList) {
                if (!Graphs.equals(candidate, current)) continue;
                continue block0;
            }
            return candidate;
        }
        return null;
    }

    public static void main(String[] args) {
        double[][] nonLogPot = new double[][]{{1.0, 1.0}, {1.0, 1.0}};
        Exact2DIsing exact = new Exact2DIsing(5, nonLogPot);
        System.out.println("Exact logZ=" + exact.logZ());
        map = new HashMap<Integer, Integer>();
        Graphs.Grid grid = Graphs.boundedGrid(2, 5);
        for (Integer v : grid.vertexSet()) {
            map.put(v, 2);
        }
        TabularGMFct<Integer> blank = new TabularGMFct<Integer>(grid, map);
        indexer = GMFctUtils.createIndexer(blank);
        double[] naturalParam = IsingMeasureFactorization.naturalParam(nonLogPot, grid, indexer);
        ArrayList<Graph<Integer>> currentList = new ArrayList<Graph<Integer>>();
        ArrayList factors = new ArrayList();
        for (int iter = 0; iter < 100; ++iter) {
            System.out.println("------------------------------");
            Graph<Integer> factGraph = IsingMeasureFactorization.nextDistinctRandomForestCover(grid, currentList, rand, 10);
            currentList.add(factGraph);
            System.out.println(Graphs.print2DGrid(factGraph, 2, 5));
            Factor[] factorArray = new Factor[currentList.size()];
            for (int i = 0; i < currentList.size(); ++i) {
                factorArray[i] = new ForestFacto((Graph)currentList.get(i));
            }
            MFBP mfBP = new MFBP(naturalParam, factorArray, true, true);
            for (int subIter = 0; subIter < 10; ++subIter) {
                mfBP.iterate();
            }
            System.out.println("Current estimate=" + mfBP.logPartitionEstimate());
        }
    }

    private static double[] naturalParam(double[][] nonLogPot, Graph<Integer> g, Indexer<Object> indexer) {
        double[] result = new double[indexer.size()];
        for (UnorderedPair<Integer, Integer> edge : Graphs.edgeSet(g)) {
            for (int s1 = 0; s1 < nonLogPot.length; ++s1) {
                for (int s2 = 0; s2 < nonLogPot.length; ++s2) {
                    result[GMFctUtils.getEdgePotentialIndex(indexer, edge.getFirst(), edge.getSecond(), (int)s1, (int)s2)] = Math.log(nonLogPot[s1][s2]);
                }
            }
        }
        return result;
    }

    static {
        rand = new Random(1L);
    }

    public static class ForestFacto
    implements Factor {
        public final Graph<Integer> forest;

        public ForestFacto(Graph<Integer> forest) {
            this.forest = forest;
        }

        private GMFct<Integer> getMoments(final double[] naturalParams) {
            final TabularGMFct<Integer> pots = GMFctUtils.ones(new TabularGMFct<Integer>(this.forest, map));
            GMFctUtils.visit(pots, new CliqueVisitor<Integer>(){

                @Override
                public void visitEdge(Integer n1, Integer n2, int s1, int s2) {
                    pots.set(n1, n2, s1, s2, Math.exp(naturalParams[GMFctUtils.getEdgePotentialIndex(indexer, n1, n2, s1, s2)]));
                }

                @Override
                public void visitVertex(Integer n, int s) {
                    pots.set(n, s, Math.exp(naturalParams[GMFctUtils.getNodePotentialIndex(indexer, n, s)]));
                }
            });
            return TreeSumProd.computeMoments(pots);
        }

        @Override
        public double entropy(double[] naturalParameters) {
            return GMFctUtils.entropy(this.getMoments(naturalParameters));
        }

        @Override
        public double[] gradient(double[] naturalParameters) {
            final double[] result = new double[naturalParameters.length];
            final GMFct<Integer> moments = this.getMoments(naturalParameters);
            GMFctUtils.visit(moments, new CliqueVisitor<Integer>(){

                @Override
                public void visitEdge(Integer n1, Integer n2, int s1, int s2) {
                    result[GMFctUtils.getEdgePotentialIndex(IsingMeasureFactorization.indexer, n1, n2, (int)s1, (int)s2)] = moments.get(n1, n2, s1, s2);
                }

                @Override
                public void visitVertex(Integer n, int s) {
                    result[GMFctUtils.getNodePotentialIndex(IsingMeasureFactorization.indexer, n, (int)s)] = moments.get(n, s);
                }
            });
            return result;
        }
    }
}

