/*
 * Decompiled with CFR 0.152.
 */
package ma;

import fig.basic.NumUtils;
import fig.basic.Option;
import fig.prob.SampleUtils;
import goblin.DerivationTree;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import ma.AffineGapAlignmentSampler;
import ma.SubstitutionMatrixLoader;
import nuts.math.MeasureZeroException;
import nuts.tui.Table;
import nuts.util.Arbre;
import nuts.util.Counter;
import pepper.Encodings;

public class LongGapAlignmentSampler {
    public static final double UNDEF = Double.NEGATIVE_INFINITY;
    public static final int INSERTED = -1;
    private final int[] top;
    private final int[] bottom;
    private final double[][] scores;
    private final double topGap;
    private final double bottomGap;
    private final int n;
    private final int m;
    private double[][][] backwardSumScores = null;
    private final String topStr;
    private final String bottomStr;
    private double[][] logScores;
    private double logTopGap;
    private double logBottomGap;
    private double[][] backwardMaxLogScores = null;
    private Random rand = null;
    public static final int UNDEF_INT = -666;

    public LongGapAlignmentSampler(String topStr, String bottomStr, LongGapAlignmentSamplerParams params) {
        this.checkConsistent(params.substitutionMatrix);
        this.topStr = topStr;
        this.bottomStr = bottomStr;
        this.top = params.enc.string2PhoneIds(topStr);
        this.bottom = params.enc.string2PhoneIds(bottomStr);
        this.scores = params.substitutionMatrix;
        this.topGap = params.topGap;
        this.bottomGap = params.bottomGap;
        this.n = this.top.length;
        this.m = this.bottom.length;
    }

    private void insureBackwardMax() {
        if (this.backwardMaxLogScores == null) {
            this.computeBackwardMaxLogScores();
        }
        if (this.backwardMaxLogScores[0][0] == Double.POSITIVE_INFINITY) {
            throw new ArithmeticPrecisionException("Overflow");
        }
    }

    public DerivationTree.Derivation mode() throws MeasureZeroException {
        this.insureBackwardMax();
        if (this.backwardMaxLogScores[0][0] == Double.NEGATIVE_INFINITY) {
            throw new MeasureZeroException("Alignment sampler (long gap version) ran on measure zero");
        }
        return this.backtrack(true);
    }

    private void insureBackwardSums() {
        if (this.backwardSumScores == null) {
            this.computeBackwardSumScores();
        }
        if (this.getBackwardSumScore(0, 0) == Double.POSITIVE_INFINITY) {
            throw new RuntimeException("Overflow");
        }
        if (this.getBackwardSumScore(0, 0) == 0.0 && this.getLogMaxScore() != Double.NEGATIVE_INFINITY) {
            throw new ArithmeticPrecisionException("Underflow when sampling alignment between " + this.topStr + " and " + this.bottomStr);
        }
    }

    public DerivationTree.Derivation sample(Random rand) throws MeasureZeroException {
        this.insureBackwardSums();
        if (this.getBackwardSumScore(0, 0) == 0.0) {
            throw new MeasureZeroException("Alignment sampler (long gap version) ran on measure zero");
        }
        this.rand = rand;
        return this.backtrack(false);
    }

    private DerivationTree.Derivation backtrack(boolean isMax) throws MeasureZeroException {
        int[] ancestors = LongGapAlignmentSampler.initAncestors(this.m);
        int topPosition = 0;
        for (int bottomPosition = 0; bottomPosition < this.m && topPosition < this.n; ++bottomPosition) {
            int n = ancestors[bottomPosition] = isMax ? this.max(bottomPosition, topPosition) : this.sample(this.rand, bottomPosition, topPosition);
            if (ancestors[bottomPosition] == -1) continue;
            topPosition = ancestors[bottomPosition] + 1;
        }
        return new DerivationTree.Derivation(ancestors, this.topStr, this.bottomStr);
    }

    private static int[] initAncestors(int bottomLength) {
        int[] ancestors = new int[bottomLength];
        for (int i = 0; i < bottomLength; ++i) {
            ancestors[i] = -1;
        }
        return ancestors;
    }

