/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml;

import conifer.ml.AnnotatedCharacter;
import conifer.ml.CTMCExpFam;
import conifer.ml.ExpectedStatistics;
import conifer.ml.OptimizationOptions;
import conifer.ml.extractors.CategoryCollapsedFeatureExtractor;
import conifer.ml.extractors.IdentityExtractor;
import conifer.ml.extractors.RateFeatureExtractor;
import conifer.multicategories.CategoryModel;
import conifer.multicategories.PhylogeneticFactorGraph;
import fenchel.factor.multisitecat.MSCUnaryScaledFactor;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.exec.Execution;
import gep.util.OutputManager;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import nuts.io.IO;
import nuts.util.Arbre;
import nuts.util.Counter;
import nuts.util.CounterUtils;
import nuts.util.MathUtils;
import pty.Observations;
import pty.RootedTree;

public class FeaturizedCategoryModel {
    public final CategoryModel categoryModel;
    public final CTMCExpFam<AnnotatedCharacter> expFam;
    public OptimizationOptions optimizationOptions = new OptimizationOptions();
    public boolean saveToDiskOnUpdates = true;
    Parameters currentParameters;
    Statistics currentStatistics;
    private int i = 0;

    public FeaturizedCategoryModel(CategoryModel categoryModel, FeatureOptions featureOptions) {
        this.categoryModel = categoryModel;
        this.expFam = null;
        this.currentStatistics = new Statistics(this.expFam);
        this.initFeatures(featureOptions);
    }

    public void initParameters(Counter<Object> initial, double errorPr) {
        this.currentParameters = new Parameters(this.expFam.reversibleModelWithParameters(initial), errorPr);
        this.saveToExec();
    }

    private void initFeatures(FeatureOptions featureOptions) {
        ArrayList univariateFeatures = new ArrayList();
        if (featureOptions.fineUnivariateFeatures) {
            univariateFeatures.add(new IdentityExtractor());
        }
        if (featureOptions.categoryCollapsedFeatures) {
            univariateFeatures.add(new CategoryCollapsedFeatureExtractor(this.categoryModel.isFirstCategoryInvariant));
        }
        this.expFam.extractUnivariateFeatures(univariateFeatures);
        ArrayList bivariateFeatures = new ArrayList();
        if (featureOptions.fineBivariateFeatures) {
            bivariateFeatures.add(new IdentityExtractor());
        }
        if (featureOptions.rateFeatures) {
            bivariateFeatures.add(new RateFeatureExtractor(this.categoryModel.isFirstCategoryInvariant));
        }
        if (featureOptions.categoryCollapsedFeatures) {
            bivariateFeatures.add(new CategoryCollapsedFeatureExtractor(this.categoryModel.isFirstCategoryInvariant));
        }
        this.expFam.extractReversibleBivariateFeatures(bivariateFeatures);
    }

    public PhylogeneticFactorGraph getFactorGraph(RootedTree rt, Observations observations) {
        return this.categoryModel.getFactorGraph(this.currentParameters, rt, observations);
    }

    public void reestimateParametersAndFlushStats() {
        this.currentParameters = this.reestimateParameters(this.currentStatistics, this.currentParameters.ctmcParameters.weights);
        this.saveToExec();
        this.currentStatistics = new Statistics(this.expFam);
    }

    public Parameters reestimateParameters(Statistics stats, double[] warmStart) {
        double errorPr = stats.observationErrorCount / stats.observationCount;
        return new Parameters(this.expFam.fitReversibleModel(this.optimizationOptions, stats.ctmcStatistics, warmStart), errorPr);
    }

    public void addStatistics(PhylogeneticFactorGraph posteriorCalculator) {
        this.addStatistics(posteriorCalculator, this.currentStatistics);
    }

