/*
 * Decompiled with CFR 0.152.
 */
package pty.learn;

import Jama.Matrix;
import fig.basic.Option;
import fig.basic.Parallelizer;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.List;
import nuts.math.GMFct;
import nuts.math.RateMtxUtils;
import nuts.tui.Table;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import pty.Observations;
import pty.RootedTree;
import pty.learn.CTMCExpectations;
import pty.learn.DiscreteBP;
import pty.learn.Estimators;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.models.CTMC;

public class LearningProcessor
implements ParticleFilter.ParticleProcessor<PartialCoalescentState> {
    @Option
    public static int nThreads = 1;
    @Option
    public static boolean useReversibleMtx = false;
    @Option
    public static boolean tieSites = true;
    private final Matrix[] suffStats;

    public static Matrix getSufficientStatistics(RootedTree state, CTMC ctmc, Observations observations, int site) {
        GMFct<Taxon> posterior = DiscreteBP.posteriorMarginalTransitions(state, ctmc, observations, site);
        int nCharacters = ctmc.nCharacter(site);
        Matrix result = new Matrix(nCharacters, nCharacters);
        for (Arbre<Taxon> node : state.topology().nodes()) {
            if (node.isRoot()) continue;
            Taxon par = node.getParent().getContents();
            Taxon cur = node.getContents();
            double T = state.branchLengths().get(cur);
            double[][][][] expectations = CTMCExpectations.expectations(T, ctmc.getRateMtx(site));
            for (int topState = 0; topState < nCharacters; ++topState) {
                for (int botState = 0; botState < nCharacters; ++botState) {
                    Matrix current = new Matrix(expectations[topState][botState]);
                    double post = posterior.get(par, cur, topState, botState);
                    result.plusEquals(current.times(post));
                }
            }
        }
        return result;
    }

    public LearningProcessor(CTMC current) {
        this.suffStats = new Matrix[current.nSites()];
        for (int s = 0; s < current.nSites(); ++s) {
            this.suffStats[s] = new Matrix(current.nCharacter(s), current.nCharacter(s));
        }
    }

    @Override
    public void process(PartialCoalescentState pcs, double weight) {
        this.process(pcs.getFullCoalescentState(), pcs.getCTMC(), pcs.getObservations(), weight);
    }

    public void process(final RootedTree state, final CTMC ctmc, final Observations observations, final double weight) {
        List<Integer> ints = CollUtils.ints(this.suffStats.length);
        Parallelizer<Integer> par = new Parallelizer<Integer>(nThreads);
        par.setPrimaryThread();
        par.process(ints, new Parallelizer.Processor<Integer>(){

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            public void process(Integer s, int i, int n, boolean log) {
                Matrix current = LearningProcessor.getSufficientStatistics(state, ctmc, observations, s).times(weight);
                Matrix[] matrixArray = LearningProcessor.this.suffStats;
                synchronized (matrixArray) {
                    LearningProcessor.this.suffStats[s].plusEquals(current);
                }
            }
        });
    }

    public CTMC reestimate(CTMC old) {
        if (useReversibleMtx) {
            if (!tieSites) {
                throw new RuntimeException();
            }
            if (!old.isSiteTied()) {
                throw new RuntimeException();
            }
            double[] sd = old.getInitialDistribution(0);
            return this.getTiedReversibleMLE(sd);
        }
        if (tieSites) {
            return this.getTiedMLE();
        }
        return this.getMLE();
    }

    private CTMC getMLE() {
        ArrayList<double[][]> Qs = new ArrayList<double[][]>();
        for (Matrix m : this.suffStats) {
            Qs.add(Estimators.getGeneralRateMatrixMLE(m));
        }
        return new CTMC.GeneralCTMC(Qs);
    }

    public CTMC getTiedMLE() {
        return new CTMC.SimpleCTMC(Estimators.getGeneralRateMatrixMLE(this.sumSuffStats()), this.suffStats.length);
    }

    public Matrix sumSuffStats() {
        int nCharacters = this.suffStats[0].getColumnDimension();
        Matrix sum = new Matrix(nCharacters, nCharacters);
        for (Matrix m : this.suffStats) {
            if (m.getColumnDimension() != nCharacters) {
                throw new RuntimeException();
            }
            sum.plusEquals(m);
        }
        return sum;
    }

    public CTMC getTiedReversibleMLE() {
        int nChars = this.suffStats[0].getColumnDimension();
        double[] statDistn = new double[nChars];
        for (int i = 0; i < statDistn.length; ++i) {
            statDistn[i] = 1.0 / (double)nChars;
        }
        return this.getTiedReversibleMLE(statDistn);
    }

    public CTMC getTiedReversibleMLE(double[] statDistn) {
        int nCharacters = this.suffStats[0].getColumnDimension();
        double[][] rateMtx = new double[nCharacters][nCharacters];
        Matrix sum = this.sumSuffStats();
        for (int r = 0; r < nCharacters; ++r) {
            for (int c = 0; c < nCharacters; ++c) {
                if (r == c) continue;
                rateMtx[r][c] = statDistn[c] * (sum.get(r, c) + sum.get(c, r)) / (sum.get(c, c) * statDistn[r] + sum.get(r, r) * statDistn[c]);
            }
        }
        RateMtxUtils.fillRateMatrixDiagonalEntries(rateMtx);
        return new CTMC.SimpleCTMC(rateMtx, this.suffStats.length);
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        for (Matrix m : this.suffStats) {
            result.append(Table.toString(m) + '\n');
        }
        return result.toString();
    }
}

