/*
 * Decompiled with CFR 0.152.
 */
package pty.mcmc;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import fig.exec.Execution;
import goblin.Taxon;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import nuts.util.EasyFormat;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.UnrootedTree;
import pty.eval.Purity;
import pty.mcmc.ProposalDistribution;
import pty.mcmc.UnrootedTreeState;

public class PhyloSampler {
    private ProposalDistribution.Options proposalOptions = ProposalDistribution.Util._defaultProposalDistributionOptions;
    private Options phyloSamplerOptions = _defaultPhyloSamplerOptions;
    private PriorOptions priorOptions = _defaultPriorOptions;
    private double temperature = Double.NaN;
    private boolean initialized = false;
    private UnrootedTreeState initialState = null;
    private UnrootedTreeState currentState = null;
    private NonClockTreePrior prior = null;
    private int iteration = 0;
    private SummaryStatistics globalRatioStatistics = null;
    private Map<String, SummaryStatistics> detailedRatioStat = null;
    private SummaryStatistics conditioningAcceptStat = null;
    private List<PhyloProcessor> processors = null;
    private RecordHighestLikelihood rhl = null;
    private LinkedList<ProposalDistribution> proposalDistributions = null;
    private boolean outputText = false;
    private String fileOutputPrefix = "";
    private PrintWriter currentTreeOut = null;
    private int nTreesInCurrentFile = 0;
    private UnrootedTree last = null;
    private Map<Taxon, String> conditionedCluster = null;
    public static Options _defaultPhyloSamplerOptions = new Options();
    public static PriorOptions _defaultPriorOptions = new PriorOptions();

    public void setOutputText(boolean v) {
        this.outputText = v;
    }

    public void init(UnrootedTreeState initialState) {
        this.iteration = 0;
        this.temperature = 1.0;
        this.initialState = initialState;
        this.currentState = initialState;
        this.prior = this.phyloSamplerOptions.prior.prior(this.priorOptions);
        this.globalRatioStatistics = new SummaryStatistics();
        this.conditioningAcceptStat = new SummaryStatistics();
        this.detailedRatioStat = new HashMap<String, SummaryStatistics>();
        this.proposalDistributions = new LinkedList();
        this.processors = new ArrayList<PhyloProcessor>();
        this.rhl = new RecordHighestLikelihood();
        this.processors.add(this.rhl);
        this.initialized = true;
        this.log(0);
    }

    public PhyloSampler createHeatedVersion(double temperature) {
        PhyloSampler ps = new PhyloSampler();
        ps.outputText = false;
        ps.proposalOptions = this.proposalOptions;
        ps.phyloSamplerOptions = this.phyloSamplerOptions;
        ps.priorOptions = this.priorOptions;
        ps.setConstraints(this.conditionedCluster);
        ps.init(this.initialState.deepClone());
        ps.temperature = temperature;
        return ps;
    }

    public UnrootedTreeState mle() {
        return this.rhl.argmax;
    }

    public double mleLogLikelihood() {
        return this.rhl.argmax != null ? this.mle().logLikelihood() : Double.NaN;
    }

    public double logLikelihood() {
        return this.currentState.logLikelihood();
    }

    public void sampleManyTimes() {
        if (this.outputText) {
            LogInfo.track("Sampling " + this.phyloSamplerOptions.nIteration + " MCMC steps");
        }
        while (this.iteration < this.phyloSamplerOptions.nIteration) {
            this.sample();
        }
        if (this.outputText) {
            LogInfo.end_track();
        }
    }

    public int nIterations() {
        return this.iteration;
    }

    public void sample() {
        if (!this.initialized) {
            throw new RuntimeException();
        }
        this.sample(this.phyloSamplerOptions.rand);
        if ((this.iteration + 1) % this.phyloSamplerOptions.logFrequency == 0) {
            this.log(this.iteration);
        }
        ++this.iteration;
    }

    public double getMeanAcceptanceRatio() {
        return this.globalRatioStatistics.getMean();
    }

    public double energy(UnrootedTreeState ncts) {
        double value = -ncts.logLikelihood() - this.prior.logPriorDensity(ncts.getNonClockTree());
        return value;
    }

    public double currentEnergy() {
        return this.energy(this.currentState);
    }

    public double energyOverTemperature(UnrootedTreeState ncts) {
        return (-ncts.logLikelihood() - this.prior.logPriorDensity(ncts.getNonClockTree())) / this.temperature;
    }

    public double currentEnergyOverTemperature() {
        return this.energyOverTemperature(this.currentState);
    }

    public static void sampleSwap(PhyloSampler ps1, PhyloSampler ps2, Random rand, SummaryStatistics ratioStats) {
        double ratio = PhyloSampler.swapRatio(ps1, ps2);
        ratioStats.addValue(ratio);
        if (rand.nextDouble() < ratio) {
            PhyloSampler.swapStates(ps1, ps2);
        }
    }

