/*
 * Decompiled with CFR 0.152.
 */
package ev.hmm;

import ev.hmm.HetPairHMMSpecification;
import ev.par.ExponentialFamily;
import ev.par.FeatureExtractor;
import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import goblin.DerivationTree;
import java.util.List;
import java.util.Random;
import ma.SequenceType;
import nuts.maxent.MaxentClassifier;
import nuts.maxent.SloppyMath;
import nuts.tui.Table;
import nuts.util.CoordinatesPacker;
import nuts.util.MathUtils;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;

public final class HetPairHMM {
    private double[][][] prefix;
    private double[][][] suffix;
    private double[][][] maxSuffix;
    public final HetPairHMMSpecification hmm;
    private final int nStates;
    public final String str1;
    public final String str2;
    private boolean fwdInitialized = false;
    private boolean bwdInitialized = false;
    private boolean bwdMaxInitialized = false;

    private int startState() {
        return this.hmm.startState();
    }

    private int endState() {
        return this.hmm.endState();
    }

    public HetPairHMM(String str1, String str2, HetPairHMMSpecification pairHMM) {
        this.str1 = str1;
        this.str2 = str2;
        this.hmm = pairHMM;
        this.nStates = pairHMM.nStates();
    }

    public double logSumProduct() {
        double result = this.prefixLogSumProduct(this.endState(), this.str1.length(), this.str2.length());
        double other = this.suffixLogSumProduct(this.startState(), this.str1.length(), this.str2.length());
        if (!MathUtils.close(other, result)) {
            throw new RuntimeException(result + " vs " + other);
        }
        return result;
    }

    public double logSumProduct(int state1, int state2, int x, int y, int deltaX, int deltaY) {
        return this.prefixLogSumProduct(state1, x, y) + this.hmm.logWeight(state1, state2, x, y, deltaX, deltaY) + this.suffixLogSumProduct(state2, this.str1.length() - x - deltaX, this.str2.length() - y - deltaY);
    }

    public String alignmentMatrixToString() {
        DescriptiveStatistics stats = new DescriptiveStatistics();
        for (int x = 0; x < this.str1.length(); ++x) {
            for (int y = 0; y < this.str2.length(); ++y) {
                double expd = Math.exp(this.logPosteriorAlignment(x, y));
                stats.addValue(expd);
            }
        }
        double small = stats.getPercentile(25.0);
        double med = stats.getPercentile(50.0);
        double big = stats.getPercentile(75.0);
        Table t = new Table();
        for (int x = 0; x < this.str1.length(); ++x) {
            for (int y = 0; y < this.str2.length(); ++y) {
                double expd = Math.exp(this.logPosteriorAlignment(x, y));
                if (expd > big) {
                    t.set(x, y, "#");
                    continue;
                }
                if (expd > med) {
                    t.set(x, y, "+");
                    continue;
                }
                if (expd > small) {
                    t.set(x, y, "-");
                    continue;
                }
                t.set(x, y, " ");
            }
        }
        t.setBorder(false);
        return t.toString();
    }

    public double logPosteriorAlignment(int x, int y) {
        double logSum = Double.NEGATIVE_INFINITY;
        for (int s1 = 0; s1 < this.nStates; ++s1) {
            for (int s2 = 0; s2 < this.nStates; ++s2) {
                logSum = SloppyMath.logAdd(logSum, this.logSumProduct(s1, s2, x, y, 1, 1));
            }
        }
        return logSum - this.logSumProduct();
    }

    public double prefixLogSumProduct(int finalState, int x, int y) {
        if (!this.fwdInitialized) {
            this.computeForward();
        }
        return this.prefix[finalState][x][y];
    }

