/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml;

import conifer.ml.ExpectedStatistics;
import conifer.ml.OptimizationOptions;
import conifer.ml.extractors.BivariateFeatureExtractor;
import conifer.ml.extractors.UnivariateFeatureExtractor;
import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Set;
import nuts.math.Graph;
import nuts.math.HashGraph;
import nuts.math.RateMtxUtils;
import nuts.math.SemiGraph;
import nuts.maxent.DifferentiableFunction;
import nuts.maxent.LBFGSMinimizer;
import nuts.maxent.SparseVector;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.MathUtils;

public class CTMCExpFam<S> {
    public final Indexer<S> stateIndexer;
    final int[][] supports;
    private final SparseVector[][] bivariateFeatures;
    private final SparseVector[] univariateFeatures;
    final int nStates;
    private int nFeatures;
    public final Indexer<Object> featuresIndexer = new Indexer();
    public final boolean isNormalized;

    public int nFeatures() {
        return this.nFeatures;
    }

    public CTMCExpFam(Graph<S> support, Indexer<S> indexer, boolean isNormalized) {
        this.stateIndexer = indexer;
        this.nStates = this.stateIndexer.size();
        this.supports = new int[this.nStates][];
        this.bivariateFeatures = new SparseVector[this.nStates][];
        this.univariateFeatures = new SparseVector[this.nStates];
        this.isNormalized = isNormalized;
        for (int state = 0; state < this.nStates; ++state) {
            S current = this.stateIndexer.i2o(state);
            Set<S> nbhr = support.nbrs(current);
            this.supports[state] = new int[nbhr.size()];
            int i = 0;
            for (Object item : nbhr) {
                if (item == current) {
                    throw new RuntimeException();
                }
                this.supports[state][i++] = this.stateIndexer.o2i(item);
            }
            Arrays.sort(this.supports[state]);
            this.bivariateFeatures[state] = new SparseVector[nbhr.size()];
        }
    }

    public static <S> CTMCExpFam<S> createModelWithFullSupport(Indexer<S> indexer, boolean isNormalized) {
        return new CTMCExpFam<S>(CTMCExpFam.completeGraph(indexer.objects()), indexer, isNormalized);
    }

    public static <S> Graph<S> completeGraph(final Set<S> states) {
        SemiGraph sg = new SemiGraph<S>(){

            @Override
            public boolean hasSemiEdge(S one, S two) {
                return one != two;
            }

            @Override
            public Set<S> vertexSet() {
                return states;
            }
        };
        return new HashGraph(sg);
    }

    public void extractUnivariateFeatures(Collection<UnivariateFeatureExtractor<S>> univariateFeatureExtractors) {
        this._extractFeatures(true, this.featuresIndexer, null, univariateFeatureExtractors);
    }

    public void extractReversibleBivariateFeatures(Collection<BivariateFeatureExtractor<S>> bivariateFeatureExtractors) {
        this._extractFeatures(false, this.featuresIndexer, bivariateFeatureExtractors, null);
    }

    public LearnedReversibleModel fitReversibleModel(OptimizationOptions optimizationOptions, ExpectedStatistics<S> currentStats, double[] warmStart) {
        this.checkFeaturesInitialized();
        ExpectedCompleteReversibleObjective expectedCompleteObjective = this.getExpectedCompleteReversibleObjective(optimizationOptions.regularizationStrength, currentStats);
        LBFGSMinimizer minimizer = new LBFGSMinimizer(optimizationOptions.maxIterations);
        if (warmStart == null) {
            warmStart = new double[this.nFeatures];
        }
        double[] w = minimizer.minimize(expectedCompleteObjective, warmStart, optimizationOptions.tolerance);
        return this.reversibleModelWithParameters(w);
    }

    public LearnedReversibleModel reversibleModelWithParameters(double[] w) {
        this.checkFeaturesInitialized();
        return new LearnedReversibleModel(w, this.isNormalized);
    }

    public LearnedReversibleModel reversibleModelWithParameters(Counter<Object> counter) {
        return this.reversibleModelWithParameters(this.convertFeatureCounter(counter));
    }

    public double[] convertFeatureCounter(Counter<Object> counter) {
        double[] w = new double[this.nFeatures];
        for (Object o : counter.keySet()) {
            w[this.featuresIndexer.o2i((Object)o)] = counter.getCount(o);
        }
        return w;
    }

    public ExpectedCompleteReversibleObjective getExpectedCompleteReversibleObjective(double kappa, ExpectedStatistics<S> stats) {
        this.checkFeaturesInitialized();
        return new ExpectedCompleteReversibleObjective(kappa, stats);
    }