    private static void swapStates(PhyloSampler ps1, PhyloSampler ps2) {
        UnrootedTreeState s2;
        UnrootedTreeState s1 = ps1.currentState;
        ps1.currentState = s2 = ps2.currentState;
        ps2.currentState = s1;
    }

    private static double swapRatio(PhyloSampler ps1, PhyloSampler ps2) {
        return Math.min(1.0, Math.exp((ps1.currentEnergy() - ps2.currentEnergy()) * (1.0 / ps1.getTemperature() - 1.0 / ps2.getTemperature())));
    }

    private void log(int iter) {
        if (this.outputText) {
            LogInfo.logs(this);
        }
        if (this.temperature == 1.0) {
            this.outputTree(this.currentState.getNonClockTree(), iter);
        }
        if (this.mle() != null && this.last != this.mle().getNonClockTree()) {
            this.last = this.mle().getNonClockTree();
            this.outputMLETree(this.last);
        }
    }

    private void outputMLETree(UnrootedTree nct) {
        IO.writeToDisk(Execution.getFile(this.fileOutputPrefix + "mle.newick"), nct.toNewick());
    }

    private void outputTree(UnrootedTree nct, int iter) {
        if (this.currentTreeOut == null || this.nTreesInCurrentFile >= this.phyloSamplerOptions.nTreesPerFile) {
            if (this.currentTreeOut != null) {
                this.currentTreeOut.close();
            }
            this.currentTreeOut = IOUtils.openOutHard(Execution.getFile(this.fileOutputPrefix + "samples-" + iter + ".newick.gz"));
            this.nTreesInCurrentFile = 0;
        }
        this.currentTreeOut.append(nct.toNewick() + "\n");
        ++this.nTreesInCurrentFile;
    }

    public void closeFile() {
        if (this.currentTreeOut != null) {
            this.currentTreeOut.close();
            this.currentTreeOut = null;
        }
    }

    public boolean isConditioning() {
        return this.conditionedCluster != null;
    }

    public void setConstraints(Map<Taxon, String> constraints) {
        this.conditionedCluster = constraints;
    }

    private double conditionedImpurity(UnrootedTree t) {
        return 1.0 - Purity.purity(t, this.conditionedCluster);
    }

    private boolean annealAccept(Random rand, UnrootedTreeState proposedState) {
        boolean result = this._annealAccept(rand, proposedState);
        this.conditioningAcceptStat.addValue(result ? 1.0 : 0.0);
        return result;
    }

    public double getConditionalAnnealRatio() {
        return this.conditioningAcceptStat.getMean();
    }

    public double getConditionalFraction() {
        return this.conditionedImpurity(this.currentState.getNonClockTree());
    }

    private boolean _annealAccept(Random rand, UnrootedTreeState proposedState) {
        if (!this.isConditioning()) {
            return true;
        }
        double newFractionOfConditionedCladesViolated = this.conditionedImpurity(proposedState.getNonClockTree());
        if (newFractionOfConditionedCladesViolated == 0.0) {
            return true;
        }
        double oldFractionOfConditionedCladesViolated = this.getConditionalFraction();
        if (oldFractionOfConditionedCladesViolated == 0.0) {
            return false;
        }
        if (newFractionOfConditionedCladesViolated <= oldFractionOfConditionedCladesViolated) {
            return true;
        }
        return rand.nextDouble() < this.phyloSamplerOptions.conditionAnneal;
    }

    private void sample(Random rand) {
        ProposalDistribution proposal = this.nextProposal(rand);
        Pair<UnrootedTree, Double> result = proposal.propose(this.currentState.getNonClockTree(), rand);
        if (result != null) {
            double energyLogRatio;
            UnrootedTreeState proposedState = this.currentState.copyAndChange(result.getFirst());
            double logProposalRatio = result.getSecond();
            double ratio = Math.min(1.0, Math.exp(logProposalRatio - (energyLogRatio = this.energy(proposedState) - this.currentEnergy()) / this.temperature));
            if (Double.isNaN(ratio)) {
                throw new RuntimeException();
            }
            this.globalRatioStatistics.addValue(ratio);
            this.getDetailedRatioStatistics(proposal).addValue(ratio);
            if (this.annealAccept(rand, proposedState) && rand.nextDouble() < ratio) {
                this.currentState = proposedState;
            }
            for (PhyloProcessor processor : this.processors) {
                processor.process(this.currentState);
            }
        }
    }

    private ProposalDistribution nextProposal(Random rand) {
        if (this.proposalDistributions.isEmpty()) {
            this.proposalDistributions.addAll(ProposalDistribution.Util.proposalList(this.proposalOptions, this.initialState.getNonClockTree(), rand));
        }
        return this.proposalDistributions.get(rand.nextInt(this.proposalDistributions.size()));
    }

    private SummaryStatistics getDetailedRatioStatistics(ProposalDistribution pd) {
        return CollUtils.getNoNull(this.detailedRatioStat, pd.description(), new SummaryStatistics());
    }

