/*
 * Decompiled with CFR 0.152.
 */
package phyloPMCMC;

import ev.poi.processors.TreeDistancesProcessor;
import ev.poi.processors.TreeTopologyProcessor;
import fig.basic.Pair;
import fig.exec.Execution;
import gep.util.OutputManager;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.io.Dataset;
import pty.io.TreeEvaluator;
import pty.mcmc.UnrootedTreeState;
import pty.smc.LazyParticleFilter;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import pty.smc.PriorPriorKernel;
import pty.smc.models.CTMC;

public class PGS4K2P {
    private final Dataset dataset;
    LazyParticleFilter.ParticleFilterOptions options = null;
    private final TreeDistancesProcessor tdp;
    private boolean useTopologyProcessor = false;
    private final TreeTopologyProcessor trTopo;
    private double previousLogLLEstimate = Double.NEGATIVE_INFINITY;
    private RootedTree currentSample = null;
    public static OutputManager outMan = new OutputManager();
    private int iter = 0;
    private int treeCount = 0;
    File output = new File(Execution.getFile((String)"results"));
    private String nameOfAllTrees = "allTrees-PGS4K2P.trees";
    private boolean saveTreesFromPMCMC = false;
    private int sampleTreeEveryNIter = 100;
    private boolean processTree = false;
    private boolean isGS4Clock = true;
    private double trans2tranv = 2.0;
    public double a = 1.25;
    private boolean sampleTrans2tranv = true;

    public PGS4K2P(Dataset dataset0, LazyParticleFilter.ParticleFilterOptions options, TreeDistancesProcessor tdp, boolean useTopologyProcessor, TreeTopologyProcessor trTopo, RootedTree initrt, boolean processTree, boolean isGS4Clock, int sampleTreeEveryNIter) {
        this.dataset = dataset0;
        this.options = options;
        this.tdp = tdp;
        this.useTopologyProcessor = useTopologyProcessor;
        this.trTopo = useTopologyProcessor ? trTopo : null;
        this.currentSample = initrt;
        this.processTree = processTree;
        this.isGS4Clock = isGS4Clock;
        this.sampleTreeEveryNIter = sampleTreeEveryNIter;
    }

    public PGS4K2P(Dataset dataset0, LazyParticleFilter.ParticleFilterOptions options, TreeDistancesProcessor tdp, RootedTree initrt, TreeTopologyProcessor trTopo) {
        this.dataset = dataset0;
        this.options = options;
        this.tdp = tdp;
        this.currentSample = initrt;
        this.trTopo = trTopo;
    }

    public void setSaveTreesFromPMCMC(boolean saveTreesFromPMCMC) {
        this.saveTreesFromPMCMC = saveTreesFromPMCMC;
    }

    public void setNameOfAllTrees(String nameOfAllTrees) {
        this.nameOfAllTrees = nameOfAllTrees;
    }

    public String getNameOfAllTrees() {
        return this.nameOfAllTrees;
    }

    public boolean getSaveTreesFromPMCMC() {
        return this.saveTreesFromPMCMC;
    }

    public void setProcessTree(boolean processTree) {
        this.processTree = processTree;
    }

    public RootedTree getRootedTree() {
        return this.currentSample;
    }

    private double MHTrans2tranv(double currentTrans2tranv, Random rand) {
        Trans2tranvProposal kappaProposal = new Trans2tranvProposal(this.a, rand);
        Pair<Double, Double> proposed = kappaProposal.propose(currentTrans2tranv);
        double proposedTrans2tranv = (Double)proposed.getFirst();
        CTMC.SimpleCTMC ctmc = CTMC.SimpleCTMC.dnaCTMC((int)this.dataset.nSites(), (double)proposedTrans2tranv);
        UnrootedTreeState ncs = UnrootedTreeState.initFastState((UnrootedTree)this.currentSample.getUnrooted(), (Dataset)this.dataset, (CTMC)ctmc);
        double logratio = ncs.logLikelihood() - this.previousLogLLEstimate + (Double)proposed.getSecond();
        double acceptPr = Math.min(1.0, Math.exp(logratio));
        boolean accept = Sampling.sampleBern((double)acceptPr, (Random)rand);
        if (accept) {
            this.trans2tranv = proposedTrans2tranv;
        }
        return acceptPr;
    }

