/*
 * Decompiled with CFR 0.152.
 */
package conifer.multicategories;

import conifer.ml.AnnotatedCharacter;
import conifer.ml.FeaturizedCategoryModel;
import conifer.multicategories.PhylogeneticFactorGraph;
import conifer.multicategories.PhylogenyPotentials;
import nuts.math.MutableGraph;
import nuts.math.RateMtxUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import pty.Observations;
import pty.RootedTree;

public class CategoryModel {
    public final int nCategories;
    public final boolean isFirstCategoryInvariant;
    public final Indexer<AnnotatedCharacter> indexer = new Indexer();
    public final Indexer<Character> observationsIndexer;
    public final MutableGraph<AnnotatedCharacter> reacheabilityGraph = new MutableGraph();

    public CategoryModel(int nCategories, boolean isFirstCategoryInvariant, Indexer<Character> observationsIndexer) {
        this.nCategories = nCategories;
        this.isFirstCategoryInvariant = isFirstCategoryInvariant;
        this.observationsIndexer = observationsIndexer;
        this.initIndexer();
        this.initGraph();
    }

    public int observationIndex(AnnotatedCharacter ann) {
        return this.observationsIndexer.o2i(Character.valueOf(ann.observedChar));
    }

    public int nObservedCharacters() {
        return this.observationsIndexer.size();
    }

    private void initIndexer() {
        for (int c = 0; c < this.nCategories; ++c) {
            for (int oIdx = 0; oIdx < this.observationsIndexer.size(); ++oIdx) {
                this.indexer.addToIndex((AnnotatedCharacter[])new AnnotatedCharacter[]{new AnnotatedCharacter(this.observationsIndexer.i2o(oIdx).charValue(), c)});
            }
        }
    }

    private void initGraph() {
        int c;
        int n = c = this.isFirstCategoryInvariant ? 1 : 0;
        while (c < this.nCategories) {
            for (char x : this.observationsIndexer.objects()) {
                for (char y : this.observationsIndexer.objects()) {
                    if (x == y) continue;
                    this.reacheabilityGraph.addEdge(new AnnotatedCharacter(x, c), new AnnotatedCharacter(y, c));
                }
            }
            ++c;
        }
    }

    public PhylogeneticFactorGraph getFactorGraph(FeaturizedCategoryModel.Parameters parameters, RootedTree rt, Observations observations) {
        double[] categoryPriors = new double[this.nCategories];
        double[][][] rateMatrices = new double[this.nCategories][this.nObservedCharacters()][this.nObservedCharacters()];
        for (int i = 0; i < parameters.ctmcParameters.pi.length; ++i) {
            int n = this.indexer.i2o((int)i).category;
            categoryPriors[n] = categoryPriors[n] + parameters.ctmcParameters.pi[i];
        }
        for (int c = 0; c < this.nCategories; ++c) {
            double[][] currentMatrix = rateMatrices[c];
            for (char sourceChar : this.observationsIndexer.objects()) {
                AnnotatedCharacter sourceAnn = new AnnotatedCharacter(sourceChar, c);
                Counter rates = parameters.ctmcParameters.getRates(sourceAnn);
                for (AnnotatedCharacter destAnn : rates.keySet()) {
                    currentMatrix[this.observationIndex((AnnotatedCharacter)sourceAnn)][this.observationIndex((AnnotatedCharacter)destAnn)] = rates.getCount(destAnn);
                }
            }
        }
        for (double[][] rateMatrix : rateMatrices) {
            RateMtxUtils.fillRateMatrixDiagonalEntries(rateMatrix);
        }
        double[] fullPi = parameters.ctmcParameters.pi;
        double[][] stationaryMatrices = new double[this.nCategories][this.nObservedCharacters()];
        for (int cat = 0; cat < this.nCategories; ++cat) {
            for (int charIndex = 0; charIndex < this.nObservedCharacters(); ++charIndex) {
                stationaryMatrices[cat][charIndex] = fullPi[this.indexer.o2i(new AnnotatedCharacter(this.observationsIndexer.i2o(charIndex).charValue(), cat))];
            }
        }
        PhylogeneticFactorGraph result = new PhylogeneticFactorGraph(rt, new PhylogenyPotentials(categoryPriors, rateMatrices, stationaryMatrices, parameters.observationErrorProbability), observations);
        return result;
    }

    public double[] marginalizeByCategory(double[] counts) {
        double[] result = new double[this.nCategories];
        if (counts.length != this.indexer.size()) {
            throw new RuntimeException();
        }
        for (int i = 0; i < counts.length; ++i) {
            AnnotatedCharacter annChar = this.indexer.i2o(i);
            int n = annChar.category;
            result[n] = result[n] + counts[i];
        }
        return result;
    }
}