    public double getSumPr() {
        this.insureBackwardSums();
        return this.getBackwardSumScore(0, 0);
    }

    public double getLogMaxScore() {
        this.insureBackwardMax();
        return this.backwardMaxLogScores[0][0];
    }

    private int sample(Random rand, int bottomPosition, int topPosition) {
        double[] prs = new double[this.n - topPosition + 1];
        double currentGapPenalty = 1.0;
        for (int topPosition2 = topPosition; topPosition2 < this.n; ++topPosition2) {
            prs[topPosition2 - topPosition] = this.getBackwardSumScore(topPosition2, bottomPosition, true, true) * currentGapPenalty;
            currentGapPenalty *= this.topGap;
        }
        prs[prs.length - 1] = this.getBackwardSumScore(topPosition, bottomPosition, false, false) + this.getBackwardSumScore(topPosition, bottomPosition, true, false);
        NumUtils.normalize(prs);
        int sample = SampleUtils.sampleMultinomial(rand, prs);
        if (sample == prs.length - 1) {
            return -1;
        }
        return sample + topPosition;
    }

    private int max(int bottomPosition, int topPosition) {
        int maxTopPositionArgument = -666;
        double maxTopPositionValue = Double.NEGATIVE_INFINITY;
        double currentGapPenalty = 0.0;
        for (int topPosition2 = topPosition; topPosition2 < this.n; ++topPosition2) {
            double current = this.alignmentLogCost(topPosition2, bottomPosition) + this.backwardMaxLogScores[topPosition2 + 1][bottomPosition + 1] + currentGapPenalty;
            if (current > maxTopPositionValue) {
                maxTopPositionValue = current;
                maxTopPositionArgument = topPosition2;
            }
            currentGapPenalty += this.logTopGap;
        }
        double gapScore = this.logBottomGap + this.backwardMaxLogScores[topPosition][bottomPosition + 1];
        if (gapScore > maxTopPositionValue) {
            return -1;
        }
        return maxTopPositionArgument;
    }

    private void checkConsistent(double[][] scores) {
        for (int i = 0; i < scores.length; ++i) {
            for (int j = 0; j < scores[0].length; ++j) {
                if (scores[i][j] < 0.0) {
                    throw new RuntimeException();
                }
                if (scores[i][j] == scores[j][i]) continue;
                throw new RuntimeException("Score " + i + "," + j + " is " + scores[i][j] + " while the symmetric entry is " + scores[j][i]);
            }
        }
    }

    private double getBackwardSumScore(int topPosition, int bottomPosition, boolean isTopAligned, boolean isBottomAligned) {
        if (topPosition > this.n || bottomPosition > this.m) {
            return 0.0;
        }
        int alignmentIndex = this.getAlignmnentIndex(isTopAligned, isBottomAligned);
        double result = this.backwardSumScores[topPosition][bottomPosition][alignmentIndex];
        return result;
    }

    private void hardSetBackwardSumScore(int topPosition, int bottomPosition, int alignIndex, double value) {
        this.backwardSumScores[topPosition][bottomPosition][alignIndex] = value;
    }

    private void setBackwardSumScore(int topPosition, int bottomPosition, boolean isTopAligned, boolean isBottomAligned, double value) {
        int alignIndex = this.getAlignmnentIndex(isTopAligned, isBottomAligned);
        this.backwardSumScores[topPosition][bottomPosition][alignIndex] = value;
    }

    private double getBackwardSumScore(int topPosition, int bottomPosition) {
        if (topPosition == this.n + 1 && bottomPosition == this.m + 1) {
            return 1.0;
        }
        return this.getBackwardSumScore(topPosition, bottomPosition, false, false) + this.getBackwardSumScore(topPosition, bottomPosition, false, true) + this.getBackwardSumScore(topPosition, bottomPosition, true, false) + this.getBackwardSumScore(topPosition, bottomPosition, true, true);
    }

    private void initBackwardSumScores() {
        this.backwardSumScores = new double[this.n + 1][this.m + 1][4];
    }

