/*
 * Decompiled with CFR 0.152.
 */
package ev.hmm;

import ev.hmm.HetPairHMM;
import ev.hmm.HetPairHMMSpecification;
import fig.basic.StrUtils;
import goblin.DerivationTree;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.math.MtxUtils;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.CounterMap;

public final class SimpleAligner {
    private final double[] insLogWeight;
    private final double[] delLogWeight;
    private final double[][] subLogWeigths;
    private final int modLastTop;
    private final int modLastBot;
    public final String topS;
    public final String botS;
    private HetPairHMM _pairHmm = null;
    private double norm = Double.NaN;
    public final Character BOUND = Character.valueOf('#');
    int idx = 0;

    public SimpleAligner(double[] ins, double[] del, double[][] sub) {
        int topL = del.length;
        int botL = ins.length;
        this.topS = StrUtils.repeat("*", topL);
        this.botS = StrUtils.repeat("*", botL);
        if (del.length != sub.length || ins.length != sub[0].length) {
            throw new RuntimeException();
        }
        this.insLogWeight = ins;
        this.delLogWeight = del;
        this.subLogWeigths = sub;
        this.modLastTop = this.topS.length();
        this.modLastBot = this.botS.length();
    }

    public HetPairHMM getHMM() {
        this.ensureHMM();
        return this._pairHmm;
    }

    public double logSumProduct() {
        if (Double.isNaN(this.norm)) {
            this.norm = 0.0;
            for (double x : this.insLogWeight) {
                this.norm += x;
            }
            for (double x : this.delLogWeight) {
                this.norm += x;
            }
        }
        return this.getHMM().logSumProduct() + this.norm;
    }

    public DerivationTree.Derivation viterbi() {
        return HetPairHMM.removeBoundary(this.getHMM().viterbi(), this.BOUND.charValue());
    }

    public DerivationTree.Derivation sample(Random rand) {
        return HetPairHMM.removeBoundary(this.getHMM().sample(rand), this.BOUND.charValue());
    }

    private DerivationTree.Derivation sample(Random rand, List<Integer> fullDerivation) {
        return HetPairHMM.removeBoundary(this.getHMM().sample(rand, fullDerivation), this.BOUND.charValue());
    }

    public double pathLogProbability(DerivationTree.Derivation d) {
        double num = 0.0;
        for (int i = 0; i < d.getCurrentWord().length(); ++i) {
            if (!d.hasAncestor(i)) continue;
            num += this.subLogWeigths[d.ancestor(i)][i] - this.delLogWeight[d.ancestor(i)] - this.insLogWeight[i];
        }
        return num - this.getHMM().logSumProduct();
    }

    private void ensureHMM() {
        if (this._pairHmm != null) {
            return;
        }
        this._pairHmm = new HetPairHMM(this.topS + this.BOUND, this.botS + this.BOUND, new HmmAdaptor());
    }

    public static void main(String[] args) {
        double[][] subLogWeigths;
        double[] insLogWeight = new double[]{1.0E-4, 1.0, 2.0};
        MtxUtils.logInPlace(insLogWeight);
        double[] delLogWeight = new double[]{1.0, 1.0};
        MtxUtils.logInPlace(delLogWeight);
        for (double[] mtx : subLogWeigths = new double[][]{{1.0, 1.0, 3.0}, {1.0, 4.0, 1.0}}) {
            MtxUtils.logInPlace(mtx);
        }
        SimpleAligner aligner = new SimpleAligner(insLogWeight, delLogWeight, subLogWeigths);
        System.out.println(Math.exp(aligner.logSumProduct()));
        System.out.println(aligner.viterbi());
        Random rand = new Random(1L);
        Counter<DerivationTree.Derivation> derivs = new Counter<DerivationTree.Derivation>();
        CounterMap<DerivationTree.Derivation, ArrayList<Integer>> fullDerivations = new CounterMap<DerivationTree.Derivation, ArrayList<Integer>>();
        for (int i = 0; i < 100000; ++i) {
            ArrayList<Integer> fullDeriv = CollUtils.list();
            DerivationTree.Derivation deriv = aligner.sample(rand, fullDeriv);
            derivs.incrementCount(deriv, 1.0);
            fullDerivations.incrementCount(deriv, fullDeriv, 1.0);
        }
        derivs.normalize();
        double sum = 0.0;
        System.out.println("-------------------------------------------------------");
        for (DerivationTree.Derivation d : derivs) {
            System.out.println(d);
            System.out.println("N full derivs:" + fullDerivations.getCounter(d).size());
            double curAnalytic = Math.exp(aligner.pathLogProbability(d));
            sum += curAnalytic;
            System.out.println("MC:" + derivs.getCount(d) + "\tAnalytic:" + curAnalytic);
            System.out.println("-------------------------------------------------------");
        }
        System.out.println("\t\tSum:" + sum);
    }

    private final class HmmAdaptor
    implements HetPairHMMSpecification {
        private HmmAdaptor() {
        }

        @Override
        public double logWeight(int prevState, int currentState, int x, int y, int deltaX, int deltaY) {
            boolean isDel;
            boolean isSub = deltaX == 1 && deltaY == 1;
            boolean isIns = deltaX == 0 && deltaY == 1;
            boolean bl = isDel = deltaX == 1 && deltaY == 0;
            if (prevState == 2 && currentState == 1) {
                return Double.NEGATIVE_INFINITY;
            }
            if (currentState == 2 && isDel && x != SimpleAligner.this.modLastTop) {
                return 0.0;
            }
            if (currentState == 1 && isIns && y != SimpleAligner.this.modLastBot) {
                return 0.0;
            }
            if (currentState == 0 && isSub) {
                if (x == SimpleAligner.this.modLastTop && y == SimpleAligner.this.modLastBot) {
                    return 0.0;
                }
                if (x == SimpleAligner.this.modLastTop || y == SimpleAligner.this.modLastBot) {
                    return Double.NEGATIVE_INFINITY;
                }
                return SimpleAligner.this.subLogWeigths[x][y] - SimpleAligner.this.delLogWeight[x] - SimpleAligner.this.insLogWeight[y];
            }
            return Double.NEGATIVE_INFINITY;
        }

        @Override
        public int endState() {
            return 0;
        }

        @Override
        public int startState() {
            return 0;
        }

        @Override
        public int nStates() {
            return 3;
        }
    }
}