    public void addStatistics(PhylogeneticFactorGraph posteriorCalculator, Statistics currentStatistics) {
        currentStatistics.observationCount += (double)(posteriorCalculator.nSites * posteriorCalculator.observedTaxa().size());
        currentStatistics.observationErrorCount += posteriorCalculator.observationErrorCount();
        if (currentStatistics.observationErrorCount > currentStatistics.observationCount) {
            throw new RuntimeException();
        }
        Taxon rootTaxon = posteriorCalculator.rootedTree.topology().getContents();
        MSCUnaryScaledFactor rootFactor = posteriorCalculator.getNodePosterior(rootTaxon);
        for (int site = 0; site < posteriorCalculator.nSites; ++site) {
            for (int category = 0; category < posteriorCalculator.nCategories; ++category) {
                for (int observedCharacter = 0; observedCharacter < posteriorCalculator.nCharacters; ++observedCharacter) {
                    AnnotatedCharacter currentAnn = new AnnotatedCharacter(this.categoryModel.observationsIndexer.i2o(observedCharacter).charValue(), category);
                    currentStatistics.ctmcStatistics.addInitialValue(currentAnn, rootFactor.get(site, category, observedCharacter));
                }
            }
        }
        LogInfo.track("Adding statistics for " + posteriorCalculator.rootedTree.branchLengths().size() + " branches");
        int i = 0;
        for (Arbre<Taxon> subtree : posteriorCalculator.rootedTree.topology().nodes()) {
            if (subtree.isRoot()) continue;
            LogInfo.logs("Taxon " + i++);
            Taxon currentTaxon = subtree.getContents();
            Taxon parentTaxon = subtree.getParent().getContents();
            double[][][] marginal = posteriorCalculator.getPairwisePosterior(parentTaxon, currentTaxon);
            double T = posteriorCalculator.rootedTree.branchLengths().get(currentTaxon);
            for (int cat = 0; cat < this.categoryModel.nCategories; ++cat) {
                currentStatistics.ctmcStatistics.addCategorySpecificMarginalizedPath(marginal[cat], posteriorCalculator.potentials.rateMatrices[cat], T, this.categoryModel, cat);
            }
        }
        LogInfo.end_track();
    }

    private void saveToExec() {
        if (this.saveToDiskOnUpdates) {
            try {
                this.saveToExec("FeaturizedCategoryModel-" + this.i++);
            }
            catch (Exception e) {
                this.saveToDiskOnUpdates = false;
            }
        }
    }

    public void saveToExec(String str) {
        FeaturizedCategoryModel.save(this.currentParameters, new File(Execution.getFile(str)));
    }

    public static void save(Parameters p, File file) {
        LogInfo.logs("Current matrix:\n" + p.ctmcParameters.rateMatrixString());
        LogInfo.logs("observationErrorProbability: " + p.observationErrorProbability);
        LogInfo.logs("Saving FeaturizedCategoryModel to " + file.getAbsolutePath());
        file.mkdir();
        File weights = new File(file, "weights.txt");
        CounterUtils.saveStringCounter(p.ctmcParameters.getWeights(), weights);
        File errorPrFile = new File(file, "observationErrorProbability.txt");
        IO.writeToDisk(errorPrFile, "" + p.observationErrorProbability);
        File rateMtxFile = new File(file, "rateMatrix.txt");
        IO.writeToDisk(rateMtxFile, p.ctmcParameters.getRateMatrix());
        File stationaryDistFile = new File(file, "stationaryDist.txt");
        IO.writeToDisk(stationaryDistFile, p.ctmcParameters.pi);
    }

    public class Statistics {
        public final ExpectedStatistics<AnnotatedCharacter> ctmcStatistics;
        public double observationErrorCount = 0.0;
        public double observationCount = 0.0;

        public Statistics(CTMCExpFam<AnnotatedCharacter> expFam) {
            this.ctmcStatistics = new ExpectedStatistics<AnnotatedCharacter>(expFam);
        }

        public double[] getCategoryCounts() {
            return FeaturizedCategoryModel.this.categoryModel.marginalizeByCategory(this.ctmcStatistics.nInit);
        }

        public double[] getCategoryNTransitions() {
            double[] transitions = new double[this.ctmcStatistics.model.nStates];
            for (int i = 0; i < transitions.length; ++i) {
                transitions[i] = MathUtils.sum(this.ctmcStatistics.nTrans[i]);
            }
            return FeaturizedCategoryModel.this.categoryModel.marginalizeByCategory(transitions);
        }

        public void report(OutputManager output) {
            double[] categoryCounts = this.getCategoryCounts();
            double[] transitions = this.getCategoryNTransitions();
            for (int i = 0; i < categoryCounts.length; ++i) {
                output.printWrite("categoryCounts", "category", i, "count", categoryCounts[i]);
                output.printWrite("getCategoryNTransitions", "category", i, "count", transitions[i]);
            }
            output.printWrite("observationErrorCount", "observationErrorCount", this.observationErrorCount, "observationCount", this.observationCount);
        }
    }

    public static class Parameters {
        public final CTMCExpFam.LearnedReversibleModel ctmcParameters;
        public final double observationErrorProbability;

        public Parameters(CTMCExpFam.LearnedReversibleModel ctmcParameters, double observationErrorProbability) {
            this.ctmcParameters = ctmcParameters;
            this.observationErrorProbability = observationErrorProbability;
        }
    }

    public static class FeatureOptions {
        @Option
        public boolean categoryCollapsedFeatures = true;
        @Option
        public boolean rateFeatures = true;
        @Option
        public boolean fineBivariateFeatures = true;
        @Option
        public boolean fineUnivariateFeatures = true;
    }
}

