/*
 * Decompiled with CFR 0.152.
 */
package sgi;

import fig.basic.Option;
import goblin.DataPrepUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.Tree;
import sgi.HierarchicalAligner;

public abstract class ScoreFct {
    private final Indexer<String> terminalIndexer;
    private final Indexer<String> nonTerminalIndexer;
    private static final int leftParenthesis = Integer.MIN_VALUE;
    private static final int rightParenthesis = Integer.MAX_VALUE;
    public static final Bijection identity = new Bijection();
    public static final Bijection swap = new Bijection();

    public ScoreFct(Indexer<String> terminalIndexer, Indexer<String> nonTerminalIndexer) {
        this.terminalIndexer = terminalIndexer;
        this.nonTerminalIndexer = nonTerminalIndexer;
    }

    public abstract double score(Rule var1);

    public RuleTreelet extractRuleTreelet(Tree<String> tree, List<Tree<String>> children) {
        Tree<Integer> result = this.extractRuleTreelet(tree, children, true).get(0);
        assert (this.consistent(tree, children, result));
        return new RuleTreelet(result);
    }

    private boolean consistent(Tree tree, List<Tree<String>> children, Tree out) {
        int totalSize = tree.getYield().size();
        int outSize = out.getYield().size();
        int childrenSizes = 0;
        for (Tree<String> child : children) {
            childrenSizes += child.getYield().size();
        }
        if (outSize != totalSize - childrenSizes + children.size()) {
            throw new RuntimeException();
        }
        return true;
    }

    private List<Tree<Integer>> extractRuleTreelet(Tree<String> tree, List<Tree<String>> children, boolean isLhs) {
        String label = tree.getLabel();
        boolean isTerminal = tree.isLeaf();
        boolean isNonTerminal = children.contains(tree) || isLhs;
        boolean isCollapsed = !isTerminal && !isNonTerminal;
        ArrayList recursion = null;
        if (isCollapsed || isLhs) {
            recursion = new ArrayList();
            for (Tree<String> child : tree.getChildren()) {
                recursion.addAll(this.extractRuleTreelet(child, children, false));
            }
        }
        if (isTerminal) {
            return Collections.singletonList(new Tree<Integer>(this.terminal(label)));
        }
        if (isNonTerminal) {
            int code = this.nonTerm(label);
            if (isLhs) {
                return Collections.singletonList(new Tree<Integer>(code, recursion));
            }
            return Collections.singletonList(new Tree<Integer>(code));
        }
        if (isCollapsed) {
            return recursion;
        }
        throw new RuntimeException();
    }

    private static Set<String> childrenHeads(List<Tree<String>> children) {
        HashSet<String> result = new HashSet<String>();
        for (Tree<String> child : children) {
            result.add(child.getLabel());
        }
        return result;
    }

    private boolean isTerminal(int current) {
        return current >= 0 && current != Integer.MIN_VALUE && current != Integer.MAX_VALUE;
    }

    private boolean isNonTerm(int current) {
        return current < 0 && current != Integer.MIN_VALUE && current != Integer.MAX_VALUE;
    }

    private String terminal(int current) {
        assert (this.isTerminal(current));
        return this.terminalIndexer.i2o(current);
    }

    private String nonTerm(int current) {
        assert (this.isNonTerm(current));
        return this.nonTerminalIndexer.i2o(this.convertNonTerm(current));
    }

    private int convertNonTerm(int x) {
        return -x - 1;
    }

    private int terminal(String terminal) {
        if (!this.terminalIndexer.containsObject(terminal)) {
            this.terminalIndexer.addToIndex((String[])new String[]{terminal});
        }
        return this.terminalIndexer.o2i(terminal);
    }

    private int nonTerm(String nonTerm) {
        if (!this.nonTerminalIndexer.containsObject(nonTerm)) {
            this.nonTerminalIndexer.addToIndex((String[])new String[]{nonTerm});
        }
        return this.convertNonTerm(this.nonTerminalIndexer.o2i(nonTerm));
    }

