/*
 * Decompiled with CFR 0.152.
 */
package ev.poi;

import ev.poi.IntegratedLengthMarginalComputations;
import ev.poi.PoissonParameters;
import ev.poi.PoissonUtils;
import fig.basic.Parallelizer;
import goblin.CognateId;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import ma.BalibaseCorpus;
import ma.MSAPoset;
import ma.MultiAlignment;
import ma.RateMatrixLoader;
import nuts.math.GMFctUtils;
import nuts.math.HashGraph;
import nuts.math.RateMtxUtils;
import nuts.math.TabularGMFct;
import nuts.math.TreeSumProd;
import nuts.maxent.SloppyMath;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.RootedTree;

public final class MSAMarginalLikelihoodCalculator {
    public final RootedTree rootedTree;
    public final PoissonParameters params;
    private final Map<Taxon, Double> branchLogWeights;
    private final TabularGMFct<Taxon> graphicalModel;
    private final Map<Arbre<Taxon>, Set<Taxon>> descMap;
    private double _emptyColLL = Double.NaN;
    private static final Map<Taxon, Integer> EMPTY_COLUMN = Collections.unmodifiableMap(new HashMap());

    public MSAMarginalLikelihoodCalculator(PoissonParameters params, RootedTree rootedTree) {
        for (Taxon t : rootedTree.branchLengths().keySet()) {
            if (!(rootedTree.branchLengths().get(t) <= 0.0)) continue;
            throw new RuntimeException("Branch lengths should be positive");
        }
        this.params = params;
        this.rootedTree = rootedTree;
        this.descMap = Arbre.descMap(rootedTree.topology());
        this.branchLogWeights = this.branchLogWeights();
        this.graphicalModel = this.initGraphicalModel();
    }

    public MSAMarginalLikelihoodCalculator(MSAMarginalLikelihoodCalculator model) {
        this(model.params, model.rootedTree);
    }

    public static MSAMarginalLikelihoodCalculator copyWithNewTree(MSAMarginalLikelihoodCalculator model, RootedTree newTree) {
        return new MSAMarginalLikelihoodCalculator(model.params, newTree);
    }

    public static MSAMarginalLikelihoodCalculator copyWithNewParams(MSAMarginalLikelihoodCalculator model, PoissonParameters newParams) {
        return new MSAMarginalLikelihoodCalculator(newParams, model.rootedTree);
    }

    public double marginalLogLikelihood(MSAPoset msa) {
        return this._subDel_marginalLogLikelihood(msa) + this._insert_marginalLogLikelihood(msa);
    }

    public double _subDel_marginalLogLikelihood(MSAPoset msa) {
        if (!msa.sequences().keySet().equals(CollUtils.set(this.rootedTree.topology().leaveContents()))) {
            throw new RuntimeException("Symm diff=" + CollUtils.symmetricDifference(msa.sequences().keySet(), CollUtils.set(this.rootedTree.topology().leaveContents())));
        }
        double sum = 0.0;
        for (MSAPoset.Column c : msa.columns()) {
            sum += this.columnLogLikelihood(this.convertToIndices(msa, c));
        }
        return sum;
    }

    public double _insert_marginalLogLikelihood(int nColumns) {
        double result = PoissonUtils.logPhi(this.emptyColLL(), nColumns, this.params.insertRate);
        if (result > 0.0) {
            throw new RuntimeException();
        }
        return result;
    }

    public double _insert_marginalLogLikelihood(MSAPoset msa) {
        return this._insert_marginalLogLikelihood(msa.columns().size());
    }

    public double mergerLogLikelihoodRatio(MSAPoset originalMSA, MSAPoset.Column toMerge1, MSAPoset.Column toMerge2) {
        int originalSize = originalMSA.columns().size();
        double sum = PoissonUtils.logPhi(this.emptyColLL(), originalSize + 1, this.params.insertRate) - PoissonUtils.logPhi(this.emptyColLL(), originalSize, this.params.insertRate);
        Map<Taxon, Integer> convertedCol1 = this.convertToIndices(originalMSA, toMerge1);
        Map<Taxon, Integer> convertedCol2 = this.convertToIndices(originalMSA, toMerge2);
        HashMap<Taxon, Integer> merged = CollUtils.map();
        if (CollUtils.intersects(convertedCol1.keySet(), convertedCol2.keySet())) {
            throw new RuntimeException();
        }
        merged.putAll(convertedCol1);
        merged.putAll(convertedCol2);
        return sum += this.columnLogLikelihood(merged) - this.columnLogLikelihood(convertedCol1) - this.columnLogLikelihood(convertedCol2);
    }

