/*
 * Decompiled with CFR 0.152.
 */
package nuts.math;

import fig.basic.NumUtils;
import fig.basic.UnorderedPair;
import fig.prob.SampleUtils;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.math.GMFct;
import nuts.math.Graph;
import nuts.math.Graphs;
import nuts.math.TabularGMFct;
import nuts.util.MathUtils;

public final class EdgePosteriors<V extends Comparable<V>>
implements GMFct<V> {
    private final TabularGMFct<V> edgePosteriors;
    private Map<V, double[]> marginals = null;

    public EdgePosteriors(Graph<V> graph, Map<V, Integer> randomVariableDom) {
        this.edgePosteriors = new TabularGMFct<V>(graph, randomVariableDom);
    }

    public String toString() {
        return "todo";
    }

    public void set(V n1, V n2, int s1, int s2, double v) {
        if (Double.isNaN(v) || Double.isInfinite(v)) {
            throw new RuntimeException();
        }
        this.marginals = null;
        this.edgePosteriors.set(n1, n2, s1, s2, v);
    }

    @Override
    public double get(V n, int s) {
        return this.getMarginals(n)[s];
    }

    public double[] getMarginals(V n) {
        if (this.marginals == null) {
            this.marginals = this.computeMarginals();
        }
        return this.marginals.get(n);
    }

    public void checkLocalConsistency() {
        this.marginals = this.computeMarginals();
    }

    public Map<V, Integer> sample(Random rand) {
        if (!Graphs.isForest(this.graph())) {
            throw new RuntimeException();
        }
        HashMap result = new HashMap();
        for (Comparable ccRep : Graphs.connectedComponentsReps(this.graph())) {
            this.sample(ccRep, null, -1, result, rand);
        }
        return result;
    }

    private void sample(V node, V parentNode, int parentState, Map<V, Integer> result, Random rand) {
        double[] choices = new double[this.nStates(node)];
        for (int s = 0; s < choices.length; ++s) {
            choices[s] = parentNode == null ? this.get(node, s) : this.get(node, parentNode, s, parentState) / this.get(parentNode, parentState);
        }
        NumUtils.normalize(choices);
        int choice = SampleUtils.sampleMultinomial(rand, choices);
        result.put((Integer)node, choice);
        for (Comparable nbr : this.graph().nbrs(node)) {
            if (parentNode != null && parentNode.equals(nbr)) continue;
            this.sample(nbr, node, choice, result, rand);
        }
    }

    private Map<V, double[]> computeMarginals() {
        HashMap<V, double[]> result = new HashMap<V, double[]>();
        for (UnorderedPair<V, V> edge : Graphs.edgeSet(this.graph())) {
            result.put(edge.getFirst(), this._computeMarginals((Comparable)edge.getFirst(), (Comparable)edge.getSecond(), result));
            result.put(edge.getSecond(), this._computeMarginals((Comparable)edge.getSecond(), (Comparable)edge.getFirst(), result));
        }
        return result;
    }

    private double[] _computeMarginals(V first, V second, Map<V, double[]> partialMarginals) {
        int nStates = this.edgePosteriors.nStates(first);
        double[] result = new double[nStates];
        for (int s = 0; s < nStates; ++s) {
            double cur = this.__computeMarginals(s, first, second);
            if (!MathUtils.isCloseToProb(cur)) {
                throw new RuntimeException("Should be a pr:" + cur);
            }
            if (partialMarginals.containsKey(first) && !MathUtils.close(partialMarginals.get(first)[s], cur)) {
                throw new RuntimeException("Marginal " + partialMarginals.get(first)[s] + " not consistent with " + cur);
            }
            result[s] = cur;
        }
        return result;
    }

    private double __computeMarginals(int s, V first, V second) {
        double sum = 0.0;
        for (int s2 = 0; s2 < this.edgePosteriors.nStates(second); ++s2) {
            sum += this.get(first, second, s, s2);
        }
        return sum;
    }

    @Override
    public double get(V n1, V n2, int s1, int s2) {
        return this.edgePosteriors.get(n1, n2, s1, s2);
    }

    @Override
    public int nStates(V node) {
        return this.edgePosteriors.nStates(node);
    }

    @Override
    public Graph<V> graph() {
        return this.edgePosteriors.graph();
    }

    public double entropy() {
        if (!Graphs.isForest(this.graph())) {
            throw new RuntimeException();
        }
        EntropyGraphProcessor gp = new EntropyGraphProcessor(this);
        Graph<V> graph = this.edgePosteriors.graph();
        for (Comparable ccRep : Graphs.connectedComponentsReps(this.graph())) {
            Graphs.dfs(graph, ccRep, gp);
        }
        return gp.cEntropy;
    }

    public static class EntropyGraphProcessor<V extends Comparable<V>>
    implements Graphs.GraphProcessor<V> {
        private double cEntropy = 0.0;
        private final EdgePosteriors<V> edgePosteriors;

        public EntropyGraphProcessor(EdgePosteriors<V> edgePosteriors) {
            this.edgePosteriors = edgePosteriors;
        }

        @Override
        public void process(V vertex, V parent, Set<V> visited) {
            this.cEntropy = parent == null ? (this.cEntropy += this.entropy(vertex)) : (this.cEntropy += this.cEntropy(vertex, parent));
        }

        private double cEntropy(V node, V parent) {
            double sum = 0.0;
            for (int parentState = 0; parentState < this.edgePosteriors.nStates(parent); ++parentState) {
                double cDenom = this.edgePosteriors.get(parent, parentState);
                for (int currentState = 0; currentState < this.edgePosteriors.nStates(node); ++currentState) {
                    double edgePost = this.edgePosteriors.get(node, parent, currentState, parentState);
                    sum -= edgePost * Math.log(edgePost / cDenom);
                }
            }
            return sum;
        }

        private double entropy(V node) {
            double sum = 0.0;
            for (int i = 0; i < this.edgePosteriors.nStates(node); ++i) {
                sum -= this.edgePosteriors.get(node, i) * Math.log(this.edgePosteriors.get(node, i));
            }
            return sum;
        }
    }
}

