/*
 * Decompiled with CFR 0.152.
 */
package conifer.msa;

import conifer.Phylogeny;
import conifer.msa.InformedProposal;
import conifer.msa.InformedProposals;
import conifer.msa.LinearizationProposal;
import conifer.msa.MSAProposal;
import conifer.msa.MSAUtils;
import conifer.msa.TreeMSAParameters;
import conifer.msa.UndoFuture;
import conifer.particle.PhyloParticle;
import conifer.particle.PhyloParticleInitContext;
import conifer.pip.LinearizedAlignment;
import conifer.proposals.ProposalOptions;
import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import ma.MSAPoset;
import monaco.Density;
import monaco.mcmc.ParallelTemperedMCMC;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import nuts.util.Counter;
import pty.RootedTree;
import pty.smc.models.LikelihoodModelCalculator;

public class TreeMSAState
implements ParallelTemperedMCMC.InPlaceProposedState,
PhyloParticle {
    private UndoFuture currentUndo = null;
    public MSAPoset msa;
    List<MSAPoset.Column> linearization;
    double lastAcceptedLinearizationLogPr = Double.POSITIVE_INFINITY;
    private double currentProposedLinearizationLogPr = Double.NaN;
    private RootedTree tree;
    public TreeMSAParameters params;
    private int genIndex = 0;
    private Density<PhyloParticle> treePrior;
    private ProposalOptions propOpts;
    private InformedProposals _informedProp = null;

    public TreeMSAState(MSAPoset msa, RootedTree tree, TreeMSAParameters params) {
        this.tree = tree;
        this.msa = msa;
        this.linearization = msa.linearizedColumns();
        this.params = params;
    }

    public TreeMSAState() {
    }

    @Override
    public void init(List<Pair<Taxon, LikelihoodModelCalculator>> calculators, Density<PhyloParticle> prior, PhyloParticleInitContext context) {
        this.treePrior = prior;
        this.tree = context.getTree().getRooted();
        this.msa = context.getMSA();
        this.linearization = this.msa.linearizedColumns();
        this.params = (TreeMSAParameters)((Object)context.getParameters());
        this.propOpts = context.proposalOptions;
    }

    @Override
    public void wasProposalAccepted(boolean accept) {
        if (this.currentUndo == null) {
            throw new RuntimeException();
        }
        this.currentUndo.wasProposalAccepted(accept, this);
        this.currentUndo = null;
        ++this.genIndex;
    }

    private void setAlignmentEndPoint(final Pair<Taxon, Integer> anchor, Pair<Taxon, Integer> proposedEndPoint) {
        final List<Pair<Taxon, Integer>> savedLin = TreeMSAState.saveLinearization(this.msa, this.linearization);
        final Pair<Taxon, Integer> oldEndPt = MSAProposal.setAnchorEndPoint(this.msa, anchor, proposedEndPoint);
        UndoFuture future = new UndoFuture(){

            @Override
            public void wasProposalAccepted(boolean accept, TreeMSAState state) {
                if (!accept) {
                    MSAProposal.setAnchorEndPoint(TreeMSAState.this.msa, anchor, oldEndPt);
                    TreeMSAState.this.linearization = TreeMSAState.restoreLinearization(TreeMSAState.this.msa, savedLin);
                }
                if (accept) {
                    TreeMSAState.this.lastAcceptedLinearizationLogPr = TreeMSAState.this.currentProposedLinearizationLogPr;
                }
                TreeMSAState.this.currentProposedLinearizationLogPr = Double.NaN;
            }
        };
        this.setUndoFuture(future);
    }

    public void inPlaceParamChange(TreeMSAParameters newParams) {
        final TreeMSAParameters original = this.params;
        this.params = newParams;
        UndoFuture future = new UndoFuture(){

            @Override
            public void wasProposalAccepted(boolean accept, TreeMSAState state) {
                if (!accept) {
                    TreeMSAState.this.params = original;
                }
            }
        };
        this.setUndoFuture(future);
    }

    public void inPlaceTreeChange(RootedTree newTree) {
        final RootedTree.Util.RootedTreeImpl original = new RootedTree.Util.RootedTreeImpl(this.tree.topology().copy(), new HashMap<Taxon, Double>(this.tree.branchLengths()));
        this.tree = newTree;
        UndoFuture future = new UndoFuture(){

            @Override
            public void wasProposalAccepted(boolean accept, TreeMSAState state) {
                if (!accept) {
                    TreeMSAState.this.tree = original;
                }
            }
        };
        this.setUndoFuture(future);
    }

    public static List<Pair<Taxon, Integer>> saveLinearization(MSAPoset msa, List<MSAPoset.Column> linearization) {
        ArrayList<Pair<Taxon, Integer>> result = CollUtils.list();
        for (MSAPoset.Column c : linearization) {
            Taxon aTax = CollUtils.pick(c.getPoints().keySet());
            result.add(Pair.makePair(aTax, c.getPoints().get(aTax)));
        }
        return result;
    }

    public static List<MSAPoset.Column> restoreLinearization(MSAPoset msa, List<Pair<Taxon, Integer>> saved) {
        ArrayList<MSAPoset.Column> result = CollUtils.list();
        for (Pair<Taxon, Integer> item : saved) {
            result.add(msa.column(item.getFirst(), item.getSecond()));
        }
        return result;
    }

    private InformedProposals getInformedProp() {
        if (this._informedProp == null) {
            this._informedProp = InformedProposals.getSingleton(this.propOpts.informedProposalThreshold, this.msa.sequences(), MSAUtils.loadBasicRNAAligner());
        }
        return this._informedProp;
    }

    public double proposeInformedInPlaceLocalMSAMove(Random rand) {
        double logFwdLinPr;
        List<Taxon> taxa = this.msa.taxa();
        List<Integer> taxaIntegers = Sampling.sampleWithoutReplacement(rand, taxa.size(), 2);
        Pair<Taxon, Taxon> taxaPair = Pair.makePair(taxa.get(taxaIntegers.get(0)), taxa.get(taxaIntegers.get(1)));
        InformedProposal ip = this.getInformedProp().proposals.get(taxaPair);
        Taxon anchorTaxon = taxaPair.getFirst();
        int anchorPosition = -1;
        anchorPosition = rand.nextBoolean() || ip.ambiguousPositions.size() == 0 ? rand.nextInt(this.msa.sequences().get(anchorTaxon).length()) : ip.ambiguousPositions.get(rand.nextInt(ip.ambiguousPositions.size())).intValue();
        Pair<Taxon, Integer> anchor = Pair.makePair(anchorTaxon, anchorPosition);
        Pair<List<Pair<Taxon, Integer>>, Integer> listPair = MSAProposal.listPotentialLinks(this.msa, anchor);
        List<Pair<Taxon, Integer>> potentialLinks = listPair.getFirst();
        double[] prs = new double[potentialLinks.size()];
        for (int i = 0; i < prs.length; ++i) {
            prs[i] = 0.5 * (1.0 / (double)prs.length);
        }
        Counter<Integer> current = ip.probabilities.getCounter(anchorPosition);
        for (int i = 0; i < prs.length; ++i) {
            Pair<Taxon, Integer> currentLink = potentialLinks.get(i);
            if (currentLink == null) {
                if (i != prs.length - 1) {
                    throw new RuntimeException();
                }
                int n = i;
                prs[n] = prs[n] + current.getCount(-1);
                continue;
            }
            int n = i;
            prs[n] = prs[n] + 0.5 * current.getCount(currentLink.getSecond());
        }
        NumUtils.normalize(prs);
        int sampledEdgeIdx = SampleUtils.sampleMultinomial(rand, prs);
        this.setAlignmentEndPoint(anchor, potentialLinks.get(sampledEdgeIdx));
        double edgeLogRatio = Math.log(prs[listPair.getSecond()]) - Math.log(prs[sampledEdgeIdx]);
        double logBwdLinPr = this.lastAcceptedLinearizationLogPr;
        Pair<List<MSAPoset.Column>, Double> proposedLin = LinearizationProposal.sample_logPrLinearization(this.msa.getPoset(), rand);
        this.linearization = proposedLin.getFirst();
        this.currentProposedLinearizationLogPr = logFwdLinPr = proposedLin.getSecond().doubleValue();
        return edgeLogRatio + logBwdLinPr - logFwdLinPr;
    }

    public double proposeInPlaceLocalMSAMove(Random rand, Pair<Taxon, Integer> anchor) {
        double logFwdLinPr;
        List<Pair<Taxon, Integer>> potentialLinks = MSAProposal.listPotentialLinks(this.msa, anchor).getFirst();
        double[] prs = new double[potentialLinks.size()];
        for (int i = 0; i < prs.length - 1; ++i) {
            prs[i] = 0.5 * (1.0 / (double)(prs.length - 1));
        }
        prs[prs.length - 1] = prs.length == 1 ? 1.0 : 0.5;
        int sampledEdgeIdx = SampleUtils.sampleMultinomial(rand, prs);
        this.setAlignmentEndPoint(anchor, potentialLinks.get(sampledEdgeIdx));
        double logBwdLinPr = this.lastAcceptedLinearizationLogPr;
        Pair<List<MSAPoset.Column>, Double> proposedLin = LinearizationProposal.sample_logPrLinearization(this.msa.getPoset(), rand);
        this.linearization = proposedLin.getFirst();
        this.currentProposedLinearizationLogPr = logFwdLinPr = proposedLin.getSecond().doubleValue();
        return logBwdLinPr - logFwdLinPr;
    }

    private void setUndoFuture(UndoFuture future) {
        if (this.currentUndo != null) {
            throw new RuntimeException();
        }
        this.currentUndo = future;
    }

    @Override
    public int generationIndex() {
        return this.genIndex;
    }

    @Override
    public Phylogeny getPhylogeny() {
        return this.tree;
    }

    @Override
    public double getLogLikelihood() {
        return this.params.getLogLikelihood(new LinearizedAlignment(this.msa, this.linearization), this.tree);
    }

    @Override
    public double getLogPrior() {
        return this.treePrior.logDensity(this);
    }
}