    private void _extractFeatures(boolean statio, Indexer featureIndexer, Collection<BivariateFeatureExtractor<S>> transitionFeatureExtractors, Collection<UnivariateFeatureExtractor<S>> stationaryFeatureExtractors) {
        LogInfo.track((Object)("Extracting features for the " + (statio ? "stationary distribution" : "transitions")), true);
        Counter counter = new Counter();
        for (int state = 0; state < this.nStates; ++state) {
            for (int state2Idx = 0; state2Idx < (statio ? 1 : this.supports[state].length); ++state2Idx) {
                ArrayList<Integer> indices = new ArrayList<Integer>();
                ArrayList<Double> values = new ArrayList<Double>();
                S stateObj = this.stateIndexer.i2o(state);
                Object stateObj2 = statio ? null : (Object)this.stateIndexer.i2o(this.supports[state][state2Idx]);
                LogInfo.track((Object)((statio ? "state" : "trans") + "(" + state + (statio ? "" : "," + this.supports[state][state2Idx]) + ") = " + stateObj + (statio ? "" : " -> " + stateObj2)), true);
                for (Object extractor : statio ? stationaryFeatureExtractors : transitionFeatureExtractors) {
                    LogInfo.track("extractor = " + extractor);
                    if (statio) {
                        ((UnivariateFeatureExtractor)extractor).extract(counter, stateObj);
                    } else {
                        S _stateObj = state < this.supports[state][state2Idx] ? stateObj : stateObj2;
                        Object _stateObj2 = state < this.supports[state][state2Idx] ? stateObj2 : (Object)stateObj;
                        ((BivariateFeatureExtractor)extractor).extract(counter, _stateObj, _stateObj2);
                    }
                    for (Object feature : counter.keySet()) {
                        if (!featureIndexer.containsObject(feature)) {
                            featureIndexer.addToIndex(feature);
                        }
                        int fIndex = featureIndexer.o2i(feature);
                        double value = counter.getCount(feature);
                        LogInfo.logs("feature(" + fIndex + ") = " + feature + " [value=" + value + "]");
                        indices.add(fIndex);
                        values.add(value);
                    }
                    LogInfo.end_track();
                    counter.clear();
                }
                LogInfo.logs("nFeatures = " + indices.size());
                if (statio) {
                    this.univariateFeatures[state] = new SparseVector(indices, values);
                } else {
                    this.bivariateFeatures[state][state2Idx] = new SparseVector(indices, values);
                }
                LogInfo.end_track();
            }
        }
        LogInfo.end_track();
        this.nFeatures = featureIndexer.size();
    }

    private void checkFeaturesInitialized() {
        if (this.featuresIndexer.size() == 0) {
            throw new RuntimeException();
        }
    }

    public class LearnedReversibleModel {
        public final double[] weights;
        public final double[] pi;
        public final double normalizedValue;

        private LearnedReversibleModel(double[] w, boolean isNormalized) {
            this.weights = w;
            this.pi = this._buildPi();
            this.normalizedValue = isNormalized ? this.normalization(w) : 1.0;
        }

        private double[] _buildPi() {
            double[] pi = new double[CTMCExpFam.this.nStates];
            for (int i = 0; i < CTMCExpFam.this.nStates; ++i) {
                pi[i] = CTMCExpFam.this.univariateFeatures[i].dotProduct(this.weights);
            }
            NumUtils.expNormalize(pi);
            return pi;
        }

        private SparseVector theta(int startState, int endState) {
            int[] support = CTMCExpFam.this.supports[startState];
            int supportSize = support.length;
            double[] values = new double[supportSize];
            for (int i = 0; i < supportSize; ++i) {
                values[i] = Math.exp(CTMCExpFam.this.bivariateFeatures[startState][i].dotProduct(this.weights));
            }
            return new SparseVector(support, values);
        }

        private double normalization(double[] x) {
            if (!CTMCExpFam.this.isNormalized) {
                throw new RuntimeException();
            }
            LearnedReversibleModel w = new LearnedReversibleModel(x, false);
            double betainv = 0.0;
            double beta = 0.0;
            for (int startState = 0; startState < CTMCExpFam.this.nStates; ++startState) {
                int[] curSupports = CTMCExpFam.this.supports[startState];
                SparseVector qs = w.qs(startState);
                double sumQs = 0.0;
                for (int endStateIdx = 0; endStateIdx < curSupports.length; ++endStateIdx) {
                    double currentQ = qs.values[endStateIdx];
                    sumQs += currentQ;
                }
                betainv += w.pi[startState] * sumQs;
            }
            beta = 1.0 / betainv;
            return beta;
        }

