/*
 * Decompiled with CFR 0.152.
 */
package facto;

import facto.BipartiteMatching;
import facto.Factor;
import facto.Rasmussen;
import fig.basic.NumUtils;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.SortedMap;
import java.util.TreeMap;
import nuts.math.MtxUtils;
import nuts.util.CoordinatesPacker;
import nuts.util.MathUtils;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;

public final class MFBP {
    public final int I;
    public final int J;
    private double[][] zetas;
    private final double[] naturalParameters;
    public final Factor[] factors;
    public final boolean useMF;
    private double[] _moments = null;
    static SortedMap<String, SummaryStatistics> stats = new TreeMap<String, SummaryStatistics>();

    public MFBP(double[] param, Factor[] baseMeasureFacto, boolean isParamVectorInLogSpace, boolean useMF) {
        this.useMF = useMF;
        this.naturalParameters = new double[param.length];
        for (int i = 0; i < this.naturalParameters.length; ++i) {
            this.naturalParameters[i] = isParamVectorInLogSpace ? param[i] : Math.log(param[i]);
        }
        this.factors = baseMeasureFacto;
        this.I = baseMeasureFacto.length;
        this.J = this.naturalParameters.length;
        this.zetas = new double[this.I][this.J];
    }

    public void iterate() {
        int i;
        this._moments = null;
        double[][] xis = new double[this.I][this.J];
        for (i = 0; i < this.I; ++i) {
            MtxUtils.plusEqual(xis[i], this.naturalParameters, 1.0);
            int i2 = 0;
            while (i < this.I) {
                if (this.useMF || i != i2) {
                    MtxUtils.plusEqual(xis[i], this.zetas[i2], 1.0);
                }
                ++i;
            }
        }
        for (i = 0; i < this.I; ++i) {
            double[] curGradient = this.factors[i].gradient(xis[i]);
            for (int j = 0; j < this.J; ++j) {
                if (Double.isInfinite(xis[i][j])) {
                    this.zetas[i][j] = xis[i][j];
                } else {
                    double d = curGradient[j];
                    if (d > 1.0) {
                        MathUtils.checkClose(d, 1.0);
                        d = 1.0;
                    }
                    if (d < 0.0) {
                        MathUtils.checkClose(d, 0.0);
                        d = 0.0;
                    }
                    this.zetas[i][j] = MathUtils.logit(d) - (this.useMF ? 0.0 : xis[i][j]);
                }
                if (!Double.isNaN(this.zetas[i][j])) continue;
                throw new RuntimeException();
            }
        }
    }

    public double[] moments() {
        if (this._moments != null) {
            return this._moments;
        }
        this._moments = MathUtils.logistic(this.paramPlusAllZetas());
        return this._moments;
    }

    public double[] paramPlusAllZetas() {
        double[] result = new double[this.J];
        MtxUtils.plusEqual(result, this.naturalParameters, 1.0);
        for (int i = 0; i < this.I; ++i) {
            MtxUtils.plusEqual(result, this.zetas[i], 1.0);
        }
        return result;
    }

    static SummaryStatistics getStat(String ... strings) {
        String name = "";
        for (String s : strings) {
            name = name + s + '\t';
        }
        SummaryStatistics result = (SummaryStatistics)stats.get(name);
        if (result == null) {
            result = new SummaryStatistics();
            stats.put(name, result);
        }
        return result;
    }

    public double logPartitionEstimate() {
        if (!this.useMF) {
            throw new RuntimeException();
        }
        double sum = 0.0;
        for (int j = 0; j < this.naturalParameters.length; ++j) {
            sum += this.naturalParameters[j] * this.moments()[j];
            sum += MathUtils.entropy(this.moments()[j]);
        }
        double[] paramPlusAllZetas = this.paramPlusAllZetas();
        for (int i = 0; i < this.factors.length; ++i) {
            sum += this.factors[i].entropy(paramPlusAllZetas);
        }
        return sum;
    }

