/*
 * Decompiled with CFR 0.152.
 */
package biparse;

import biparse.Utils;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import nuts.io.IO;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import nuts.util.Tree;

public final class Grammar {
    private final double[][][] binPrs;
    private final Map<Pair<Integer, Integer>, Double> unPrs = new HashMap<Pair<Integer, Integer>, Double>();
    private final GrammarIndex idx;
    public static final String allowedSymbolsRegex = "[^ ]+";
    public static final Pattern rule = Pattern.compile("^([^ ]+)\\s+[-][>]\\s+([^ ]+)(?:\\s+([^ ]+))?\\t(.*)$");
    public static final Pattern ignoredLines = Pattern.compile("(^\\s*[#]|^\\s*$)");
    public static final Pattern rootSpec = Pattern.compile("START SYMBOL[:]\\s*([^ ]+)");

    public Grammar(GrammarIndex idx) {
        this.idx = idx;
        this.binPrs = new double[idx.nNonTerm()][idx.nNonTerm()][idx.nNonTerm()];
    }

    public boolean isProb() {
        for (int s1 = 0; s1 < this.idx.nNonTerm(); ++s1) {
            double sum = 0.0;
            for (int t = 0; t < this.idx.nTerm(); ++t) {
                sum += this.unPr(t, s1);
            }
            if (!MathUtils.close(sum, 1.0)) {
                return false;
            }
            sum = 0.0;
            for (int s2 = 0; s2 < this.idx.nNonTerm(); ++s2) {
                for (int s3 = 0; s3 < this.idx.nNonTerm(); ++s3) {
                    sum += this.binPr(s1, s2, s3);
                }
            }
            if (MathUtils.close(sum, 1.0)) continue;
            return false;
        }
        return true;
    }

    public double binPr(int parent, int child1, int child2) {
        return this.binPrs[parent][child1][child2];
    }

    public void setBinPr(int parent, int child1, int child2, double value) {
        this.binPrs[parent][child1][child2] = value;
    }

    public double unPr(int parent, int child) {
        if (parent < 0 || parent >= this.idx.nNonTerm() || child < 0 || child >= this.idx.nTerm()) {
            throw new RuntimeException();
        }
        Double result = this.unPrs.get(new Pair<Integer, Integer>(parent, child));
        if (result == null) {
            return 0.0;
        }
        return result;
    }

    public void setUnPr(int parent, int child, double value) {
        this.unPrs.put(new Pair<Integer, Integer>(parent, child), value);
    }

    public int nNonTerm() {
        return this.idx.nNonTerm();
    }

    public int nTerm() {
        return this.idx.nTerm();
    }

    public int startSymbol() {
        return this.idx.startSymbol;
    }

    public GrammarIndex getIndex() {
        return this.idx;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("START SYMBOL:" + this.idx.nonTermIndex2String(this.startSymbol()) + "\n");
        for (int parent = 0; parent < this.nNonTerm(); ++parent) {
            for (int prod = 0; prod < this.nTerm(); ++prod) {
                if (!(this.unPr(parent, prod) > 0.0)) continue;
                builder.append(this.toString(parent, prod) + "\n");
            }
            for (int child1 = 0; child1 < this.nNonTerm(); ++child1) {
                for (int child2 = 0; child2 < this.nNonTerm(); ++child2) {
                    if (!(this.binPr(parent, child1, child2) > 0.0)) continue;
                    builder.append(this.toString(parent, child1, child2) + "\n");
                }
            }
        }
        return builder.toString();
    }

    public String toString(int parent, int child) {
        return this.idx.toString(parent, child) + "\t" + this.unPr(parent, child);
    }

    public String toString(int parent, int child1, int child2) {
        return this.idx.toString(parent, child1, child2) + "\t" + this.binPr(parent, child1, child2);
    }

    public String nonTermIndex2String(int index) {
        return (String)this.idx.nonTermIndexer.i2o(index);
    }

    public String termIndex2String(int index) {
        return (String)this.idx.termIndexer.i2o(index);
    }

    public int nonTermString2Index(String string) {
        return this.idx.nonTermIndexer.o2i(string);
    }

    public int termString2Index(String string) {
        return this.idx.termIndexer.o2i(string);
    }

