/*
 * Decompiled with CFR 0.152.
 */
package conifer.msa;

import conifer.msa.LinearizationProposal;
import conifer.msa.MSAUtils;
import conifer.pip.LinearizedAlignment;
import conifer.pip.PIPForwardSampler;
import conifer.pip.PIPLikelihoodCalculator;
import ev.poi.PoissonParameters;
import fig.basic.Pair;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.RateMatrixLoader;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import nuts.util.Counter;
import org.junit.Test;
import pty.RandomRootedTrees;
import pty.RootedTree;

public class MSAProposal {
    public static final Object DECISION_KEY = "DECISION_KEY";

    public static Pair<Taxon, Integer> setAnchorEndPoint(MSAPoset msa, Pair<Taxon, Integer> anchor, Pair<Taxon, Integer> newEndPoint) {
        Taxon t = anchor.getFirst();
        int p = anchor.getSecond();
        MSAPoset.Column curC = msa.column(t, p);
        Pair<Taxon, Integer> oldEndPt = MSAProposal.getEndPoint(msa, anchor);
        if (oldEndPt != null) {
            msa.split(curC, Collections.singleton(t));
        }
        if (newEndPoint != null && !msa.tryAdding(new GreedyDecoder.Edge(p, newEndPoint.getSecond(), t, newEndPoint.getFirst()))) {
            throw new RuntimeException();
        }
        return oldEndPt;
    }

    public static Pair<Taxon, Integer> getEndPoint(MSAPoset msa, Pair<Taxon, Integer> anchor) {
        int p;
        Taxon t = anchor.getFirst();
        MSAPoset.Column curC = msa.column(t, p = anchor.getSecond().intValue());
        if (curC.getPoints().size() == 1) {
            return null;
        }
        Taxon other = null;
        Map<Taxon, Integer> map = curC.getPoints();
        for (Taxon tIter : map.keySet()) {
            if (tIter.equals(t)) continue;
            other = tIter;
        }
        if (other == null) {
            throw new RuntimeException();
        }
        return Pair.makePair(other, map.get(other));
    }

    public static Pair<List<Pair<Taxon, Integer>>, Integer> listPotentialLinks(MSAPoset msa, Pair<Taxon, Integer> anchor) {
        Taxon t = anchor.getFirst();
        int p = anchor.getSecond();
        Pair<Taxon, Integer> initialEndPoint = MSAProposal.setAnchorEndPoint(msa, anchor, null);
        MSAPoset.Column curC = msa.column(t, p);
        ArrayList result = CollUtils.list();
        int originalIdx = -1;
        int curIdx = 0;
        for (MSAPoset.Column c : msa.columns()) {
            if (c == curC) continue;
            Map<Taxon, Integer> cPoints = c.getPoints();
            Taxon t2 = CollUtils.pick(cPoints.keySet());
            int p2 = cPoints.get(t2);
            Pair<Taxon, Integer> candidate = Pair.makePair(t2, p2);
            GreedyDecoder.Edge candidateEdge = new GreedyDecoder.Edge(p, p2, t, t2);
            if (initialEndPoint != null && cPoints.containsKey(initialEndPoint.getFirst()) && cPoints.get(initialEndPoint.getFirst()).intValue() == initialEndPoint.getSecond().intValue()) {
                if (!msa.isValidAddition(candidateEdge)) {
                    throw new RuntimeException();
                }
                originalIdx = curIdx;
            }
            if (!msa.isValidAddition(candidateEdge)) continue;
            ++curIdx;
            result.add(candidate);
        }
        if (initialEndPoint == null) {
            originalIdx = result.size();
        } else if (originalIdx == -1) {
            throw new RuntimeException();
        }
        result.add(null);
        MSAProposal.setAnchorEndPoint(msa, anchor, initialEndPoint);
        return Pair.makePair(result, originalIdx);
    }

    public static Pair<Map<GreedyDecoder.Edge, Boolean>, MSAPoset> propose(Map<Taxon, String> seqns, EdgeInclusionProbabilityFunction edgeInclPrs, Random rand) {
        MSAPoset msa = new MSAPoset(seqns);
        HashMap<GreedyDecoder.Edge, Boolean> decisions = new HashMap<GreedyDecoder.Edge, Boolean>();
        for (GreedyDecoder.Edge e : edgeInclPrs.orderedEdges()) {
            boolean canAdd = msa.isValidAddition(e);
            Boolean decision = null;
            if (!canAdd) {
                decision = null;
            } else {
                decision = Sampling.sampleBern(edgeInclPrs.edgeInclusionPr(e), rand);
                if (decision.booleanValue() && !msa.tryAdding(e)) {
                    throw new RuntimeException();
                }
            }
            decisions.put(e, decision);
        }
        return Pair.makePair(decisions, msa);
    }

    public static double logProposalPr(Map<GreedyDecoder.Edge, Boolean> decisions, EdgeInclusionProbabilityFunction edgeInclPrs) {
        double logPr = 0.0;
        for (GreedyDecoder.Edge e : decisions.keySet()) {
            Boolean decision = decisions.get(e);
            if (decision == null) continue;
            logPr += Math.log(decision != false ? edgeInclPrs.edgeInclusionPr(e) : 1.0 - edgeInclPrs.edgeInclusionPr(e));
        }
        return logPr;
    }

