/*
 * Decompiled with CFR 0.152.
 */
package pty.mcmc;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.exec.Execution;
import goblin.BayesRiskMinimizer;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import nuts.util.EasyFormat;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.UnrootedTree;
import pty.eval.CladeOverlap;
import pty.eval.SymmetricDiff;
import pty.io.Dataset;
import pty.io.HGDPDataset;
import pty.io.WalsDataset;
import pty.learn.CTMCLoader;
import pty.mcmc.ParallelTemperingChain;
import pty.mcmc.PhyloSampler;
import pty.mcmc.ProposalDistribution;
import pty.mcmc.UnrootedTreeState;
import pty.smc.MapLeaves;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import pty.smc.models.BrownianModel;
import pty.smc.test.TestBrownianModel;

public class Main
implements Runnable {
    @Option
    public ArrayList<Dataset.DatasetType> dataTypes = new ArrayList<Dataset.DatasetType>(Arrays.asList(Dataset.DatasetType.WALS));
    @Option
    public TestBrownianModel.KernelType kernelType = TestBrownianModel.KernelType.PRIOR_PRIOR;
    @Option
    public double brownianMotionVariance = 0.2;
    @Option
    public ArrayList<String> pathsToInitTree = new ArrayList<String>(Arrays.asList(""));
    @Option
    public int nStepsPerSubRound = 150;
    @Option
    public Random paramRand = new Random(1L);
    @Option
    public int nParamResamplingPerRound = 10;
    @Option
    public String mapfile = "data/language-gene-map.txt";
    @Option
    public boolean resampleAgreementParam = true;
    @Option
    public double paramResamplingStepSize = 2.0;
    @Option
    public double hyperParamScale = 3.0;
    @Option
    public double hyperParamShape = 2.0;
    @Option
    public double initialAgreementParam = 1.0;
    @Option
    public boolean monotonicallyIncreaseAgreementParam = false;
    @Option
    public double monotonicIncrease = 1.0;
    @Option
    public AgreementLoss cladeLoss = AgreementLoss.CLADEOVER;
    @Option
    public boolean conditionOnFamilies = true;
    private MapLeaves ml = null;
    private static ParticleFilter<PartialCoalescentState> pf = new ParticleFilter();
    private static CTMCLoader loader = new CTMCLoader();
    private Dataset data;
    private Set<Taxon> allLangs = new HashSet<Taxon>();
    private int nOuterRuns = -1;

    public static void main(String[] args) {
        Execution.monitor = true;
        Execution.makeThunk = false;
        Execution.create = true;
        Execution.useStandardExecPoolDirStrategy = true;
        Execution.run(args, new Main(), "prior", PhyloSampler._defaultPriorOptions, "prop", ProposalDistribution.Util._defaultProposalDistributionOptions, "sampler", PhyloSampler._defaultPhyloSamplerOptions, "wals", WalsDataset.class, "hddp", HGDPDataset.class, "filter", pf, "langparam", loader, "partemp", ParallelTemperingChain._defaultTemperingOptions);
    }

    @Override
    public void run() {
        this.run("");
    }

    public List<File> run(String prefixForSampleDirectories) {
        ArrayList<File> result = new ArrayList<File>();
        HashMap<Dataset.DatasetType, ParallelTemperingChain> chains = CollUtils.map();
        for (int i = 0; i < this.dataTypes.size(); ++i) {
            Dataset.DatasetType dataType = this.dataTypes.get(i);
            String pathToInitTree = "";
            if (this.pathsToInitTree.size() > i) {
                pathToInitTree = this.pathsToInitTree.get(i);
            }
            this.data = dataType.loadDataset();
            loader.setData(this.data);
            this.allLangs.addAll(this.data.observations().keySet());
            UnrootedTreeState nctsInit = null;
            if (pathToInitTree.equals("")) {
                LogInfo.track("Particle Filter phase");
                PartialCoalescentState initState = null;
                if (dataType == Dataset.DatasetType.HGDP) {
                    initState = PartialCoalescentState.initState(this.data, new BrownianModel(this.data.nSites(), this.brownianMotionVariance), false);
                } else if (dataType == Dataset.DatasetType.WALS) {
                    initState = PartialCoalescentState.initState(this.data, loader.load());
                } else {
                    throw new RuntimeException();
                }
                PartialCoalescentState initTree = null;
                ParticleKernel<PartialCoalescentState> pk = this.kernelType.load(initState, null);
                ParticleFilter.MAPDecoder processor = new ParticleFilter.MAPDecoder();
                pf.sample(pk, processor);
                initTree = (PartialCoalescentState)processor.map();
                LogInfo.end_track();
                nctsInit = UnrootedTreeState.fromPartialCoalescentState(initTree);
            } else {
                UnrootedTree nct = UnrootedTree.fromNewick(new File(pathToInitTree));
                if (dataType == Dataset.DatasetType.HGDP) {
                    nctsInit = UnrootedTreeState.fromBrownianMotion(nct, this.data, new BrownianModel(this.data.nSites(), this.brownianMotionVariance));
                } else if (dataType == Dataset.DatasetType.WALS) {
                    nctsInit = UnrootedTreeState.fromCTMC(nct, this.data, loader.load());
                } else {
                    throw new RuntimeException();
                }
            }
            PhyloSampler sampler = new PhyloSampler();
            sampler.init(nctsInit);
            if (dataType == Dataset.DatasetType.WALS && this.conditionOnFamilies) {
                sampler.setConstraints(WalsDataset.langDB.familyMap());
            }
            ParallelTemperingChain temperingChain = new ParallelTemperingChain();
            result.add(temperingChain.setOutputPrefix(prefixForSampleDirectories + "-" + (Object)((Object)dataType) + "-"));
            temperingChain.init(sampler);
            chains.put(dataType, temperingChain);
        }
        if (chains.size() == 1) {
            LogInfo.track("Markov Chain Monte Carlo phase");
            ((ParallelTemperingChain)chains.values().iterator().next()).sample();
            LogInfo.end_track();
        } else {
            if (chains.size() > 2) {
                throw new RuntimeException();
            }
            this.sampleJoint(chains);
        }
        return result;
    }

    private void sampleJoint(Map<Dataset.DatasetType, ParallelTemperingChain> chains) {
        this.ml = MapLeaves.parse(this.mapfile);
        if (this.monotonicallyIncreaseAgreementParam && this.resampleAgreementParam) {
            throw new RuntimeException();
        }
        LogInfo.track((Object)"Checking map file", true);
        LogInfo.logs("Data not in map items: " + this.ml.dataNotInMapItems(this.allLangs));
        LogInfo.logs("Map items not in data: " + this.ml.mapItemsNotInData(this.allLangs));
        LogInfo.end_track();
        List<JointPrior<Set<Set<Taxon>>>> jointPriors = this.initJoints(chains);
        this.update(jointPriors, chains);
        for (int iter = 0; iter < this.nOuterRuns; ++iter) {
            for (Dataset.DatasetType dt : chains.keySet()) {
                chains.get((Object)dt).sample();
                this.update(jointPriors, chains);
                LogInfo.track("Agreement statistics:");
                LogInfo.logsForce(this.toString(jointPriors));
                LogInfo.end_track();
                if (!this.resampleAgreementParam && !this.monotonicallyIncreaseAgreementParam) continue;
                this.resampleParams(jointPriors);
            }
        }
    }

    public String toString(List<JointPrior<Set<Set<Taxon>>>> jointPriors) {
        StringBuilder result = new StringBuilder();
        for (int i = 0; i < jointPriors.size(); ++i) {
            result.append("#" + i + ": " + "agreeParam=" + EasyFormat.fmt2(((JointPrior)jointPriors.get(i)).agreementParameter) + ", " + "dist=" + EasyFormat.fmt2(jointPriors.get(i).getCurrentDistance()) + ", " + "paramAccept=" + EasyFormat.fmt2(((JointPrior)jointPriors.get(i)).resamplingStats.getMean()) + "\n");
        }
        return result.toString();
    }

    private void resampleParams(List<JointPrior<Set<Set<Taxon>>>> jointPriors) {
        for (JointPrior<Set<Set<Taxon>>> jp : jointPriors) {
            if (this.monotonicallyIncreaseAgreementParam) {
                jp.increaseAgreementParam(this.monotonicIncrease);
                continue;
            }
            for (int iter = 0; iter < this.nParamResamplingPerRound; ++iter) {
                jp.sampleAgreementParameter(this.paramRand);
            }
        }
    }

    private void update(List<JointPrior<Set<Set<Taxon>>>> jointPriors, Map<Dataset.DatasetType, ParallelTemperingChain> chains) {
        for (int c = 0; c < chains.values().iterator().next().nChains(); ++c) {
            JointPrior<Set<Set<Taxon>>> jp = jointPriors.get(c);
            for (Dataset.DatasetType dt : chains.keySet()) {
                jp.update(dt, chains.get((Object)dt).getChain(c).getCurrentState().getNonClockTree());
            }
        }
    }

    private List<JointPrior<Set<Set<Taxon>>>> initJoints(Map<Dataset.DatasetType, ParallelTemperingChain> chains) {
        ArrayList<JointPrior<Set<Set<Taxon>>>> jointPriors = CollUtils.list();
        for (int i = 0; i < chains.values().iterator().next().nChains(); ++i) {
            JointPrior<Set<Set<Taxon>>> jp = new JointPrior<Set<Set<Taxon>>>(new CladeExtractor(), this.cladeLoss.getLoss(this), this.paramResamplingStepSize, this.hyperParamShape, this.hyperParamScale, this.initialAgreementParam);
            for (Dataset.DatasetType dt : chains.keySet()) {
                jp.addToJointModel(dt, chains.get((Object)dt).getChain(i));
            }
            jointPriors.add(jp);
        }
        for (Dataset.DatasetType dt : chains.keySet()) {
            int initialNumberOfRounds = chains.get((Object)((Object)dt)).getOptions().nRounds;
            if (this.nOuterRuns == -1) {
                this.nOuterRuns = initialNumberOfRounds / this.nStepsPerSubRound;
            }
            chains.get((Object)((Object)dt)).getOptions().nRounds = this.nStepsPerSubRound;
        }
        return jointPriors;
    }

    public static Dataset.DatasetType other(Collection<Dataset.DatasetType> all, Dataset.DatasetType one) {
        if (all.size() != 2) {
            throw new RuntimeException();
        }
        for (Dataset.DatasetType ds : all) {
            if (ds.equals((Object)one)) continue;
            return ds;
        }
        throw new RuntimeException();
    }

    public class CladeExtractor
    implements NonClockTreeSuffStatExtractor<Set<Set<Taxon>>> {
        @Override
        public Set<Set<Taxon>> extract(UnrootedTree nct) {
            return Main.this.ml.filterClades(nct.clades());
        }
    }

    public static interface NonClockTreeSuffStatExtractor<SuffStat> {
        public SuffStat extract(UnrootedTree var1);
    }

    public static class JointPrior<SuffStat> {
        public final NonClockTreeSuffStatExtractor<SuffStat> suffStatExtractor;
        public final BayesRiskMinimizer.LossFct<SuffStat> lossFct;
        public final double a;
        public final double hyperShape;
        public final double hyperScale;
        private double temp = Double.NaN;
        private double agreementParameter;
        private Map<Dataset.DatasetType, SuffStat> statesSS = CollUtils.map();
        private Map<Dataset.DatasetType, UnrootedTree> states = CollUtils.map();
        private Map<Dataset.DatasetType, PhyloSampler.NonClockTreePrior> basePriors = CollUtils.map();
        private SummaryStatistics resamplingStats = new SummaryStatistics();

        public JointPrior(NonClockTreeSuffStatExtractor<SuffStat> suffStatExtractor, BayesRiskMinimizer.LossFct<SuffStat> lossFct, double parameterRescalingProposal, double hyperShape, double hyperScale, double initialAgreementParameter) {
            if (parameterRescalingProposal <= 1.0) {
                throw new RuntimeException();
            }
            this.agreementParameter = initialAgreementParameter;
            this.suffStatExtractor = suffStatExtractor;
            this.lossFct = lossFct;
            this.a = parameterRescalingProposal;
            this.hyperShape = hyperShape;
            this.hyperScale = hyperScale;
        }

        public void update(Dataset.DatasetType datasetBeingResampled, UnrootedTree newValue) {
            this.states.put(datasetBeingResampled, newValue);
            this.statesSS.put(datasetBeingResampled, this.suffStatExtractor.extract(newValue));
        }

        public void addToJointModel(final Dataset.DatasetType datasetBeingResampled, PhyloSampler ps) {
            if (Double.isNaN(this.temp)) {
                this.temp = ps.getTemperature();
            } else if (this.temp != ps.getTemperature()) {
                throw new RuntimeException();
            }
            PhyloSampler.NonClockTreePrior basePrior = ps.getPrior();
            this.basePriors.put(datasetBeingResampled, basePrior);
            PhyloSampler.NonClockTreePrior conditionalPrior = new PhyloSampler.NonClockTreePrior(){
                private Dataset.DatasetType other = null;

                @Override
                public double logPriorDensity(UnrootedTree nct) {
                    if (this.other == null) {
                        this.other = Main.other(states.keySet(), datasetBeingResampled);
                    }
                    return ((PhyloSampler.NonClockTreePrior)basePriors.get((Object)datasetBeingResampled)).logPriorDensity(nct) + this.logPriorAgreement(suffStatExtractor.extract(nct), statesSS.get((Object)this.other));
                }
            };
            ps.setPrior(conditionalPrior);
        }

        public void sampleAgreementParameter(Random rand) {
            double m = Sampling.nextDouble(rand, 1.0 / this.a, this.a);
            double newParam = m * this.agreementParameter;
            double hyperDensityLogRatio = this.hyperLogDensity(newParam) - this.hyperLogDensity(this.agreementParameter);
            double densityLogRatio = this.logJointPriorDensity(newParam) - this.logJointPriorDensity(this.agreementParameter);
            double ratio = Math.min(1.0, Math.exp(Math.log(m) + (hyperDensityLogRatio + densityLogRatio) / this.temp));
            this.resamplingStats.addValue(ratio);
            if (rand.nextDouble() < ratio) {
                this.agreementParameter = newParam;
            }
        }

        public void increaseAgreementParam(double monotonicIncrease) {
            this.agreementParameter += monotonicIncrease;
        }

        private double hyperLogDensity(double x) {
            return (this.hyperShape - 1.0) * Math.log(x) - x / this.hyperScale;
        }

        public double getCurrentDistance() {
            ArrayList suffStats = CollUtils.list();
            for (Dataset.DatasetType dt : this.basePriors.keySet()) {
                suffStats.add(this.suffStatExtractor.extract(this.states.get((Object)dt)));
            }
            return this.lossFct.loss(suffStats.get(0), suffStats.get(1));
        }

        private double logJointPriorDensity(double param) {
            ArrayList suffStats = CollUtils.list();
            double sum = 0.0;
            for (Dataset.DatasetType dt : this.basePriors.keySet()) {
                sum += this.basePriors.get((Object)dt).logPriorDensity(this.states.get((Object)dt));
                suffStats.add(this.suffStatExtractor.extract(this.states.get((Object)dt)));
            }
            return sum + this.logPriorAgreement(suffStats.get(0), suffStats.get(1), param);
        }

        private double logPriorAgreement(SuffStat ss1, SuffStat ss2) {
            return this.logPriorAgreement(ss1, ss2, this.agreementParameter);
        }

        private double logPriorAgreement(SuffStat ss1, SuffStat ss2, double param) {
            return -param * this.lossFct.loss(ss1, ss2);
        }
    }

    class CladeOverlapLoss
    implements BayesRiskMinimizer.LossFct<Set<Set<Taxon>>> {
        CladeOverlapLoss() {
        }

        @Override
        public double loss(Set<Set<Taxon>> t1, Set<Set<Taxon>> t2) {
            double n = t1.size();
            return 1.0 - CladeOverlap.cladeOverlap(t1, Main.this.ml.mapClades(t2)) / n / n;
        }
    }

    class SymmDiffLoss
    implements BayesRiskMinimizer.LossFct<Set<Set<Taxon>>> {
        SymmDiffLoss() {
        }

        @Override
        public double loss(Set<Set<Taxon>> t1, Set<Set<Taxon>> t2) {
            double result = (double)SymmetricDiff.symmetricDifferenceSize(t1, Main.this.ml.mapClades(t2)) / 2.0 / (double)t1.size();
            return result;
        }
    }

    public static enum AgreementLoss {
        SYMMDIFF{

            @Override
            public BayesRiskMinimizer.LossFct<Set<Set<Taxon>>> getLoss(Main main) {
                return main.new SymmDiffLoss();
            }
        }
        ,
        CLADEOVER{

            @Override
            public BayesRiskMinimizer.LossFct<Set<Set<Taxon>>> getLoss(Main main) {
                return main.new CladeOverlapLoss();
            }
        };


        abstract BayesRiskMinimizer.LossFct<Set<Set<Taxon>>> getLoss(Main var1);
    }
}

