/*
 * Decompiled with CFR 0.152.
 */
package pty.mcmc;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Parallelizer;
import fig.exec.Execution;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.util.CollUtils;
import nuts.util.EasyFormat;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.mcmc.PhyloSampler;

public class ParallelTemperingChain {
    private List<PhyloSampler> samplers;
    private List<SummaryStatistics> swapStats;
    private TemperingOptions options = _defaultTemperingOptions;
    private boolean initialized = false;
    private String outputPrefix = "";
    public static TemperingOptions _defaultTemperingOptions = new TemperingOptions();

    public int nChains() {
        return this.samplers.size();
    }

    public PhyloSampler roomTempChain() {
        return this.samplers.get(0);
    }

    public PhyloSampler getChain(int i) {
        return this.samplers.get(i);
    }

    public TemperingOptions getOptions() {
        return this.options;
    }

    public void init(PhyloSampler roomTemperatureSampler) {
        this.samplers = new ArrayList<PhyloSampler>(this.options.nChains);
        this.swapStats = new ArrayList<SummaryStatistics>(this.options.nChains);
        for (int i = 0; i < this.options.nChains; ++i) {
            if (i == 0) {
                this.samplers.add(roomTemperatureSampler);
            } else {
                this.samplers.add(roomTemperatureSampler.createHeatedVersion(1.0 + this.options.temperature * (double)i * (double)i));
            }
            if (i < this.options.nChainMovesPerRoundPerChain - 1) {
                this.swapStats.add(new SummaryStatistics());
            }
            File directory = this.getOutDir(i);
            directory.mkdir();
            this.samplers.get(i).setFileOutputPrefix(directory.getName() + "/");
        }
        this.initialized = true;
    }

    public File getOutDir(int i) {
        return new File(Execution.getFile(this.outputPrefix + "chain-" + i));
    }

    public String toString() {
        StringBuilder result = new StringBuilder("Output prefix: " + this.outputPrefix + "\n");
        for (int i = 0; i < this.nChains(); ++i) {
            PhyloSampler current = this.samplers.get(i);
            result.append("#" + i + ": T=" + EasyFormat.fmt2(current.getTemperature()) + ", LL=" + EasyFormat.fmt2(current.logLikelihood()) + ", maxLL=" + EasyFormat.fmt2(current.mleLogLikelihood()) + ", ratio=" + EasyFormat.fmt2(current.getMeanAcceptanceRatio()) + ", condAnn=" + EasyFormat.fmt2(current.getConditionalAnnealRatio()) + ", condFra=" + (current.isConditioning() ? EasyFormat.fmt2(current.getConditionalFraction()) : "n/a") + ", swapRatio=" + (i != this.nChains() - 1 ? EasyFormat.fmt2(this.swapStats.get(i).getMean()) : "n/a") + ", details={" + current.detailedRatioToString() + "}\n");
        }
        return result.toString();
    }

    public void sample() {
        long start = System.currentTimeMillis();
        if (!this.initialized) {
            throw new RuntimeException();
        }
        for (int iter = 0; iter < this.options.nRounds; ++iter) {
            if ((iter + 1) % this.options.printFreq == 0) {
                this.log(iter);
            }
            this.sampleChains();
            if (System.currentTimeMillis() - start > this.options.timeCutOff) {
                this.closeFiles();
                return;
            }
            for (int swapIter = 0; swapIter < this.nSwapsPerRound(); ++swapIter) {
                this.sampleSwap(this.options.rand);
                if (System.currentTimeMillis() - start <= this.options.timeCutOff) continue;
                this.closeFiles();
                return;
            }
        }
    }

    public void closeFiles() {
        for (PhyloSampler sampler : this.samplers) {
            sampler.closeFile();
        }
    }

    private void log(int iter) {
        LogInfo.track("Round " + iter + "/" + this.options.nRounds + " (" + this.samplers.get(0).nIterations() + " chain iters, " + iter * this.nSwapsPerRound() + " swaps)");
        LogInfo.logs(this.toString());
        LogInfo.end_track();
    }

    public int nSwapsPerRound() {
        return this.options.nSwapsPerRoundPerChain * this.nChains();
    }

    private void sampleSwap(Random rand) {
        if (this.nChains() == 1) {
            return;
        }
        int i = rand.nextInt(this.nChains() - 1);
        PhyloSampler.sampleSwap(this.samplers.get(i), this.samplers.get(i + 1), rand, this.swapStats.get(i));
    }

    private void sampleChains() {
        Parallelizer<Integer> parallelizer = new Parallelizer<Integer>(this.options.nThreads);
        parallelizer.setPrimaryThread();
        parallelizer.process(CollUtils.ints(this.nChains()), new Parallelizer.Processor<Integer>(){

            @Override
            public void process(Integer x, int _i, int _n, boolean log) {
                for (int i = 0; i < ((ParallelTemperingChain)ParallelTemperingChain.this).options.nChainMovesPerRoundPerChain; ++i) {
                    ((PhyloSampler)ParallelTemperingChain.this.samplers.get(x)).sample();
                }
            }
        });
    }

    public File setOutputPrefix(String outputPrefix) {
        this.outputPrefix = outputPrefix;
        return this.getOutDir(0);
    }

    public static class TemperingOptions {
        @Option
        public long timeCutOff = Long.MAX_VALUE;
        @Option
        public int nThreads = 1;
        @Option
        public int nChains = 4;
        @Option
        public double temperature = 0.2;
        @Option
        public int nRounds = 100000;
        @Option
        public int nSwapsPerRoundPerChain = 2;
        @Option
        public int nChainMovesPerRoundPerChain = 10;
        @Option
        public int printFreq = 100;
        @Option
        public Random rand = new Random(1L);
    }
}