    public static void main(String[] args) throws IOException {
        Grammar g = Grammar.loadGrammar("test/astroGrammar");
        Random rand = new Random();
        ArrayList<List<String>> test = new ArrayList<List<String>>();
        for (int i = 0; i < 100000; ++i) {
            test.add(g.generate(rand));
        }
        Counter<List<String>> counter = new Counter<List<String>>();
        counter.incrementAll(test, 1.0);
        counter.normalize();
        System.out.println(counter.toString(30));
    }

    public List<String> generate(Random rand) {
        return this.generate(this.idx.startSymbol, rand);
    }

    private List<String> generate(int symbol, Random rand) {
        ArrayList<String> result = new ArrayList<String>();
        double[] prs = new double[this.nNonTerm() + this.nNonTerm() * this.nNonTerm()];
        for (int prod1 = 0; prod1 < this.nNonTerm(); ++prod1) {
            prs[prod1] = this.unPr(symbol, prod1);
            for (int prod2 = 0; prod2 < this.nNonTerm(); ++prod2) {
                prs[this.nNonTerm() + prod1 + this.nNonTerm() * prod2] = this.binPrs[symbol][prod1][prod2];
            }
        }
        if (!MathUtils.isProb(prs)) {
            throw new RuntimeException();
        }
        int sample = SampleUtils.sampleMultinomial(rand, prs);
        if (sample < this.nNonTerm()) {
            int prod = sample;
            result.add(this.termIndex2String(prod));
        } else {
            int prod1 = (sample -= this.nNonTerm()) % this.nNonTerm();
            int prod2 = sample / this.nNonTerm();
            result.addAll(this.generate(prod1, rand));
            result.addAll(this.generate(prod2, rand));
        }
        return result;
    }

    public static GrammarIndex loadGrammarIndex(Iterable<Tree<String>> trees) {
        HashSet<String> terminals = new HashSet<String>();
        HashSet<String> nonTerminals = new HashSet<String>();
        String start = null;
        for (Tree<String> tree : trees) {
            if (start != null && !start.equals(tree.getLabel())) {
                throw new RuntimeException("All trees should have the same root");
            }
            start = tree.getLabel();
            Utils.terminals(tree, terminals);
            Utils.nonTerminals(tree, nonTerminals);
        }
        return new GrammarIndex(nonTerminals, terminals, start);
    }

    public static Grammar hardLoadGrammar(String path) {
        try {
            return Grammar.loadGrammar(path);
        }
        catch (IOException e) {
            throw new RuntimeException();
        }
    }

    public static Grammar loadGrammar(String path) throws IOException {
        final GrammarIndex idx = new GrammarIndex();
        idx.termIndexer = new Indexer();
        idx.nonTermIndexer = new Indexer();
        Grammar.processRules(path, new RuleProcessor(){

            @Override
            public void processStartSymbol(String symbol) {
                idx.nonTermIndexer.addToIndex(symbol);
                idx.startSymbol = idx.nonTermString2Index(symbol);
            }

            @Override
            public void processBinary(String lhs, String rhs1, String rhs2, double pr) {
                idx.nonTermIndexer.addToIndex(lhs, rhs1, rhs2);
            }

            @Override
            public void processUnary(String lhs, String rhs, double pr) {
                idx.nonTermIndexer.addToIndex(lhs);
                idx.termIndexer.addToIndex(rhs);
            }
        });
        final Grammar g = new Grammar(idx);
        Grammar.processRules(path, new RuleProcessor(){

            @Override
            public void processStartSymbol(String symbol) {
            }

            @Override
            public void processBinary(String lhs, String rhs1, String rhs2, double pr) {
                int parent = g.idx.nonTermIndexer.o2i(lhs);
                int child1 = g.idx.nonTermIndexer.o2i(rhs1);
                int child2 = g.idx.nonTermIndexer.o2i(rhs2);
                g.setBinPr(parent, child1, child2, pr);
            }

            @Override
            public void processUnary(String lhs, String rhs, double pr) {
                int parent = g.idx.nonTermIndexer.o2i(lhs);
                int child = g.idx.termIndexer.o2i(rhs);
                g.setUnPr(parent, child, pr);
            }
        });
        return g;
    }

