/*
 * Decompiled with CFR 0.152.
 */
package pepper;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.basic.StrUtils;
import fig.prob.Multinomial;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import pepper.Edit;
import pepper.Encodings;
import pepper.editmodel.EditParam;

public class WordSampler {
    private double[][][][] sumProb;
    private EditParam param;
    private String topStr;
    private String bottomStr;
    private int ntop;
    private int nbottom;
    private int nc;
    private int[] top;
    private int[] bottom;
    private int boundaryc;
    private StringBuilder middleStrBuf;
    private List<Edit> top2middleEdits;
    private List<Edit> middle2bottomEdits;
    private Random random;
    private int maxDeviation;
    private final Encodings enc;

    public WordSampler(Random random, EditParam param, int maxDeviation, String topStr, String bottomStr) {
        this.enc = param.getEncodings();
        this.random = random;
        this.param = param;
        this.maxDeviation = maxDeviation;
        this.topStr = topStr;
        this.bottomStr = bottomStr;
        this.ntop = topStr.length();
        this.nbottom = bottomStr.length();
        this.top = this.enc.string2PhoneIds(topStr);
        this.bottom = this.enc.string2PhoneIds(bottomStr);
        this.nc = this.enc.getNumberOfEqClasses();
        this.boundaryc = this.enc.getBoundaryEqClassId();
    }

    public void sample() {
        this.sumProb = new double[this.ntop + 1][this.nbottom + 1][this.nc][this.nc];
        for (int i = 0; i <= this.ntop; ++i) {
            for (int j = 0; j <= this.nbottom; ++j) {
                for (int c1 = 0; c1 < this.nc; ++c1) {
                    for (int c2 = 0; c2 < this.nc; ++c2) {
                        this.sumProb[i][j][c1][c2] = Double.NaN;
                    }
                }
            }
        }
        this.computeSumProb(false);
        this.middleStrBuf = new StringBuilder();
        this.top2middleEdits = new ArrayList<Edit>();
        this.middle2bottomEdits = new ArrayList<Edit>();
        this.computeSumProb(true);
    }

    public String getMiddleStr() {
        return this.middleStrBuf.toString();
    }

    public List<Edit> getTop2MiddleEdits() {
        return this.top2middleEdits;
    }

    public List<Edit> getMiddle2BottomEdits() {
        return this.middle2bottomEdits;
    }

    private int topc(int i) {
        return i >= 0 && i < this.top.length ? this.x2c(this.top[i]) : this.boundaryc;
    }

    private int x2c(int x) {
        return this.enc.phoneId2EqClassId(x);
    }

    private void computeSumProb(boolean doSample) {
        int mc1 = this.boundaryc;
        if (!doSample) {
            for (int mc2 = 0; mc2 < this.nc; ++mc2) {
                this.computeSumProb(0, 0, mc1, mc2, false);
            }
        } else {
            int mc2;
            double[] probs = new double[this.nc];
            for (mc2 = 0; mc2 < this.nc; ++mc2) {
                probs[mc2] = this.computeSumProb(0, 0, mc1, mc2, false);
            }
            NumUtils.normalize(probs);
            mc2 = Multinomial.sample(this.random, probs);
            this.computeSumProb(0, 0, mc1, mc2, true);
        }
    }