    public static void main(String[] args) {
        Random rand = new Random(1L);
        int bpIters = 3;
        for (double temp = 1.0; temp >= 0.5; temp -= 0.1) {
            for (int N = 5; N < 6; ++N) {
                for (int replication = 0; replication < 100; ++replication) {
                    double[][] pots = BipartiteMatching.randomBinaryPotentials(rand, N, temp);
                    BipartiteMatching matching = new BipartiteMatching(pots);
                    for (int i = 1; i <= 16; i *= 2) {
                        long time = System.currentTimeMillis();
                        double[][] rasmPost = Rasmussen.moments(pots, rand, i);
                        time = System.currentTimeMillis() - time;
                        MFBP.getStat("N=" + N, "temp=" + temp, "rasmussen-" + i + "-time").addValue((double)time);
                        MFBP.getStat("N=" + N, "temp=" + temp, "rasmussen-" + i).addValue(MFBP.rms(rasmPost, matching.posteriors));
                        for (Object d : (Iterator<Boolean>)rasmPost) {
                            NumUtils.normalize((double[])d);
                        }
                        MFBP.getStat("N=" + N, "temp=" + temp, "rasmussen-norm-" + i).addValue(MFBP.rms(rasmPost, matching.posteriors));
                    }
                    CoordinatesPacker cp = new CoordinatesPacker(N);
                    pots = BipartiteMatching.addEpsilon(1.0E-4, pots);
                    double[] naturalParams = new double[N * N];
                    for (int i = 0; i < naturalParams.length; ++i) {
                        int[] coord = cp.int2coord(i);
                        naturalParams[i] = Math.log(pots[coord[0]][coord[1]]);
                    }
                    Factor[] facto = new Factor[]{BipartiteMatching.getFunctionFactor(false, cp, N), BipartiteMatching.getFunctionFactor(true, cp, N)};
                    Iterator<Boolean> iterator = Arrays.asList(true, false).iterator();
                    while (iterator.hasNext()) {
                        boolean useMF = iterator.next();
                        MFBP bp = new MFBP(naturalParams, facto, true, useMF);
                        long totalTime = 0L;
                        for (int iter = 0; iter < 3; ++iter) {
                            double[][] unpacked = MFBP.unpack(bp.moments(), cp);
                            MFBP.getStat("useMF=" + useMF, "N=" + N, "temp=" + temp, "vmf-iter=" + iter).addValue(MFBP.rms(unpacked, matching.posteriors));
                            for (double[] d : unpacked) {
                                NumUtils.normalize(d);
                            }
                            MFBP.getStat("useMF=" + useMF, "N=" + N, "temp=" + temp, "vmf-renorm-iter=" + iter).addValue(MFBP.rms(unpacked, matching.posteriors));
                            long time = System.currentTimeMillis();
                            bp.iterate();
                            time = System.currentTimeMillis() - time;
                            MFBP.getStat("useMF=" + useMF, "N=" + N, "temp=" + temp, "vmf-iter=" + iter + "-time").addValue((double)(totalTime += time));
                        }
                    }
                }
            }
        }
        System.out.println();
        for (String key : stats.keySet()) {
            System.out.println(key + "\t" + ((SummaryStatistics)stats.get(key)).getMean() + "\t" + ((SummaryStatistics)stats.get(key)).getStandardDeviation());
        }
    }

    public static double[][] unpack(double[] packed, CoordinatesPacker cp) {
        double[][] result = new double[cp.L][cp.L];
        for (int i = 0; i < cp.L; ++i) {
            for (int j = 0; j < cp.L; ++j) {
                result[i][j] = packed[cp.coord2int(i, j)];
            }
        }
        return result;
    }

    public static double rms(double[][] one, double[][] two) {
        if (one.length != two.length) {
            throw new RuntimeException();
        }
        SummaryStatistics stats = new SummaryStatistics();
        for (int row = 0; row < one.length; ++row) {
            for (int col = 0; col < one[0].length; ++col) {
                double curDiff = one[row][col] - two[row][col];
                if (Double.isNaN(curDiff)) {
                    stats.addValue(1.0);
                    continue;
                }
                stats.addValue(curDiff * curDiff);
            }
        }
        if (Double.isNaN(stats.getMean())) {
            return 1.0;
        }
        return Math.sqrt(stats.getMean());
    }

    public static double rms(double[] one, double[] two) {
        if (one.length != two.length) {
            throw new RuntimeException();
        }
        SummaryStatistics stats = new SummaryStatistics();
        for (int row = 0; row < one.length; ++row) {
            double curDiff = one[row] - two[row];
            stats.addValue(curDiff * curDiff);
        }
        return Math.sqrt(stats.getMean());
    }
}

