/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml;

import conifer.ml.AnnotatedCharacter;
import conifer.ml.CTMCExpFam;
import conifer.ml.RateMtxExpectations;
import conifer.ml.data.EndPointDataset;
import conifer.multicategories.CategoryModel;
import fig.basic.Pair;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import nuts.math.RateMtxUtils;
import nuts.tui.Table;
import nuts.util.Counter;
import nuts.util.CounterMap;
import nuts.util.MathUtils;

public class ExpectedStatistics<S> {
    final double[] holdTimes;
    final double[][] nTrans;
    final double[] nInit;
    public final CTMCExpFam<S> model;

    public ExpectedStatistics(CTMCExpFam<S> model) {
        this.model = model;
        this.holdTimes = new double[model.nStates];
        this.nInit = new double[model.nStates];
        this.nTrans = new double[model.nStates][];
        for (int state = 0; state < model.nStates; ++state) {
            this.nTrans[state] = new double[model.supports[state].length];
        }
    }

    public void addFromMarginalizedData(EndPointDataset<S> data, double[][] rateMatrix) {
        Counter initialCounts = data.initialCounts;
        for (Object state : initialCounts.keySet()) {
            this.addInitialValue(state, initialCounts.getCount(state));
        }
        Iterator<Object> iterator = data.branchLengths().iterator();
        while (iterator.hasNext()) {
            double len = (Double)iterator.next();
            this.addMarginalizedPath(data.getEndPointCounter(len), rateMatrix, len);
        }
    }

    public void addHoldingTime(S state, double time) {
        int n = this.model.stateIndexer.o2i(state);
        this.holdTimes[n] = this.holdTimes[n] + time;
    }

    public void addTransition(S state1, S state2, double count) {
        int index2;
        int index1 = this.model.stateIndexer.o2i(state1);
        int[] curSupport = this.model.supports[index1];
        int supportIdx = Arrays.binarySearch(curSupport, index2 = this.model.stateIndexer.o2i(state2));
        if (supportIdx < 0) {
            throw new RuntimeException("Transition not in the support: " + state1 + " -> " + state2);
        }
        double[] dArray = this.nTrans[index1];
        int n = supportIdx;
        dArray[n] = dArray[n] + count;
    }

    public void addInitialValue(S state, double count) {
        int n = this.model.stateIndexer.o2i(state);
        this.nInit[n] = this.nInit[n] + count;
    }

    public void addMarginalizedPath(CounterMap<S, S> endPointCounts, double[][] rateMtx, double T) {
        this._addMarginalizedPath(endPointCounts, null, rateMtx, T);
    }

    public void addMarginalizedPath(double[][] marginalCounts, double[][] rateMtx, double T) {
        this._addMarginalizedPath(null, marginalCounts, rateMtx, T);
    }

    private void _addMarginalizedPath(CounterMap<S, S> endPointCounts, double[][] marginalCounts, double[][] rateMtx, double T) {
        boolean useCounter;
        if (endPointCounts != null && marginalCounts != null) {
            throw new RuntimeException();
        }
        boolean bl = useCounter = endPointCounts != null;
        if (useCounter && endPointCounts.totalCount() == 0.0) {
            return;
        }
        int dim = rateMtx.length;
        double[][] auxMtx = new double[2 * dim][2 * dim];
        double[][] simpleExp = RateMtxUtils.marginalTransitionMtx(rateMtx, T);
        for (int state1 = 0; state1 < dim; ++state1) {
            int[] curSupport = this.model.supports[state1];
            for (int state2Idx = 0; state2Idx < curSupport.length + 1; ++state2Idx) {
                boolean isHoldTime = state2Idx == curSupport.length;
                int state2 = isHoldTime ? state1 : curSupport[state2Idx];
                double sum = 0.0;
                double[][] current = RateMtxExpectations._expectations(rateMtx, T, state1, state2, simpleExp, auxMtx);
                if (useCounter) {
                    for (S s1 : endPointCounts.keySet()) {
                        Counter<S> currentCounter = endPointCounts.getCounter(s1);
                        int i1 = this.model.stateIndexer.o2i(s1);
                        for (S s2 : currentCounter.keySet()) {
                            int i2 = this.model.stateIndexer.o2i(s2);
                            double currentCount = currentCounter.getCount(s2);
                            sum += currentCount * current[i1][i2];
                        }
                    }
                } else {
                    for (int x = 0; x < this.model.stateIndexer.size(); ++x) {
                        for (int y = 0; y < this.model.stateIndexer.size(); ++y) {
                            sum += marginalCounts[x][y] * current[x][y];
                        }
                    }
                }
                if (isHoldTime) {
                    int n = state1;
                    this.holdTimes[n] = this.holdTimes[n] + sum;
                    continue;
                }
                double[] dArray = this.nTrans[state1];
                int n = state2Idx;
                dArray[n] = dArray[n] + sum;
            }
        }
    }