    private void initBackwardMaxLogScores() {
        int j;
        int i;
        this.logTopGap = Math.log(this.topGap);
        this.logBottomGap = Math.log(this.bottomGap);
        int scoreSize = this.scores.length;
        this.logScores = new double[scoreSize][scoreSize];
        for (i = 0; i < scoreSize; ++i) {
            for (j = 0; j < scoreSize; ++j) {
                this.logScores[i][j] = Math.log(this.scores[i][j]);
            }
        }
        this.backwardMaxLogScores = new double[this.n + 1][this.m + 1];
        for (i = 0; i < this.n + 1; ++i) {
            for (j = 0; j < this.m + 1; ++j) {
                this.backwardMaxLogScores[i][j] = Double.NEGATIVE_INFINITY;
            }
        }
        for (i = 0; i < this.n; ++i) {
            this.backwardMaxLogScores[i][this.m] = (double)(this.n - i) * this.logTopGap;
        }
        for (int j2 = 0; j2 < this.m; ++j2) {
            this.backwardMaxLogScores[this.n][j2] = (double)(this.m - j2) * this.logBottomGap;
        }
        this.backwardMaxLogScores[this.n][this.m] = 0.0;
    }

    private double[][] getBackwardMaxScores() {
        double[][] result = new double[this.n + 1][this.m + 1];
        for (int i = 0; i < this.n + 1; ++i) {
            for (int j = 0; j < this.m + 1; ++j) {
                result[i][j] = Math.exp(this.backwardMaxLogScores[i][j]);
            }
        }
        return result;
    }

    private void computeBackwardMaxLogScores() {
        this.initBackwardMaxLogScores();
        for (int top = this.n - 1; top >= 0; --top) {
            for (int bottom = this.m - 1; bottom >= 0; --bottom) {
                this.computeBackwardMaxLogScores(top, bottom);
            }
        }
    }

    private void computeBackwardMaxLogScores(int topPosition, int bottomPosition) {
        double bottomDelLogCost;
        double topDelLogCost = this.logTopGap + this.backwardMaxLogScores[topPosition + 1][bottomPosition];
        double max = this.alignmentLogCost(topPosition, bottomPosition) + this.backwardMaxLogScores[topPosition + 1][bottomPosition + 1];
        if (topDelLogCost > max) {
            max = topDelLogCost;
        }
        if ((bottomDelLogCost = this.logBottomGap + this.backwardMaxLogScores[topPosition][bottomPosition + 1]) > max) {
            max = bottomDelLogCost;
        }
        this.backwardMaxLogScores[topPosition][bottomPosition] = max;
    }

    private void computeBackwardSumScores() {
        this.initBackwardSumScores();
        for (int top = this.n; top >= 0; --top) {
            for (int bottom = this.m; bottom >= 0; --bottom) {
                for (int topAligned = 0; topAligned < 2; ++topAligned) {
                    for (int bottomAligned = 0; bottomAligned < 2; ++bottomAligned) {
                        this.computeBackwardSumScore(top, bottom, topAligned != 0, bottomAligned != 0);
                    }
                }
            }
        }
    }

    private void computeBackwardSumScore(int topPosition, int bottomPosition, boolean isTopAligned, boolean isBottomAligned) {
        double result = isTopAligned && isBottomAligned ? this.alignmentScore(topPosition, bottomPosition) * this.getBackwardSumScore(topPosition + 1, bottomPosition + 1) : (!isTopAligned && !isBottomAligned ? this.computeUnalignedBackwardSumScore(topPosition, bottomPosition) : (!isTopAligned && isBottomAligned ? this.computeTopUnalignedBackwardSumScore(topPosition, bottomPosition, false) : this.computeTopUnalignedBackwardSumScore(bottomPosition, topPosition, true)));
        this.setBackwardSumScore(topPosition, bottomPosition, isTopAligned, isBottomAligned, result);
    }

