/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.NumUtils;
import fig.basic.Option;
import fig.exec.Execution;
import fig.prob.SampleUtils;
import goblin.DerivationTree;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import nuts.io.IO;
import nuts.math.MeasureZeroException;
import nuts.tui.Table;
import pepper.Edit;
import pepper.Encodings;
import pepper.editmodel.EditParam;
import pepper.editmodel.ObservedWordSampler;

public class AlignmentSampler {
    private final String topWord;
    private final String bottomWord;
    private final int[] topIds;
    private final int[] bottomIds;
    private final EditParam param;
    private final Encodings enc;
    private final DerivationTree.Derivation derivation;
    private final DerivationTree.Derivation inverseDerivation;
    private final DerivationTree.Window topLineageWindow;
    private final DerivationTree.Window bottomLineageWindow;
    private final RelativeIndex topIdx;
    private final RelativeIndex bottomIdx;
    private final double[][] sumPr;
    private boolean computed = false;
    private PartialAlignmentSample sample = null;
    @Option(gloss="File where the params should be loaded from.", required=true)
    public static String savedparams = null;
    @Option(gloss="Top String file", required=true)
    public static String topFile = null;
    @Option(gloss="Bottom string file", required=true)
    public static String bottomFile = null;

    public static double[][] createTable(int topL, int bottomL) {
        return new double[topL][bottomL + 1];
    }

    public AlignmentSampler(EditParam param, DerivationTree.Derivation derivation, DerivationTree.Window bottomLineageWindow, DerivationTree.Window topLineageWindow) {
        this(param, derivation, bottomLineageWindow, topLineageWindow, AlignmentSampler.createTable(derivation.getAncestorWord().length(), derivation.getCurrentWord().length()));
    }

    public AlignmentSampler(EditParam param, DerivationTree.Derivation derivation, DerivationTree.Window bottomLineageWindow, DerivationTree.Window topLineageWindow, double[][] allocatedTable) {
        this.enc = param.getEncodings();
        this.derivation = derivation;
        this.topWord = derivation.getAncestorWord();
        this.bottomWord = derivation.getCurrentWord();
        this.topIds = this.enc.string2PhoneIds(this.topWord);
        this.bottomIds = this.enc.string2PhoneIds(this.bottomWord);
        this.param = param;
        this.inverseDerivation = derivation.invert();
        this.topLineageWindow = topLineageWindow;
        this.bottomLineageWindow = bottomLineageWindow;
        this.bottomIdx = new RelativeIndex(0, this.nAbsBottom());
        this.topIdx = new RelativeIndex(0, this.nAbsTop());
        this.sumPr = allocatedTable.length < this.nRelTop() || allocatedTable.length < this.nRelBottom() + 1 ? new double[this.nRelTop()][this.nRelBottom() + 1] : allocatedTable;
    }

    private int nAbsTop() {
        return this.topWord.length();
    }

    private int nAbsBottom() {
        return this.bottomWord.length();
    }

    private int nRelTop() {
        return this.topIdx.nIndices;
    }

    private int nRelBottom() {
        return this.bottomIdx.nIndices;
    }

    public PartialAlignmentSample getSample() {
        if (this.sample == null) {
            throw new RuntimeException();
        }
        return this.sample;
    }

    public double getMarginalPr() {
        if (this.topIdx.nIndices == 0) {
            return 0.0;
        }
        if (!this.computed) {
            this.computeSumPr();
        }
        return this.sumPr[0][0];
    }

    public void sample(Random rand) throws MeasureZeroException {
        this.sample = null;
        Edit[] edits = new Edit[this.topIdx.nIndices];
        if (!this.computed) {
            this.computeSumPr();
        }
        if (this.sumPr[0][0] == 0.0) {
            throw new MeasureZeroException("Alignment sampler ran a measure zero, top: " + this.topLineageWindow.toString(this.topWord) + ", bottom: " + this.bottomLineageWindow.toString(this.bottomWord));
        }
        int cRelTopPos = 0;
        int cRelBottomPos = 0;
        while (cRelTopPos < this.nRelTop()) {
            int y;
            double[] prs = new double[]{this.delRec(cRelTopPos, cRelBottomPos), this.subRec(cRelTopPos, cRelBottomPos), this.fisRec(cRelTopPos, cRelBottomPos)};
            assert (prs[0] + prs[1] + prs[2] > 0.0);
            NumUtils.normalize(prs);
            int decision = SampleUtils.sampleMultinomial(rand, prs);
            int c1 = this.topc(cRelTopPos - 1);
            int x = this.topIds[cRelTopPos];
            int c2 = this.topc(cRelTopPos + 1);
            if (decision == 0) {
                edits[cRelTopPos] = new Edit(this.enc, c1, x, c2);
                ++cRelTopPos;
                continue;
            }
            if (decision == 1) {
                y = this.bottomIds[cRelBottomPos];
                edits[cRelTopPos] = new Edit(this.enc, c1, x, c2, y);
                ++cRelTopPos;
                ++cRelBottomPos;
                continue;
            }
            if (decision == 2) {
                y = this.bottomIds[cRelBottomPos];
                int z = this.bottomIds[cRelBottomPos + 1];
                edits[cRelTopPos] = new Edit(this.enc, c1, x, c2, y, z);
                ++cRelTopPos;
                cRelBottomPos += 2;
                continue;
            }
            throw new RuntimeException();
        }
        this.sample = new PartialAlignmentSample(edits);
    }

