/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.Option;
import fig.exec.Execution;
import goblin.DerivationTree;
import goblin.HLFeatureExtractor;
import goblin.HLParams;
import goblin.HLParamsUpdater;
import goblin.LineageSampler;
import goblin.ObservationsTracker;
import goblin.Taxon;
import goblin.TreeSamplers;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import ma.AffineGapAlignmentSampler;
import ma.MultiAlignment;
import nuts.lispparser.LispParser;
import nuts.lispparser.ParseException;
import nuts.math.MeasureZeroException;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Tree;
import pepper.Encodings;

public class HLParamsTester
implements Runnable {
    public Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> suffStat = new Counter();
    @Option
    public int encodingSize = 1;
    @Option
    public Random algoRand = new Random(1L);
    @Option
    public Random paramGenRand = new Random(3L);
    @Option
    public Random dataGenRand = new Random(1L);
    @Option
    public boolean denyInDel = false;
    @Option
    public boolean printDeltaLength = false;
    @Option
    public boolean checkMissing = false;
    @Option
    public boolean microtest = false;
    @Option
    public String tree = "(a b c)";
    @Option
    public int N = 1000000;
    @Option
    public int interval = 1000;
    @Option
    public boolean resampleAlign = false;
    @Option
    public boolean resampleDeriv = true;
    private Encodings enc;
    public static HLParams gold;
    public static final int MAX_ANALYTIC_TOP_L = 10;

    public static void main(String[] args) throws ParseException, MeasureZeroException {
        System.out.println(Arrays.toString(args));
        Execution.run(args, new HLParamsTester(), "mo", TreeSamplers.mixtureOptions, "eo", TreeSamplers.edgeOptions, "big", TreeSamplers.bigAncestryOptions, "small", TreeSamplers.smallAncestryOptions);
    }

    @Override
    public void run() {
        try {
            this.go();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public Set<String> allStrs(int minLength, int maxLength) {
        HashSet<String> result = new HashSet<String>();
        for (int i = minLength; i < maxLength; ++i) {
            result.addAll(this.allStrs(i));
        }
        return result;
    }

    public Set<String> allStrs(int length) {
        HashSet<String> result = new HashSet<String>();
        if (length == 0) {
            result.add("");
        } else {
            for (int i = 0; i < this.enc.getNumberOfPhonemes(); ++i) {
                if (i == this.enc.getBoundaryPhoneId()) continue;
                result.addAll(HLParamsTester.addPrefix(this.enc.phoneId2Char(i), this.allStrs(length - 1)));
            }
        }
        return result;
    }

    public static Set<String> addPrefix(char prefix, Set<String> strs) {
        HashSet<String> result = new HashSet<String>();
        for (String str : strs) {
            result.add(prefix + str);
        }
        return result;
    }

    public Counter<String> analytic(String bot, boolean denyInDels) {
        if (!this.denyInDel && denyInDels) {
            throw new RuntimeException();
        }
        if (!this.tree.equals("(a b)")) {
            throw new RuntimeException();
        }
        Set<String> tops = denyInDels ? this.allStrs(bot.length()) : this.allStrs(0, 10);
        Counter<String> posterior = new Counter<String>();
        for (String top : tops) {
            posterior.setCount(top, this.unnormalizedPosterior(top, bot, denyInDels));
        }
        posterior.normalize();
        return posterior;
    }

    public static Arbre<DerivationTree.DerivationNode> monotonicTree(String top, String bot) {
        int l = top.length();
        if (l != bot.length()) {
            throw new RuntimeException();
        }
        int[] ancestors = new int[l];
        for (int i = 0; i < l; ++i) {
            ancestors[i] = i;
        }
        DerivationTree.Derivation d = new DerivationTree.Derivation(ancestors, top, bot);
        DerivationTree.DerivationNode topNode = new DerivationTree.DerivationNode(new Taxon("a"), top);
        DerivationTree.DerivationNode botNode = new DerivationTree.DerivationNode(new Taxon("b"), bot, d);
        return Arbre.arbre(botNode, Arbre.arbre(topNode)).root();
    }

    private Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> monotonicDerivTree(String top, String bot) {
        Arbre<DerivationTree.DerivationNode> a = HLParamsTester.monotonicTree(top, bot);
        Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> suffStats = new Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>>();
        HLParams.addSuffStats(suffStats, a, this.enc);
        return suffStats;
    }

    private double unnormalizedPosterior(String top, String bot, boolean denyInDels) {
        if (denyInDels) {
            Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> suffStats = this.monotonicDerivTree(top, bot);
            double result = 1.0;
            for (LabeledInstance<HLParams.HLContext, HLParams.HLOutcome> stat : suffStats.keySet()) {
                result *= Math.pow(gold.pr(stat.getInput(), stat.getLabel()), suffStats.getCount(stat));
            }
            return result;
        }
        if (!this.tree.equals("(a b)")) {
            throw new RuntimeException();
        }
        Taxon rootLang = new Taxon("a");
        double lm = Math.exp(gold.getRootLogPr(top, rootLang));
        AffineGapAlignmentSampler agas = AffineGapAlignmentSampler.createHLAlignmentSampler(top, bot, gold.getBranchParams().get(new Taxon("b")));
        double align = agas.getSumPr();
        return lm * align;
    }

    public void go() throws Exception {
        HLFeatureExtractor fe = new HLFeatureExtractor();
        fe.featureTemplates = new ArrayList<HLFeatureExtractor.HLFeatureTemplate>(Arrays.asList(HLFeatureExtractor.HLFeatureTemplate.FINEST));
        fe.init(null);
        Tree<String> topo = new LispParser(this.tree).parse();
        Taxon mainRoot = new Taxon(topo.getLabel());
        Set<Taxon> languages = HLParams.languages(topo);
        this.enc = Encodings.toyCtxFreeEncodings(this.encodingSize);
        gold = HLParams.randomHLParams(this.paramGenRand, this.enc, languages, this.denyInDel);
        System.out.println(HLParams.compare(gold, gold, 0.0, mainRoot));
        if (this.microtest) {
            this.microtest();
            return;
        }
        for (int i = 0; i < this.N; ++i) {
            Tree<String> curTopo = topo;
            Arbre<DerivationTree.DerivationNode> generatedTree = gold.generate(curTopo, this.dataGenRand);
            if (this.resampleAlign) {
                generatedTree = gold.resampleAlignments(generatedTree, this.algoRand);
            }
            if (this.resampleDeriv) {
                int oldtopl = generatedTree.root().getContents().getWord().length();
                int newl = (generatedTree = gold.resampleDerivation(generatedTree, this.algoRand)).root().getContents().getWord().length();
                if (newl != oldtopl && this.printDeltaLength) {
                    System.out.println("delta l:" + oldtopl + "->" + newl);
                }
            }
            Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> temp = new Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>>();
            HLParams.addSuffStats(temp, generatedTree, this.enc);
            this.suffStat.incrementAll(temp);
            if ((i + 1) % this.interval != 0) continue;
            System.out.println("Iteration " + i);
            MaxentClassifier.MaxentOptions op = new MaxentClassifier.MaxentOptions();
            op.iterations = 1000;
            op.sigma = 100.0;
            op.tolerance = 1.0E-10;
            HLParamsUpdater updater = new HLParamsUpdater(this.enc, languages, fe, op, new Counter<String>(), 1);
            HLParams estimated = updater.update(this.suffStat);
            System.out.println(this.suffStat);
            System.out.println(HLParams.compare(gold, estimated, 0.01, mainRoot));
            if (!this.checkMissing) continue;
            HashSet<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> all = new HashSet<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>>();
            for (HLParams.HLContext ctxt : HLParams.allHLContexts(this.enc, new Taxon("b"))) {
                if (ctxt.type == HLParams.ChoiceType.ROOT) continue;
                for (HLParams.HLOutcome outc : ctxt.allOutcomes()) {
                    all.add(new LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>(outc, ctxt));
                }
            }
            for (int p = 0; p < this.enc.getNumberOfPhonemes(); ++p) {
                for (int c = 0; c < this.enc.getNumberOfPhonemes(); ++c) {
                    all.add(HLParams.createRootSuffStat(new Taxon("a"), this.enc, p, c));
                }
            }
            HashSet<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> set = new HashSet<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>>();
            set.addAll(this.suffStat.keySet());
            set.removeAll(all);
            HashSet set2 = new HashSet();
            set.addAll(all);
            set.removeAll(this.suffStat.keySet());
            System.out.println("Those in suff stat not in all (" + set.size() + "): " + ((Object)set).toString());
            System.out.println("Those in all not in suff stat (" + set2.size() + "): " + ((Object)set2).toString());
        }
    }

    private void nanotest() {
        System.out.println(HLParams.compare(gold, gold, 0.0, new Taxon("a")));
        String bot = "baab";
        String top = "aaba";
        String top2 = "aaa";
        int insCode = -1;
        int[] ancestors = new int[]{0, -1, -1, -1};
        DerivationTree.Derivation d = new DerivationTree.Derivation(ancestors, "aaba", "baab");
        DerivationTree.Derivation d2 = new DerivationTree.Derivation(ancestors, "aaa", "baab");
        Taxon topLang = new Taxon("a");
        Taxon botLang = new Taxon("b");
        DerivationTree.DerivationNode botNode = new DerivationTree.DerivationNode(botLang, "baab", d);
        DerivationTree.DerivationNode botNode2 = new DerivationTree.DerivationNode(botLang, "baab", d2);
        DerivationTree.DerivationNode topNode = new DerivationTree.DerivationNode(topLang, "aaba");
        DerivationTree.DerivationNode topNode2 = new DerivationTree.DerivationNode(topLang, "aaa");
        Arbre<DerivationTree.DerivationNode> a = Arbre.arbre(botNode, Arbre.arbre(topNode)).root();
        Arbre<DerivationTree.DerivationNode> a2 = Arbre.arbre(botNode2, Arbre.arbre(topNode2)).root();
        DerivationTree.Window win = new DerivationTree.Window(1, 3);
        Arbre<DerivationTree.LineagedNode> la = DerivationTree.lineage(a.getChildren().get(0), win);
        Arbre<DerivationTree.LineagedNode> la2 = DerivationTree.lineage(a2.getChildren().get(0), win);
        System.out.println(la.deepToString());
        System.out.println(MultiAlignment.fullInducedMultiAlignment(a));
        ObservationsTracker obs = ObservationsTracker.modernObservationsTracker(a);
        LineageSampler.LongGapAlignmentSamplerAdaptor lgasa = new LineageSampler.LongGapAlignmentSamplerAdaptor(gold);
        TreeSamplers.AncestryMCMCKernelOptions so = new TreeSamplers.AncestryMCMCKernelOptions();
        so.restrictToObservedSymbols = TreeSamplers.AncestryMCMCKernelOptions.RestrictType.FALSE;
        LineageSampler ls = new LineageSampler(obs, la, lgasa, so);
    }

    private void microtest() throws MeasureZeroException {
        System.out.println("Doing microtest!");
        String bot = "baab";
        String initTop = "aaaa";
        Counter<String> analytic = this.analytic(bot, this.denyInDel);
        Counter<String> sampled = new Counter<String>();
        Arbre<DerivationTree.DerivationNode> current = HLParamsTester.monotonicTree(initTop, bot);
        double N2 = 100000.0;
        int i = 0;
        while ((double)i < 100000.0) {
            current = gold.resampleDerivation(current, this.algoRand);
            sampled.incrementCount(current.getContents().getWord(), 1.0);
            ++i;
        }
        sampled.normalize();
        System.out.println("Analytic\tSampled\tDifference)");
        for (String key : CollUtils.union(sampled.keySet(), analytic.keySet())) {
            System.out.println(key.toString() + "\t" + analytic.getCount(key) + "\t" + sampled.getCount(key) + "\t" + Math.abs(sampled.getCount(key) - analytic.getCount(key)));
        }
    }
}

