/*
 * Decompiled with CFR 0.152.
 */
package biparse;

import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.io.IO;
import nuts.util.EasyFormat;
import nuts.util.MathUtils;

public class Parser {
    private Map<Rule, Double> ruleScores;
    private Map<List<String>, Double> bestScores;
    private Map<List<String>, AppliedRule> bestRules;
    private List<Rule> listOfRules = null;
    private double[] prs;

    public static Parser loadParser(String path) throws NumberFormatException, IOException {
        Parser parser = new Parser();
        parser.ruleScores = new HashMap<Rule, Double>();
        for (String line : IO.i(path)) {
            String[] tokens = line.split("\\t");
            if (tokens.length != 2) {
                throw new RuntimeException();
            }
            double score = Double.parseDouble(tokens[0]);
            Rule rule = Rule.parse(tokens[1]);
            parser.ruleScores.put(rule, score);
        }
        return parser;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        for (Rule rule : this.ruleScores.keySet()) {
            builder.append(EasyFormat.fmt(this.ruleScores.get(rule)));
            builder.append("\t");
            builder.append(rule.toString());
            builder.append("\n");
        }
        return builder.toString();
    }

    public double parse(List<String> sentence) {
        this.bestScores = new HashMap<List<String>, Double>();
        this.bestRules = new HashMap<List<String>, AppliedRule>();
        this.recurse(sentence);
        return this.bestScores.get(sentence);
    }

    public StringBuilder derivationString(List<String> seq) {
        StringBuilder builder = new StringBuilder();
        AppliedRule bestRule = this.bestRules.get(seq);
        List<String> listRep = bestRule.rule.getListRep();
        double score = this.bestScores.get(seq);
        builder.append("(");
        int nonTermIndex = 0;
        for (String item : listRep) {
            if (item != null) {
                builder.append(item);
            } else {
                if (nonTermIndex == 0) {
                    builder.append((CharSequence)this.derivationString(seq.subList(bestRule.l, bestRule.m)));
                } else if (nonTermIndex == 1) {
                    builder.append((CharSequence)this.derivationString(seq.subList(bestRule.p, bestRule.q)));
                } else {
                    throw new RuntimeException();
                }
                ++nonTermIndex;
            }
            builder.append(" ");
        }
        builder.append(": " + EasyFormat.fmt(score) + ")");
        return builder;
    }

    private double recurse(List<String> seq) {
        int m;
        int l;
        if (this.bestScores.containsKey(seq)) {
            return this.bestScores.get(seq);
        }
        AppliedRule argMax = new AppliedRule(this.rule(seq));
        double max = this.score(argMax.rule);
        for (l = 0; l < seq.size(); ++l) {
            for (m = l + 1; m <= seq.size(); ++m) {
                Rule rule = this.rule(seq, l, m);
                double current = this.score(rule);
                if (current > 0.0) {
                    current *= this.recurse(seq.subList(l, m));
                }
                if (!(current > max)) continue;
                max = current;
                argMax = new AppliedRule(rule, l, m);
            }
        }
        for (l = 0; l < seq.size() - 1; ++l) {
            for (m = l + 1; m <= seq.size() - 1; ++m) {
                for (int p = m; p < seq.size(); ++p) {
                    for (int q = p + 1; q <= seq.size(); ++q) {
                        Rule rule = this.rule(seq, l, m, p, q);
                        double current = this.score(rule);
                        if (current > 0.0) {
                            current *= this.recurse(seq.subList(l, m));
                        }
                        if (current > 0.0) {
                            current *= this.recurse(seq.subList(p, q));
                        }
                        if (!(current > max)) continue;
                        max = current;
                        argMax = new AppliedRule(rule, l, m, p, q);
                    }
                }
            }
        }
        this.bestScores.put(seq, max);
        this.bestRules.put(seq, argMax);
        return max;
    }

    private double score(Rule key) {
        if (!this.ruleScores.containsKey(key)) {
            return 0.0;
        }
        return this.ruleScores.get(key);
    }

    private Rule rule(List<String> seq) {
        return new Rule(seq);
    }

    private Rule rule(List<String> seq, int l, int m) {
        assert (l < m && l >= 0 && m <= seq.size());
        ArrayList<String> key = new ArrayList<String>();
        for (int i = 0; i < seq.size(); ++i) {
            if (i < l || i >= m) {
                key.add(seq.get(i));
                continue;
            }
            if (i != l) continue;
            key.add(null);
        }
        return new Rule(key);
    }

