/*
 * Decompiled with CFR 0.152.
 */
package biparse;

import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.io.IO;
import nuts.util.Counter;
import nuts.util.EasyFormat;
import nuts.util.MathUtils;

public class TypedParser {
    private Map<TypedRule, Double> ruleScores;
    private int maxNonTermIndex = 0;
    private Map<List<String>, Double> bestScores;
    private Map<List<String>, AppliedRule> bestRules;

    public static TypedParser loadParser(String path) throws NumberFormatException, IOException {
        TypedParser parser = new TypedParser();
        parser.ruleScores = new HashMap<TypedRule, Double>();
        for (String line : IO.i(path)) {
            String[] tokens = line.split("\\t");
            if (tokens.length != 2) {
                throw new RuntimeException();
            }
            double score = Double.parseDouble(tokens[0]);
            TypedRule rule = TypedRule.parse(tokens[1]);
            if (parser.maxNonTermIndex < rule.type) {
                parser.maxNonTermIndex = rule.type;
            }
            parser.ruleScores.put(rule, score);
        }
        return parser;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        for (TypedRule rule : this.ruleScores.keySet()) {
            builder.append(EasyFormat.fmt(this.ruleScores.get(rule)));
            builder.append("\t");
            builder.append(rule.toString());
            builder.append("\n");
        }
        return builder.toString();
    }

    public double parse(List<String> sentence) {
        this.bestScores = new HashMap<List<String>, Double>();
        this.bestRules = new HashMap<List<String>, AppliedRule>();
        this.recurse(sentence, 0);
        return this.bestScores.get(sentence);
    }

    public StringBuilder derivationString(List seq) {
        StringBuilder builder = new StringBuilder();
        AppliedRule bestRule = this.bestRules.get(seq);
        List<Object> listRep = bestRule.rule.getListRep();
        double score = this.bestScores.get(seq);
        builder.append("(X" + bestRule.rule.type + " : ");
        int nonTermIndex = 0;
        for (Object item : listRep) {
            if (item instanceof String) {
                builder.append(item);
            } else {
                if (nonTermIndex == 0) {
                    builder.append((CharSequence)this.derivationString(seq.subList(bestRule.l, bestRule.m)));
                } else if (nonTermIndex == 1) {
                    builder.append((CharSequence)this.derivationString(seq.subList(bestRule.p, bestRule.q)));
                } else {
                    throw new RuntimeException();
                }
                ++nonTermIndex;
            }
            builder.append(" ");
        }
        builder.append(": " + EasyFormat.fmt(score) + ")");
        return builder;
    }

    private double recurse(List<String> seq, int rootType) {
        int m;
        int l;
        if (this.bestScores.containsKey(seq)) {
            return this.bestScores.get(seq);
        }
        AppliedRule argMax = new AppliedRule(this.rule(rootType, seq));
        double max = this.score(argMax.rule);
        for (l = 0; l < seq.size(); ++l) {
            for (m = l + 1; m <= seq.size(); ++m) {
                for (int childType = 0; childType <= this.maxNonTermIndex; ++childType) {
                    TypedRule rule = this.rule(rootType, childType, seq, l, m);
                    double current = this.score(rule);
                    if (current > 0.0) {
                        current *= this.recurse(seq.subList(l, m), childType);
                    }
                    if (!(current > max)) continue;
                    max = current;
                    argMax = new AppliedRule(rule, l, m);
                }
            }
        }
        for (l = 0; l < seq.size() - 1; ++l) {
            for (m = l + 1; m <= seq.size() - 1; ++m) {
                for (int p = m; p < seq.size(); ++p) {
                    for (int q = p + 1; q <= seq.size(); ++q) {
                        for (int childType0 = 0; childType0 <= this.maxNonTermIndex; ++childType0) {
                            for (int childType1 = 0; childType1 <= this.maxNonTermIndex; ++childType1) {
                                TypedRule rule = this.rule(rootType, childType0, childType1, seq, l, m, p, q);
                                double current = this.score(rule);
                                if (current > 0.0) {
                                    current *= this.recurse(seq.subList(l, m), childType0);
                                }
                                if (current > 0.0) {
                                    current *= this.recurse(seq.subList(p, q), childType1);
                                }
                                if (!(current > max)) continue;
                                max = current;
                                argMax = new AppliedRule(rule, l, m, p, q);
                            }
                        }
                    }
                }
            }
        }
        this.bestScores.put(seq, max);
        this.bestRules.put(seq, argMax);
        return max;
    }

    private double score(TypedRule key) {
        if (!this.ruleScores.containsKey(key)) {
            return 0.0;
        }
        return this.ruleScores.get(key);
    }

    private TypedRule rule(int rootType, List seq) {
        return new TypedRule(seq, rootType);
    }

    private TypedRule rule(int rootType, int childType, List seq, int l, int m) {
        assert (l < m && l >= 0 && m <= seq.size());
        ArrayList<Object> key = new ArrayList<Object>();
        for (int i = 0; i < seq.size(); ++i) {
            if (i < l || i >= m) {
                key.add(seq.get(i));
                continue;
            }
            if (i != l) continue;
            key.add(childType);
        }
        return new TypedRule(key, rootType);
    }

