/*
 * Decompiled with CFR 0.152.
 */
package ev.poi;

import ev.multi.IncrementalAlignExperiment;
import ev.multi.ManyPairsAligner;
import ev.multi.Segmenter;
import ev.par.ExponentialFamily;
import ev.par.FeatureExtractor;
import ev.poi.LargeStepHomologySampler;
import ev.poi.MSAMarginalLikelihoodCalculator;
import ev.poi.PoissonModel;
import ev.poi.PoissonParameters;
import ev.poi.PoissonSampleProcessor;
import ev.poi.SampleContext;
import ev.poi.processors.CountEdgeProcessor;
import ev.poi.proposals.PoissonProposalOptions;
import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import goblin.CognateId;
import goblin.Taxon;
import java.io.File;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.BalibaseCorpus;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.RateMatrixLoader;
import ma.SequenceType;
import nuts.io.CSV;
import nuts.io.IO;
import nuts.lang.StringUtils;
import nuts.math.Sampling;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pepper.Encodings;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.mcmc.ProposalDistribution;

public class PoissonAlignerMain
implements Runnable {
    @Option
    public double delRate = 1.0;
    @Option
    public double insertRate = 2.0;
    @Option
    public Random rand = new Random(1L);
    @Option
    public int burnIn = 0;
    @Option
    public boolean useInitializer = true;
    @Option
    public int segmentSize = 50;
    private ExponentialFamily initializer;
    private double[][] subRates;
    private Indexer<Character> indexer;
    public static final IncrementalAlignExperiment exp = new IncrementalAlignExperiment();
    public static final BalibaseCorpus.BalibaseCorpusOptions baliopt = new BalibaseCorpus.BalibaseCorpusOptions();
    public static ExponentialFamily.ExponentialFamilyOptions expFamOptions = new ExponentialFamily.ExponentialFamilyOptions();
    public static FeatureExtractor.FeatureOptions featureOptions = new FeatureExtractor.FeatureOptions();
    public static ProposalDistribution.Options proposalOptions = new ProposalDistribution.Options();
    public static PoissonProposalOptions poissonProposalOptions = PoissonProposalOptions.instance;
    public static ManyPairsAligner.ManyPairsAlignerOptions manyPairsAlignerOptions = new ManyPairsAligner.ManyPairsAlignerOptions();

    public static void main(String[] args) {
        IO.run(args, new PoissonAlignerMain(), "popt", manyPairsAlignerOptions, "bali", baliopt, "exp", exp, "expfam", expFamOptions, "feat", featureOptions, "prop", proposalOptions, "poiprop", poissonProposalOptions);
    }

    @Override
    public void run() {
        BalibaseCorpus bc = new BalibaseCorpus(baliopt);
        if (bc.getType() == SequenceType.RNA) {
            this.indexer = Encodings.rnaEncodings().nonGapCharactersIndexer();
            this.subRates = RateMatrixLoader.k2p();
        } else if (bc.getType() == SequenceType.PROTEIN) {
            this.indexer = Encodings.proteinEncodings().nonGapCharactersIndexer();
            this.subRates = RateMatrixLoader.dayhoff();
        } else {
            throw new RuntimeException();
        }
        if (this.useInitializer) {
            this.initializer = ExponentialFamily.createExpfam(null, expFamOptions, featureOptions, bc.getDistances());
        }
        exp.init(bc, new PoissonAlignerFactory());
        exp.run();
    }

    private MSAPoset getInit(Map<Taxon, String> sequences) {
        if (this.initializer != null) {
            try {
                return ManyPairsAligner.maxRecallAlignFromMostPairs(sequences, this.initializer, manyPairsAlignerOptions);
            }
            catch (Exception e) {
                LogInfo.error("Something went wrong while trying to form one of the inits... using blank init for these sequences:" + sequences.keySet());
            }
        }
        return new MSAPoset(sequences);
    }

    public String retreiveAnnotation(Taxon t) {
        String str = t.toString();
        String ann = StringUtils.selectFirstRegex("^.*[-]([^-]+)$", str);
        return ann;
    }

    public Map<Taxon, String> retreiveAnnotations(Collection<Taxon> ts) {
        HashMap<Taxon, String> result = CollUtils.map();
        for (Taxon t : ts) {
            result.put(t, this.retreiveAnnotation(t));
        }
        return result;
    }

    public static double meanBL(RootedTree rootedTree) {
        SummaryStatistics stat = new SummaryStatistics();
        for (Taxon lang : rootedTree.branchLengths().keySet()) {
            stat.addValue(rootedTree.branchLengths().get(lang).doubleValue());
        }
        return stat.getMean();
    }

    public class PoissonAlignerOld
    implements IncrementalAlignExperiment.IncrementalAligner {
        private MSAMarginalLikelihoodCalculator calculator;
        private final Random rand;
        private final List<Taxon> observedTaxa;
        private final List<Taxon> allTaxa;
        private final Counter<GreedyDecoder.Edge> edgeCounter = new Counter();
        private final MSAPoset currentState;
        private final PrintWriter out;
        private final double initLL;
        private final File outputFile;
        private final double treeExpBranchHyperPrior;
        private final double insertHyperPrior;
        private final double deleteHyperPrior;
        private double delta = Double.NaN;
        private LinkedList<ProposalDistribution> proposalDistributions = new LinkedList();

        public PoissonAlignerOld(MSAMarginalLikelihoodCalculator calculator, Random rand, MSAPoset init, File outputFile) {
            this.treeExpBranchHyperPrior = PoissonAlignerMain.meanBL(calculator.rootedTree);
            this.insertHyperPrior = calculator.params.insertRate;
            this.deleteHyperPrior = calculator.params.deleteRate;
            this.outputFile = outputFile;
            this.rand = new Random(rand.nextLong());
            this.currentState = new MSAPoset(init);
            this.out = IOUtils.openOutHard(new File(outputFile, "mcmc-moves-details"));
            this.out.println(CSV.body("#move-description", "accept", "marg-ll-rel-to-init", "notes"));
            this.observedTaxa = CollUtils.list(this.currentState.sequences().keySet());
            this.calculator = calculator;
            this.allTaxa = this.currentRooted().topology().nodeContents();
            this.initLL = calculator.marginalLogLikelihood(init);
            IO.writeToDisk(new File(outputFile, "initial-tree.newick"), this.currentUnrooted().toNewick());
        }

        @Override
        public MSAPoset currentGuess(IncrementalAlignExperiment.GuessType guessType) {
            if (guessType == IncrementalAlignExperiment.GuessType.SAMPLE) {
                return this.currentState;
            }
            if (guessType == IncrementalAlignExperiment.GuessType.MAX_RECALL_MBR) {
                return MSAPoset.maxRecallMSA(this.currentState.sequences(), this.edgeCounter);
            }
            throw new RuntimeException();
        }

        @Override
        public void iterate(int iter, int nIters) {
            this.resampleInDelParams();
            this.resampleTopology();
            this.resampleRooting();
            IO.writeToDisk(new File(this.outputFile, "tree-" + iter + ".newick"), this.currentUnrooted().toNewick());
            this.resampleMSA(iter);
            this.out.flush();
        }

        private void resampleInDelParams() {
            for (int iter = 0; iter < Math.max(PoissonAlignerMain.poissonProposalOptions.nResampleInsertParams, PoissonAlignerMain.poissonProposalOptions.nResampleDeleteParams); ++iter) {
                boolean accept;
                if (iter < PoissonAlignerMain.poissonProposalOptions.nResampleDeleteParams) {
                    accept = this.resampleInDelParam(true);
                    this.out.println(CSV.body("del", accept ? 1 : 0, this.delta, "val=" + this.calculator.params.deleteRate));
                }
                if (iter >= PoissonAlignerMain.poissonProposalOptions.nResampleInsertParams) continue;
                accept = this.resampleInDelParam(false);
                this.out.println(CSV.body("ins", accept ? 1 : 0, this.delta, "val=" + this.calculator.params.insertRate));
            }
        }

        private boolean resampleInDelParam(boolean isDelete) {
            boolean accept;
            if (PoissonAlignerMain.poissonProposalOptions.indelMultiplicativeProposalScaling <= 1.0) {
                throw new RuntimeException();
            }
            double oldValue = isDelete ? this.calculator.params.deleteRate : this.calculator.params.insertRate;
            double m = Sampling.nextDouble(this.rand, 1.0 / PoissonAlignerMain.poissonProposalOptions.indelMultiplicativeProposalScaling, PoissonAlignerMain.poissonProposalOptions.indelMultiplicativeProposalScaling);
            double newValue = m * oldValue;
            MSAMarginalLikelihoodCalculator newCalc = isDelete ? MSAMarginalLikelihoodCalculator.copyWithNewParams(this.calculator, PoissonParameters.copyWithNewDeleteRate(this.calculator.params, newValue)) : MSAMarginalLikelihoodCalculator.copyWithNewParams(this.calculator, PoissonParameters.copyWithNewInsertRate(this.calculator.params, newValue));
            double oldLL = this.calculator.marginalLogLikelihood(this.currentState);
            double newLL = newCalc.marginalLogLikelihood(this.currentState);
            if (PoissonAlignerMain.poissonProposalOptions.maximize) {
                accept = newLL > oldLL;
            } else {
                double logRatio = Math.log(m) + (isDelete ? this.deleteLogPrior(newValue) : this.insertLogPrior(newValue)) + newLL - (isDelete ? this.deleteLogPrior(oldValue) : this.insertLogPrior(oldValue)) - oldLL;
                boolean bl = accept = this.rand.nextDouble() < Sampling.min1exp(logRatio);
            }
            if (accept) {
                this.calculator = newCalc;
            }
            double curLL = accept ? newLL : oldLL;
            this.delta = curLL - this.initLL;
            return accept;
        }

        private double insertLogPrior(double value) {
            return Sampling.exponentialLogDensity(this.insertHyperPrior, value);
        }

        private double deleteLogPrior(double value) {
            return Sampling.exponentialLogDensity(this.deleteHyperPrior, value);
        }

        private void resampleMSA(int iter) {
            Map clades = Arbre.debox(Arbre.leavesMap(this.currentRooted().topology()));
            for (int i = 0; i < PoissonAlignerMain.poissonProposalOptions.nMSAResampling; ++i) {
                Collections.shuffle(this.observedTaxa, this.rand);
                for (Taxon taxon : PoissonAlignerMain.poissonProposalOptions.useExtendedMSAMove ? this.allTaxa : this.observedTaxa) {
                    Set<Taxon> taxaSet = clades.get(taxon);
                    boolean accept = LargeStepHomologySampler.largeHomologySamplingStep(taxaSet, this.currentState, this.calculator, this.rand, PoissonAlignerMain.poissonProposalOptions.maximize);
                    this.out.println(CSV.body("align-" + taxaSet, accept ? 1 : 0, this.calculator.marginalLogLikelihood(this.currentState) - this.initLL));
                    if (iter < PoissonAlignerMain.this.burnIn) continue;
                    this.edgeCounter.incrementAll(this.currentState.edges(), 1.0);
                }
            }
        }

        private void resampleTopology() {
            for (int i = 0; i < PoissonAlignerMain.poissonProposalOptions.nTreeResamplePerIterPerLang * this.currentState.sequences().size(); ++i) {
                this.resampleTopology(this.nextTreeResamplingProposal());
            }
        }

        private void resampleTopology(ProposalDistribution prop) {
            Pair<UnrootedTree, Double> proposed = prop.propose(this.currentUnrooted(), this.rand);
            if (proposed == null) {
                return;
            }
            RootedTree.RootingInfo currentRooting = RootedTree.Util.getRootingInfo(this.currentRooted());
            RootedTree newRootedTree = proposed.getFirst().reRoot(currentRooting);
            if (newRootedTree == null) {
                newRootedTree = proposed.getFirst().reRoot(this.nextRerooting(proposed.getFirst()));
            }
            boolean accept = this.topologyMHStep(newRootedTree);
            this.out.println(CSV.body("topo-" + prop.description(), accept ? 1 : 0, this.delta));
        }

        public double logPrior(RootedTree t) {
            double sum = 0.0;
            for (Taxon lang : t.branchLengths().keySet()) {
                sum += Sampling.exponentialLogDensity(this.treeExpBranchHyperPrior, t.branchLengths().get(lang));
            }
            return sum;
        }

        private void resampleRooting() {
            for (int i = 0; i < PoissonAlignerMain.poissonProposalOptions.nRerootingPerIterPerLang * this.currentState.sequences().size(); ++i) {
                this.resampleRooting(this.nextRerooting(this.currentUnrooted()));
            }
        }

        private void resampleRooting(RootedTree.RootingInfo nextRerooting) {
            RootedTree newRootedTree = this.currentUnrooted().reRoot(nextRerooting);
            boolean accept = this.topologyMHStep(newRootedTree);
            this.out.println(CSV.body("root-" + nextRerooting, accept ? 1 : 0, this.delta));
        }

        private boolean topologyMHStep(RootedTree proposed) {
            boolean accept;
            MSAMarginalLikelihoodCalculator newCalc = MSAMarginalLikelihoodCalculator.copyWithNewTree(this.calculator, proposed);
            double oldLL = this.calculator.marginalLogLikelihood(this.currentState);
            double newLL = newCalc.marginalLogLikelihood(this.currentState);
            if (PoissonAlignerMain.poissonProposalOptions.maximize) {
                accept = newLL > oldLL;
            } else {
                double logRatio = newLL + this.logPrior(proposed) - oldLL - this.logPrior(this.currentRooted());
                boolean bl = accept = this.rand.nextDouble() < Sampling.min1exp(logRatio);
            }
            if (accept) {
                this.calculator = newCalc;
            }
            double curLL = accept ? newLL : oldLL;
            this.delta = curLL - this.initLL;
            return accept;
        }

        private RootedTree.RootingInfo nextRerooting(UnrootedTree nct) {
            Taxon curRoot = this.currentRooted().topology().getContents();
            UnorderedPair<Taxon, Taxon> randomEdge = nct.randomEdge(this.rand);
            double randomRatio = this.rand.nextDouble();
            return new RootedTree.RootingInfo(randomEdge.getFirst(), randomEdge.getSecond(), curRoot, randomRatio);
        }

        private UnrootedTree currentUnrooted() {
            return UnrootedTree.fromRooted(this.currentRooted());
        }

        private RootedTree currentRooted() {
            return this.calculator.rootedTree;
        }

        private ProposalDistribution nextTreeResamplingProposal() {
            UnrootedTree cur = this.currentUnrooted();
            if (this.proposalDistributions.isEmpty()) {
                this.proposalDistributions.addAll(ProposalDistribution.Util.proposalList(proposalOptions, cur, this.rand));
            }
            return this.proposalDistributions.get(this.rand.nextInt(this.proposalDistributions.size()));
        }
    }

    public static class CladeProcessor
    implements PoissonSampleProcessor {
        private final Counter<Set<Taxon>> _unrootedClades = new Counter();
        private double nSamples = 0.0;

        @Override
        public void process(PoissonModel sample, SampleContext context) {
            this._unrootedClades.incrementAll(sample.currentUnrooted().clades(), 1.0);
            this.nSamples += 1.0;
        }

        public Counter<Set<Taxon>> getUnrootedCladesPosterior() {
            Counter<Set<Taxon>> result = new Counter<Set<Taxon>>();
            for (Set<Taxon> key : this._unrootedClades.keySet()) {
                result.setCount(key, this._unrootedClades.getCount(key) / this.nSamples);
            }
            return result;
        }
    }

    public class PoissonAligner
    implements IncrementalAlignExperiment.IncrementalAligner {
        private final PoissonModel model;
        private final CountEdgeProcessor msaProcessor;
        private final CladeProcessor treeProcessor;
        private final Segmenter segmenter;
        private final File outFolder;

        public PoissonAligner(Segmenter segmenter, PoissonModel model, File outFold) {
            this.segmenter = segmenter;
            this.model = model;
            this.msaProcessor = new CountEdgeProcessor();
            this.treeProcessor = new CladeProcessor();
            this.outFolder = outFold;
        }

        @Override
        public MSAPoset currentGuess(IncrementalAlignExperiment.GuessType guessType) {
            LogInfo.logsForce("Current accept rate:" + this.msaProcessor.acceptRate.getMean());
            if (guessType == IncrementalAlignExperiment.GuessType.SAMPLE) {
                return this.segmenter.desegment(this.model.alignments);
            }
            if (guessType == IncrementalAlignExperiment.GuessType.MAX_RECALL_MBR) {
                Counter<GreedyDecoder.Edge> desegmented = this.segmenter.desegment(this.msaProcessor.getEdgePosteriors());
                return MSAPoset.maxRecallMSA(this.segmenter.sequences(), desegmented);
            }
            if (guessType == IncrementalAlignExperiment.GuessType.HALF_MBR) {
                Counter<GreedyDecoder.Edge> desegmented = this.segmenter.desegment(this.msaProcessor.getEdgePosteriors());
                Counter<GreedyDecoder.Edge> filtered = new Counter<GreedyDecoder.Edge>();
                for (GreedyDecoder.Edge e : desegmented.keySet()) {
                    if (!(desegmented.getCount(e) >= 0.5)) continue;
                    filtered.setCount(e, desegmented.getCount(e));
                }
                return MSAPoset.maxRecallMSA(this.segmenter.sequences(), filtered);
            }
            throw new RuntimeException();
        }

        @Override
        public void iterate(int iter, int total) {
        }
    }

    public class PoissonAlignerFactory
    implements IncrementalAlignExperiment.IncrementalAlignerFactory {
        @Override
        public IncrementalAlignExperiment.IncrementalAligner createNewAligner(Map<Taxon, String> sequences, RootedTree heuristicTree, File outputFolder) {
            MSAPoset msa = PoissonAlignerMain.this.getInit(sequences);
            msa.toMultiAlignmentObject().saveToMSF(new File(outputFolder, "init.msf"));
            MSAMarginalLikelihoodCalculator calc = new MSAMarginalLikelihoodCalculator(new PoissonParameters(PoissonAlignerMain.this.indexer, PoissonAlignerMain.this.subRates, PoissonAlignerMain.this.insertRate * MSAMarginalLikelihoodCalculator.meanSequenceLength(msa.sequences().values()), PoissonAlignerMain.this.delRate), heuristicTree);
            double meanBL = PoissonAlignerMain.meanBL(calc.rootedTree);
            List<Segmenter.SegmentBoundary> bounds = Segmenter.chunk(msa, PoissonAlignerMain.this.segmentSize);
            Segmenter segmenter = new Segmenter(new CognateId("singleton"), sequences, bounds);
            return new PoissonAligner(segmenter, new PoissonModel(segmenter.segmentMSA(msa), calc, meanBL, calc.params.insertRate, calc.params.deleteRate), outputFolder);
        }

        @Override
        public Set<IncrementalAlignExperiment.GuessType> supportedGuessTypes() {
            return CollUtils.set(Arrays.asList(IncrementalAlignExperiment.GuessType.MAX_RECALL_MBR, IncrementalAlignExperiment.GuessType.SAMPLE, IncrementalAlignExperiment.GuessType.HALF_MBR));
        }
    }
}

