/*
 * Decompiled with CFR 0.152.
 */
package ma;

import fig.basic.Option;
import goblin.DerivationTree;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import ma.LongGapAlignmentSampler;
import ma.SubstitutionMatrixLoader;
import nuts.io.IO;
import nuts.maxent.FeatureExtractor;
import nuts.util.Counter;
import pepper.Encodings;

public class AlignFeatureExtractor
implements FeatureExtractor<DerivationTree.Derivation, String> {
    private static final long serialVersionUID = 1L;
    public static final String TOP_GAP = "TOP_GAP";
    public static final String BOTTOM_GAP = "BOTTOM_GAP";
    private final Encodings enc;
    private List<double[][]> matrices;
    @Option
    public ArrayList<String> matricesNames = null;
    @Option
    public ArrayList<Double> negOfMatricesWeightInit = null;
    @Option
    public double initTopGap = -2.0;
    @Option
    public double initBottomGap = -3.0;

    private void initMatrix() {
        this.matrices = new ArrayList<double[][]>();
        if (this.matricesNames == null) {
            return;
        }
        try {
            for (String matrixName : this.matricesNames) {
                double[][] current = SubstitutionMatrixLoader.loadMatrix(matrixName, this.enc);
                this.matrices.add(current);
                this.checkEntriesNonNegative(current);
            }
        }
        catch (IOException ioe) {
            throw new RuntimeException(ioe);
        }
    }

    private void checkEntriesNonNegative(double[][] current) {
        for (int i = 0; i < current.length; ++i) {
            for (int j = 0; j < current[0].length; ++j) {
                if (!(current[i][j] < 0.0)) continue;
                throw new RuntimeException("Problematic negative entry:" + i + "," + j + "=" + current[i][j]);
            }
        }
    }

    public Counter<String> initWeights() {
        Counter<String> result = new Counter<String>();
        result.incrementCount(TOP_GAP, this.initTopGap);
        result.incrementCount(BOTTOM_GAP, this.initBottomGap);
        if (this.matricesNames == null) {
            return result;
        }
        int nMatrices = this.matricesNames.size();
        if (nMatrices != this.negOfMatricesWeightInit.size()) {
            throw new RuntimeException();
        }
        for (int i = 0; i < nMatrices; ++i) {
            double wInit = -1.0 * this.negOfMatricesWeightInit.get(i);
            if (wInit > 0.0) {
                throw new RuntimeException("Weight " + i + ":" + wInit + "> 0.0");
            }
            result.incrementCount(this.matricesNames.get(i), wInit);
        }
        return result;
    }

    @Override
    public Counter<String> extractFeatures(DerivationTree.Derivation instance) {
        if (this.matrices == null) {
            this.initMatrix();
        }
        Counter<String> result = new Counter<String>();
        if (instance == null) {
            return result;
        }
        String top = instance.getAncestorWord();
        String bottom = instance.getCurrentWord();
        int currentTopPosition = 0;
        for (int i = 0; i < bottom.length(); ++i) {
            if (instance.hasAncestor(i)) {
                int ancestorIndex = instance.ancestor(i);
                char topChar = top.charAt(ancestorIndex);
                char bottomChar = bottom.charAt(i);
                int topIdx = this.enc.char2PhoneId(topChar);
                int bottomIdx = this.enc.char2PhoneId(bottomChar);
                for (int mIndex = 0; mIndex < this.matrices.size(); ++mIndex) {
                    String name = this.matricesNames.get(mIndex);
                    double[][] matrix = this.matrices.get(mIndex);
                    result.incrementCount(name, matrix[topIdx][bottomIdx]);
                }
                result.incrementCount(TOP_GAP, ancestorIndex - currentTopPosition);
                currentTopPosition = ancestorIndex + 1;
                continue;
            }
            result.incrementCount(BOTTOM_GAP, 1.0);
        }
        return result;
    }

    public LongGapAlignmentSampler.LongGapAlignmentSamplerParams computeParams(Counter<String> weights) {
        if (this.matrices == null) {
            this.initMatrix();
        }
        int nPhones = this.enc.getNumberOfPhonemes();
        double[][] mutationScores = new double[nPhones][nPhones];
        for (int i = 0; i < nPhones; ++i) {
            for (int j = 0; j < nPhones; ++j) {
                char topChar = this.enc.phoneId2Char(i);
                char bottomChar = this.enc.phoneId2Char(j);
                double sum = 0.0;
                int topIdx = this.enc.char2PhoneId(topChar);
                int bottomIdx = this.enc.char2PhoneId(bottomChar);
                for (int mIndex = 0; mIndex < this.matrices.size(); ++mIndex) {
                    String name = this.matricesNames.get(mIndex);
                    double[][] matrix = this.matrices.get(mIndex);
                    sum += weights.getCount(name) * matrix[topIdx][bottomIdx];
                }
                mutationScores[i][j] = Math.exp(sum);
            }
        }
        double topGapScore = Math.exp(weights.getCount(TOP_GAP) * 1.0);
        double bottomGapScore = Math.exp(weights.getCount(BOTTOM_GAP) * 1.0);
        return new LongGapAlignmentSampler.LongGapAlignmentSamplerParams(mutationScores, topGapScore, bottomGapScore, this.enc);
    }

    public static void main(String[] args) throws NumberFormatException, IOException {
        double min = 0.0;
        double max = Double.POSITIVE_INFINITY;
        String file = "data/PAM60";
        for (String line : IO.i(file)) {
            for (String field : line.split("\\s+")) {
                if (AlignFeatureExtractor.isDouble(field)) {
                    System.out.print(AlignFeatureExtractor.process(Double.valueOf(field), min, max) + " ");
                    continue;
                }
                System.out.print(field + " ");
            }
            System.out.print("\n");
        }
    }

    private static double process(double d, double minIncl, double maxExcl) {
        if (d >= minIncl && d < maxExcl) {
            return 1.0;
        }
        return 0.0;
    }

    private static boolean isDouble(String field) {
        try {
            Double.valueOf(field);
            return true;
        }
        catch (NumberFormatException ne) {
            return false;
        }
    }

    @Override
    public double regularizationFactor(String feature) {
        return 1.0;
    }

    public AlignFeatureExtractor(Encodings enc) {
        this.enc = enc;
    }
}