    private double computeTopUnalignedBackwardSumScore(int topPosition, int bottomPosition, boolean swappedTopBottom) {
        double result = 0.0;
        double currentGapScore = swappedTopBottom ? this.bottomGap : this.topGap;
        for (int topAlignedPosition = topPosition + 1; topAlignedPosition < (swappedTopBottom ? this.m : this.n) + 1; ++topAlignedPosition) {
            result += (swappedTopBottom ? this.getBackwardSumScore(bottomPosition, topAlignedPosition, true, true) : this.getBackwardSumScore(topAlignedPosition, bottomPosition, true, true)) * currentGapScore;
            currentGapScore *= swappedTopBottom ? this.bottomGap : this.topGap;
        }
        return result;
    }

    private double computeUnalignedBackwardSumScore(int topPosition, int bottomPosition) {
        double result = 0.0;
        double squaredGap = this.topGap * this.bottomGap;
        result += squaredGap * this.getBackwardSumScore(topPosition + 1, bottomPosition + 1, true, true);
        result += squaredGap * this.getBackwardSumScore(topPosition + 1, bottomPosition + 1, true, false);
        result += squaredGap * this.getBackwardSumScore(topPosition + 1, bottomPosition + 1, false, true);
        return result += squaredGap * this.getBackwardSumScore(topPosition + 1, bottomPosition + 1, false, false);
    }

    private int getAlignmnentIndex(boolean isTopAligned, boolean isBottomAligned) {
        if (isTopAligned && isBottomAligned) {
            return 0;
        }
        if (isTopAligned && !isBottomAligned) {
            return 1;
        }
        if (!isTopAligned && isBottomAligned) {
            return 2;
        }
        if (!isTopAligned && !isBottomAligned) {
            return 3;
        }
        throw new RuntimeException();
    }

    private double alignmentScore(int topPosition, int bottomPosition) {
        if (topPosition == this.n || bottomPosition == this.m) {
            if (topPosition == this.n && bottomPosition == this.m) {
                return 1.0;
            }
            return 0.0;
        }
        return this.scores[this.top[topPosition]][this.bottom[bottomPosition]];
    }

    private double alignmentLogCost(int topPosition, int bottomPosition) {
        if (topPosition == this.n || bottomPosition == this.m) {
            if (topPosition == this.n && bottomPosition == this.m) {
                return 0.0;
            }
            return Double.NEGATIVE_INFINITY;
        }
        return this.logScores[this.top[topPosition]][this.bottom[bottomPosition]];
    }

    public static void main(String[] args) throws MeasureZeroException {
        Encodings enc = Encodings.toyCtxFreeEncodings(2);
        double[][] scores = new double[][]{{4.0, 3.0}, {3.0, 5.0}};
        String top = "abab";
        String bottom = "ababb";
        double topGap = 3.0;
        double bottomGap = 2.0;
        LongGapAlignmentSamplerParams params = new LongGapAlignmentSamplerParams(scores, topGap, bottomGap, enc);
        LongGapAlignmentSampler lgas = new LongGapAlignmentSampler(top, bottom, params);
        DerivationTree.Derivation d = lgas.sample(new Random());
        lgas.computeBackwardSumScores();
        System.out.println("Sum for DP: " + lgas.getBackwardSumScore(0, 0));
        System.out.println("Arg max for DP: \n" + lgas.mode());
        double dpScore = Math.exp(lgas.getLogMaxScore());
        System.out.println("Max for DP: " + dpScore);
        double realScore = Math.exp(LongGapAlignmentSampler.logScore(lgas.mode(), scores, topGap, bottomGap, enc));
        System.out.println("Real score of max: " + realScore + ", dp max(): " + dpScore);
        Counter<DerivationTree.Derivation> samples = new Counter<DerivationTree.Derivation>();
        double[][] backMaxScores = lgas.getBackwardMaxScores();
        Random rand = new Random();
        for (int i = 0; i < 1000000; ++i) {
            samples.incrementCount(lgas.sample(rand), 1.0);
        }
        System.out.println("Samples:");
        LongGapAlignmentSampler.printCounter(samples);
        System.out.println("---");
        Set<DerivationTree.Derivation> allDerivs = LongGapAlignmentSampler.allDerivations(top, bottom);
        double sum = 0.0;
        double max = Double.NEGATIVE_INFINITY;
        DerivationTree.Derivation argMax = null;
        for (DerivationTree.Derivation deriv : allDerivs) {
            double current = Math.exp(LongGapAlignmentSampler.logScore(deriv, scores, topGap, bottomGap, enc));
            if (current > max) {
                max = current;
                argMax = deriv;
            }
            sum += current;
        }
        System.out.println("Sum for naive: " + sum);
        System.out.println("Arg max for naive: \n" + argMax);
        System.out.println("Max for naive: " + max);
        System.out.println("Samples:");
        LongGapAlignmentSampler.printCounter(LongGapAlignmentSampler.score(allDerivs, scores, topGap, bottomGap, enc));
    }