    private void computeForward() {
        int len1 = this.str1.length();
        int len2 = this.str2.length();
        if (this.startState() != 0 || this.endState() != 0) {
            throw new RuntimeException();
        }
        double[][][] prefix = this.prefix = new double[this.hmm.nStates()][len1 + 1][len2 + 1];
        for (int x = 0; x <= len1; ++x) {
            for (int y = 0; y <= len2; ++y) {
                for (int finalState = 0; finalState < this.nStates; ++finalState) {
                    double result = Double.NEGATIVE_INFINITY;
                    if (x == 0 && y == 0) {
                        result = finalState == this.startState() ? 0.0 : Double.NEGATIVE_INFINITY;
                    } else {
                        int previousState;
                        if (x > 0) {
                            for (previousState = 0; previousState < this.nStates; ++previousState) {
                                result = SloppyMath.logAdd(result, prefix[previousState][x - 1][y] + this.hmm.logWeight(previousState, finalState, x - 1, y, 1, 0));
                            }
                        }
                        if (y > 0) {
                            for (previousState = 0; previousState < this.nStates; ++previousState) {
                                result = SloppyMath.logAdd(result, prefix[previousState][x][y - 1] + this.hmm.logWeight(previousState, finalState, x, y - 1, 0, 1));
                            }
                        }
                        if (x > 0 && y > 0) {
                            for (previousState = 0; previousState < this.nStates; ++previousState) {
                                result = SloppyMath.logAdd(result, prefix[previousState][x - 1][y - 1] + this.hmm.logWeight(previousState, finalState, x - 1, y - 1, 1, 1));
                            }
                        }
                    }
                    prefix[finalState][x][y] = result;
                }
            }
        }
        this.fwdInitialized = true;
    }

    public double suffixLogSumProduct(int firstState, int x, int y) {
        if (!this.bwdInitialized) {
            this.computeBackward();
        }
        return this.suffix[firstState][x][y];
    }

    private void computeBackward() {
        int len1 = this.str1.length();
        int len2 = this.str2.length();
        double[][][] suffix = this.suffix = new double[this.hmm.nStates()][len1 + 1][len2 + 1];
        for (int x = 0; x <= len1; ++x) {
            for (int y = 0; y <= len2; ++y) {
                for (int firstState = 0; firstState < this.nStates; ++firstState) {
                    double result = Double.NEGATIVE_INFINITY;
                    if (x == 0 && y == 0) {
                        result = firstState == this.endState() ? 0.0 : Double.NEGATIVE_INFINITY;
                    } else {
                        int nextState;
                        if (x > 0) {
                            for (nextState = 0; nextState < this.nStates; ++nextState) {
                                result = SloppyMath.logAdd(result, suffix[nextState][x - 1][y] + this.hmm.logWeight(firstState, nextState, len1 - x, len2 - y, 1, 0));
                            }
                        }
                        if (y > 0) {
                            for (nextState = 0; nextState < this.nStates; ++nextState) {
                                result = SloppyMath.logAdd(result, suffix[nextState][x][y - 1] + this.hmm.logWeight(firstState, nextState, len1 - x, len2 - y, 0, 1));
                            }
                        }
                        if (x > 0 && y > 0) {
                            for (nextState = 0; nextState < this.nStates; ++nextState) {
                                result = SloppyMath.logAdd(result, suffix[nextState][x - 1][y - 1] + this.hmm.logWeight(firstState, nextState, len1 - x, len2 - y, 1, 1));
                            }
                        }
                    }
                    suffix[firstState][x][y] = result;
                }
            }
        }
        this.bwdInitialized = true;
    }

    public double suffixLogMaxProduct(int firstState, int x, int y) {
        if (!this.bwdMaxInitialized) {
            this.computeMaxBackward();
        }
        return this.maxSuffix[firstState][x][y];
    }