        private SparseVector qs(int startState) {
            int[] support = CTMCExpFam.this.supports[startState];
            int supportSize = support.length;
            double[] values = new double[supportSize];
            for (int i = 0; i < supportSize; ++i) {
                values[i] = Math.exp(CTMCExpFam.this.bivariateFeatures[startState][i].dotProduct(this.weights)) * this.pi[support[i]] * this.normalizedValue;
            }
            return new SparseVector(support, values);
        }

        public Counter<S> getRates(S source) {
            int s = CTMCExpFam.this.stateIndexer.o2i(source);
            SparseVector qs = this.qs(s);
            int[] support = CTMCExpFam.this.supports[s];
            Counter result = new Counter();
            for (int j = 0; j < support.length; ++j) {
                result.setCount(CTMCExpFam.this.stateIndexer.i2o(support[j]), qs.values[j]);
            }
            return result;
        }

        public Counter<S> getStationaryDistribution() {
            Counter result = new Counter();
            for (int i = 0; i < this.pi.length; ++i) {
                result.setCount(CTMCExpFam.this.stateIndexer.i2o(i), this.pi[i]);
            }
            return result;
        }

        public double[][] getRateMatrix() {
            double[][] result = new double[CTMCExpFam.this.nStates][CTMCExpFam.this.nStates];
            for (int s1 = 0; s1 < CTMCExpFam.this.nStates; ++s1) {
                SparseVector qs = this.qs(s1);
                int[] support = CTMCExpFam.this.supports[s1];
                for (int j = 0; j < support.length; ++j) {
                    result[s1][support[j]] = qs.values[j];
                }
            }
            RateMtxUtils.fillRateMatrixDiagonalEntries(result);
            return result;
        }

        public Counter<String> getWeights() {
            Counter<String> result = new Counter<String>();
            for (int i = 0; i < CTMCExpFam.this.nFeatures; ++i) {
                result.setCount(CTMCExpFam.this.featuresIndexer.i2o(i).toString(), this.weights[i]);
            }
            return result;
        }

        public String rateMatrixString() {
            return RateMtxUtils.toString(this.getRateMatrix(), CTMCExpFam.this.stateIndexer);
        }
    }