    public String toString() {
        String result = this.toString("global", this.globalRatioStatistics) + " {\n";
        result = result + "\tBest log likelihood: " + this.mleLogLikelihood() + "\n";
        result = result + "\tCurrent log likelihood: " + this.currentState.logLikelihood() + "\n";
        for (String key : this.detailedRatioStat.keySet()) {
            result = result + "\t" + this.toString(key, this.detailedRatioStat.get(key)) + "\n";
        }
        return result + "}\n";
    }

    public String toString(Object descr, SummaryStatistics ss) {
        return "Number of " + descr + " sampling steps: " + ss.getN() + "\t" + "(Mean acceptance ratio: " + ss.getMean() + ")";
    }

    public String detailedRatioToString() {
        StringBuilder result = new StringBuilder(" ");
        ArrayList<String> keys = new ArrayList<String>(this.detailedRatioStat.keySet());
        Collections.sort(keys);
        for (String key : keys) {
            result.append(key + "=" + EasyFormat.fmt2(this.detailedRatioStat.get(key).getMean()) + " ");
        }
        return result.toString();
    }

    public ProposalDistribution.Options getProposalOptions() {
        return this.proposalOptions;
    }

    public void setProposalOptions(ProposalDistribution.Options proposalOptions) {
        this.proposalOptions = proposalOptions;
    }

    public Options getPhyloSamplerOptions() {
        return this.phyloSamplerOptions;
    }

    public void setPhyloSamplerOptions(Options phyloSamplerOptions) {
        this.phyloSamplerOptions = phyloSamplerOptions;
    }

    public List<PhyloProcessor> getProcessors() {
        return this.processors;
    }

    public void setProcessors(List<PhyloProcessor> processors) {
        this.processors = processors;
    }

    public double getTemperature() {
        return this.temperature;
    }

    public void setTemperature(double temperature) {
        this.temperature = temperature;
    }

    public void setFileOutputPrefix(String fileOutputPrefix) {
        this.fileOutputPrefix = fileOutputPrefix;
    }

    public UnrootedTreeState getInitialState() {
        return this.initialState;
    }

    public UnrootedTreeState getCurrentState() {
        return this.currentState;
    }

    public NonClockTreePrior getPrior() {
        return this.prior;
    }

    public void setPrior(NonClockTreePrior prior) {
        this.prior = prior;
    }

    public static enum Prior {
        EXP{

            @Override
            public NonClockTreePrior prior(PriorOptions options) {
                return new ExponentialPrior(options.multiplicativeBranchFactor);
            }
        }
        ,
        IMPROPER{

            @Override
            public NonClockTreePrior prior(PriorOptions options) {
                return new ImproperPrior();
            }
        };


        abstract NonClockTreePrior prior(PriorOptions var1);
    }

    public static class PriorOptions {
        @Option
        public double multiplicativeBranchFactor = 2.0;
    }

    public static class Options {
        @Option
        public Prior prior = Prior.EXP;
        @Option
        public Random rand = new Random(1L);
        @Option
        public int nIteration = 1000;
        @Option
        public int logFrequency = 100;
        @Option
        public int nTreesPerFile = 100;
        @Option
        public double conditionAnneal = 0.01;
    }

    public static class ExponentialPrior
    implements NonClockTreePrior {
        public final double meanParam;

        public ExponentialPrior(double meanParam) {
            this.meanParam = meanParam;
        }

        @Override
        public double logPriorDensity(UnrootedTree nct) {
            double sum = 0.0;
            for (UnorderedPair<Taxon, Taxon> edge : nct.edges()) {
                sum += Sampling.exponentialLogDensity(this.meanParam, nct.branchLength(edge));
            }
            return sum;
        }
    }

    public static class ImproperPrior
    implements NonClockTreePrior {
        @Override
        public double logPriorDensity(UnrootedTree nct) {
            return 0.0;
        }
    }

    public static interface NonClockTreePrior {
        public double logPriorDensity(UnrootedTree var1);
    }

    public static class RecordHighestLikelihood
    implements PhyloProcessor {
        private UnrootedTreeState argmax = null;
        private double max = Double.NEGATIVE_INFINITY;

        @Override
        public void process(UnrootedTreeState ncts) {
            if (ncts.logLikelihood() > this.max) {
                this.max = ncts.logLikelihood();
                this.argmax = ncts;
            }
        }
    }

    public static class PhyloProcessorAdaptor
    implements PhyloProcessor {
        public final UnrootedTree.UnrootedTreeProcessor rtp;

        public PhyloProcessorAdaptor(UnrootedTree.UnrootedTreeProcessor rtp) {
            this.rtp = rtp;
        }

        @Override
        public void process(UnrootedTreeState ncts) {
            this.rtp.process(ncts.getNonClockTree());
        }
    }

    public static interface PhyloProcessor {
        public void process(UnrootedTreeState var1);
    }
}

