/*
 * Decompiled with CFR 0.152.
 */
package sage;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.exec.Execution;
import goblin.CognateId;
import goblin.DerivationTree;
import goblin.HLFeatureExtractor;
import goblin.HLParams;
import goblin.HLParamsLoader;
import goblin.HLParamsUpdater;
import goblin.TreeSamplers;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import ma.MultiAlignment;
import nuts.lispparser.LispParser;
import nuts.math.MeasureZeroException;
import nuts.math.RejectionSampler;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.Arbre;
import nuts.util.Counter;
import nuts.util.Tree;
import pepper.Encodings;
import sage.FatContext;
import sage.FatFeatureExtractor;
import sage.FatGenerator;
import sage.LikelihoodModel;

public class FatModelTester
implements Runnable {
    @Option
    public String goldParams;
    @Option
    public String tree = "(a b c)";
    @Option
    public int nTrees = 100000;
    @Option
    public Random generationRandom = new Random(1L);
    @Option
    public Random samplingRandom = new Random(1L);
    @Option
    public int nEval = 100;
    @Option
    public int constructionEvalIters = 1000;
    @Option
    public int evalInterval = 1000;
    @Option
    public boolean doSample = true;
    @Option
    public boolean doAlignSample = false;
    private static HLFeatureExtractor baseExtractor = new HLFeatureExtractor();
    private static FatFeatureExtractor extractor = new FatFeatureExtractor(baseExtractor);
    private static TreeSamplers.AncestryMCMCKernelOptions so = new TreeSamplers.AncestryMCMCKernelOptions();
    private static HLParamsLoader proposalParamsLoader = new HLParamsLoader();
    private static RejectionSampler<Arbre<DerivationTree.DerivationNode>> rs = new RejectionSampler();
    private Encodings enc = Encodings.proteinEncodings(true);
    private MaxentClassifier.MaxentOptions<Object> lbfgso = new MaxentClassifier.MaxentOptions();
    private FatGenerator generator;
    private Tree<String> topo;
    private HLParams proposal;
    private LikelihoodModel.FatLikelihoodModel goldModel;
    private Set<LabeledInstance<FatContext, HLParams.HLOutcome>> evaluation;

    public static void main(String[] args) {
        Execution.run(args, new FatModelTester(), "bfe", baseExtractor, "fe", extractor, "so", so, "pr", proposalParamsLoader, "rs", rs);
    }

    @Override
    public void run() {
        try {
            Counter goldWeights = HLParamsUpdater.restoreCounter(this.goldParams);
            Encodings.registerEncodings(this.enc);
            baseExtractor.setIgnoreLanguages(extractor);
            proposalParamsLoader.setFeatureExtractor(baseExtractor);
            this.proposal = proposalParamsLoader.getParams();
            this.goldModel = new LikelihoodModel.FatLikelihoodModel(this.enc, extractor, this.lbfgso, goldWeights, new Counter<Object>());
            this.generator = new FatGenerator(this.goldModel);
            this.topo = new LispParser(this.tree).parse();
            this.evaluation = this.createEvaluation();
            Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> suffStats = new Counter<LabeledInstance<FatContext, HLParams.HLOutcome>>();
            for (int i = 0; i < this.nTrees; ++i) {
                Arbre<DerivationTree.DerivationNode> cTree = this.generator.generate(this.topo, this.generationRandom, CognateId.dummy);
                if (this.doSample) {
                    cTree = this.sample(cTree);
                }
                if (this.doAlignSample) {
                    throw new RuntimeException();
                }
                FatContext.addSuffStats(suffStats, cTree, extractor.granularities(), this.enc, CognateId.dummy);
                if ((i + 1) % this.evalInterval != 0) continue;
                this.eval(suffStats);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void eval(Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> suffStats) {
        LogInfo.track((Object)"Evaluating", true);
        LikelihoodModel.FatLikelihoodModel estimated = new LikelihoodModel.FatLikelihoodModel(this.enc, extractor, this.lbfgso);
        estimated.update(suffStats);
        LogInfo.logs("Stat\tGold\tEstimated\tDelta");
        for (LabeledInstance<FatContext, HLParams.HLOutcome> stat : this.evaluation) {
            String cString = "";
            cString = cString + stat + "\t";
            double g = Math.exp(this.goldModel.logLikelihood(stat));
            double e = Math.exp(estimated.logLikelihood(stat));
            cString = cString + "" + g + "\t";
            cString = cString + "" + e + "\t";
            cString = cString + "" + Math.abs(g - e) + "\n";
            LogInfo.logs(cString);
        }
        LogInfo.end_track();
    }

    private Arbre<DerivationTree.DerivationNode> sample(Arbre<DerivationTree.DerivationNode> arbre) throws MeasureZeroException {
        throw new RuntimeException();
    }

    private Set<LabeledInstance<FatContext, HLParams.HLOutcome>> createEvaluation() {
        Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> all = new Counter<LabeledInstance<FatContext, HLParams.HLOutcome>>();
        HashSet<LabeledInstance<FatContext, HLParams.HLOutcome>> result = new HashSet<LabeledInstance<FatContext, HLParams.HLOutcome>>();
        for (int i = 0; i < this.constructionEvalIters; ++i) {
            Arbre<DerivationTree.DerivationNode> generated = this.generator.generate(this.topo, this.generationRandom, CognateId.dummy);
            System.out.println(MultiAlignment.fullInducedMultiAlignment(generated));
            Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> counter = new Counter<LabeledInstance<FatContext, HLParams.HLOutcome>>();
            FatContext.addSuffStats(counter, generated, extractor.granularities(), this.enc, CognateId.dummy);
            for (LabeledInstance<FatContext, HLParams.HLOutcome> stat : counter.keySet()) {
                if (!FatContext.isLongIns(stat) && !FatContext.isLongIns(stat)) continue;
                all.incrementCount(stat, counter.getCount(stat));
            }
        }
        int n = 0;
        for (LabeledInstance labeledInstance : all) {
            if (n++ >= this.nEval) continue;
            result.add(labeledInstance);
        }
        return result;
    }
}

