/*
 * Decompiled with CFR 0.152.
 */
package automata;

import Jama.Matrix;
import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.lang.ArrayUtils;
import nuts.math.MtxUtils;
import nuts.tui.Table;
import nuts.util.CoordinatesPacker;
import nuts.util.MathUtils;

public final class Automaton {
    public final Matrix[] transitions;
    private Matrix _closure = null;
    public final int size;
    public final int alphSize;
    private final int endState;
    private double[][] _samplingMtx = null;
    private final CoordinatesPacker.MSCoordinatePacker cp;

    public Automaton(Matrix[] transitions, Matrix epsilonTransition) {
        this.size = transitions[0].getColumnDimension();
        this.alphSize = transitions.length;
        this.transitions = this.removeEpsilons(transitions, epsilonTransition);
        this.cp = new CoordinatesPacker.MSCoordinatePacker(new int[]{this.alphSize, this.size});
        this.endState = this.size - 1;
    }

    private Matrix[] removeEpsilons(Matrix[] transitions, Matrix epsilonTransition) {
        if (epsilonTransition == null) {
            return transitions;
        }
        Matrix star = MtxUtils.star(epsilonTransition);
        for (int i = 0; i < this.alphSize; ++i) {
            transitions[i] = star.times(transitions[i]);
        }
        return transitions;
    }

    public static Automaton pointwiseMultiply(Automaton a1, Automaton a2) {
        if (a1.alphSize != a2.alphSize) {
            throw new RuntimeException();
        }
        Matrix[] transitions = new Matrix[a1.alphSize];
        for (int i = 0; i < transitions.length; ++i) {
            transitions[i] = MtxUtils.kronecker(a1.transitions[i], a2.transitions[i]);
        }
        return new Automaton(transitions, null);
    }

    public Matrix closure() {
        if (this._closure != null) {
            return this._closure;
        }
        Matrix sum = MtxUtils.zeroes(this.size);
        for (Matrix m : this.transitions) {
            sum.plusEquals(m);
        }
        this._closure = MtxUtils.star(sum);
        return this._closure;
    }

    public double norm() {
        return this.closure().get(0, this.endState);
    }

    public static Automaton marginalize(Automaton leafFactor, Matrix[][] transducer) {
        Matrix[] transitions = new Matrix[leafFactor.alphSize];
        Matrix epsilon = null;
        for (int i = 0; i < leafFactor.alphSize + 1; ++i) {
            Matrix sum = MtxUtils.zeroes(leafFactor.size * transducer[0][0].getRowDimension());
            for (int j = 0; j < leafFactor.alphSize + 1; ++j) {
                sum.plusEquals(MtxUtils.kronecker(transducer[i][j], j == leafFactor.alphSize ? MtxUtils.id(leafFactor.size) : leafFactor.transitions[j]));
            }
            if (i < leafFactor.alphSize) {
                transitions[i] = sum;
                continue;
            }
            epsilon = sum;
        }
        return new Automaton(transitions, epsilon);
    }

    private double[][] samplingMtx() {
        if (this._samplingMtx != null) {
            return this._samplingMtx;
        }
        this._samplingMtx = new double[this.size][this.alphSize * this.size];
        Matrix closure = this.closure();
        for (int ini = 0; ini < this.size; ++ini) {
            for (int l = 0; l < this.alphSize; ++l) {
                for (int nxt = 0; nxt < this.size; ++nxt) {
                    this._samplingMtx[ini][this.cp.coord2int((int[])new int[]{l, nxt})] = this.transitions[l].get(ini, nxt) * closure.get(nxt, this.endState);
                }
            }
            NumUtils.normalize(this._samplingMtx[ini]);
        }
        return this._samplingMtx;
    }

    public List<Integer> sample(Random rand) {
        ArrayList<Integer> result = new ArrayList<Integer>();
        int currentState = 0;
        double[][] mtx = this.samplingMtx();
        while (true) {
            int idx = SampleUtils.sampleMultinomial(rand, mtx[currentState]);
            int[] coord = this.cp.int2coord(idx);
            int currentSymbol = coord[0];
            currentState = coord[1];
            if (currentState == this.endState) {
                return result;
            }
            result.add(currentSymbol);
        }
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        for (int i = 0; i < this.alphSize; ++i) {
            result.append(Table.toString(this.transitions[i]) + "\n");
        }
        return result.toString();
    }

    public static Automaton sequenceIndicator(int alphSize, int[] symbols) {
        int i;
        Matrix[] transitions = new Matrix[alphSize];
        int size = symbols.length + 2;
        for (i = 0; i < transitions.length; ++i) {
            transitions[i] = MtxUtils.zeroes(size);
        }
        for (i = 0; i < symbols.length; ++i) {
            if (symbols[i] == 0) {
                throw new RuntimeException("Symbol 0 reserved for string boundaries");
            }
            transitions[symbols[i]].set(i, i + 1, 1.0);
        }
        transitions[0].set(symbols.length, symbols.length + 1, 1.0);
        return new Automaton(transitions, null);
    }

    public static Automaton geometricUnigram(double[] prs) {
        if (!MathUtils.isProb(prs)) {
            throw new RuntimeException();
        }
        int alphSize = prs.length;
        Matrix[] transitions = new Matrix[alphSize];
        for (int i = 0; i < transitions.length; ++i) {
            transitions[i] = MtxUtils.zeroes(2);
            transitions[i].set(0, i == 0 ? 1 : 0, prs[i]);
        }
        return new Automaton(transitions, null);
    }