    public void addCategorySpecificMarginalizedPath(double[][] marginalCounts, double[][] rateMtx, double T, CategoryModel catModel, int cat) {
        int nObservations = rateMtx.length;
        double[][] auxMtx = new double[2 * nObservations][2 * nObservations];
        double[][] simpleExp = RateMtxUtils.marginalTransitionMtx(rateMtx, T);
        for (int observationIndex1 = 0; observationIndex1 < nObservations; ++observationIndex1) {
            AnnotatedCharacter annChar1 = new AnnotatedCharacter(catModel.observationsIndexer.i2o(observationIndex1).charValue(), cat);
            int state1 = catModel.indexer.o2i(annChar1);
            int[] curSupport = this.model.supports[state1];
            for (int stateIdx2 = 0; stateIdx2 < curSupport.length + 1; ++stateIdx2) {
                boolean isHoldTime = stateIdx2 == curSupport.length;
                int observationIndex2 = isHoldTime ? observationIndex1 : catModel.observationsIndexer.o2i(Character.valueOf(catModel.indexer.i2o((int)curSupport[stateIdx2]).observedChar));
                double sum = 0.0;
                double[][] current = RateMtxExpectations._expectations(rateMtx, T, observationIndex1, observationIndex2, simpleExp, auxMtx);
                for (int x = 0; x < nObservations; ++x) {
                    for (int y = 0; y < nObservations; ++y) {
                        if (marginalCounts[x][y] == 0.0) continue;
                        sum += marginalCounts[x][y] * current[x][y];
                    }
                }
                if (isHoldTime) {
                    int n = state1;
                    this.holdTimes[n] = this.holdTimes[n] + sum;
                    continue;
                }
                double[] dArray = this.nTrans[state1];
                int n = stateIdx2;
                dArray[n] = dArray[n] + sum;
            }
        }
    }

    public double nSeries() {
        return MathUtils.sum(this.nInit);
    }

    public double totalTime() {
        return MathUtils.sum(this.holdTimes);
    }

    public String toString() {
        return "holdTimes = " + Arrays.toString(this.holdTimes) + "\nnInitStates = " + Arrays.toString(this.nInit) + "\nnTransitions =\n" + Table.toString(this.nTrans);
    }

    public void addInitialAndFullyObservedPathStatistics(List<Pair<Integer, Double>> datum) {
        this.addInitialValue(this.model.stateIndexer.i2o(datum.get(0).getFirst()), 1.0);
        for (int d = 0; d < datum.size(); ++d) {
            Object char1 = this.model.stateIndexer.i2o(datum.get(d).getFirst());
            this.addHoldingTime(char1, datum.get(d).getSecond());
            if (d == datum.size() - 1) continue;
            Object char2 = this.model.stateIndexer.i2o(datum.get(d + 1).getFirst());
            this.addTransition(char1, char2, 1.0);
        }
    }
}