    public static Pair<LinearizedAlignment, Double> sample_logRatio_independenceMSAProposal(EdgeInclusionProbabilityFunction edgeInclPrs, Random rand, LinearizedAlignment lia, boolean initializing) {
        Pair<Map<GreedyDecoder.Edge, Boolean>, MSAPoset> proposedEdges = MSAProposal.propose(lia.getMsa().sequences(), edgeInclPrs, rand);
        double logFwdEdgePr = MSAProposal.logProposalPr(proposedEdges.getFirst(), edgeInclPrs);
        Map initialStateDecision = (Map)lia.cache.get(DECISION_KEY);
        double logBwdEdgePr = Double.NaN;
        if (initialStateDecision == null) {
            if (!initializing) {
                throw new RuntimeException();
            }
            logBwdEdgePr = 0.0;
        } else {
            logBwdEdgePr = MSAProposal.logProposalPr(initialStateDecision, edgeInclPrs);
        }
        double logBwdLinPr = LinearizationProposal.logPrLinearization(lia.getMsa().getPoset(), lia.getColumns());
        Pair<List<MSAPoset.Column>, Double> proposedLin = LinearizationProposal.sample_logPrLinearization(proposedEdges.getSecond().getPoset(), rand);
        double logFwdLinPr = proposedLin.getSecond();
        double logRatio = logBwdEdgePr + logBwdLinPr - logFwdEdgePr - logFwdLinPr;
        LinearizedAlignment proposedLIA = new LinearizedAlignment(proposedEdges.getSecond(), proposedLin.getFirst());
        proposedLIA.cache.put(DECISION_KEY, proposedEdges.getFirst());
        return Pair.makePair(proposedLIA, logRatio);
    }

    @Test
    public void testInvariantMeasure() {
        double num = 0.0;
        double denom = 0.0;
        Random rand = new Random(1L);
        for (int i = 0; i < 10000; ++i) {
            PoissonParameters pip = PoissonParameters.createFromAdditiveLengthIntensityParameterization(RateMatrixLoader.rnaIndexer(), RateMatrixLoader.k2p(), 25.0, 2.0);
            RootedTree rt = RandomRootedTrees.sampleCoalescent(rand, 10, 1.0);
            PIPForwardSampler pipf = new PIPForwardSampler(pip, rt);
            MSAPoset msa = pipf.sampleMSA(rand);
            System.out.println(msa);
            PIPLikelihoodCalculator calc2 = new PIPLikelihoodCalculator(pip, new LinearizedAlignment(msa), rt);
            System.out.println("logLL = " + calc2.computeDataLogProbabilityGivenTree());
            System.out.println("\n\n === \n\n");
            Counter<GreedyDecoder.Edge> edges = MSAUtils.loadBasicRNAAligner().allPairsPosterior(msa.sequences());
            MSAPoset maxRecall = MSAPoset.maxRecallMSA(msa.sequences(), edges);
            System.out.println("maxRecall F1 = " + MSAPoset.edgeF1(msa, maxRecall));
            LinearizedAlignment lia = new LinearizedAlignment(new MSAPoset(msa.sequences()));
            boolean initialized = false;
            for (int k = 0; k < 1000; ++k) {
                PIPLikelihoodCalculator calc = new PIPLikelihoodCalculator(pip, lia, rt);
                double logll = calc.computeDataLogProbabilityGivenTree();
                int expMag = rand.nextInt(4) - 2;
                double exponent = Math.pow(10.0, expMag);
                System.out.println("exponent = " + exponent);
                EdgeInclusionProbabilityFunction eipf = new EdgeInclusionProbabilityFunction(exponent, edges);
                Pair<LinearizedAlignment, Double> propPair = MSAProposal.sample_logRatio_independenceMSAProposal(eipf, rand, lia, !initialized);
                calc = new PIPLikelihoodCalculator(pip, propPair.getFirst(), rt);
                double newLogll = calc.computeDataLogProbabilityGivenTree();
                double logRatio = newLogll - logll + propPair.getSecond();
                double acceptPr = Sampling.min1exp(logRatio);
                System.out.println("logRatio = " + logRatio + " = " + newLogll + " - " + logll + " + " + propPair.getSecond());
                System.out.println("acceptPr = " + acceptPr);
                System.out.println();
                boolean accept = Sampling.sampleBern(acceptPr, rand);
                if (accept) {
                    initialized = true;
                    lia = propPair.getFirst();
                    System.out.println("Change accepted");
                    System.out.println(lia.getMsa());
                    System.out.println("F1 = " + MSAPoset.edgeF1(msa, lia.getMsa()));
                }
                System.out.println("\n\n --- \n\n");
            }
            for (Taxon t : msa.sequences().keySet()) {
                num += (double)lia.getMsa().sequences().get(t).length();
                denom += 1.0;
            }
            if (i % 10 != 0) continue;
            System.out.println("current avg = " + num / denom);
        }
    }

    public static class EdgeInclusionProbabilityFunction {
        private final double exponent;
        private final Counter<GreedyDecoder.Edge> basePrs;

        public EdgeInclusionProbabilityFunction(double exponent, Counter<GreedyDecoder.Edge> basePrs) {
            this.exponent = exponent;
            this.basePrs = basePrs;
        }

        public Iterable<GreedyDecoder.Edge> orderedEdges() {
            return this.basePrs;
        }

        public double edgeInclusionPr(GreedyDecoder.Edge e) {
            return Math.pow(this.basePrs.getCount(e), this.exponent);
        }
    }
}