    private Rule rule(List<String> seq, int l, int m, int p, int q) {
        assert (0 <= l && l < m && m <= p && p < q && q <= seq.size());
        ArrayList<String> key = new ArrayList<String>();
        for (int i = 0; i < seq.size(); ++i) {
            if (i < l || i >= m && i < p || i >= q) {
                key.add(seq.get(i));
                continue;
            }
            if (i != l && i != p) continue;
            key.add(null);
        }
        return new Rule(key);
    }

    public static void main(String[] args) throws NumberFormatException, IOException {
        Parser parser = Parser.loadParser("test/grammar2");
        Random rand = new Random();
        for (int i = 0; i < 10; ++i) {
            Pair<List<String>, Double> sample = parser.sample(rand);
            System.out.println("Sentence: " + sample.getFirst());
            System.out.println("Sampled score: " + sample.getSecond());
            double bestScore = parser.parse(sample.getFirst());
            System.out.println("Optimal score: " + bestScore);
            System.out.println("Optimal derivation: " + parser.derivationString(sample.getFirst()));
            if (bestScore < sample.getSecond()) {
                System.out.println("!!!!!!!!!");
            }
            if (!(bestScore > sample.getSecond())) continue;
            System.out.println("*********");
        }
    }

    public Pair<List<String>, Double> sample(Random rand) {
        if (this.listOfRules == null) {
            this.listOfRules = new ArrayList<Rule>();
            this.listOfRules.addAll(this.ruleScores.keySet());
            this.prs = new double[this.listOfRules.size()];
            for (int i = 0; i < this.prs.length; ++i) {
                this.prs[i] = this.ruleScores.get(this.listOfRules.get(i));
            }
            if (!MathUtils.isProb(this.prs)) {
                throw new RuntimeException();
            }
        }
        Rule sample = this.listOfRules.get(SampleUtils.sampleMultinomial(rand, this.prs));
        double score = this.ruleScores.get(sample);
        ArrayList<String> result = new ArrayList<String>();
        for (String item : sample.getListRep()) {
            if (item != null) {
                result.add(item);
                continue;
            }
            Pair<List<String>, Double> recursion = this.sample(rand);
            result.addAll((Collection)recursion.getFirst());
            score *= recursion.getSecond().doubleValue();
        }
        return new Pair<List<String>, Double>(result, score);
    }

    public static class Rule {
        private final List<String> rule = new ArrayList<String>();
        public static final String nonTerm = "X";

        public Rule(List<String> rule) {
            for (String item : rule) {
                this.rule.add(item);
            }
        }

        public int getNumberOfNonTerminals() {
            int numbNT = 0;
            for (String ruleToken : this.rule) {
                if (ruleToken != null) continue;
                ++numbNT;
            }
            return numbNT;
        }

        public List<String> getListRep() {
            return Collections.unmodifiableList(this.rule);
        }

        public static Rule parse(String description) {
            String[] tokens = description.split("\\s+");
            ArrayList<String> rule = new ArrayList<String>();
            if (tokens.length == 0) {
                throw new RuntimeException();
            }
            for (String token : tokens) {
                if (token.equals(nonTerm)) {
                    rule.add(null);
                    continue;
                }
                rule.add(token);
            }
            return new Rule(rule);
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            for (int i = 0; i < this.rule.size(); ++i) {
                String item = this.rule.get(i);
                if (item == null) {
                    builder.append(nonTerm);
                } else {
                    builder.append(item);
                }
                if (i == this.rule.size() - 1) continue;
                builder.append(" ");
            }
            return builder.toString();
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null) {
                return false;
            }
            if (!(o instanceof Rule)) {
                return false;
            }
            Rule o_cast = (Rule)o;
            return !(this.rule != null ? !this.rule.equals(o_cast.rule) : o_cast.rule != null);
        }

        public int hashCode() {
            int hashCode = this.rule != null ? this.rule.hashCode() : 0;
            return hashCode;
        }
    }

    public static class AppliedRule {
        private Rule rule;
        private int l = -1;
        private int m = -1;
        private int p = -1;
        private int q = -1;

        private AppliedRule(Rule rule) {
            this.rule = rule;
        }

        private AppliedRule(Rule rule, int l, int m) {
            this.rule = rule;
            this.l = l;
            this.m = m;
        }

        private AppliedRule(Rule rule, int l, int m, int p, int q) {
            this.rule = rule;
            this.l = l;
            this.m = m;
            this.p = p;
            this.q = q;
        }
    }
}

