/*
 * Decompiled with CFR 0.152.
 */
package ev.multi;

import ev.hmm.HetPairHMM;
import fig.basic.Option;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Set;
import ma.MultiAlignment;
import nuts.math.GMFctUtils;
import nuts.math.HashGraph;
import nuts.math.SemiGraph;
import nuts.math.TabularGMFct;
import nuts.maxent.SloppyMath;
import nuts.tui.Table;

public class MessageComputations {
    @Option
    public static double anneal = 0.1;
    public static final double smallestLogMessage = -10.0;

    public static double[][][] initRMessages(int len1, int len2) {
        return new double[len1 + 1][len2 + 1][2];
    }

    public static double[][][] rMessages(Taxon l1, Taxon l2, QMessages previousQMessages) {
        int len1 = previousQMessages.one.nStates(l1);
        int len2 = previousQMessages.one.nStates(l2);
        double[][][] result = new double[len1][len2][2];
        ArrayList<Taxon> others = new ArrayList<Taxon>(previousQMessages.one.graph().vertexSet());
        others.remove(l1);
        others.remove(l2);
        for (int i = 0; i < len1 - 1; ++i) {
            for (int k = 0; k < len2 - 1; ++k) {
                double sum0 = 0.0;
                double sum1 = 0.0;
                for (Taxon other : others) {
                    for (int j = 0; j < previousQMessages.one.nStates(other) - 1; ++j) {
                        double logq_ij1 = previousQMessages.one.get(l1, other, i, j);
                        double logq_jk1 = previousQMessages.one.get(other, l2, j, k);
                        double logq_ij0 = previousQMessages.zero.get(l1, other, i, j);
                        double logq_jk0 = previousQMessages.zero.get(other, l2, j, k);
                        double unnorm_logr0 = Double.NEGATIVE_INFINITY;
                        unnorm_logr0 = SloppyMath.logAdd(unnorm_logr0, logq_ij1 + logq_jk0);
                        unnorm_logr0 = SloppyMath.logAdd(unnorm_logr0, logq_ij0 + logq_jk1);
                        unnorm_logr0 = SloppyMath.logAdd(unnorm_logr0, logq_ij0 + logq_jk0);
                        double unnorm_logr1 = Double.NEGATIVE_INFINITY;
                        unnorm_logr1 = SloppyMath.logAdd(unnorm_logr1, logq_ij1 + logq_jk1);
                        unnorm_logr1 = SloppyMath.logAdd(unnorm_logr1, logq_ij0 + logq_jk0);
                        double lognorm = SloppyMath.logAdd(unnorm_logr0, unnorm_logr1);
                        double logr0 = unnorm_logr0 - lognorm;
                        double logr1 = unnorm_logr1 - lognorm;
                        if (Double.isNaN(logr0) || Double.isNaN(logr1)) {
                            throw new RuntimeException();
                        }
                        if (Double.isInfinite(logr0) || Double.isInfinite(logr1)) {
                            throw new RuntimeException();
                        }
                        sum0 += logr0;
                        sum1 += logr1;
                    }
                }
                if (Double.isNaN(sum0) || Double.isNaN(sum1)) {
                    throw new RuntimeException();
                }
                sum0 = anneal * sum0;
                sum1 = anneal * sum1;
                double fullNorm = SloppyMath.logAdd(sum0, sum1);
                result[i][k][0] = sum0 - fullNorm;
                result[i][k][1] = sum1 - fullNorm;
            }
        }
        return result;
    }

    public static void print(final double[][][] msg) {
        Table table = new Table(new Table.Populator(){

            @Override
            public void populate() {
                for (int i = 0; i < msg.length; ++i) {
                    for (int j = 0; j < msg[0].length; ++j) {
                        if (msg[i][j][1] - msg[i][j][0] > 0.0) {
                            this.addLines(i, j, "" + Math.exp(msg[i][j][1] - msg[i][j][0]));
                            continue;
                        }
                        this.addLines(i, j, "");
                    }
                }
            }
        });
        System.out.println(table.toString());
    }

    public static void qMessages(HetPairHMM hmm, QMessages destination, Taxon top, Taxon bot, double[][][] messagesUsed) {
        for (int topIndex = 0; topIndex < hmm.str1.length(); ++topIndex) {
            for (int botIndex = 0; botIndex < hmm.str2.length(); ++botIndex) {
                double logq1;
                double logq0;
                double loga1 = hmm.logPosteriorAlignment(topIndex, botIndex);
                double loga0 = Math.log(1.0 - Math.min(1.0, Math.exp(loga1)));
                if (loga1 == Double.NEGATIVE_INFINITY) {
                    logq0 = 0.0;
                    logq1 = Double.NEGATIVE_INFINITY;
                } else if (loga0 == Double.NEGATIVE_INFINITY) {
                    logq1 = 0.0;
                    logq0 = Double.NEGATIVE_INFINITY;
                } else {
                    double logr1 = messagesUsed[topIndex][botIndex][1];
                    double logr0 = messagesUsed[topIndex][botIndex][0];
                    double unnorm_logq0 = loga0 - logr0;
                    double unnorm_logq1 = loga1 - logr1;
                    if (Double.isInfinite(unnorm_logq0) || Double.isInfinite(unnorm_logq1)) {
                        throw new RuntimeException();
                    }
                    double norm = SloppyMath.logAdd(unnorm_logq0, unnorm_logq1);
                    logq0 = unnorm_logq0 - norm;
                    logq1 = unnorm_logq1 - norm;
                    if (Double.isNaN(logq0) || Double.isNaN(logq1)) {
                        throw new RuntimeException();
                    }
                }
                if (logq0 < -10.0) {
                    logq0 = -10.0;
                    logq1 = Math.log(1.0 - Math.exp(-10.0));
                }
                if (logq1 < -10.0) {
                    logq1 = -10.0;
                    logq0 = Math.log(1.0 - Math.exp(-10.0));
                }
                destination.one.set(top, bot, topIndex, botIndex, logq1);
                destination.zero.set(top, bot, topIndex, botIndex, logq0);
            }
        }
    }

    public static QMessages blank(MultiAlignment msa) {
        return new QMessages(MessageComputations.blankGMF(msa), MessageComputations.blankGMF(msa));
    }

    public static TabularGMFct<Taxon> blankGMF(MultiAlignment msa) {
        final Set<Taxon> allTaxa = msa.getSequences().keySet();
        HashGraph<Taxon> graph = new HashGraph<Taxon>(new SemiGraph<Taxon>(){

            @Override
            public boolean hasSemiEdge(Taxon one, Taxon two) {
                return !one.equals(two);
            }

            @Override
            public Set<Taxon> vertexSet() {
                return allTaxa;
            }
        });
        HashMap<Taxon, Integer> domains = new HashMap<Taxon, Integer>();
        for (Taxon l : allTaxa) {
            domains.put(l, msa.getSequences().get(l).length() + 1);
        }
        return GMFctUtils.cnst(new TabularGMFct<Taxon>(graph, domains), Double.NaN);
    }

    public static final class QMessages {
        public final TabularGMFct<Taxon> zero;
        public final TabularGMFct<Taxon> one;

        public QMessages(TabularGMFct<Taxon> zero, TabularGMFct<Taxon> one) {
            this.zero = zero;
            this.one = one;
        }
    }
}