    private void computeMaxBackward() {
        double[][][] maxSuffix = this.maxSuffix = new double[this.hmm.nStates()][this.str1.length() + 1][this.str2.length() + 1];
        int len1 = this.str1.length();
        int len2 = this.str2.length();
        for (int x = 0; x <= len1; ++x) {
            for (int y = 0; y <= len2; ++y) {
                for (int firstState = 0; firstState < this.nStates; ++firstState) {
                    double result = Double.NEGATIVE_INFINITY;
                    if (x == 0 && y == 0) {
                        result = firstState == this.endState() ? 0.0 : Double.NEGATIVE_INFINITY;
                    } else {
                        int nextState;
                        if (x > 0) {
                            for (nextState = 0; nextState < this.nStates; ++nextState) {
                                result = Math.max(result, maxSuffix[nextState][x - 1][y] + this.hmm.logWeight(firstState, nextState, len1 - x, len2 - y, 1, 0));
                            }
                        }
                        if (y > 0) {
                            for (nextState = 0; nextState < this.nStates; ++nextState) {
                                result = Math.max(result, maxSuffix[nextState][x][y - 1] + this.hmm.logWeight(firstState, nextState, len1 - x, len2 - y, 0, 1));
                            }
                        }
                        if (x > 0 && y > 0) {
                            for (nextState = 0; nextState < this.nStates; ++nextState) {
                                result = Math.max(result, maxSuffix[nextState][x - 1][y - 1] + this.hmm.logWeight(firstState, nextState, len1 - x, len2 - y, 1, 1));
                            }
                        }
                    }
                    maxSuffix[firstState][x][y] = result;
                }
            }
        }
        this.bwdMaxInitialized = true;
    }

    public static DerivationTree.Derivation removeBoundary(DerivationTree.Derivation d, char bound) {
        int oldLastBotIndex = d.getCurrentWord().length() - 1;
        int oldLastTopIndex = d.getAncestorWord().length() - 1;
        if (d.getCurrentWord().charAt(oldLastBotIndex) != bound || d.getAncestorWord().charAt(oldLastTopIndex) != bound || d.ancestor(oldLastBotIndex) != oldLastTopIndex) {
            throw new RuntimeException();
        }
        int[] newAnc = new int[oldLastBotIndex];
        for (int i = 0; i < newAnc.length; ++i) {
            newAnc[i] = d.hasAncestor(i) ? d.ancestor(i) : -1;
        }
        return new DerivationTree.Derivation(newAnc, d.getAncestorWord().substring(0, oldLastTopIndex), d.getCurrentWord().substring(0, oldLastBotIndex));
    }

    public DerivationTree.Derivation viterbi(List<Integer> stateSequence) {
        return this._viterbiOrSample(stateSequence, false, null);
    }

    public DerivationTree.Derivation viterbi() {
        return this._viterbiOrSample(null, false, null);
    }

    public DerivationTree.Derivation sample(Random rand) {
        return this._viterbiOrSample(null, true, rand);
    }

    public DerivationTree.Derivation sample(Random rand, List<Integer> stateSequence) {
        return this._viterbiOrSample(stateSequence, true, rand);
    }

