/*
 * Decompiled with CFR 0.152.
 */
package facto;

import facto.BipartiteMatching;
import facto.Factor;
import facto.MFBP;
import fig.basic.BipartiteMatcher;
import fig.prob.Gaussian;
import java.util.Random;
import nuts.util.CoordinatesPacker;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;

public class GenerateRecon {
    public static double[][] noisyObservations(Random rand, double variance, int N2) {
        double[][] result = new double[N2][N2];
        for (int i = 0; i < N2; ++i) {
            for (int j = 0; j < N2; ++j) {
                result[i][j] = Gaussian.sample(rand, i == j ? 1.0 : 0.0, variance);
            }
        }
        return result;
    }

    public static double[][] posteriorWeights(double[][] observations, double variance) {
        int N2 = observations.length;
        double[][] result = new double[N2][N2];
        for (int i = 0; i < N2; ++i) {
            for (int j = 0; j < N2; ++j) {
                result[i][j] = Math.exp(0.5 / variance * (2.0 * observations[i][j] - 1.0));
            }
        }
        return result;
    }

    public static int[] decode(double[][] posteriors) {
        return new BipartiteMatcher().findMaxWeightAssignment(posteriors);
    }

    public static double loss(int[] guess) {
        SummaryStatistics stats = new SummaryStatistics();
        for (int i = 0; i < guess.length; ++i) {
            stats.addValue(guess[i] == i ? 0.0 : 1.0);
        }
        return stats.getMean();
    }

    public static void main(String[] args) {
        boolean useMF = false;
        double diagPot = 100.0;
        double var = 0.1;
        int bpIters = 4;
        Random rand = new Random(2L);
        for (int N2 = 64; N2 <= 64; N2 *= 2) {
            for (int replic = 0; replic < 100; ++replic) {
                int i;
                int[] coord;
                double[][] observations = GenerateRecon.noisyObservations(rand, 0.1, N2);
                double[][] posteriorWeights = GenerateRecon.posteriorWeights(observations, 0.1);
                long time = System.currentTimeMillis();
                CoordinatesPacker cp = new CoordinatesPacker(N2);
                double[] naturalParams = new double[N2 * N2];
                for (int i2 = 0; i2 < naturalParams.length; ++i2) {
                    coord = cp.int2coord(i2);
                    naturalParams[i2] = Math.log(posteriorWeights[coord[0]][coord[1]]);
                }
                Factor[] facto = new Factor[]{BipartiteMatching.getFunctionFactor(false, cp, N2), BipartiteMatching.getFunctionFactor(true, cp, N2)};
                MFBP bp = new MFBP(naturalParams, facto, true, false);
                for (int i3 = 0; i3 < 4; ++i3) {
                    bp.iterate();
                }
                time = System.currentTimeMillis() - time;
                double[][] unpacked = MFBP.unpack(bp.moments(), cp);
                int[] decoded = GenerateRecon.decode(unpacked);
                MFBP.getStat("BP-3", "var=0.1", "N=" + N2).addValue(GenerateRecon.loss(decoded));
                MFBP.getStat("BP-3-time", "var=0.1", "N=" + N2).addValue((double)time);
                time = System.currentTimeMillis();
                CoordinatesPacker.MSCoordinatePacker mscp = new CoordinatesPacker.MSCoordinatePacker(new int[]{N2, N2, 2});
                naturalParams = new double[N2 * N2 * 2];
                for (i = 0; i < naturalParams.length; ++i) {
                    coord = mscp.int2coord(i);
                    naturalParams[i] = Math.log(posteriorWeights[coord[0]][coord[1]]);
                    if (coord[2] != 1) continue;
                    int n = i;
                    naturalParams[n] = naturalParams[n] + Math.log(100.0);
                }
                facto[0] = BipartiteMatching.getHMMFunctionFactor(false, mscp, N2, 0.0);
                facto[1] = BipartiteMatching.getHMMFunctionFactor(true, mscp, N2, 0.0);
                bp = new MFBP(naturalParams, facto, true, false);
                for (i = 0; i < 4; ++i) {
                    bp.iterate();
                }
                time = System.currentTimeMillis() - time;
                unpacked = GenerateRecon.unpack2(bp.moments(), mscp, N2);
                decoded = GenerateRecon.decode(unpacked);
                MFBP.getStat("BP-HMM-3", "var=0.1", "N=" + N2).addValue(GenerateRecon.loss(decoded));
                MFBP.getStat("BP-HMM-3-time", "var=0.1", "N=" + N2).addValue((double)time);
            }
        }
        System.out.println();
        for (String key : MFBP.stats.keySet()) {
            System.out.println(key + "\t" + ((SummaryStatistics)MFBP.stats.get(key)).getMean() + "\t" + ((SummaryStatistics)MFBP.stats.get(key)).getStandardDeviation());
        }
    }

    private static double[][] unpack2(double[] packed, CoordinatesPacker.MSCoordinatePacker mscp, int N2) {
        double[][] result = new double[N2][N2];
        for (int i = 0; i < N2; ++i) {
            for (int j = 0; j < N2; ++j) {
                result[i][j] = packed[mscp.coord2int(i, j, 0)] + packed[mscp.coord2int(i, j, 1)];
            }
        }
        return result;
    }
}