    public static void printCounter(Counter<DerivationTree.Derivation> counter) {
        counter.normalize();
        for (DerivationTree.Derivation key : counter) {
            System.out.println(key.alignmentToString() + ":" + counter.getCount(key));
        }
    }

    public static Counter<DerivationTree.Derivation> score(Set<DerivationTree.Derivation> ds, double[][] scores, double topGap, double bottomGap, Encodings enc) {
        return LongGapAlignmentSampler.score(ds, scores, topGap, bottomGap, topGap, bottomGap, enc);
    }

    public static double logScore(DerivationTree.Derivation d, LongGapAlignmentSamplerParams params) {
        return LongGapAlignmentSampler.logScore(d, params.substitutionMatrix, params.topGap, params.bottomGap, params.enc);
    }

    public static double logScore(Arbre<DerivationTree.DerivationNode> arbre, LongGapAlignmentSamplerParams params) {
        double logSum = arbre.isRoot() ? 0.0 : LongGapAlignmentSampler.logScore(arbre.getContents().getDerivation(), params);
        for (Arbre<DerivationTree.DerivationNode> child : arbre.getChildren()) {
            logSum += LongGapAlignmentSampler.logScore(child, params);
        }
        return logSum;
    }

    public static double logScore(DerivationTree.Derivation d, double[][] scores, double topGap, double bottomGap, Encodings enc) {
        return LongGapAlignmentSampler.logScore(d, scores, topGap, bottomGap, topGap, bottomGap, enc);
    }

    public static Counter<DerivationTree.Derivation> score(Set<DerivationTree.Derivation> ds, double[][] scores, double topGap, double bottomGap, double topGapExtend, double bottomGapExtend, Encodings enc) {
        Counter<DerivationTree.Derivation> result = new Counter<DerivationTree.Derivation>();
        for (DerivationTree.Derivation d : ds) {
            result.setCount(d, Math.exp(LongGapAlignmentSampler.logScore(d, scores, topGap, bottomGap, topGapExtend, bottomGapExtend, enc)));
        }
        return result;
    }

    public static double logScore(DerivationTree.Derivation d, double[][] scores, double topGap, double bottomGap, double topGapExtend, double bottomGapExtend, Encodings enc) {
        double logSum = 0.0;
        for (int i = 0; i < d.getCurrentWord().length(); ++i) {
            int bottomCharId = enc.char2PhoneId(d.getCurrentWord().charAt(i));
            if (d.hasAncestor(i)) {
                int topIndex = d.ancestor(i);
                int topCharId = enc.char2PhoneId(d.getAncestorWord().charAt(topIndex));
                logSum += Math.log(scores[bottomCharId][topCharId]);
                continue;
            }
            if (i > 0 && !d.hasAncestor(i - 1)) {
                logSum += Math.log(bottomGapExtend);
                continue;
            }
            logSum += Math.log(bottomGap);
        }
        DerivationTree.Derivation inverseD = d.invert();
        for (int i = 0; i < inverseD.getCurrentWord().length(); ++i) {
            if (inverseD.hasAncestor(i)) continue;
            if (i > 0 && !inverseD.hasAncestor(i - 1)) {
                logSum += Math.log(topGapExtend);
                continue;
            }
            logSum += Math.log(topGap);
        }
        return logSum;
    }

