/*
 * Decompiled with CFR 0.152.
 */
package monaco.mcmc;

import fig.basic.LogInfo;
import fig.basic.Pair;
import fig.basic.Parallelizer;
import fig.exec.Execution;
import gep.util.OutputManager;
import java.io.File;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import ma.MSAPoset;
import monaco.StandardKernel;
import monaco.mcmc.MCAlgorithm;
import monaco.mcmc.MCInitContext;
import monaco.process.ProcessSchedule;
import monaco.process.ProcessScheduleContext;
import monaco.process.ResampleStatus;
import nuts.io.OutputProducer;
import nuts.math.Sampling;
import nuts.math.StatisticsMap;
import nuts.util.CollUtils;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;

public class ParallelTemperedMCMC<S>
implements MCAlgorithm<S>,
OutputProducer {
    private final List<StandardKernel<S>> kernels = CollUtils.list();
    private final List<Double> exponents = CollUtils.list();
    private final S[] particles;
    private final List<StatisticsMap<String>> mhRatioAvgs = CollUtils.list();
    private final StatisticsMap<String> swapRatioAvgs = new StatisticsMap();
    private final Random rand;
    private final int nThreads;
    private final double swapProportion;
    private final ProcessSchedule processSchedule;
    private final ParticleFilter.ParticleProcessor<S> processor;
    private final int nChains;
    private final int generationPerPrintMHPeriod;
    private final OutputManager manager = new OutputManager();
    public static MSAPoset ref = null;

    public ParallelTemperedMCMC(ParticleKernel<S> _kernel, MCInitContext<S> context) {
        this.rand = context.getOptions().samplingRand;
        this.nChains = context.getOptions().nTemperedChains;
        this.nThreads = Math.min(this.nChains, context.getOptions().maxNThreads);
        this.swapProportion = context.getOptions().swapProportion;
        this.processSchedule = context.getProcessSchedule();
        this.processor = context.getProcessor();
        this.generationPerPrintMHPeriod = context.getOptions().generationPerPrintMHPeriod;
        StandardKernel kernel = (StandardKernel)_kernel;
        Object initial = kernel.getInitial();
        this.particles = new Object[this.nChains()];
        for (int c = 0; c < this.nChains(); ++c) {
            double currentExponent = context.getOptions().getTemperedExponent(c);
            StandardKernel currentKernel = kernel.createTemperedVersion(currentExponent);
            this.kernels.add(currentKernel);
            this.exponents.add(currentExponent);
            this.particles[c] = initial;
            this.mhRatioAvgs.add(new StatisticsMap());
        }
    }

    private int nChains() {
        return this.nChains;
    }

    private void sampleSwaps(Random rand) {
        List<Integer> indices = CollUtils.ints(this.nChains() - 1);
        Collections.shuffle(indices, rand);
        for (int i : indices) {
            this.sampleSwap(rand, i, i + 1);
        }
    }

    private void sampleSwap(Random rand, int i1, int i2) {
        if (i1 >= i2) {
            throw new RuntimeException();
        }
        StandardKernel<S> k1 = this.kernels.get(i1);
        StandardKernel<S> k2 = this.kernels.get(i2);
        S initParticle1 = this.particles[i1];
        S initParticle2 = this.particles[i2];
        double alpha1 = this.exponents.get(i1);
        double alpha2 = this.exponents.get(i2);
        double density1 = k1.getDensity().logDensity(initParticle1);
        double density2 = k2.getDensity().logDensity(initParticle2);
        double ratio = Math.min(1.0, Math.exp((alpha2 - alpha1) * (density1 - density2)));
        this.swapRatioAvgs.addValue("swap(" + i1 + "," + i2 + ")", ratio);
        if (rand.nextDouble() < ratio) {
            this.particles[i1] = initParticle2;
            this.particles[i2] = initParticle1;
        }
    }

    private void sampleChains(Random rand) {
        final long[] seeds = Sampling.createSeeds(this.nChains(), rand);
        Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(this.nThreads);
        parallelizer.setPrimaryThread();
        parallelizer.process(CollUtils.ints(this.nChains()), new Parallelizer.Processor<Integer>(){

            @Override
            public void process(Integer x, int _i, int _n, boolean log) {
                Random rand = new Random(seeds[x]);
                ParallelTemperedMCMC.this.sampleChain(rand, x);
            }
        });
    }

    public static MSAPoset getRef() {
        if (ref == null) {
            ref = MSAPoset.parseFASTA(new File(Execution.getFile("full-data"), "msa.fasta"));
        }
        return ref;
    }

    private void sampleChain(Random rand, int i) {
        Pair<S, Double> pair = this.kernels.get(i).next(rand, this.particles[i], this.mhRatioAvgs.get(i));
        double ratio = Math.min(1.0, Math.exp(pair.getSecond()));
        boolean accepted = rand.nextDouble() < ratio;
        S proposed = pair.getFirst();
        if (accepted) {
            this.particles[i] = proposed;
        }
        if (proposed instanceof InPlaceProposedState) {
            ((InPlaceProposedState)proposed).wasProposalAccepted(accepted);
        }
    }

    @Override
    public void run() {
        int nGenerations = this.kernels.get(0).nIterationsLeft(this.particles[0]);
        LogInfo.logs("Starting ParallelTemperedMCMC (nChains=" + this.nChains + ")");
        for (int gen = 0; gen < nGenerations; ++gen) {
            if (this.rand.nextDouble() < this.swapProportion) {
                this.sampleSwaps(this.rand);
            } else {
                this.sampleChains(this.rand);
            }
            ProcessScheduleContext context = new ProcessScheduleContext(gen, gen == nGenerations - 1, ResampleStatus.NA);
            if (this.processSchedule.shouldProcess(context)) {
                LogInfo.logs("Processing sample (gen=" + (1 + gen) + "/" + nGenerations + ")");
                this.processor.process(this.particles[0], 1.0);
                this.processSchedule.monitor(context);
            }
            if (gen % this.generationPerPrintMHPeriod != 0) continue;
            this.printMHStatistics(gen);
        }
    }

    private void printMHStatistics(int gen) {
        for (int c = 0; c < this.nChains(); ++c) {
            StatisticsMap<String> currentMap = this.mhRatioAvgs.get(c);
            for (String key : currentMap.keySet()) {
                this.manager.write("ratio-chain-" + c + ",exp=" + this.exponents.get(c), "generation", gen, "proposal", key, "mean", currentMap.getSummaryStat(key).getMean());
            }
        }
        for (String key : this.swapRatioAvgs.keySet()) {
            this.manager.write("ratio-swaps", "generation", gen, "proposal", key, "mean", this.swapRatioAvgs.getSummaryStat(key).getMean());
        }
    }

    @Override
    public void setOutputFolder(File f) {
        this.manager.setOutputFolder(f);
    }

    public static interface InPlaceProposedState {
        public void wasProposalAccepted(boolean var1);
    }
}