    private static void processRules(String path, RuleProcessor p) throws NumberFormatException, IOException {
        for (String line : IO.i(path)) {
            if (ignoredLines.matcher(line).matches()) continue;
            Matcher m = rootSpec.matcher(line);
            if (m.matches()) {
                p.processStartSymbol(m.group(1));
                continue;
            }
            m = rule.matcher(line);
            if (m.matches()) {
                double pr = Double.parseDouble(m.group(4));
                String lhs = m.group(1);
                String rhs = m.group(2);
                if (m.group(3) == null) {
                    p.processUnary(lhs, rhs, pr);
                    continue;
                }
                p.processBinary(lhs, rhs, m.group(3), pr);
                continue;
            }
            throw new RuntimeException("Wrong format: " + line);
        }
    }

    public static double[] nonTermHells(Grammar realGrammar, Grammar approx) {
        return Grammar.nonTermDivergence(realGrammar, approx, new Hellinger());
    }

    private static double[] nonTermDivergence(Grammar realGrammar, Grammar approximation, Divergence factory) {
        int nSymb = realGrammar.idx.nNonTerm();
        if (nSymb != approximation.idx.nNonTerm()) {
            throw new RuntimeException();
        }
        double[] result = new double[nSymb];
        for (int lhs = 0; lhs < nSymb; ++lhs) {
            Divergence d = factory.newInstance();
            for (int rhs1 = 0; rhs1 < nSymb; ++rhs1) {
                for (int rhs2 = 0; rhs2 < nSymb; ++rhs2) {
                    double p = realGrammar.binPr(lhs, rhs1, rhs2);
                    double q = approximation.binPr(lhs, rhs1, rhs2);
                    d.addPoints(p, q);
                }
            }
            result[lhs] = d.getDivergence();
        }
        return result;
    }

    public static class Hellinger
    implements Divergence {
        private double sum = 0.0;

        @Override
        public void addPoints(double real, double approx) {
            this.sum += Math.pow(Math.sqrt(real) - Math.sqrt(approx), 2.0);
        }

        @Override
        public double getDivergence() {
            return Math.sqrt(this.sum);
        }

        @Override
        public Divergence newInstance() {
            return new Hellinger();
        }
    }

    private static interface Divergence {
        public void addPoints(double var1, double var3);

        public double getDivergence();

        public Divergence newInstance();
    }

    private static interface RuleProcessor {
        public void processUnary(String var1, String var2, double var3);

        public void processBinary(String var1, String var2, String var3, double var4);

        public void processStartSymbol(String var1);
    }

    public static final class GrammarIndex {
        private Indexer<String> termIndexer;
        private Indexer<String> nonTermIndexer;
        private int startSymbol;

        private GrammarIndex() {
        }

        public GrammarIndex(Collection<String> nonTerminals, Collection<String> terminals, String startSymbol) {
            this.termIndexer = new Indexer();
            this.nonTermIndexer = new Indexer();
            for (String nonTerminal : nonTerminals) {
                this.nonTermIndexer.addToIndex((String[])new String[]{nonTerminal});
            }
            for (String terminal : terminals) {
                this.termIndexer.addToIndex((String[])new String[]{terminal});
            }
            this.startSymbol = this.nonTermString2Index(startSymbol);
        }

        public int[] convertSentence(List<String> sentence) {
            int[] result = new int[sentence.size()];
            for (int i = 0; i < sentence.size(); ++i) {
                result[i] = this.termIndexer.o2i(sentence.get(i));
            }
            return result;
        }

        public int nNonTerm() {
            return this.nonTermIndexer.size();
        }

        public int nTerm() {
            return this.termIndexer.size();
        }

        public String nonTermIndex2String(int index) {
            return this.nonTermIndexer.i2o(index);
        }

        public String termIndex2String(int index) {
            return this.termIndexer.i2o(index);
        }

        public int nonTermString2Index(String string) {
            return this.nonTermIndexer.o2i(string);
        }

        public int termString2Index(String string) {
            return this.termIndexer.o2i(string);
        }

        public String toString(int parent, int child) {
            return "" + this.nonTermIndex2String(parent) + " -> " + this.termIndex2String(child);
        }

        public String toString(int parent, int child1, int child2) {
            return "" + this.nonTermIndex2String(parent) + " -> " + this.nonTermIndex2String(child1) + " " + this.nonTermIndex2String(child2);
        }
    }
}