    private static void canonicalLinearization(Tree<Integer> treelet, List<Integer> currentList) {
        currentList.add(Integer.MIN_VALUE);
        currentList.add(treelet.getLabel());
        for (Tree<Integer> child : treelet.getChildren()) {
            ScoreFct.canonicalLinearization(child, currentList);
        }
        currentList.add(Integer.MAX_VALUE);
    }

    public static void main(String[] args) {
        CollapsedDPScoreFct scoreFct = new CollapsedDPScoreFct(new Indexer<String>(), new Indexer<String>());
        Tree<String> tree = DataPrepUtils.lisp2tree("(a (b d (e f g)) (c h i))");
        Tree<String> st1 = tree.getChildren().get(0).getChildren().get(1);
        Tree<String> st2 = tree.getChildren().get(1);
        List<Tree<String>> sts = Arrays.asList(st2);
        System.out.println("full:" + tree.toString());
        System.out.println("s trs:" + sts);
        RuleTreelet ruletreelet = scoreFct.extractRuleTreelet(tree, sts);
        System.out.println(ruletreelet);
    }

    public static class Bijection {
    }

    public class RuleTreelet {
        private final int cachedHash;
        private final int[] canonicalLinearization;
        private int nTerm;
        private int nNonTerm;

        public int getNTerm() {
            return this.nTerm;
        }

        public int getNNonTerm() {
            return this.nNonTerm;
        }

        public int lhsCode() {
            return ScoreFct.this.convertNonTerm(this.canonicalLinearization[0]);
        }

        private RuleTreelet(Tree<Integer> treelet) {
            ArrayList canLinList = new ArrayList();
            ScoreFct.canonicalLinearization(treelet, canLinList);
            this.canonicalLinearization = new int[canLinList.size() - 2];
            for (int i = 1; i < canLinList.size() - 1; ++i) {
                int current;
                this.canonicalLinearization[i - 1] = current = ((Integer)canLinList.get(i)).intValue();
                if (ScoreFct.this.isNonTerm(current)) {
                    ++this.nNonTerm;
                }
                if (!ScoreFct.this.isTerminal(current)) continue;
                ++this.nTerm;
            }
            --this.nNonTerm;
            this.cachedHash = Arrays.hashCode(this.canonicalLinearization);
        }

        public boolean equals(Object o) {
            RuleTreelet other = (RuleTreelet)o;
            return Arrays.equals(this.canonicalLinearization, other.canonicalLinearization);
        }

        public int hashCode() {
            return this.cachedHash;
        }

        public String toString() {
            StringBuilder result = new StringBuilder();
            result.append("(");
            for (int i = 0; i < this.canonicalLinearization.length; ++i) {
                int current = this.canonicalLinearization[i];
                if (current == Integer.MIN_VALUE || current == Integer.MAX_VALUE) continue;
                if (ScoreFct.this.isTerminal(current)) {
                    result.append(ScoreFct.this.terminal(current) + " ");
                    continue;
                }
                if (ScoreFct.this.isNonTerm(current)) {
                    result.append(ScoreFct.this.nonTerm(current) + " ");
                    continue;
                }
                throw new RuntimeException();
            }
            result.append(")");
            return result.toString();
        }
    }

    public static class Rule {
        private final RuleTreelet lang1;
        private final RuleTreelet lang2;
        private final Bijection bijection;

        public Rule(RuleTreelet lang1, RuleTreelet lang2, Bijection bijection) {
            assert (lang1 != null && lang2 != null && bijection != null && lang1.nNonTerm == lang2.nNonTerm);
            this.lang1 = lang1;
            this.lang2 = lang2;
            this.bijection = bijection;
        }

