/*
 * Decompiled with CFR 0.152.
 */
package fenchel.algo;

import fenchel.algo.BlockingPriorities;
import fenchel.algo.FactorGraphInferenceAlgorithm;
import fenchel.factor.BinaryFactor;
import fenchel.factor.FactorGraph;
import fenchel.factor.FactorUtils;
import fenchel.factor.UnaryFactor;
import fig.basic.Pair;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import nuts.math.Graphs;
import nuts.util.CollUtils;

public class FactorGraphSumProduct<N>
implements FactorGraphInferenceAlgorithm<N> {
    private BlockingPriorities<Pair<N, N>> priorities = null;
    private Map<Pair<N, N>, UnaryFactor> producedMessages = null;
    private FactorGraph<N> factorGraph = null;
    private boolean computed;

    public UnaryFactor getMessage(N source, N dest) {
        this.ensureMessagesComputed();
        return this.producedMessages.get(Pair.makePair(source, dest));
    }

    @Override
    public void init(FactorGraph<N> factorGraph) {
        if (!Graphs.isForest(factorGraph.getTopology())) {
            throw new RuntimeException("Only supports forests");
        }
        this.computed = false;
        this.factorGraph = factorGraph;
        this.priorities = BlockingPriorities.initPriorities(factorGraph.getTopology());
        this.producedMessages = new HashMap<Pair<N, N>, UnaryFactor>();
    }

    public boolean initialized() {
        return this.factorGraph != null;
    }

    @Override
    public UnaryFactor moment(N node) {
        this.ensureMessagesComputed();
        UnaryFactor modelFactor = this.factorGraph.getUnary(node);
        return FactorUtils.multiply(modelFactor, FactorUtils.allIncoming(this.factorGraph, node, this.producedMessages));
    }

    @Override
    public double logZ() {
        Set<Set<N>> ccs = Graphs.connectedComponents(this.factorGraph.getTopology());
        double sum = 0.0;
        for (Set<N> cc : ccs) {
            sum += this.moment(CollUtils.pick(cc)).logNorm();
        }
        return sum;
    }

    private void ensureMessagesComputed() {
        if (!this.initialized()) {
            throw new RuntimeException("Need to first call init()");
        }
        if (this.computed) {
            return;
        }
        this._computeMessagesSerially();
    }

    private void _computeMessagesSerially() {
        while (!this.priorities.isBlocked()) {
            Pair<N, N> messageJob = this.popMessageJob();
            UnaryFactor factor = this.computeMessage(messageJob);
            this.updateMessageJobs(messageJob, factor);
        }
        this.computed = true;
    }

    private UnaryFactor computeMessage(Pair<N, N> messageJob) {
        N source = messageJob.getFirst();
        N destination = messageJob.getSecond();
        UnaryFactor modelFactor = this.factorGraph.getUnary(source);
        BinaryFactor binary = this.factorGraph.getBinary(source, destination);
        List<UnaryFactor> incomingFactors = FactorUtils.distinctIncoming(this.factorGraph, messageJob, this.producedMessages);
        incomingFactors.add(modelFactor);
        FactorUtils.multiplyIfNeeded(incomingFactors, binary);
        UnaryFactor newFactor = binary.marginalize(incomingFactors);
        return newFactor;
    }

    private Pair<N, N> popMessageJob() {
        return this.priorities.popUnblocked();
    }

    private void updateMessageJobs(Pair<N, N> edge, UnaryFactor factor) {
        this.producedMessages.put(edge, factor);
        for (Pair<N, N> distinctOutgoingEdge : FactorUtils.distinctOutgoing(this.factorGraph, edge)) {
            this.priorities.removeBlock(distinctOutgoingEdge);
        }
    }

    public FactorGraph<N> getFactorGraph() {
        return this.factorGraph;
    }

    public UnaryFactor distinctIncomingProduct(N source, N dest) {
        this.ensureMessagesComputed();
        return FactorUtils.multiply(this.factorGraph.getUnary(source), FactorUtils.distinctIncoming(this.factorGraph, Pair.makePair(source, dest), this.producedMessages));
    }
}