    public static Set<DerivationTree.Derivation> allDerivations(String topStr, String bottomStr) {
        Set<List<Integer>> allDerivMaps = LongGapAlignmentSampler.allDerivations(topStr.length(), bottomStr.length());
        HashSet<DerivationTree.Derivation> result = new HashSet<DerivationTree.Derivation>();
        for (List<Integer> derivationMap : allDerivMaps) {
            int[] array = new int[derivationMap.size()];
            for (int i = 0; i < array.length; ++i) {
                array[i] = derivationMap.get(i);
            }
            result.add(new DerivationTree.Derivation(array, topStr, bottomStr));
        }
        return result;
    }

    public static Set<List<Integer>> allDerivations(int topLength, int bottomLength) {
        HashSet<List<Integer>> result = new HashSet<List<Integer>>();
        if (bottomLength == 0) {
            result.add(new ArrayList());
            return result;
        }
        if (topLength == 0) {
            ArrayList<Integer> list = new ArrayList<Integer>();
            for (int i = 0; i < bottomLength; ++i) {
                list.add(-1);
            }
            result.add(list);
            return result;
        }
        result.addAll(LongGapAlignmentSampler.allDerivations(topLength - 1, bottomLength));
        Set<List<Integer>> truncBotRec = LongGapAlignmentSampler.allDerivations(topLength, bottomLength - 1);
        LongGapAlignmentSampler.append(truncBotRec, -1);
        result.addAll(truncBotRec);
        Set<List<Integer>> truncTopBot = LongGapAlignmentSampler.allDerivations(topLength - 1, bottomLength - 1);
        LongGapAlignmentSampler.append(truncTopBot, topLength - 1);
        result.addAll(truncTopBot);
        return result;
    }

    private static void append(Set<List<Integer>> truncBotRec, int inserted) {
        for (List<Integer> item : truncBotRec) {
            item.add(inserted);
        }
    }

    public static class ArithmeticPrecisionException
    extends RuntimeException {
        private static final long serialVersionUID = 1L;

        public ArithmeticPrecisionException(String msg) {
            super(msg);
        }
    }

    public static class LongGapAlignmentSamplerParams
    implements AffineGapAlignmentSampler.GapAlignmentParams {
        public final double[][] substitutionMatrix;
        public final double topGap;
        public final double bottomGap;
        public final Encodings enc;
        @Option
        public static String matrixName = "DAYHOFF";
        @Option
        public static double exponent = 1.0;
        @Option
        public static double gapPr = 1.0E-4;

        public LongGapAlignmentSamplerParams(double[][] substitutionMatrix, double topGap, double bottomGap, Encodings enc) {
            this.substitutionMatrix = substitutionMatrix;
            this.bottomGap = bottomGap;
            this.topGap = topGap;
            this.enc = enc;
        }

        public static LongGapAlignmentSamplerParams createDefaultParams() {
            Encodings enc = Encodings.proteinEncodings(false);
            try {
                double[][] additiveMatrix = SubstitutionMatrixLoader.loadMatrix(matrixName, enc);
                double[][] expMatrix = SubstitutionMatrixLoader.exp(additiveMatrix, exponent);
                return new LongGapAlignmentSamplerParams(expMatrix, gapPr, gapPr, enc);
            }
            catch (IOException ioe) {
                throw new RuntimeException(ioe);
            }
        }

        public String toString() {
            return "Substitutions:\n" + new Table(new Table.Populator(){

                @Override
                public void populate() {
                    for (int i = 0; i < substitutionMatrix.length; ++i) {
                        char cChar = enc.phoneId2Char(i);
                        this.set(0, i + 1, "" + cChar);
                        this.set(i + 1, 0, "" + cChar);
                        for (int j = 0; j < substitutionMatrix.length; ++j) {
                            this.set(i + 1, j + 1, substitutionMatrix[i][j]);
                        }
                    }
                }
            }).toString() + "\nGap score (top, bottom): " + this.topGap + ", " + this.bottomGap;
        }

        @Override
        public Encodings getEncodings() {
            return this.enc;
        }
    }
}

