/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.StrUtils;
import fig.exec.Execution;
import goblin.Baseline;
import goblin.BayesRiskMinimizer;
import goblin.BioDataLoaderAdaptor;
import goblin.CognateId;
import goblin.CognateSet;
import goblin.DataLoader;
import goblin.DataLoaderInterface;
import goblin.DerivationTree;
import goblin.HLFeatureExtractor;
import goblin.HLIntegrator;
import goblin.HLParams;
import goblin.HLParamsLoader;
import goblin.HLParamsUpdater;
import goblin.ObservationsTracker;
import goblin.Taxon;
import goblin.TreeSamplers;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import ma.AffineGapAlignmentSampler;
import ma.BalibaseCorpus;
import ma.MultiAlignment;
import nuts.math.Sampling;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.Arbre;
import nuts.util.Counter;
import nuts.util.Tree;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
import pepper.Encodings;
import pepper.editmodel.Utils;
import sage.Fat2HLPrj;
import sage.FatContext;
import sage.FatFeatureExtractor;
import sage.FatGenerator;
import sage.LikelihoodModel;
import sage.MartinetHypothesesCalculator;

public class HLEM
implements Runnable {
    public static HLFeatureExtractor global_ext = null;
    public static HLParamsLoader global_paramLoader = null;
    @Option
    public boolean centerReguOnInit = true;
    @Option
    public ApplicationType application = ApplicationType.BALI;
    @Option(gloss="Breaks cognate sets into all pairs of observed word forms, removing all unobserved languages")
    public boolean pairsCognateSet = false;
    @Option(gloss="Max number of EM steps")
    public int emMaxNIterations = 20;
    @Option
    public Random randomBaselineRand = new Random(1L);
    @Option
    public boolean usePairedModel = true;
    @Option
    public boolean useProjection = false;
    @Option
    public String fullModelInit = "";
    public static final String SAME = "SAME";
    @Option
    public boolean onlyOneE = false;
    @Option
    public boolean estimate = true;
    @Option
    public boolean printFeatDebug = false;
    @Option
    public boolean computeFunctionalLoad = false;
    @Option
    public boolean dumpSuffStats = false;
    @Option
    public boolean generateHack = false;
    private DataLoaderInterface dataLoader;
    private final DataLoader _dataLoader;
    private final BioDataLoaderAdaptor _bioDataLoader;
    private final BalibaseCorpus.BalibaseCorpusOptions bco;
    private final HLParamsLoader initParamLoader;
    public static HLIntegrator.HLIOptions hliOptions = null;
    private final HLFeatureExtractor featureExtractor;
    private final MaxentClassifier.MaxentOptions<Object> lbfgsOptions;
    private final String[] args;
    private final Sampling.RandomSrcsScrambler scrambler;
    private final FatFeatureExtractor fatFeatureExtractor;
    private HLIntegrator integrator;
    public HLParams currentProposalParams;
    private HLParamsUpdater paramUpdater;
    private LikelihoodModel.FatLikelihoodModel model;

    public HLEM(Sampling.RandomSrcsScrambler scrambler, String[] args, DataLoader dataLoader, BioDataLoaderAdaptor bio, BalibaseCorpus.BalibaseCorpusOptions bco, HLParamsLoader initParamLoader, HLIntegrator.HLIOptions hliOptions, HLFeatureExtractor featureExtractor, FatFeatureExtractor fatEx, MaxentClassifier.MaxentOptions<Object> lbfgsOptions) {
        global_ext = featureExtractor;
        global_paramLoader = initParamLoader;
        this.scrambler = scrambler;
        this.args = args;
        this._dataLoader = dataLoader;
        this._bioDataLoader = bio;
        this.bco = bco;
        this.initParamLoader = initParamLoader;
        if (HLEM.hliOptions != null) {
            throw new RuntimeException();
        }
        HLEM.hliOptions = hliOptions;
        this.featureExtractor = featureExtractor;
        this.lbfgsOptions = lbfgsOptions;
        this.fatFeatureExtractor = fatEx;
    }

    @Override
    public void run() {
        Fat2HLPrj prj = new Fat2HLPrj();
        LogInfo.logs("Invocation args:" + StrUtils.join(this.args, " "));
        this.init();
        if (this.computeFunctionalLoad) {
            if (this.application != ApplicationType.LANG) {
                throw new RuntimeException();
            }
            this.integrator.collectNGrams();
        }
        if (this.dataLoader.hasHeldout()) {
            this.baselines();
        }
        HLParams.saveHLParamsInExec(this.currentProposalParams, "init-proposal");
        if (this.dataLoader.generated()) {
            HLParams.saveHLParamsInExec(this.dataLoader.getGeneratingParams(), "generating");
            CognateSet.saveCognateSetInExec(this.dataLoader.getGeneratingCognateSet(), "generating");
        }
        if (this.generateHack) {
            FatGenerator fg = new FatGenerator(this.model);
            Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> ss = new Counter<LabeledInstance<FatContext, HLParams.HLOutcome>>();
            Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> tkfss = new Counter<LabeledInstance<FatContext, HLParams.HLOutcome>>();
            for (CognateId id : this._bioDataLoader.getCognateSet().getCognateIds()) {
                Tree<String> t = Fat2HLPrj.convert(this._bioDataLoader.getCognateSet().getTree(id));
                AffineGapAlignmentSampler.TKFGenerator tkfg = new AffineGapAlignmentSampler.TKFGenerator(new AffineGapAlignmentSampler.TKFParams(Encodings.EncodingType.PROTEIN), this._dataLoader.dataGenerationRand, t, this._bioDataLoader.getBranchLengths(id));
                for (int i = 0; i < 100; ++i) {
                    Arbre<DerivationTree.DerivationNode> a = fg.generate(t, this._dataLoader.dataGenerationRand, id);
                    Arbre<DerivationTree.DerivationNode> tkfa = tkfg.generate(100);
                    HLEM hLEM = this;
                    FatContext.addSuffStats(ss, a, hLEM.fatFeatureExtractor.granularities(), this.model.getEncodings(), id);
                    HLEM hLEM2 = this;
                    FatContext.addSuffStats(tkfss, tkfa, hLEM2.fatFeatureExtractor.granularities(), this.model.getEncodings(), id);
                }
            }
            HLIntegrator.saveSuffStatsInExec("generatedFromInit", null, ss);
            HLIntegrator.saveSuffStatsInExec("generatedTKFFromInit", null, tkfss);
        }
        if (this.printFeatDebug) {
            this.featureExtractor.saveDebug("feat-debug", Encodings.getGlobalEncodings(), ((DataLoader)this.dataLoader).allLanguages());
        }
        for (int i = 0; i < this.emMaxNIterations; ++i) {
            this.integrator.setT(i);
            LogInfo.track((Object)"E step", true);
            boolean flushRecons = !this.lbfgsOptions.learningAlgo.isSampling;
            this.integrator.compute(flushRecons, this.currentProposalParams, this.model, this.dataLoader.getHeldout());
            LogInfo.end_track();
            Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> expSufStat = this.integrator.getSufficientStatistics();
            if (this.dataLoader.hasHeldout()) {
                this.evaluate(this.integrator.getReconstructions(), "GOBLIN-" + i);
                this.integrator.saveReconstructionSamples("iteration", i);
            }
            if (this.dataLoader.hasReferenceAlignments()) {
                if (this.dataLoader.hasHeldout()) {
                    LogInfo.warning("Warning: using both word heldout and MSA refs.. untested!");
                }
                this.evaluateMSA(this.integrator.getMSAReconstructions(), "GOBLIN-" + i);
                this.evaluateMSA(this.integrator.getSPBoundMSA(), "SP-UPPER-BOUND-" + i);
                try {
                    this.evaluateMSA(this.integrator.getBayesMSAReconstructions(), "GOBLIN-BAYES-" + i);
                }
                catch (Exception e) {
                    LogInfo.warning(e);
                }
            }
            CognateSet.saveCognateSetInExec(this.integrator.getCognateSet(), "snapshot", i);
            if (HLEM.hliOptions.collectContextEdit) {
                this.integrator.saveContextEditInExec("iteration", i);
            }
            if (this.dumpSuffStats) {
                this.integrator.saveSuffStatsInExec(i);
            }
            if (this.computeFunctionalLoad) {
                new MartinetHypothesesCalculator(this.integrator.getSufficientStatistics(), this.integrator.getNGramSuffStats(), this._dataLoader.getTopology(), this.integrator.getCognateSet().getLanguages()).weakPointStatistics("weakPointStats", i);
            }
            if (this.onlyOneE) {
                return;
            }
            if (this.estimate) {
                if (this.usePairedModel) {
                    if (this.lbfgsOptions.learningAlgo.isSampling) {
                        throw new RuntimeException();
                    }
                    this.model.update(this.integrator.getFatSufficientStatistics());
                    this.model.saveWeightsInExec("reest-model", i);
                    if (this.useProjection) {
                        this.currentProposalParams = prj.project(this.model, this.paramUpdater, Fat2HLPrj.cognateSet2Trees(this.integrator.getCognateSet()));
                    }
                }
                if (!this.usePairedModel || !this.useProjection) {
                    if (this.initParamLoader.loadFromSeri()) {
                        throw new RuntimeException("Fill this optimization or do not use seri in the future");
                    }
                    this.lbfgsOptions.initialWeights = HLParamsUpdater.restoreCounter(this.initParamLoader.paramsPath);
                    if (this.lbfgsOptions.learningAlgo.isSampling) {
                        CognateSet data = this.integrator.getCognateSet();
                        expSufStat = new Counter();
                        for (CognateId id : data.getCognateIds()) {
                            HLParams.addSuffStats(expSufStat, data.getTree(id), this.currentProposalParams.enc);
                        }
                    }
                    this.currentProposalParams = this.paramUpdater.update(expSufStat);
                }
                this.paramUpdater.saveWeightsInExec("reest-proposal", i);
            }
            if (!this.dataLoader.generated()) continue;
            LogInfo.logs("Proposal:");
            LogInfo.logs(HLParams.compare(this.currentProposalParams, this.dataLoader.getGeneratingParams(), 0.01, ((DataLoader)this.dataLoader).root()));
        }
    }

    private void evaluateMSA(Map<CognateId, MultiAlignment> reconstructions, String name) {
        PrintWriter out = IOUtils.openOutHard(Utils.safeGetExecFilePath(name + ".msaEval"));
        double numCS = 0.0;
        double numSP = 0.0;
        double nMSA = 0.0;
        double numAMA = 0.0;
        out.append("----\n");
        for (CognateId id : this.dataLoader.getCognateSet().getCognateIds()) {
            nMSA += 1.0;
            out.append("ID " + id + "\n");
            MultiAlignment rec = reconstructions.get(id);
            MultiAlignment gold = this.dataLoader.referenceAlignments().get(id);
            if (rec == null) {
                LogInfo.warning("One of the MSA was not reconstructed:" + id);
                continue;
            }
            out.append("TRUTH\n" + gold + "\n");
            out.append(name + "\n" + rec + "\n");
            double CS = gold.columnScore(rec);
            double SP = gold.sumOfPairsScore(rec);
            double AMA = MultiAlignment.amaSim(gold, rec);
            out.append("CS " + CS);
            out.append(" SP " + SP);
            out.append(" AMA " + AMA);
            numCS += CS;
            numSP += SP;
            numAMA += AMA;
            out.append("\n----\n");
        }
        String summary = "Average CS for " + name + ":" + numCS / nMSA + "\nAverage SP for " + name + ":" + numSP / nMSA + "\nAverage AMA for " + name + ":" + numAMA / nMSA + "\n";
        out.append(summary);
        LogInfo.logs(summary);
        out.close();
    }

    private void evaluate(Map<CognateId, String> reconstructions, String name) {
        PrintWriter out = IOUtils.openOutHard(Utils.safeGetExecFilePath(name + ".eval"));
        BayesRiskMinimizer.LossFct<String> loss = this.integrator.getLoss();
        DescriptiveStatistics stat = new DescriptiveStatistics();
        double totalTrueLen = 0.0;
        out.append("----\n");
        for (DataLoader.HeldoutEntry entry : this.dataLoader.getHeldout()) {
            out.append("ID " + entry.id + "\n");
            String rec = reconstructions.get(entry.id);
            if (rec == null) {
                LogInfo.warning("One of the words was not reconstructed:" + entry.id);
                continue;
            }
            out.append("TRUTH " + entry.trueReconstruction + "\n");
            out.append(name + " " + rec + "\n");
            double closs = loss.loss(rec, entry.trueReconstruction);
            stat.addValue(closs);
            totalTrueLen += (double)entry.trueReconstruction.length();
            out.append("LOSS " + closs);
            out.append("\n----\n");
        }
        String summary = "Average/median/var of losses for " + name + ": " + stat.getMean() + " / " + stat.getPercentile(50.0) + " / " + stat.getVariance();
        summary = summary + "\nAverage normalized losses for " + name + ": " + stat.getSum() / totalTrueLen;
        out.append(summary);
        LogInfo.logs(summary);
        out.close();
    }

    private void baselines() {
        Baseline.OracleBaseline ob = new Baseline.OracleBaseline(this.integrator.getLoss());
        this.evaluate(ob.baseline(this.dataLoader), "ORACLE");
        this.evaluate(new Baseline.RestrictToModernCharsUpperBound().baseline(this.dataLoader), "RSTRCT-U-BND");
        LogInfo.logs("Languages used by Oracle:\n" + ob.toString());
        this.evaluate(new Baseline.RandomBaseline(this.randomBaselineRand).baseline(this.dataLoader), "RANDOM");
        this.evaluate(new Baseline.FlatMinRiskBaseline(this.integrator.getLoss(), false).baseline(this.dataLoader), "ALL-MIN-RISK");
        if (HLEM.hliOptions.useFrankDecode) {
            this.evaluate(new Baseline.FlatMinRiskBaseline(this.integrator.getLoss(), true).baseline(this.dataLoader), "ALL-MIN-RISK-FRNK");
        }
        if (this.dataLoader instanceof DataLoader) {
            DataLoader _loader = (DataLoader)this.dataLoader;
            this.evaluate(new Baseline.FlatMinRiskBaseline(this.integrator.getLoss(), _loader, false).baseline(_loader), "PRJ-MIN-RISK");
            if (HLEM.hliOptions.useFrankDecode) {
                this.evaluate(new Baseline.FlatMinRiskBaseline(this.integrator.getLoss(), _loader, true).baseline(_loader), "PRJ-MIN-RISK-FRNK");
            }
        }
    }

    private void init() {
        if (this.application == ApplicationType.BALI) {
            this._bioDataLoader.setBioCorpusOptions(this.bco);
            this.dataLoader = this._bioDataLoader;
        } else if (this.application == ApplicationType.LANG) {
            this.dataLoader = this._dataLoader;
        } else {
            throw new RuntimeException();
        }
        for (Random rand : this.dataLoader.randomness()) {
            this.scrambler.addSrc(rand);
        }
        this.scrambler.addSrc(this.initParamLoader.paramGenerationRand);
        this.scrambler.addSrc(HLEM.hliOptions.combinatorialInitRandom);
        this.scrambler.addSrc(HLEM.hliOptions.treeSamplerRandom);
        this.scrambler.addSrc(this.randomBaselineRand);
        this.scrambler.scramble();
        CognateSet cognateSet = this.dataLoader.getCognateSet().copy();
        Encodings.saveEncodingsInExec(Encodings.getGlobalEncodings(), "globalEnc");
        if (this.pairsCognateSet) {
            cognateSet = HLEM.convertToPairs(cognateSet);
            LogInfo.logs("Corpus transformed to model only pairwise evolution links");
        }
        if (this.dataLoader.hasHeldout()) {
            DataLoaderInterface.Utils.saveHeldoutToExec(this.dataLoader);
        }
        if (this.dataLoader instanceof DataLoader) {
            ((DataLoader)this.dataLoader).saveCorpusToExec();
        }
        this.featureExtractor.init(this.dataLoader, this.fatFeatureExtractor);
        this.initParamLoader.setLanguages(cognateSet.getLanguages());
        this.currentProposalParams = this.initParamLoader.getParams();
        this.model = this.getModel();
        HLEM hLEM = this;
        this.integrator = new HLIntegrator(cognateSet, hliOptions, hLEM.fatFeatureExtractor.granularities());
        if (this.dataLoader.hasReferenceAlignments()) {
            this.integrator.setReferenceAlignments(this.dataLoader.referenceAlignments());
        }
        if (HLEM.hliOptions.initWithClustalw) {
            this.integrator.setClustalwAlignments(this._bioDataLoader.getClustalwAlignments());
        }
        String propInit = this.initParamLoader.paramsPath;
        Counter initModelWeigths = HLParamsUpdater.restoreCounter(propInit);
        this.paramUpdater = new HLParamsUpdater(this.currentProposalParams.enc, cognateSet.getLanguages(), this.featureExtractor, this.lbfgsOptions, this.centerReguOnInit ? initModelWeigths : new Counter(), HLEM.hliOptions.numThreads);
    }

    private LikelihoodModel.FatLikelihoodModel getModel() {
        if (!this.usePairedModel) {
            return null;
        }
        String modelInit = this.fullModelInit.equals(SAME) ? this.initParamLoader.paramsPath : this.fullModelInit;
        Counter initModelWeigths = HLParamsUpdater.restoreCounter(modelInit);
        return new LikelihoodModel.FatLikelihoodModel(this.currentProposalParams.enc, this.fatFeatureExtractor, this.lbfgsOptions, initModelWeigths, this.centerReguOnInit ? initModelWeigths : new Counter());
    }

    public static CognateSet convertToPairs(CognateSet cs) {
        CognateSet result = new CognateSet();
        for (CognateId id : cs.getCognateIds()) {
            ArrayList<Taxon> observed = new ArrayList<Taxon>(cs.getObs(id).observedLanguages());
            int s = observed.size();
            for (int i = 0; i < s; ++i) {
                for (int j = 0; j < s; ++j) {
                    if (i == j) continue;
                    DerivationTree.DerivationNode d1 = DerivationTree.findNodeByLangName(cs.getTree(id), (Taxon)observed.get(i)).getContents();
                    DerivationTree.DerivationNode d2 = DerivationTree.findNodeByLangName(cs.getTree(id), (Taxon)observed.get(j)).getContents();
                    result.addCognate(new CognateId("" + id + "-" + i + "-" + j), HLEM.arbrePair(d1, d2), HLEM.obsPair(d1, d2));
                }
            }
        }
        return result;
    }

    public static Arbre<DerivationTree.DerivationNode> arbrePair(DerivationTree.DerivationNode d1, DerivationTree.DerivationNode d2) {
        DerivationTree.DerivationNode parent = new DerivationTree.DerivationNode(d1.getLanguage(), d1.getWord());
        DerivationTree.DerivationNode child = new DerivationTree.DerivationNode(d2.getLanguage(), d2.getWord());
        return Arbre.arbre(child, Arbre.arbre(parent)).root();
    }

    public static ObservationsTracker obsPair(DerivationTree.DerivationNode d1, DerivationTree.DerivationNode d2) {
        HashSet<Taxon> obs = new HashSet<Taxon>(Arrays.asList(d1.getLanguage(), d2.getLanguage()));
        return new ObservationsTracker(obs);
    }

    public static void main(String[] args) {
        Execution.monitor = true;
        Execution.makeThunk = false;
        Execution.create = true;
        Execution.useStandardExecPoolDirStrategy = true;
        if (!Arrays.asList(args).contains("NOJARS")) {
            Execution.jarFiles = new ArrayList<String>(Arrays.asList("/home/eecs/bouchard/jars/pepper.jar", "/home/eecs/bouchard/jars/nuts.jar"));
        }
        Sampling.RandomSrcsScrambler scrambler = new Sampling.RandomSrcsScrambler();
        DataLoader dataLoader = new DataLoader();
        BalibaseCorpus.BalibaseCorpusOptions bco = new BalibaseCorpus.BalibaseCorpusOptions();
        BioDataLoaderAdaptor bioDataLoader = new BioDataLoaderAdaptor();
        HLParamsLoader initParamLoader = new HLParamsLoader();
        HLIntegrator.HLIOptions hliOptions = new HLIntegrator.HLIOptions();
        HLFeatureExtractor featureExtractor = new HLFeatureExtractor();
        initParamLoader.setFeatureExtractor(featureExtractor);
        dataLoader.generationParamLoader.setFeatureExtractor(featureExtractor);
        MaxentClassifier.MaxentOptions<Object> lbfgsOptions = new MaxentClassifier.MaxentOptions<Object>();
        FatFeatureExtractor fatFeatureEx = new FatFeatureExtractor(featureExtractor);
        Execution.run(args, new HLEM(scrambler, args, dataLoader, bioDataLoader, bco, initParamLoader, hliOptions, featureExtractor, fatFeatureEx, lbfgsOptions), "data", dataLoader, "biodata", bioDataLoader, "bali", bco, "artif", dataLoader.generationParamLoader, "init", initParamLoader, "hlio", hliOptions, "big", TreeSamplers.bigAncestryOptions, "small", TreeSamplers.smallAncestryOptions, "eso", TreeSamplers.edgeOptions, "mso", TreeSamplers.mixtureOptions, "fe", featureExtractor, "lo", lbfgsOptions, "enc", Encodings.class, "gap", AffineGapAlignmentSampler.class, "mr", scrambler, "ff", FatFeatureExtractor.class);
    }

    public static enum ApplicationType {
        BALI,
        LANG;

    }
}