    public double splitLogLikelihoodRatio(MSAPoset originalMSA, MSAPoset.Column toSplit, Set<Taxon> elts) {
        if (!MSAPoset.isValidSplit(toSplit, elts)) {
            throw new RuntimeException();
        }
        int originalSize = originalMSA.columns().size();
        double sum = PoissonUtils.logPhi(this.emptyColLL(), originalSize - 1, this.params.insertRate) - PoissonUtils.logPhi(this.emptyColLL(), originalSize, this.params.insertRate);
        Map<Taxon, Integer> original = this.convertToIndices(originalMSA, toSplit);
        HashMap<Taxon, Integer> split1 = CollUtils.map();
        HashMap<Taxon, Integer> split2 = CollUtils.map();
        for (Taxon key : original.keySet()) {
            (elts.contains(key) ? split1 : split2).put(key, original.get(key));
        }
        return sum += -this.columnLogLikelihood(original) + this.columnLogLikelihood(split1) + this.columnLogLikelihood(split2);
    }

    private double emptyColLL() {
        if (Double.isNaN(this._emptyColLL)) {
            this._emptyColLL = this.columnLogLikelihood(EMPTY_COLUMN);
        }
        return this._emptyColLL;
    }

    public static double meanSequenceLength(Collection<String> sequences) {
        SummaryStatistics stats = new SummaryStatistics();
        for (String str : sequences) {
            stats.addValue((double)str.length());
        }
        return stats.getMean();
    }

    public static void main(String[] args) {
        final double[][] subRates = RateMatrixLoader.dayhoff();
        final Indexer<Character> indexer = RateMatrixLoader.proteinIndexer();
        BalibaseCorpus.BalibaseCorpusOptions baliopt = new BalibaseCorpus.BalibaseCorpusOptions();
        baliopt.referenceAlignmentsPath.clear();
        for (String arg : args) {
            baliopt.referenceAlignmentsPath.add(arg);
        }
        final BalibaseCorpus bc = new BalibaseCorpus(baliopt);
        final HashSet toSkip = CollUtils.set();
        final ArrayList<CognateId> cognates = CollUtils.list(bc.intersectedIds());
        System.out.println("#Ins\tDel\tLL");
        double max = Double.NEGATIVE_INFINITY;
        double bestInsert = -1.0;
        double bestDel = -1.0;
        for (double insertRate = 0.01; insertRate < 1000.0; insertRate *= 2.0) {
            for (double delRate = 0.01; delRate < 1000.0; delRate *= 2.0) {
                final double _insertRate = insertRate;
                final double _delRate = delRate;
                final double[] sums = new double[cognates.size()];
                Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(8);
                parallelizer.setPrimaryThread();
                parallelizer.process(CollUtils.ints(cognates.size()), new Parallelizer.Processor<Integer>(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void process(Integer x, int _i, int _n, boolean log) {
                        CognateId id = (CognateId)cognates.get(x);
                        if (toSkip.contains(id)) {
                            return;
                        }
                        try {
                            MultiAlignment _ma = bc.getMultiAlignment(id);
                            MSAPoset msa = MSAPoset.fromMultiAlignmentObject(_ma);
                            RootedTree rt = RootedTree.Util.fromBalibase(bc, id);
                            MSAMarginalLikelihoodCalculator calc = new MSAMarginalLikelihoodCalculator(new PoissonParameters(indexer, subRates, _insertRate * MSAMarginalLikelihoodCalculator.meanSequenceLength(msa.sequences().values()), _delRate), rt);
                            sums[x.intValue()] = calc.marginalLogLikelihood(msa);
                        }
                        catch (NoSuchElementException nse) {
                            System.err.println("Skipping:" + id + " because of unknown character");
                            Set set = toSkip;
                            synchronized (set) {
                                toSkip.add(id);
                            }
                        }
                    }
                });
                double cur = MathUtils.sum(sums);
                System.out.println("" + insertRate + "\t" + delRate + "\t" + cur);
                if (!(cur > max)) continue;
                max = cur;
                bestDel = _delRate;
                bestInsert = _insertRate;
            }
        }
        System.out.println("#BEST:" + bestInsert + "\t" + bestDel + "\t" + max);
    }

    public Map<Taxon, Integer> convertToIndices(MSAPoset msa, MSAPoset.Column c) {
        return this.convertToIndices(msa.sequences(), c.getPoints());
    }

    public Map<Taxon, Integer> convertToIndices(Map<Taxon, String> sequences, Map<Taxon, Integer> points) {
        return MSAMarginalLikelihoodCalculator.convertToIndices(sequences, points, this.params.indexer);
    }

    public static Map<Taxon, Integer> convertToIndices(Map<Taxon, String> sequences, Map<Taxon, Integer> points, Indexer<Character> indexer) {
        HashMap<Taxon, Integer> result = CollUtils.map();
        for (Taxon lang : points.keySet()) {
            int position = points.get(lang);
            char currentChar = sequences.get(lang).charAt(position);
            Integer idx = MSAMarginalLikelihoodCalculator.convertToIndex(indexer, currentChar);
            result.put(lang, idx);
        }
        return result;
    }