    private DerivationTree.Derivation _viterbiOrSample(List<Integer> stateSequence, boolean sample, Random rand) {
        double[] choices;
        if (stateSequence != null && stateSequence.size() > 0) {
            throw new RuntimeException();
        }
        int[] ancestors = new int[this.str2.length()];
        for (int i = 0; i < ancestors.length; ++i) {
            ancestors[i] = -2;
        }
        int x = this.str1.length();
        int y = this.str2.length();
        int previousState = this.startState();
        if (stateSequence != null) {
            stateSequence.add(this.startState());
        }
        CoordinatesPacker.MSCoordinatePacker cp = this.getPacker();
        double[] dArray = choices = sample ? new double[4 * this.nStates] : null;
        while (x > 0 || y > 0) {
            int argmax = -1;
            double max = Double.NEGATIVE_INFINITY;
            if (sample) {
                for (int i = 0; i < choices.length; ++i) {
                    choices[i] = Double.NEGATIVE_INFINITY;
                }
            }
            for (int currentState = 0; currentState < this.nStates; ++currentState) {
                int currentCoord;
                double current;
                if (x > 0) {
                    current = (sample ? this.suffixLogSumProduct(currentState, x - 1, y) : this.suffixLogMaxProduct(currentState, x - 1, y)) + this.hmm.logWeight(previousState, currentState, this.str1.length() - x, this.str2.length() - y, 1, 0);
                    currentCoord = cp.coord2int(1, 0, currentState);
                    if (sample) {
                        choices[currentCoord] = current;
                    }
                    if (current > max) {
                        max = current;
                        argmax = currentCoord;
                    }
                }
                if (y > 0) {
                    current = (sample ? this.suffixLogSumProduct(currentState, x, y - 1) : this.suffixLogMaxProduct(currentState, x, y - 1)) + this.hmm.logWeight(previousState, currentState, this.str1.length() - x, this.str2.length() - y, 0, 1);
                    currentCoord = cp.coord2int(0, 1, currentState);
                    if (sample) {
                        choices[currentCoord] = current;
                    }
                    if (current > max) {
                        max = current;
                        argmax = currentCoord;
                    }
                }
                if (x <= 0 || y <= 0) continue;
                current = (sample ? this.suffixLogSumProduct(currentState, x - 1, y - 1) : this.suffixLogMaxProduct(currentState, x - 1, y - 1)) + this.hmm.logWeight(previousState, currentState, this.str1.length() - x, this.str2.length() - y, 1, 1);
                currentCoord = cp.coord2int(1, 1, currentState);
                if (sample) {
                    choices[currentCoord] = current;
                }
                if (!(current > max)) continue;
                max = current;
                argmax = currentCoord;
            }
            if (Double.isInfinite(max)) {
                LogInfo.error("Fixme: encountered -infinity in viterbi");
                argmax = 0;
                throw new RuntimeException();
            }
            if (sample) {
                if (!NumUtils.expNormalize(choices)) {
                    throw new RuntimeException();
                }
                argmax = SampleUtils.sampleMultinomial(rand, choices);
            }
            int[] coords = cp.int2coord(argmax);
            if (stateSequence != null) {
                stateSequence.add(coords[2]);
            }
            int positionInTop = this.str1.length() - x;
            int positionInBot = this.str2.length() - y;
            if (coords[0] == 1 && coords[1] == 1) {
                ancestors[positionInBot] = positionInTop;
            } else if (coords[0] == 0 && coords[1] == 1) {
                ancestors[positionInBot] = -1;
            }
            x -= coords[0];
            y -= coords[1];
            previousState = coords[2];
        }
        return new DerivationTree.Derivation(ancestors, this.str1, this.str2);
    }

    private CoordinatesPacker.MSCoordinatePacker getPacker() {
        int[] sizes = new int[3];
        sizes[1] = 2;
        sizes[0] = 2;
        sizes[2] = this.nStates;
        return new CoordinatesPacker.MSCoordinatePacker(sizes);
    }

    public static void main(String[] args) {
        ExponentialFamily.ExponentialFamilyOptions expFamOptions = new ExponentialFamily.ExponentialFamilyOptions();
        expFamOptions.encodingType = SequenceType.BINARY;
        FeatureExtractor.FeatureOptions featureOptions = new FeatureExtractor.FeatureOptions();
        expFamOptions.initParams = "INTERNAL";
        expFamOptions.internal.setCount("q=0,h=1,state1=1&state2=1", -40.0);
        ExponentialFamily expFam = ExponentialFamily.createExpfam(new MaxentClassifier.MaxentOptions<Object>(), expFamOptions, featureOptions, null);
        System.out.println(expFam);
        String top = "a";
        String bot = "a";
        HetPairHMM hmm = expFam.getHMM(top, bot, null, null);
        System.out.println("----");
        System.out.println("===");
        hmm.viterbi(null);
        System.out.println("===");
        System.out.println(hmm.logSumProduct());
    }

    public String toString() {
        Table table = new Table(new Table.Populator(){

            @Override
            public void populate() {
                for (int i = 0; i < HetPairHMM.this.str1.length(); ++i) {
                    for (int j = 0; j < HetPairHMM.this.str2.length(); ++j) {
                        this.set(i, j, Math.exp(HetPairHMM.this.logPosteriorAlignment(i, j)));
                    }
                }
            }
        });
        return table.toString();
    }
}

