/*
 * Decompiled with CFR 0.152.
 */
package fenchel.factor;

import fenchel.factor.BinaryFactor;
import fenchel.factor.FactorGraph;
import fenchel.factor.IdentityFactor;
import fenchel.factor.UnaryFactor;
import fenchel.factor.multisites.MSFactorGraph;
import fenchel.factor.multisites.MSScaledFactorBuilder;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import nuts.math.GMFct;
import nuts.math.Graphs;
import nuts.util.CollUtils;

public class FactorUtils {
    public static <N> MSFactorGraph<N> newFactorGraph() {
        return new MSFactorGraph(new MSScaledFactorBuilder());
    }

    public static double[][] simpleUnary(double[] factor, int nSites) {
        double[][] result = new double[nSites][];
        for (int i = 0; i < nSites; ++i) {
            result[i] = factor;
        }
        return result;
    }

    public static void multiplyIfNeeded(List<UnaryFactor> incomingFactors, BinaryFactor binary) {
        incomingFactors.removeAll(Collections.singleton(IdentityFactor.identity));
        int maxNFactors = binary.maxNumberOfFactorsSupported();
        int initialSize = incomingFactors.size();
        if (initialSize > maxNFactors) {
            ArrayList<UnaryFactor> toMultiply = new ArrayList<UnaryFactor>(initialSize - maxNFactors + 1);
            CollUtils.transfer(incomingFactors, maxNFactors - 1, toMultiply);
            UnaryFactor product = FactorUtils.multiply(toMultiply);
            incomingFactors.add(product);
            if (incomingFactors.size() != maxNFactors) {
                throw new RuntimeException();
            }
        }
    }

    public static UnaryFactor multiply(List<UnaryFactor> toMultiply) {
        return FactorUtils.multiply(IdentityFactor.identity, toMultiply);
    }

    public static UnaryFactor multiply(UnaryFactor modelFactor, List<UnaryFactor> toMultiply) {
        toMultiply.removeAll(Collections.singleton(IdentityFactor.identity));
        if (modelFactor != IdentityFactor.identity) {
            return modelFactor.multiply(toMultiply);
        }
        if (toMultiply.size() == 0) {
            return IdentityFactor.identity;
        }
        UnaryFactor last = toMultiply.get(toMultiply.size() - 1);
        toMultiply.remove(toMultiply.size() - 1);
        return last.multiply(toMultiply);
    }

    public static <N> List<Pair<N, N>> distinctOutgoing(FactorGraph<N> f, Pair<N, N> e) {
        return FactorUtils.neighbors(f, e.getSecond(), e.getFirst(), false, null);
    }

    public static <N> String factorGraphUnariesToString(FactorGraph<N> fg) {
        StringBuilder result = new StringBuilder();
        for (N node : fg.getTopology().vertexSet()) {
            result.append("Node: " + node + "\nFactor\n" + fg.getUnary(node) + "\n\n");
        }
        return result.toString();
    }

    public static <N, K> List<K> distinctIncoming(FactorGraph<N> f, Pair<N, N> e, Map<Pair<N, N>, K> edgeMap) {
        return FactorUtils.neighbors(f, e.getFirst(), e.getSecond(), true, edgeMap);
    }

    public static <N, K> List<K> allIncoming(FactorGraph<N> f, N node, Map<Pair<N, N>, K> edgeMap) {
        return FactorUtils.neighbors(f, node, null, true, edgeMap);
    }

    private static <N> List neighbors(FactorGraph<N> f, N node, N excluded, boolean incoming, Map edgeMap) {
        ArrayList<Pair<Object, Object>> result = new ArrayList<Pair<Object, Object>>();
        for (N neighbor : f.getTopology().nbrs(node)) {
            if (neighbor.equals(excluded)) continue;
            Pair<N, N> edge = Pair.makePair(incoming ? neighbor : node, incoming ? node : neighbor);
            if (edgeMap == null) {
                result.add(edge);
                continue;
            }
            Object item = edgeMap.get(edge);
            if (item == null) {
                throw new RuntimeException();
            }
            result.add((Pair<Object, Object>)item);
        }
        return result;
    }

    public static <N> void gm2FactorGraph(GMFct<N> gm, MSFactorGraph<N> factorGraph) {
        FactorUtils.gmNodePotentials2FactorGraph(gm, factorGraph);
        FactorUtils.gmEdgePotentials2FactorGraph(gm, factorGraph);
    }

    public static <N> void gmNodePotentials2FactorGraph(GMFct<N> gm, MSFactorGraph<N> factorGraph) {
        FactorUtils._gmNodePotentials2FactorGraph(gm, factorGraph, 1, gm.graph().vertexSet());
    }

    public static <N> void gmEdgePotentials2FactorGraph(GMFct<N> gm, MSFactorGraph<N> factorGraph) {
        for (UnorderedPair<N, N> undirectedEdge : Graphs.edgeSet(gm.graph())) {
            N n1 = undirectedEdge.getFirst();
            N n2 = undirectedEdge.getSecond();
            int nStates1 = gm.nStates(n1);
            int nStates2 = gm.nStates(n2);
            double[][] array = new double[nStates1][nStates2];
            for (int s1 = 0; s1 < nStates1; ++s1) {
                for (int s2 = 0; s2 < nStates2; ++s2) {
                    array[s1][s2] = gm.get(n1, n2, s1, s2);
                }
            }
            factorGraph.addBinary(n1, n2, array);
        }
    }

    public static <N> void _gmNodePotentials2FactorGraph(GMFct<N> gm, MSFactorGraph<N> factorGraph, int nSites, Set<N> subset) {
        for (N node : subset) {
            int nStates = gm.nStates(node);
            double[][] array = new double[nSites][gm.nStates(node)];
            for (int n = 0; n < nSites; ++n) {
                for (int s = 0; s < nStates; ++s) {
                    array[n][s] = gm.get(node, s);
                }
            }
            factorGraph.addUnary(node, array);
        }
    }
}