    public static Integer convertToIndex(Indexer<Character> indexer, char c) {
        if (indexer.containsObject(Character.valueOf(c))) {
            return indexer.o2i(Character.valueOf(c));
        }
        return null;
    }

    private TabularGMFct<Taxon> initGraphicalModel() {
        HashGraph<Taxon> graph = new HashGraph<Taxon>(Arbre.arbre2Tree(this.rootedTree.topology()));
        TabularGMFct<Taxon> result = GMFctUtils.ones(graph, CollUtils.cnstMap(graph.vertexSet(), this.params.numberOfCharacterPlusGap));
        for (Arbre<Taxon> node : this.rootedTree.topology().nodes()) {
            if (node.isRoot()) continue;
            Taxon curTaxon = node.getContents();
            Taxon parTaxon = node.getParent().getContents();
            double currentBL = this.rootedTree.branchLengths().get(curTaxon);
            double[][] marginalTrans = RateMtxUtils.marginalTransitionMtx(this.params.Q, currentBL);
            for (int s1 = 0; s1 < this.params.numberOfCharacterPlusGap; ++s1) {
                for (int s2 = 0; s2 < this.params.numberOfCharacterPlusGap; ++s2) {
                    result.set(parTaxon, curTaxon, s1, s2, marginalTrans[s1][s2]);
                }
            }
        }
        return result;
    }

    private Map<Taxon, Double> branchLogWeights() {
        ArrayList langs = CollUtils.list();
        for (Arbre<Taxon> subt : this.rootedTree.topology().nodes()) {
            if (subt.isRoot()) continue;
            langs.add(subt.getContents());
        }
        double[] bls = new double[langs.size()];
        for (int i = 0; i < langs.size(); ++i) {
            bls[i] = this.rootedTree.branchLengths().get(langs.get(i));
        }
        double[] weights = IntegratedLengthMarginalComputations.branchWeights(bls, this.params.Q, this.params.quasiStatProbs);
        double sum = MathUtils.sum(weights);
        if (sum <= 0.0 || sum >= 1.0) {
            throw new RuntimeException();
        }
        HashMap<Taxon, Double> result = CollUtils.map();
        for (int i = 0; i < langs.size(); ++i) {
            result.put((Taxon)langs.get(i), Math.log(weights[i]));
        }
        result.put(this.rootedTree.topology().getContents(), Math.log(weights[weights.length - 1]));
        return result;
    }

    public double columnLogLikelihood(Map<Taxon, Integer> alignmentColumn) {
        for (Taxon leaf : this.rootedTree.topology().leaveContents()) {
            if (alignmentColumn.keySet().contains(leaf)) {
                this.setObservation(leaf, alignmentColumn.get(leaf));
                continue;
            }
            this.setGap(leaf);
        }
        TreeSumProd<Taxon> tsp = new TreeSumProd<Taxon>(this.graphicalModel);
        double logSum = Double.NEGATIVE_INFINITY;
        for (Arbre<Taxon> node : this.rootedTree.topology().nodes()) {
            if (!this.descMap.get(node).containsAll(alignmentColumn.keySet())) continue;
            logSum = SloppyMath.logAdd(logSum, this.branchLogWeights.get(node.getContents()) + this.subtreeLogLikelihood(tsp, node));
        }
        if (logSum > 0.0) {
            throw new RuntimeException();
        }
        return logSum;
    }

    private double subtreeLogLikelihood(TreeSumProd<Taxon> tsp, Arbre<Taxon> subt) {
        double logSum = Double.NEGATIVE_INFINITY;
        for (int state = 0; state < this.params.numberOfCharacter; ++state) {
            logSum = SloppyMath.logAdd(logSum, this.params.quasiStatLogProbabilities[state] + (subt.isRoot() ? tsp.logZ((Taxon)((Comparable)subt.getContents()), state) : tsp.logMessage(state, (Taxon)((Comparable)subt.getContents()), (Taxon)((Comparable)subt.getParent().getContents()))));
        }
        return logSum;
    }

    private void setGap(Taxon leaf) {
        for (int idx = 0; idx < this.params.numberOfCharacter; ++idx) {
            this.graphicalModel.set(leaf, idx, 0.0);
        }
        this.graphicalModel.set(leaf, this.params.gapIndex, 1.0);
    }

    private void setObservation(Taxon leaf, Integer observedValue) {
        for (int idx = 0; idx < this.params.numberOfCharacter; ++idx) {
            this.graphicalModel.set(leaf, idx, observedValue == null || idx == observedValue ? 1.0 : 0.0);
        }
        this.graphicalModel.set(leaf, this.params.gapIndex, 0.0);
    }
}

