/*
 * Decompiled with CFR 0.152.
 */
package ev.poi;

import ev.poi.MSAMarginalLikelihoodCalculator;
import fig.basic.Pair;
import goblin.CognateId;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import ma.BalibaseCorpus;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.MultiAlignment;
import ma.RateMatrixLoader;
import nuts.util.CollUtils;
import nuts.util.Indexer;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.RootedTree;

public final class HomologySampler {
    public final SummaryStatistics hasEdgeStats = new SummaryStatistics();
    public final SummaryStatistics invalidAddStats = new SummaryStatistics();
    public final SummaryStatistics splitRatioStats = new SummaryStatistics();
    public final SummaryStatistics mergeRatioStats = new SummaryStatistics();
    public double relativeLogLikelihood = 0.0;
    public static final double LOG2 = Math.log(2.0);

    public void homologySamplingStep(GreedyDecoder.Edge e, MSAPoset msa, MSAMarginalLikelihoodCalculator calculator, Random rand) {
        boolean hasEdge = msa.containsEdge(e);
        this.hasEdgeStats.addValue(hasEdge ? 1.0 : 0.0);
        if (!hasEdge) {
            boolean validAdd = msa.isValidAddition(e);
            this.invalidAddStats.addValue(validAdd ? 0.0 : 1.0);
            if (!validAdd) {
                return;
            }
        }
        HashSet<Taxon> split = null;
        if (hasEdge) {
            split = CollUtils.set();
            ArrayList<Taxon> available = CollUtils.list(msa.column(e.lang1(), e.index1()).getPoints().keySet());
            Collections.shuffle(available, rand);
            double nToPick = 1 + (available.size() > 2 ? rand.nextInt(available.size() - 2) : 0);
            int i = 0;
            while ((double)i < nToPick) {
                split.add((Taxon)available.get(i));
                ++i;
            }
        }
        MSAPoset.Column c1 = msa.column(e.lang1(), e.index1());
        MSAPoset.Column c2 = msa.column(e.lang2(), e.index2());
        double logLikelihoodRatio = hasEdge ? calculator.splitLogLikelihoodRatio(msa, c1, split) : calculator.mergerLogLikelihoodRatio(msa, c1, c2);
        int setSize = hasEdge ? c1.getPoints().size() : c1.getPoints().size() + c2.getPoints().size();
        double ratio = Math.min(1.0, Math.exp(logLikelihoodRatio += LOG2 * (double)(hasEdge ? -1 : 1) * (double)(setSize - 2)));
        (hasEdge ? this.splitRatioStats : this.mergeRatioStats).addValue(ratio);
        if (rand.nextDouble() < ratio) {
            this.relativeLogLikelihood += logLikelihoodRatio;
            if (hasEdge) {
                msa.split(c1, split);
            } else {
                boolean result = msa.tryAdding(e);
                if (!result) {
                    throw new RuntimeException();
                }
            }
        }
    }

    public static void main(String[] args) {
        Random rand = new Random(1L);
        BalibaseCorpus.BalibaseCorpusOptions baliopt = new BalibaseCorpus.BalibaseCorpusOptions();
        baliopt.referenceAlignmentsPath.clear();
        File path = new File("/Users/bouchard/w/evolvere/data/BAliBASE/ref1/test1/");
        for (String arg : args) {
            baliopt.referenceAlignmentsPath.add(arg);
        }
        BalibaseCorpus bc = new BalibaseCorpus(baliopt);
        CognateId id = CollUtils.pick(bc.intersectedIds());
        double[][] subRates = RateMatrixLoader.dayhoff();
        Indexer<Character> indexer = RateMatrixLoader.proteinIndexer();
        double _insertRate = 1.0;
        double _delRate = 1.0;
        MultiAlignment _ma = bc.getMultiAlignment(id);
        MSAPoset msa = MSAPoset.fromMultiAlignmentObject(_ma);
        RootedTree rt = RootedTree.Util.fromBalibase(bc, id);
        MSAMarginalLikelihoodCalculator calc = null;
        UniformEdgeSelector reg = new UniformEdgeSelector(msa.sequences());
        HomologySampler sampler = new HomologySampler();
        System.out.println("#iter\tscore\tLL\thasEdgeMean\tinvalidAddMean\tsplitRatioMean\tmergeRatioMean");
        for (int i = 0; i < 500000000; ++i) {
            sampler.homologySamplingStep(reg.next(rand), msa, calc, rand);
            if (i % 5000 != 0) continue;
            System.out.println(i + "\t" + _ma.sumOfPairsScore(msa.toMultiAlignmentObject()) + "\t" + sampler.relativeLogLikelihood + "\t" + sampler.hasEdgeStats.getMean() + "\t" + sampler.invalidAddStats.getMean() + "\t" + sampler.splitRatioStats.getMean() + "\t" + sampler.mergeRatioStats.getMean());
        }
    }

    public static class UniformEdgeSelector {
        private final List<Taxon> langs;
        private final Map<Taxon, String> seqs;

        public UniformEdgeSelector(Map<Taxon, String> sequences) {
            this.langs = CollUtils.list(sequences.keySet());
            this.seqs = sequences;
        }

        public GreedyDecoder.Edge next(Random rand) {
            Taxon l1;
            Taxon l2 = l1 = this.langs.get(rand.nextInt(this.langs.size()));
            while (l2 == l1) {
                l2 = this.langs.get(rand.nextInt(this.langs.size()));
            }
            int i1 = rand.nextInt(this.seqs.get(l1).length());
            int i2 = rand.nextInt(this.seqs.get(l2).length());
            return new GreedyDecoder.Edge(i1, i2, l1, l2);
        }
    }

    public static class InformedEdgeSelector {
        private final UniformEdgeSelector base;
        private final Map<Pair<Taxon, Taxon>, double[][]> distributions;

        public InformedEdgeSelector(Map<Pair<Taxon, Taxon>, double[][]> distributions, Map<Taxon, String> sequences) {
            this.base = new UniformEdgeSelector(sequences);
            this.distributions = distributions;
        }

        public GreedyDecoder.Edge next(Random rand) {
            GreedyDecoder.Edge fromUnif = this.base.next(rand);
            double[] prs = this.distributions.get(Pair.makePair(fromUnif.lang1(), fromUnif.lang2()))[fromUnif.index1()];
            int newIndex = rand.nextInt(prs.length);
            return new GreedyDecoder.Edge(fromUnif.index1(), newIndex, fromUnif.lang1(), fromUnif.lang2());
        }
    }
}