    public static LevPot uniformLevPot(int alephSize) {
        double pr = 1.0 / (double)(alephSize - 1);
        double[] insert = new double[alephSize];
        double[] delete = new double[alephSize];
        double[][] sub = new double[alephSize][alephSize];
        for (int i = 1; i < alephSize; ++i) {
            insert[i] = 0.5 * pr;
            delete[i] = 0.5;
            for (int j = 1; j < alephSize; ++j) {
                sub[i][j] = pr;
            }
        }
        sub[0][0] = 1.0;
        return new LevPot(insert, delete, sub);
    }

    public static LevPot idLevPot(int alephSize) {
        double[] insert = new double[alephSize];
        double[] delete = new double[alephSize];
        double[][] sub = new double[alephSize][alephSize];
        for (int i = 0; i < alephSize; ++i) {
            sub[i][i] = 1.0;
        }
        return new LevPot(insert, delete, sub);
    }

    public static void main(String[] args) {
        Random rand = new Random(1L);
        int[] s1 = new int[]{1, 2, 1, 2, 4, 3};
        int[] s2 = new int[]{1, 2, 2};
        double[] uni = new double[]{0.2, 0.2, 0.2, 0.2, 0.2};
        int alephSize = 5;
        Automaton obs1 = Automaton.sequenceIndicator(5, s1);
        Automaton obs2 = Automaton.sequenceIndicator(5, s2);
        LevPot levPot = Automaton.uniformLevPot(5);
        Matrix[][] directArrow = levPot.transitions();
        Automaton marg = Automaton.marginalize(obs1, directArrow);
        Automaton marg2 = Automaton.marginalize(obs2, directArrow);
        Automaton geoUni = Automaton.geometricUnigram(uni);
        System.out.println("norm of geo:" + geoUni.norm());
        Automaton post = Automaton.pointwiseMultiply(marg2, Automaton.pointwiseMultiply(geoUni, marg));
        System.out.println("norm of marginal:" + post.norm());
        for (int i = 0; i < 10; ++i) {
            System.out.println("Sample:" + post.sample(rand));
        }
        int N = directArrow.length;
        Matrix[][] inverseArrow = new Matrix[N][N];
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                inverseArrow[i][j] = directArrow[j][i];
            }
        }
        Automaton post2 = Automaton.marginalize(geoUni, inverseArrow);
        post2 = Automaton.marginalize(post2, inverseArrow);
        System.out.println(post2);
        System.out.println("Should be close to one:" + post2.norm());
    }

    public static void testPtwise() {
        int[] s1 = new int[]{1, 2, 3};
        int[] s2 = new int[]{1, 2, 3};
        Automaton a1 = Automaton.sequenceIndicator(5, s1);
        Automaton a2 = Automaton.sequenceIndicator(5, s2);
        System.out.println("norm=" + Automaton.pointwiseMultiply(a1, a2).norm());
    }

    public static class LevPot {
        private final double stopInsert;
        private final double[] insert;
        private final double[] delete;
        private final double[][] sub;
        public final int alephSize;
        private Matrix[][] _transitions = null;

        public LevPot(double[] insert, double[] delete, double[][] sub) {
            this.alephSize = insert.length;
            this.insert = insert;
            this.delete = delete;
            this.sub = sub;
            this.stopInsert = 1.0 - ArrayUtils.sum(insert);
            if (this.stopInsert <= 0.0) {
                throw new RuntimeException();
            }
            for (double d : delete) {
                if (MathUtils.isProb(d)) continue;
                throw new RuntimeException();
            }
            for (int i = 0; i < sub.length; ++i) {
                MathUtils.checkIsProb(sub[i]);
                if (i == 0 && sub[0][0] != 1.0) {
                    throw new RuntimeException();
                }
                if (i > 0 && sub[i][0] != 0.0) {
                    throw new RuntimeException();
                }
                if (i == 0 && insert[i] > 0.0) {
                    throw new RuntimeException();
                }
                if (i != 0 || !(delete[i] > 0.0)) continue;
                throw new RuntimeException();
            }
        }

        public Matrix[][] transitions() {
            if (this._transitions != null) {
                return this._transitions;
            }
            int N = this.alephSize + 1;
            this._transitions = new Matrix[N][N];
            for (int i = 0; i < N; ++i) {
                for (int j = 0; j < N; ++j) {
                    Matrix m = MtxUtils.zeroes(3);
                    if (i < this.alephSize && j < this.alephSize) {
                        if (i == 0 && j == 0) {
                            m.set(1, 2, 1.0);
                        } else {
                            m.set(1, 0, this.sub[i][j] * (1.0 - this.delete[i]));
                        }
                    } else if (i < this.alephSize && j == this.alephSize) {
                        m.set(1, 1, this.delete[i]);
                    } else if (i == this.alephSize && j < this.alephSize) {
                        m.set(0, 0, this.insert[j]);
                    } else {
                        m.set(0, 1, this.stopInsert);
                    }
                    this._transitions[i][j] = m;
                }
            }
            return this._transitions;
        }
    }
}