    private void computeSumPr() {
        for (int s = this.nRelTop() - 1; s >= 0; --s) {
            for (int d = this.nRelBottom() + 1 - 1; d >= 0; --d) {
                this.sumPr[s][d] = this.delRec(s, d) + this.subRec(s, d) + this.fisRec(s, d);
            }
        }
        this.computed = true;
    }

    private double delRec(int s, int d) {
        double rec = this.getSumPr(s + 1, d);
        if (rec == 0.0) {
            return 0.0;
        }
        return this.del(s) * this.getSumPr(s + 1, d);
    }

    private double subRec(int s, int d) {
        double rec = this.getSumPr(s + 1, d + 1);
        if (rec == 0.0) {
            return 0.0;
        }
        return this.sub(s, d) * this.getSumPr(s + 1, d + 1);
    }

    private double fisRec(int s, int d) {
        double rec = this.getSumPr(s + 1, d + 2);
        if (rec == 0.0) {
            return 0.0;
        }
        return this.fis(s, d, d + 1) * this.getSumPr(s + 1, d + 2);
    }

    private double getSumPr(int relTopPos, int relBottomPos) {
        if (relTopPos == this.nRelTop() && relBottomPos == this.nRelBottom()) {
            return 1.0;
        }
        if (relTopPos >= this.nRelTop()) {
            return 0.0;
        }
        if (relBottomPos >= this.nRelBottom() + 1) {
            return 0.0;
        }
        return this.sumPr[relTopPos][relBottomPos];
    }

    private double del(int relTopPos) {
        int absTopPos = this.topIdx.getAbsolute(relTopPos);
        if (!this.topLineageWindow.contains(absTopPos) && this.inverseDerivation.hasAncestor(absTopPos)) {
            return 0.0;
        }
        int c1 = this.topc(absTopPos - 1);
        int x = this.topIds[absTopPos];
        int c2 = this.topc(absTopPos + 1);
        return this.param.deletionCost(c1, x, c2);
    }

    private double sub(int relTopPos, int relBottomPos) {
        int absBottomPos;
        int absTopPos = this.topIdx.getAbsolute(relTopPos);
        if (!AlignmentSampler.isOutOfWindowLinkPreserved(absTopPos, absBottomPos = this.bottomIdx.getAbsolute(relBottomPos), this.inverseDerivation, this.topLineageWindow)) {
            return 0.0;
        }
        if (!AlignmentSampler.isOutOfWindowLinkPreserved(absBottomPos, absTopPos, this.derivation, this.bottomLineageWindow)) {
            return 0.0;
        }
        int c1 = this.topc(absTopPos - 1);
        int x = this.topIds[absTopPos];
        int c2 = this.topc(absTopPos + 1);
        int y = this.bottomIds[absBottomPos];
        return this.param.substitutionCost(c1, x, c2, y);
    }

    private double fis(int relTopPos, int relBottomPos1, int relBottomPos2) {
        int absBottomPos;
        int absTopPos = this.topIdx.getAbsolute(relTopPos);
        int absBottomPos1 = this.bottomIdx.getAbsolute(relBottomPos1);
        int absBottomPos2 = this.bottomIdx.getAbsolute(relBottomPos2);
        int c1 = this.topc(absTopPos - 1);
        int x = this.topIds[absTopPos];
        int c2 = this.topc(absTopPos + 1);
        int y = this.bottomIds[absBottomPos1];
        int z = this.bottomIds[absBottomPos2];
        assert (x == y || x == z || this.param.fissionCost(c1, x, c2, y, z) == 0.0);
        int n = absBottomPos = x == y ? absBottomPos1 : absBottomPos2;
        if (!AlignmentSampler.isOutOfWindowLinkPreserved(absTopPos, absBottomPos, this.inverseDerivation, this.topLineageWindow)) {
            return 0.0;
        }
        if (!AlignmentSampler.isOutOfWindowLinkPreserved(absBottomPos, absTopPos, this.derivation, this.bottomLineageWindow)) {
            return 0.0;
        }
        return this.param.fissionCost(c1, x, c2, y, z);
    }

    private static boolean isOutOfWindowLinkPreserved(int absTopPos, int absBottomPos, DerivationTree.Derivation inverseDerivation, DerivationTree.Window topLineageWindow) {
        if (!topLineageWindow.contains(absTopPos)) {
            if (!inverseDerivation.hasAncestor(absTopPos)) {
                return false;
            }
            if (inverseDerivation.ancestor(absTopPos) != absBottomPos) {
                return false;
            }
        }
        return true;
    }