        public boolean equals(Object o) {
            Rule o_cast = (Rule)o;
            if (!this.lang1.equals(o_cast.lang1)) {
                return false;
            }
            if (this.bijection != o_cast.bijection) {
                return false;
            }
            return this.lang2.equals(o_cast.lang2);
        }

        public int hashCode() {
            int hashCode = this.lang1.hashCode();
            hashCode = 29 * hashCode + this.bijection.hashCode();
            hashCode = 29 * hashCode + this.lang2.hashCode();
            return hashCode;
        }

        public String toString() {
            return "" + this.lang1 + " ||| " + this.lang2 + (this.bijection == swap ? " ||| SWAPPED" : "");
        }
    }

    public static class UniformScoreFct
    extends ScoreFct {
        public UniformScoreFct(Indexer<String> terminalIndexer, Indexer<String> nonTerminalIndexer) {
            super(terminalIndexer, nonTerminalIndexer);
        }

        @Override
        public double score(Rule rule) {
            return 1.0;
        }
    }

    public static final class CollapsedDPScoreFct
    extends ScoreFct {
        @Option
        private static double prior0NonTerm = 0.3333333333333333;
        @Option
        private static double prior1NonTerm = 0.3333333333333333;
        @Option
        private static double geometricLexParam = 0.9;
        @Option
        private static double priorSwap = 0.1;
        @Option
        private static double alpha0 = 100000.0;
        private final double termUnigramPr;
        private double[][] totalCounts;
        private Counter<Rule>[][] counts;

        public final void update(List<HierarchicalAligner.RuleMatch> rules, double coeff) {
            for (HierarchicalAligner.RuleMatch ruleMatch : rules) {
                Rule rule = ruleMatch.getRule();
                int lhsCode1 = rule.lang1.lhsCode();
                int lhsCode2 = rule.lang2.lhsCode();
                double[] dArray = this.totalCounts[lhsCode1];
                int n = lhsCode2;
                dArray[n] = dArray[n] + coeff;
                this.counts[lhsCode1][lhsCode2].incrementCount(rule, coeff);
            }
        }

        @Override
        public final double score(Rule rule) {
            int lhsCode1 = rule.lang1.lhsCode();
            int lhsCode2 = rule.lang2.lhsCode();
            double count = this.counts[lhsCode1][lhsCode2].getCount(rule);
            double totalCount = this.totalCounts[rule.lang1.lhsCode()][rule.lang2.lhsCode()];
            double result = count + alpha0 * this.prior(rule);
            return result / (totalCount + alpha0);
        }

        private final double prior(Rule rule) {
            double result = 1.0;
            if (rule.lang1.nNonTerm == 0) {
                result *= prior0NonTerm;
            } else if (rule.lang1.nNonTerm == 1) {
                result *= prior1NonTerm;
            } else if (rule.lang1.nNonTerm == 2) {
                result *= 1.0 - prior0NonTerm - prior1NonTerm;
                result *= rule.bijection == swap ? priorSwap : 1.0 - priorSwap;
            } else {
                throw new RuntimeException(rule.toString());
            }
            int totalLexLen = rule.lang1.getNTerm() + rule.lang2.getNTerm();
            result *= Math.pow(this.termUnigramPr, totalLexLen);
            result *= Math.pow(1.0 - geometricLexParam, totalLexLen);
            return result *= geometricLexParam;
        }

        public CollapsedDPScoreFct(Indexer<String> terminalIndexer, Indexer<String> nonTerminalIndexer) {
            super(terminalIndexer, nonTerminalIndexer);
            int nNonTerm = nonTerminalIndexer.size();
            this.termUnigramPr = 1.0 / (double)nNonTerm;
            this.totalCounts = new double[nNonTerm][nNonTerm];
            this.counts = new Counter[nNonTerm][nNonTerm];
            for (int i = 0; i < nNonTerm; ++i) {
                for (int j = 0; j < nNonTerm; ++j) {
                    this.counts[i][j] = new Counter();
                }
            }
        }
    }
}