    private TypedRule rule(int rootType, int childType0, int childType1, List seq, int l, int m, int p, int q) {
        assert (0 <= l && l < m && m <= p && p < q && q <= seq.size());
        ArrayList<Object> key = new ArrayList<Object>();
        for (int i = 0; i < seq.size(); ++i) {
            if (i < l || i >= m && i < p || i >= q) {
                key.add(seq.get(i));
                continue;
            }
            if (i == l) {
                key.add(childType0);
                continue;
            }
            if (i != p) continue;
            key.add(childType1);
        }
        return new TypedRule(key, rootType);
    }

    public static void main(String[] args) throws NumberFormatException, IOException {
        TypedParser parser = TypedParser.loadParser("test/astroGrammarOldFmt");
        Random rand = new Random();
        ArrayList<List<String>> test = new ArrayList<List<String>>();
        for (int i = 0; i < 100000; ++i) {
            test.add(parser.sample(0, rand).getFirst());
        }
        Counter<List<String>> counter = new Counter<List<String>>();
        counter.incrementAll(test, 1.0);
        counter.normalize();
        System.out.println(counter.toString(30));
    }

    public Pair<List<String>, Double> sample(int rootType, Random rand) {
        ArrayList<TypedRule> listOfRules = null;
        listOfRules = new ArrayList<TypedRule>();
        for (TypedRule rule : this.ruleScores.keySet()) {
            if (rule.type != rootType) continue;
            listOfRules.add(rule);
        }
        double[] prs = new double[listOfRules.size()];
        for (int i = 0; i < prs.length; ++i) {
            prs[i] = this.ruleScores.get(listOfRules.get(i));
        }
        if (!MathUtils.isProb(prs)) {
            throw new RuntimeException();
        }
        TypedRule sample = (TypedRule)listOfRules.get(SampleUtils.sampleMultinomial(rand, prs));
        double score = this.ruleScores.get(sample);
        ArrayList<String> result = new ArrayList<String>();
        for (Object item : sample.getListRep()) {
            if (item instanceof String) {
                result.add((String)item);
                continue;
            }
            int type = (Integer)item;
            Pair<List<String>, Double> recursion = this.sample(type, rand);
            result.addAll((Collection)recursion.getFirst());
            score *= recursion.getSecond().doubleValue();
        }
        return new Pair<List<String>, Double>(result, score);
    }

    public static class TypedRule {
        private final int type;
        private final List<Object> productions;
        public static final char nonTermPrefix = 'X';

        public TypedRule(List<Object> productions, int type) {
            this.type = type;
            this.productions = new ArrayList<Object>();
            for (Object prod : productions) {
                this.productions.add(prod);
            }
        }

        public int getType() {
            return this.type;
        }

        public int getNumberOfNonTerminals() {
            int numbNT = 0;
            for (Object ruleToken : this.productions) {
                if (!(ruleToken instanceof Integer)) continue;
                ++numbNT;
            }
            return numbNT;
        }

        public List<Object> getListRep() {
            return Collections.unmodifiableList(this.productions);
        }

        public static TypedRule parse(String description) {
            String[] separateArrow = description.split("\\s*[-][>]\\s*");
            int type = Integer.parseInt(separateArrow[0].substring(1, separateArrow[0].length()));
            String[] tokens = separateArrow[1].split("\\s+");
            ArrayList<Object> rule = new ArrayList<Object>();
            if (tokens.length == 0) {
                throw new RuntimeException();
            }
            for (String token : tokens) {
                if (token.charAt(0) == 'X') {
                    rule.add(Integer.parseInt(token.substring(1, token.length())));
                    continue;
                }
                rule.add(token);
            }
            return new TypedRule(rule, type);
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            builder.append("X" + this.type + " -> ");
            for (int i = 0; i < this.productions.size(); ++i) {
                Object item = this.productions.get(i);
                if (item instanceof Integer) {
                    builder.append("X" + item);
                } else {
                    builder.append(item);
                }
                if (i == this.productions.size() - 1) continue;
                builder.append(" ");
            }
            return builder.toString();
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null) {
                return false;
            }
            if (!(o instanceof TypedRule)) {
                return false;
            }
            TypedRule o_cast = (TypedRule)o;
            if (this.productions != null ? !this.productions.equals(o_cast.productions) : o_cast.productions != null) {
                return false;
            }
            return this.type == o_cast.type;
        }

        public int hashCode() {
            int hashCode = this.productions != null ? this.productions.hashCode() : 0;
            hashCode = 29 * hashCode + this.type;
            return hashCode;
        }
    }

    public static class AppliedRule {
        private TypedRule rule;
        private int l = -1;
        private int m = -1;
        private int p = -1;
        private int q = -1;

        private AppliedRule(TypedRule rule) {
            this.rule = rule;
        }

        private AppliedRule(TypedRule rule, int l, int m) {
            this.rule = rule;
            this.l = l;
            this.m = m;
        }

        private AppliedRule(TypedRule rule, int l, int m, int p, int q) {
            this.rule = rule;
            this.l = l;
            this.m = m;
            this.p = p;
            this.q = q;
        }
    }
}

