/*
 * Decompiled with CFR 0.152.
 */
package facto;

import facto.Factor;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import nuts.math.GMFct;
import nuts.math.GMFctUtils;
import nuts.math.Graph;
import nuts.math.Graphs;
import nuts.math.Permutations;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.util.CoordinatesPacker;

public class BipartiteMatching {
    public final double[][] potential;
    public final double partitionFct;
    public final double[][] posteriors;
    public final int N;

    public BipartiteMatching(double[][] potentials) {
        this.potential = potentials;
        this.N = potentials.length;
        this.posteriors = new double[this.N][this.N];
        double sum = 0.0;
        for (int[] perm : Permutations.permutations(this.N)) {
            int i;
            double prod = 1.0;
            for (i = 0; i < perm.length; ++i) {
                prod *= potentials[i][perm[i]];
            }
            sum += prod;
            for (i = 0; i < perm.length; ++i) {
                double[] dArray = this.posteriors[i];
                int n = perm[i];
                dArray[n] = dArray[n] + prod;
            }
        }
        this.partitionFct = sum;
        for (int i = 0; i < this.N; ++i) {
            int j = 0;
            while (j < this.N) {
                double[] dArray = this.posteriors[i];
                int n = j++;
                dArray[n] = dArray[n] / sum;
            }
        }
    }

    public static double[][] randomPotentials(Random rand, int N2, double temperature) {
        double[][] result = new double[N2][N2];
        for (int i = 0; i < N2; ++i) {
            for (int j = 0; j < N2; ++j) {
                result[i][j] = rand.nextDouble() / temperature;
            }
        }
        return result;
    }

    public static double[][] randomBinaryPotentials(Random rand, int N2, double prOfOne) {
        double[][] result = new double[N2][N2];
        for (int i = 0; i < N2; ++i) {
            for (int j = 0; j < N2; ++j) {
                result[i][j] = rand.nextDouble() < prOfOne ? 1.0 : 0.0;
            }
        }
        return result;
    }

    public static double[][] addEpsilon(double epsilon, double[][] ori) {
        int N2 = ori.length;
        double[][] result = new double[N2][N2];
        for (int i = 0; i < N2; ++i) {
            for (int j = 0; j < N2; ++j) {
                result[i][j] = ori[i][j] == 0.0 ? epsilon : ori[i][j];
            }
        }
        return result;
    }

    public static Factor getHMMFunctionFactor(final boolean flip, final CoordinatesPacker.MSCoordinatePacker cp, final int N2, final double epsilon) {
        return new Factor(){
            int[] tmp = new int[3];

            private int coord2int(int c1, int c2, boolean previousAligned) {
                this.tmp[0] = flip ? c1 : c2;
                this.tmp[1] = flip ? c2 : c1;
                this.tmp[2] = previousAligned ? 1 : 0;
                return cp.coord2int(this.tmp);
            }

            @Override
            public double[] gradient(double[] naturalParameters) {
                int prevCol;
                Graph<Integer> chain = Graphs.chainGraph(N2);
                HashMap<Integer, Integer> domains = new HashMap<Integer, Integer>();
                for (int i = 0; i < N2; ++i) {
                    domains.put(i, N2);
                }
                TabularGMFct<Integer> potentials = GMFctUtils.ones(new TabularGMFct<Integer>(chain, domains));
                for (int row = 0; row < N2; ++row) {
                    for (int col = 0; col < N2; ++col) {
                        double prevAlignedPot = Math.exp(naturalParameters[this.coord2int(row, col, true)]);
                        double prevUnAlignPot = Math.exp(naturalParameters[this.coord2int(row, col, false)]);
                        if (row == 0) {
                            potentials.set(0, col, prevUnAlignPot);
                            continue;
                        }
                        for (prevCol = 0; prevCol < N2; ++prevCol) {
                            if (prevCol == col - 1) {
                                potentials.set(row - 1, row, prevCol, col, prevAlignedPot);
                                continue;
                            }
                            if (prevCol == col) {
                                potentials.set(row - 1, row, prevCol, col, epsilon);
                                continue;
                            }
                            potentials.set(row - 1, row, prevCol, col, prevUnAlignPot);
                        }
                    }
                }
                GMFct<Integer> post = BipartiteMatching.computePartFct(potentials);
                double[] result = new double[N2 * N2 * 2];
                for (int row = 0; row < N2; ++row) {
                    for (int col = 0; col < N2; ++col) {
                        int alignedIndex = this.coord2int(row, col, true);
                        int unAlignIndex = this.coord2int(row, col, false);
                        if (row == 0) {
                            result[unAlignIndex] = post.get(row, col);
                        } else {
                            for (prevCol = 0; prevCol < N2; ++prevCol) {
                                double current = post.get(row - 1, row, prevCol, col);
                                if (prevCol == col - 1) {
                                    result[alignedIndex] = current;
                                    continue;
                                }
                                if (prevCol == col) continue;
                                int n = unAlignIndex;
                                result[n] = result[n] + current;
                            }
                        }
                        if (result[alignedIndex] == 0.0) {
                            result[alignedIndex] = epsilon;
                        }
                        if (result[unAlignIndex] != 0.0) continue;
                        result[unAlignIndex] = epsilon;
                    }
                }
                return result;
            }

            @Override
            public double entropy(double[] naturalParameters) {
                throw new RuntimeException();
            }
        };
    }