    public void next(Random rand) {
        ++this.iter;
        RootedTree previousSample = this.currentSample;
        if (this.sampleTrans2tranv) {
            this.MHTrans2tranv(this.trans2tranv, rand);
        }
        ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
        if (this.iter % this.sampleTreeEveryNIter == 0) {
            CTMC.SimpleCTMC ctmc = CTMC.SimpleCTMC.dnaCTMC((int)this.dataset.nSites(), (double)this.trans2tranv);
            PartialCoalescentState init = PartialCoalescentState.initFastState((Dataset)this.dataset, (CTMC)ctmc, (boolean)true);
            PriorPriorKernel kernel = new PriorPriorKernel(init);
            ParticleFilter pf = new ParticleFilter();
            pf.N = this.options.nParticles;
            pf.rand = rand;
            pf.nThreads = this.options.nThreads;
            List<Pair<PartialCoalescentState, Double>> restorePCS = PGS4K2P.restoreSequence((ParticleKernel<PartialCoalescentState>)kernel, this.currentSample, this.isGS4Clock);
            ArrayList path = CollUtils.list();
            double[] weights = new double[restorePCS.size()];
            for (int i = 0; i < restorePCS.size(); ++i) {
                path.add(restorePCS.get(i).getFirst());
                weights[i] = (Double)restorePCS.get(i).getSecond();
            }
            pf.setConditional((List)path, weights);
            pf.sample((ParticleKernel)kernel, (ParticleFilter.ParticleProcessor)pro);
            PartialCoalescentState sampled = (PartialCoalescentState)pro.sample(rand);
            this.currentSample = sampled.getFullCoalescentState();
            this.previousLogLLEstimate = sampled.logLikelihood();
            if (this.processTree) {
                this.tdp.process(this.currentSample);
            }
            if (this.useTopologyProcessor) {
                this.trTopo.process(this.currentSample);
            }
            ++this.treeCount;
            if (this.saveTreesFromPMCMC) {
                String stringOfTree = RootedTree.Util.toNewick((RootedTree)this.currentSample);
                String cmdStr = "echo -n 'TREE gsTree_" + this.treeCount + "=' >>" + this.nameOfAllTrees;
                IO.call((String)"bash -s", (String)cmdStr, (File)this.output);
                cmdStr = "echo '" + stringOfTree + "' | sed 's/internal_[0-9]*_[0-9]*//g' | sed 's/leaf_//g' >> " + this.nameOfAllTrees;
                IO.call((String)"bash -s", (String)cmdStr, (File)this.output);
            }
        }
        int tSize = this.currentSample.topology().nLeaves();
        outMan.write("PGS4K2P", new Object[]{"Iter", this.iter, "treeSize", tSize, "trans2tranv", this.trans2tranv, "rfDist", previousSample == null ? 0.0 : new TreeEvaluator.RobinsonFouldsMetric().score(this.currentSample, previousSample), "LogLikelihood", this.previousLogLLEstimate});
    }

    public static double height(Map<Taxon, Double> branchLengths, Arbre<Taxon> arbre) {
        double sum = 0.0;
        do {
            arbre = (Arbre)arbre.getChildren().get(0);
            sum += branchLengths.get(arbre.getContents()).doubleValue();
        } while (!arbre.isLeaf());
        return sum;
    }

    public static Map sortByValue(Map map) {
        LinkedList list = new LinkedList(map.entrySet());
        Collections.sort(list, new Comparator(){

            public int compare(Object o1, Object o2) {
                return ((Comparable)((Map.Entry)o1).getValue()).compareTo(((Map.Entry)o2).getValue());
            }
        });
        LinkedHashMap result = new LinkedHashMap();
        for (Map.Entry entry : list) {
            result.put(entry.getKey(), entry.getValue());
        }
        return result;
    }