    private int topc(int i) {
        return i >= 0 && i < this.topIds.length ? this.x2c(this.topIds[i]) : this.enc.getBoundaryEqClassId();
    }

    private int x2c(int x) {
        return this.enc.phoneId2EqClassId(x);
    }

    public static void main(String[] args) throws NumberFormatException, IOException {
        Execution.init(args, "encodings", Encodings.class, "observedwordsampler", AlignmentSampler.class);
        String top = IO.i(topFile).iterator().next();
        String bottom = IO.i(bottomFile).iterator().next();
        EditParam params = EditParam.getSanityCheck5();
        ObservedWordSampler sampler = new ObservedWordSampler(new Random(1L), params, top, bottom);
        List<Edit> edits = null;
        for (int i = 0; i < 1; ++i) {
            edits = sampler.sample(false);
            IO.so("Sample " + i + ": " + edits);
            IO.so(sampler.sampledPath2String());
        }
        IO.so("DP matrix: ");
        IO.so(sampler.transitions());
        System.out.println("====== New sampler ======");
        DerivationTree.Derivation d = DerivationTree.Derivation.editList2Derivation(edits);
        IO.so("Original derivation:\n" + d.toString());
        DerivationTree.Window w = new DerivationTree.Window(4, bottom.length() - 1);
        AlignmentSampler newSampler = new AlignmentSampler(params, d, w, d.project(w));
        boolean success = false;
        try {
            newSampler.sample(new Random(1L));
            success = true;
        }
        catch (MeasureZeroException measureZeroException) {
            // empty catch block
        }
        IO.so("Sampling succeeded?: " + success);
        PartialAlignmentSample pSample = newSampler.getSample();
        IO.so("Sample: \n" + pSample.toString());
        Execution.finish();
    }

    public class PartialAlignmentSample {
        private final Edit[] editsByRelTopPos;

        public DerivationTree.Derivation getDerivation() {
            if (this.editsByRelTopPos.length != AlignmentSampler.this.topWord.length()) {
                throw new RuntimeException();
            }
            return DerivationTree.Derivation.editList2Derivation(this.getEditsList());
        }

        public List<Edit> getEditsList() {
            if (this.editsByRelTopPos.length != AlignmentSampler.this.topWord.length()) {
                throw new RuntimeException();
            }
            return Arrays.asList(this.editsByRelTopPos);
        }

        public PartialAlignmentSample(Edit[] editsByRelTopPos) {
            this.editsByRelTopPos = editsByRelTopPos;
        }

        public String toString() {
            return new Table(new Table.Populator(){

                @Override
                public void populate() {
                    for (int topAbs = 0; topAbs < AlignmentSampler.this.nAbsTop(); ++topAbs) {
                        this.set(0, topAbs + 1, "" + AlignmentSampler.this.topWord.charAt(topAbs) + (AlignmentSampler.this.topLineageWindow.contains(topAbs) ? "!" : ""));
                    }
                    for (int bottomAbs = 0; bottomAbs < AlignmentSampler.this.nAbsBottom(); ++bottomAbs) {
                        this.set(bottomAbs + 1, 0, "" + AlignmentSampler.this.bottomWord.charAt(bottomAbs) + (AlignmentSampler.this.bottomLineageWindow.contains(bottomAbs) ? "!" : ""));
                    }
                    int relBottomPos = 0;
                    for (int relTopPos = 0; relTopPos < AlignmentSampler.this.nAbsTop(); ++relTopPos) {
                        Edit edit = PartialAlignmentSample.this.editsByRelTopPos[relTopPos];
                        if (edit.isDeletion()) continue;
                        if (edit.isSubstitution()) {
                            this.set(1 + AlignmentSampler.this.bottomIdx.getAbsolute(relBottomPos), 1 + AlignmentSampler.this.topIdx.getAbsolute(relTopPos), "+");
                            ++relBottomPos;
                            continue;
                        }
                        if (edit.isFission()) {
                            this.set(1 + AlignmentSampler.this.bottomIdx.getAbsolute(relBottomPos), 1 + AlignmentSampler.this.topIdx.getAbsolute(relTopPos), "+");
                            this.set(1 + AlignmentSampler.this.bottomIdx.getAbsolute(++relBottomPos), 1 + AlignmentSampler.this.topIdx.getAbsolute(relTopPos), "+");
                            ++relBottomPos;
                            continue;
                        }
                        throw new RuntimeException();
                    }
                }
            }).toString();
        }
    }

    private static class RelativeIndex {
        private final int offset;
        private final int nIndices;

        public RelativeIndex(int offset, int indices) {
            this.offset = offset;
            this.nIndices = indices;
        }

        private boolean isValid(int relativeIndex) {
            return relativeIndex >= 0 && relativeIndex < this.nIndices;
        }

        private int getAbsolute(int relativeIndex) {
            assert (this.isValid(relativeIndex));
            return relativeIndex + this.offset;
        }

        public boolean containsAbsolutePosition(int absoluteIndex) {
            if (absoluteIndex < this.offset) {
                return false;
            }
            return this.offset + this.nIndices >= absoluteIndex;
        }
    }
}