    public static GMFct<Integer> computePartFct(TabularGMFct<Integer> potentials) {
        int s2;
        for (Integer node : potentials.graph().vertexSet()) {
            HashSet<Integer> infties = new HashSet<Integer>();
            for (int s = 0; s < potentials.nStates(node); ++s) {
                if (potentials.get(node, s) != Double.POSITIVE_INFINITY) continue;
                infties.add(s);
            }
            if (infties.size() == 1) {
                int index = (Integer)infties.iterator().next();
                potentials.set(node, index, 1.0);
                for (s2 = 0; s2 < potentials.nStates(node); ++s2) {
                    if (index == s2) continue;
                    potentials.set(node, s2, 0.0);
                }
            } else if (infties.size() > 1) {
                throw new RuntimeException();
            }
            for (Integer node2 : potentials.graph().nbrs(node)) {
                for (int s = 0; s < potentials.nStates(node); ++s) {
                    for (int s22 = 0; s22 < potentials.nStates(node2); ++s22) {
                        if (potentials.get(node, node2, s, s22) != Double.POSITIVE_INFINITY) continue;
                        potentials.set(node, node2, s, s22, 1.0);
                    }
                }
            }
        }
        TreeSumProd<Integer> tsp = new TreeSumProd<Integer>(potentials);
        if (Double.isNaN(tsp.logZ()) || Double.isInfinite(tsp.logZ())) {
            CoordinatesPacker subpack = new CoordinatesPacker(6);
            for (int node = 0; node < 6; ++node) {
                System.out.println("node " + node + " -> " + (node + 1));
                for (int s1 = 0; s1 < potentials.nStates(node); ++s1) {
                    for (s2 = 0; s2 < potentials.nStates(node + 1); ++s2) {
                        if (potentials.get(node, node + 1, s1, s2) == 0.0) continue;
                        int[] c1 = subpack.int2coord(s1);
                        int[] c2 = subpack.int2coord(s2);
                        System.out.println(Arrays.toString(c1) + " -> " + Arrays.toString(c2) + " : " + potentials.get(node, node + 1, s1, s2));
                    }
                }
            }
            throw new RuntimeException();
        }
        return tsp.moments();
    }

    public static Factor getFunctionFactor(final boolean flip, final CoordinatesPacker cp, final int N2) {
        return new Factor(){
            int[] tmp = new int[2];

            private int coord2int(int c1, int c2) {
                this.tmp[0] = flip ? c1 : c2;
                this.tmp[1] = flip ? c2 : c1;
                return cp.coord2int(this.tmp);
            }

            @Override
            public double[] gradient(double[] naturalParameters) {
                double[] result = new double[naturalParameters.length];
                if (N2 * N2 != result.length) {
                    throw new RuntimeException();
                }
                for (int row = 0; row < N2; ++row) {
                    int col;
                    double sum = 0.0;
                    for (col = 0; col < N2; ++col) {
                        sum += Math.exp(naturalParameters[this.coord2int(row, col)]);
                    }
                    for (col = 0; col < N2; ++col) {
                        result[this.coord2int((int)row, (int)col)] = Math.exp(naturalParameters[this.coord2int(row, col)]) / sum;
                    }
                }
                return result;
            }

            @Override
            public double entropy(double[] naturalParameters) {
                throw new RuntimeException();
            }
        };
    }

    public static void main(String[] args) {
        Random rand = new Random(1L);
        for (int i = 0; i < 12; ++i) {
            double[][] pots = BipartiteMatching.randomPotentials(rand, i, 1.0);
            long time = System.currentTimeMillis();
            BipartiteMatching abm = new BipartiteMatching(pots);
            System.out.println("Size=" + i);
            System.out.println("Time=" + (System.currentTimeMillis() - time));
            System.out.println("Partition fct=" + abm.partitionFct);
            System.out.println("Posteriors=" + Arrays.deepToString((Object[])abm.posteriors));
        }
    }
}