    public class ExpectedCompleteReversibleObjective
    implements DifferentiableFunction {
        private final double kappa;
        private final double[] holdTimes;
        private final double[] nInit;
        private final double nInitStar;
        private final double[][] nTrans;
        private final double[] nTransStar;
        private final double nTransStarStar;
        private final double[] fixedDerivative;
        private double lastValue;
        private double[] lastDerivative;
        double[] lastX = null;

        private ExpectedCompleteReversibleObjective(double kappa, ExpectedStatistics<S> stats) {
            this.kappa = kappa;
            this.holdTimes = stats.holdTimes;
            this.nInit = stats.nInit;
            this.nTrans = stats.nTrans;
            this.nInitStar = stats.nSeries();
            this.nTransStar = new double[CTMCExpFam.this.nStates];
            for (int startState = 0; startState < CTMCExpFam.this.nStates; ++startState) {
                for (int endStateIdx = 0; endStateIdx < CTMCExpFam.this.supports[startState].length; ++endStateIdx) {
                    int endState;
                    int n = endState = CTMCExpFam.this.supports[startState][endStateIdx];
                    this.nTransStar[n] = this.nTransStar[n] + this.nTrans[startState][endStateIdx];
                }
            }
            this.nTransStarStar = MathUtils.sum(this.nTransStar);
            this.fixedDerivative = this._fixedDerivative();
        }

        @Override
        public int dimension() {
            return CTMCExpFam.this.nFeatures;
        }

        @Override
        public double valueAt(double[] x) {
            this.ensureCache(x);
            return this.lastValue;
        }

        @Override
        public double[] derivativeAt(double[] x) {
            this.ensureCache(x);
            return this.lastDerivative;
        }

        private void ensureCache(double[] x) {
            if (this.requiresUpdate(this.lastX, x)) {
                Pair<Double, double[]> currentValueAndDerivative = this.calculate(x);
                this.lastValue = currentValueAndDerivative.getFirst();
                this.lastDerivative = currentValueAndDerivative.getSecond();
                if (this.lastX == null) {
                    this.lastX = new double[x.length];
                }
                for (int i = 0; i < x.length; ++i) {
                    this.lastX[i] = x[i];
                }
            }
        }

        private double[] _fixedDerivative() {
            double[] result = new double[CTMCExpFam.this.nFeatures];
            for (int startState = 0; startState < CTMCExpFam.this.nStates; ++startState) {
                CTMCExpFam.this.univariateFeatures[startState].linearIncrement(this.nInit[startState] + this.nTransStar[startState], result);
                for (int endStateIdx = 0; endStateIdx < CTMCExpFam.this.supports[startState].length; ++endStateIdx) {
                    CTMCExpFam.this.bivariateFeatures[startState][endStateIdx].linearIncrement(this.nTrans[startState][endStateIdx], result);
                }
            }
            return result;
        }

        private Pair<Double, double[]> calculate(double[] x) {
            int startState;
            double[] gradient = (double[])this.fixedDerivative.clone();
            double value = 0.0;
            double normalizedValue = 0.0;
            LearnedReversibleModel w = new LearnedReversibleModel(x, CTMCExpFam.this.isNormalized);
            normalizedValue = CTMCExpFam.this.isNormalized ? w.normalization(x) : 1.0;
            double[] mStar = new double[CTMCExpFam.this.nStates];
            double mStarStar = 0.0;
            for (startState = 0; startState < CTMCExpFam.this.nStates; ++startState) {
                SparseVector curPiFeatures = CTMCExpFam.this.univariateFeatures[startState];
                curPiFeatures.linearIncrement(-w.pi[startState] * (this.nInitStar + this.nTransStarStar), gradient);
                value += Math.log(w.pi[startState]) * this.nInit[startState];
                double currentHold = this.holdTimes[startState];
                SparseVector qs = w.qs(startState);
                double sumQs = 0.0;
                int[] curSupports = CTMCExpFam.this.supports[startState];
                for (int endStateIdx = 0; endStateIdx < curSupports.length; ++endStateIdx) {
                    double currentQ = qs.values[endStateIdx];
                    value += this.nTrans[startState][endStateIdx] * Math.log(currentQ);
                    sumQs += currentQ;
                    double currentM = currentHold * currentQ;
                    CTMCExpFam.this.bivariateFeatures[startState][endStateIdx].linearIncrement(-currentM, gradient);
                    int n = curSupports[endStateIdx];
                    mStar[n] = mStar[n] + currentM;
                    mStarStar += currentM;
                }
                value -= sumQs * currentHold;
            }
            for (startState = 0; startState < CTMCExpFam.this.nStates; ++startState) {
                CTMCExpFam.this.univariateFeatures[startState].linearIncrement(w.pi[startState] * mStarStar - mStar[startState], gradient);
            }
            if (CTMCExpFam.this.isNormalized) {
                for (startState = 0; startState < CTMCExpFam.this.nStates; ++startState) {
                    SparseVector qs = w.qs(startState);
                    double sumQs = 0.0;
                    int[] curSupports = CTMCExpFam.this.supports[startState];
                    for (int endStateIdx = 0; endStateIdx < curSupports.length; ++endStateIdx) {
                        double currentQ = qs.values[endStateIdx];
                        CTMCExpFam.this.univariateFeatures[qs.indices[endStateIdx]].linearIncrement(qs.values[endStateIdx] * w.pi[startState] * (this.nTransStarStar - mStarStar) * -1.0, gradient);
                        CTMCExpFam.this.bivariateFeatures[startState][endStateIdx].linearIncrement(w.pi[startState] * qs.values[endStateIdx] * (this.nTransStarStar - mStarStar) * -1.0, gradient);
                        sumQs += currentQ;
                    }
                    CTMCExpFam.this.univariateFeatures[startState].linearIncrement(w.pi[startState] * sumQs * (this.nTransStarStar - mStarStar) * -1.0, gradient);
                    CTMCExpFam.this.univariateFeatures[startState].linearIncrement(2.0 * w.pi[startState] * (mStarStar - this.nTransStarStar) * -1.0, gradient);
                }
            }
            for (int f = 0; f < CTMCExpFam.this.nFeatures; ++f) {
                double curX = x[f];
                gradient[f] = -(gradient[f] - this.kappa * curX);
                value -= this.kappa * curX * curX / 2.0;
            }
            value = -value;
            return Pair.makePair(value, gradient);
        }

        private boolean requiresUpdate(double[] lastX, double[] x) {
            if (lastX == null) {
                return true;
            }
            for (int i = 0; i < x.length; ++i) {
                if (lastX[i] == x[i]) continue;
                return true;
            }
            return false;
        }
    }
}

