/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml;

import conifer.ml.CTMCExpFam;
import conifer.ml.FeaturizedCategoryModel;
import conifer.ml.OptimizationOptions;
import conifer.ml.data.HeldoutData;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import conifer.multicategories.CategoryModel;
import conifer.multicategories.PhylogeneticFactorGraph;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import gep.util.OutputManager;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import nuts.io.IO;
import nuts.util.CollUtils;
import nuts.util.Counter;

public class TrainMultiCategoyModel
implements Runnable {
    public static PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions phyloOptions = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
    public static FeaturizedCategoryModel.FeatureOptions featureOptions = new FeaturizedCategoryModel.FeatureOptions();
    public static OptimizationOptions optimizationOptions = new OptimizationOptions();
    public FeaturizedCategoryModel model;
    @Option(gloss="Number of EM iterations to perform")
    public int nEMIteration = 100;
    @Option
    public boolean evaluatePredictiveLikelihood = false;
    @Option
    public int nCategories = 3;
    @Option
    public boolean useInvar = true;
    @Option
    public Random initRand = new Random(1L);
    @Option
    public double initMaxLength = 1.0;
    @Option
    public double initialErrorPr = 0.001;
    OutputManager output = new OutputManager();
    private PhylogeneticHeldoutDataset phyloData;
    private PhylogeneticHeldoutDataset fullPhyloData;

    public Pair<Double, double[]> gradientAndJointLogDensity(double[] currentWeights) {
        FeaturizedCategoryModel featurizedCategoryModel = this.model;
        featurizedCategoryModel.getClass();
        this.model.currentStatistics = featurizedCategoryModel.new FeaturizedCategoryModel.Statistics(this.model.expFam);
        this.model.currentParameters = new FeaturizedCategoryModel.Parameters(this.model.expFam.reversibleModelWithParameters(currentWeights), this.model.currentParameters.observationErrorProbability);
        PhylogeneticFactorGraph factorGraph = this.model.getFactorGraph(this.phyloData.rootedTree, this.phyloData.obs);
        this.model.addStatistics(factorGraph);
        CTMCExpFam.ExpectedCompleteReversibleObjective obj = this.model.expFam.getExpectedCompleteReversibleObjective(TrainMultiCategoyModel.optimizationOptions.regularizationStrength, this.model.currentStatistics.ctmcStatistics);
        return Pair.makePair(obj.valueAt(currentWeights), obj.derivativeAt(currentWeights));
    }

    public void setup() {
        this.phyloData = PhylogeneticHeldoutDataset.loadData(phyloOptions);
        this.fullPhyloData = this.evaluatePredictiveLikelihood ? PhylogeneticHeldoutDataset.loadData(phyloOptions, true) : null;
        LogInfo.logsForce("Estimate RAM required: " + (long)this.nCategories * (long)this.phyloData.rootedTree.branchLengths().size() * 4L * (long)this.phyloData.obs.nSites() * (long)this.phyloData.indexer.size() * 8L / 1024L / 1024L + " MB");
        this.model = new FeaturizedCategoryModel(new CategoryModel(this.nCategories, this.useInvar, this.phyloData.indexer), featureOptions);
        this.model.optimizationOptions = optimizationOptions;
        Counter<Object> initialWeights = new Counter<Object>();
        ArrayList<Object> features = CollUtils.list(this.model.expFam.featuresIndexer.objects());
        Collections.sort(features);
        for (Object e : features) {
            initialWeights.setCount(e, this.initMaxLength * this.initRand.nextDouble());
        }
        this.model.initParameters(initialWeights, this.initialErrorPr);
    }

    @Override
    public void run() {
        this.setup();
        for (int emIter = 0; emIter < this.nEMIteration; ++emIter) {
            PhylogeneticFactorGraph factorGraph = this.model.getFactorGraph(this.phyloData.rootedTree, this.phyloData.obs);
            this.model.addStatistics(factorGraph);
            LogInfo.track("Expected counts");
            this.model.currentStatistics.report(this.output.child("emIter", emIter));
            LogInfo.end_track();
            double logPX = factorGraph.getSumProductPosteriorCalculator().logZ();
            if (this.evaluatePredictiveLikelihood) {
                double logPXH = this.model.getFactorGraph(this.phyloData.rootedTree, this.fullPhyloData.obs).getSumProductPosteriorCalculator().logZ();
                this.output.printWrite("logpredlikelihood", "emIter", emIter, "logpredlikelihood", logPXH - logPX);
            }
            this.output.printWrite("loglikelihood", "emIter", emIter, "loglikelihood", logPX);
            HeldoutData.EvalResult eval = this.phyloData.heldOut.evaluate(factorGraph);
            eval.report(this.output.child("emIter", emIter));
            if (emIter == this.nEMIteration - 1) continue;
            this.model.reestimateParametersAndFlushStats();
        }
    }

    public static void main(String[] args) {
        IO.run(args, new TrainMultiCategoyModel(), "phyloData", phyloOptions, "featureOptions", featureOptions, "optimizationOptions", optimizationOptions);
    }
}