    private double computeSumProb(int i, int j, int c1, int c2, boolean doSample) {
        if (Math.abs(i - j) > this.maxDeviation) {
            return 0.0;
        }
        if (i == this.ntop && j < this.nbottom) {
            return 0.0;
        }
        if (i == this.ntop && j == this.nbottom) {
            return 1.0;
        }
        if (!doSample && !Double.isNaN(this.sumProb[i][j][c1][c2])) {
            return this.sumProb[i][j][c1][c2];
        }
        double accumProb = 0.0;
        double targetProb = Double.NaN;
        if (doSample) {
            targetProb = this.random.nextDouble() * this.sumProb[i][j][c1][c2];
        }
        int tc1 = this.topc(i - 1);
        int tx = this.top[i];
        int tc2 = this.topc(i + 1);
        int mc1 = c1;
        double tProb = this.param.deletionCost(tc1, tx, tc2);
        double d = 1.0;
        int dja = 0;
        double bProb = 1.0;
        int djb = 0;
        int mc2 = c2;
        int newmc1 = mc1;
        int newmc2 = mc2;
        accumProb += tProb * d * bProb * this.computeSumProb(i + 1, j + dja + djb, newmc1, newmc2, false);
        if (doSample && accumProb >= targetProb) {
            this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2));
            this.computeSumProb(i + 1, j + dja + djb, newmc1, newmc2, true);
            return Double.NaN;
        }
        block0: for (int mc22 = 0; mc22 < this.nc; ++mc22) {
            Iterator<Serializable> iterator = this.param.substitutionSuccessors(tc1, tx, tc2).iterator();
            while (iterator.hasNext()) {
                int n = iterator.next();
                int mca = this.x2c(n);
                if (mca != c2) continue;
                double tProb2 = this.param.substitutionCost(tc1, tx, tc2, n);
                double bProb2 = 1.0;
                int djb2 = 0;
                newmc1 = mca;
                newmc2 = mc22;
                double aProb2 = this.param.deletionCost(mc1, n, mc22);
                int dja2 = 0;
                accumProb += tProb2 * aProb2 * bProb2 * this.computeSumProb(i + 1, j + dja2 + djb2, newmc1, newmc2, false);
                if (doSample && accumProb >= targetProb) {
                    this.middleStrBuf.append(this.enc.phoneId2Char(n));
                    this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, n));
                    this.middle2bottomEdits.add(new Edit(this.enc, mc1, n, mc22));
                    this.computeSumProb(i + 1, j + dja2 + djb2, newmc1, newmc2, true);
                    return Double.NaN;
                }
                if (j < this.nbottom) {
                    int bxaa = this.bottom[j];
                    double aProb3 = this.param.substitutionCost(mc1, n, mc22, bxaa);
                    int dja3 = 1;
                    accumProb += tProb2 * aProb3 * bProb2 * this.computeSumProb(i + 1, j + dja3 + djb2, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(n));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, n));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, n, mc22, bxaa));
                        this.computeSumProb(i + 1, j + dja3 + djb2, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                if (j + 1 >= this.nbottom) continue;
                int bxaa = this.bottom[j];
                int bxab = this.bottom[j + 1];
                double aProb4 = this.param.fissionCost(mc1, n, mc22, bxaa, bxab);
                int dja4 = 2;
                accumProb += tProb2 * aProb4 * bProb2 * this.computeSumProb(i + 1, j + dja4 + djb2, newmc1, newmc2, false);
                if (!doSample || !(accumProb >= targetProb)) continue;
                this.middleStrBuf.append(this.enc.phoneId2Char(n));
                this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, n));
                this.middle2bottomEdits.add(new Edit(this.enc, mc1, n, mc22, bxaa, bxab));
                this.computeSumProb(i + 1, j + dja4 + djb2, newmc1, newmc2, true);
                return Double.NaN;
            }
            for (Pair pair : this.param.fissionSuccessors(tc1, tx, tc2)) {
                int mxa = (Integer)pair.getFirst();
                int mxb = (Integer)pair.getSecond();
                int mca = this.x2c(mxa);
                int mcb = this.x2c(mxb);
                if (mca != c2) continue;
                double tProb3 = this.param.fissionCost(tc1, tx, tc2, mxa, mxb);
                newmc1 = mcb;
                newmc2 = mc22;
                int dja5 = 0;
                if (j + dja5 > this.nbottom) continue block0;
                double aProb5 = this.param.deletionCost(mc1, mxa, mcb);
                int djb3 = 0;
                if (j + dja5 + djb3 <= this.nbottom) {
                    double bProb6 = this.param.deletionCost(mca, mxb, mc22);
                    accumProb += tProb3 * aProb5 * bProb6 * this.computeSumProb(i + 1, j + dja5 + djb3, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22));
                        this.computeSumProb(i + 1, j + dja5 + djb3, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                int djb2 = 1;
                if (j + dja5 + djb2 <= this.nbottom) {
                    int bxba = this.bottom[j + dja5];
                    double bProb5 = this.param.substitutionCost(mca, mxb, mc22, bxba);
                    accumProb += tProb3 * aProb5 * bProb5 * this.computeSumProb(i + 1, j + dja5 + djb2, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22, bxba));
                        this.computeSumProb(i + 1, j + dja5 + djb2, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                djb2 = 2;
                if (j + dja5 + djb2 <= this.nbottom) {
                    int bxba = this.bottom[j + dja5];
                    int bxbb = this.bottom[j + dja5 + 1];
                    double bProb4 = this.param.fissionCost(mca, mxb, mc22, bxba, bxbb);
                    accumProb += tProb3 * aProb5 * bProb4 * this.computeSumProb(i + 1, j + dja5 + djb2, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22, bxba, bxbb));
                        this.computeSumProb(i + 1, j + dja5 + djb2, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                int dja2 = 1;
                if (j + dja2 > this.nbottom) continue block0;
                int bxaa = this.bottom[j];
                double aProb6 = this.param.substitutionCost(mc1, mxa, mcb, bxaa);
                int djb4 = 0;
                if (j + dja2 + djb4 <= this.nbottom) {
                    double bProb2 = this.param.deletionCost(mca, mxb, mc22);
                    accumProb += tProb3 * aProb6 * bProb2 * this.computeSumProb(i + 1, j + dja2 + djb4, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb, bxaa));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22));
                        this.computeSumProb(i + 1, j + dja2 + djb4, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                int djb5 = 1;
                if (j + dja2 + djb5 <= this.nbottom) {
                    int bxba = this.bottom[j + dja2];
                    double bProb3 = this.param.substitutionCost(mca, mxb, mc22, bxba);
                    accumProb += tProb3 * aProb6 * bProb3 * this.computeSumProb(i + 1, j + dja2 + djb5, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb, bxaa));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22, bxba));
                        this.computeSumProb(i + 1, j + dja2 + djb5, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                djb5 = 2;
                if (j + dja2 + djb5 <= this.nbottom) {
                    int bxba = this.bottom[j + dja2];
                    int bxbb = this.bottom[j + dja2 + 1];
                    double bProb3 = this.param.fissionCost(mca, mxb, mc22, bxba, bxbb);
                    accumProb += tProb3 * aProb6 * bProb3 * this.computeSumProb(i + 1, j + dja2 + djb5, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb, bxaa));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22, bxba, bxbb));
                        this.computeSumProb(i + 1, j + dja2 + djb5, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                dja2 = 2;
                if (j + dja2 > this.nbottom) continue;
                bxaa = this.bottom[j];
                int bxab = this.bottom[j + 1];
                double aProb7 = this.param.fissionCost(mc1, mxa, mcb, bxaa, bxab);
                int djb52 = 0;
                if (j + dja2 + djb52 <= this.nbottom) {
                    double bProb7 = this.param.deletionCost(mca, mxb, mc22);
                    accumProb += tProb3 * aProb7 * bProb7 * this.computeSumProb(i + 1, j + dja2 + djb52, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb, bxaa, bxab));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22));
                        this.computeSumProb(i + 1, j + dja2 + djb52, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                int djb6 = 1;
                if (j + dja2 + djb6 <= this.nbottom) {
                    int bxba = this.bottom[j + dja2];
                    double bProb4 = this.param.substitutionCost(mca, mxb, mc22, bxba);
                    accumProb += tProb3 * aProb7 * bProb4 * this.computeSumProb(i + 1, j + dja2 + djb6, newmc1, newmc2, false);
                    if (doSample && accumProb >= targetProb) {
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                        this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                        this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                        this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb, bxaa, bxab));
                        this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22, bxba));
                        this.computeSumProb(i + 1, j + dja2 + djb6, newmc1, newmc2, true);
                        return Double.NaN;
                    }
                }
                djb6 = 2;
                if (j + dja2 + djb6 > this.nbottom) continue;
                int bxba = this.bottom[j + dja2];
                int bxbb = this.bottom[j + dja2 + 1];
                double bProb8 = this.param.fissionCost(mca, mxb, mc22, bxba, bxbb);
                accumProb += tProb3 * aProb7 * bProb8 * this.computeSumProb(i + 1, j + dja2 + djb6, newmc1, newmc2, false);
                if (!doSample || !(accumProb >= targetProb)) continue;
                this.middleStrBuf.append(this.enc.phoneId2Char(mxa));
                this.middleStrBuf.append(this.enc.phoneId2Char(mxb));
                this.top2middleEdits.add(new Edit(this.enc, tc1, tx, tc2, mxa, mxb));
                this.middle2bottomEdits.add(new Edit(this.enc, mc1, mxa, mcb, bxaa, bxab));
                this.middle2bottomEdits.add(new Edit(this.enc, mca, mxb, mc22, bxba, bxbb));
                this.computeSumProb(i + 1, j + dja2 + djb6, newmc1, newmc2, true);
                return Double.NaN;
            }
        }
        double d2 = accumProb;
        this.sumProb[i][j][c1][c2] = d2;
        return d2;
    }

    public void printDebug() {
        LogInfo.logs("%s -> %s -> %s", this.topStr, this.getMiddleStr(), this.bottomStr);
        LogInfo.logs("top2middleEdits = " + StrUtils.join(this.getTop2MiddleEdits(), " || "));
        LogInfo.logs("middle2bottomEdits = " + StrUtils.join(this.getMiddle2BottomEdits(), " || "));
    }
}