    public static List<Pair<PartialCoalescentState, Double>> restoreSequence4NonClockTree(PartialCoalescentState current, RootedTree rt) {
        ArrayList result = CollUtils.list();
        List childrenList = rt.topology().nodes();
        class ArbreComparator
        implements Comparator<Arbre<Taxon>> {
            ArbreComparator() {
            }

            @Override
            public int compare(Arbre<Taxon> arbre1, Arbre<Taxon> arbre2) {
                return ((Taxon)arbre1.getContents()).toString().compareTo(((Taxon)arbre2.getContents()).toString());
            }
        }
        Collections.sort(childrenList, new ArbreComparator());
        Map branchLengths = rt.branchLengths();
        for (int i = 0; i < childrenList.size(); ++i) {
            if (((Arbre)childrenList.get(i)).isLeaf()) continue;
            Arbre currentArbre = (Arbre)childrenList.get(i);
            Taxon first = (Taxon)((Arbre)currentArbre.getChildren().get(0)).getContents();
            Taxon second = (Taxon)((Arbre)currentArbre.getChildren().get(1)).getContents();
            PartialCoalescentState coalesceResult = current.coalesce(current.indexOf(first), current.indexOf(second), 0.0, ((Double)branchLengths.get(first)).doubleValue(), ((Double)branchLengths.get(second)).doubleValue(), (Taxon)currentArbre.getContents());
            double logWeight = coalesceResult.logLikelihoodRatio() - Math.log(coalesceResult.nNonTrivialRoots());
            current = coalesceResult;
            result.add(Pair.makePair((Object)coalesceResult, (Object)logWeight));
        }
        return result;
    }

    public static List<Pair<PartialCoalescentState, Double>> restoreSequence(ParticleKernel<PartialCoalescentState> kernel, RootedTree rt, boolean isClock) {
        if (!isClock) {
            return PGS4K2P.restoreSequence4NonClockTree((PartialCoalescentState)kernel.getInitial(), rt);
        }
        ArrayList newNodeNames = CollUtils.list();
        ArrayList result = CollUtils.list();
        PartialCoalescentState current = (PartialCoalescentState)kernel.getInitial();
        List childrenList = rt.topology().nodes();
        Map branchLengths = rt.branchLengths();
        HashMap heightMap = CollUtils.map();
        for (int i = 0; i < childrenList.size(); ++i) {
            if (((Arbre)childrenList.get(i)).isLeaf()) continue;
            heightMap.put(childrenList.get(i), PGS4K2P.height(branchLengths, (Arbre<Taxon>)((Arbre)childrenList.get(i))));
        }
        Map sortHeightMap = PGS4K2P.sortByValue(heightMap);
        Set arbreSet = sortHeightMap.keySet();
        for (Arbre currentArbre : arbreSet) {
            newNodeNames.add(currentArbre.toString());
        }
        Iterator arbreIterator = arbreSet.iterator();
        double previousHeight = 0.0;
        while (arbreIterator.hasNext()) {
            Arbre currentArbre = (Arbre)arbreIterator.next();
            Taxon first = (Taxon)((Arbre)currentArbre.getChildren().get(0)).getContents();
            Taxon second = (Taxon)((Arbre)currentArbre.getChildren().get(1)).getContents();
            double currentHeight = (Double)sortHeightMap.get(currentArbre);
            double currentDelta = currentHeight - previousHeight;
            previousHeight = currentHeight;
            PartialCoalescentState coalesceResult = current.coalesce(current.indexOf(first), current.indexOf(second), currentDelta, 0.0, 0.0, (Taxon)currentArbre.getContents());
            double logWeight = coalesceResult.logLikelihoodRatio();
            current = coalesceResult;
            result.add(Pair.makePair((Object)coalesceResult, (Object)logWeight));
        }
        return result;
    }

    public boolean isSampleTrans2tranv() {
        return this.sampleTrans2tranv;
    }

    public void setSampleTrans2tranv(boolean sampleTrans2tranv) {
        this.sampleTrans2tranv = sampleTrans2tranv;
    }

    public static class Trans2tranvProposal {
        private final double a;
        private Random rand;

        public Trans2tranvProposal(double a, Random rand) {
            if (a <= 1.0) {
                throw new RuntimeException();
            }
            this.a = a;
            this.rand = rand;
        }

        public Pair<Double, Double> propose(double currentTrans2tranv) {
            double lambda = 2.0 * Math.log(this.a);
            double rvUnif = Sampling.nextDouble((Random)this.rand, (double)0.0, (double)1.0);
            double m = Math.exp(lambda * (rvUnif - 0.5));
            double newTrans2tranv = m * currentTrans2tranv;
            return Pair.makePair((Object)newTrans2tranv, (Object)Math.log(m));
        }
    }
}

