/*
 * Decompiled with CFR 0.152.
 */
package phyloPMCMC;

import ev.poi.processors.TreeDistancesProcessor;
import ev.poi.processors.TreeTopologyProcessor;
import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.exec.Execution;
import gep.util.OutputManager;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.io.Dataset;
import pty.mcmc.UnrootedTreeState;
import pty.smc.LazyParticleFilter;
import pty.smc.NCPriorPriorKernel;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import pty.smc.PriorPriorKernel;
import pty.smc.models.CTMC;

public class InteractingParticleGibbs4K2P {
    private final Dataset dataset;
    LazyParticleFilter.ParticleFilterOptions options = null;
    private final TreeDistancesProcessor tdp;
    private boolean useTopologyProcessor = false;
    private final TreeTopologyProcessor trTopo;
    private double[] previousLogLLEstimate;
    private double LogLL;
    private List<RootedTree> currentSample = null;
    private double trans2tranv = 2.0;
    public double a = 1.25;
    private boolean sampleTrans2tranv = true;
    public static OutputManager outMan = new OutputManager();
    private int iter = 0;
    private int nCategories = 4;
    private int treeCount = 0;
    File output = new File(Execution.getFile((String)"results"));
    private String nameOfAllTrees = "allTrees.trees";
    private boolean saveTreesFromPMCMC = false;
    private int sampleTreeEveryNIter = 100;
    private boolean processTree = false;
    private boolean isGS4Clock = true;
    private int nCSMC = 4;
    private int nUCSMC = 4;

    public boolean isSampleTrans2tranv() {
        return this.sampleTrans2tranv;
    }

    public void setSampleTrans2tranv(boolean sampleTrans2tranv) {
        this.sampleTrans2tranv = sampleTrans2tranv;
    }

    public InteractingParticleGibbs4K2P(Dataset dataset0, LazyParticleFilter.ParticleFilterOptions options, TreeDistancesProcessor tdp, boolean useTopologyProcessor, TreeTopologyProcessor trTopo, RootedTree initrt, boolean processTree, boolean isGS4Clock, int sampleTreeEveryNIter, int nCSMC, int nUCSMC) {
        int i;
        this.nCSMC = nCSMC;
        this.nUCSMC = nUCSMC;
        this.dataset = dataset0;
        this.options = options;
        this.tdp = tdp;
        this.useTopologyProcessor = useTopologyProcessor;
        this.trTopo = useTopologyProcessor ? trTopo : null;
        this.currentSample = new ArrayList<RootedTree>(this.nCSMC);
        for (i = 0; i < this.nCSMC; ++i) {
            this.currentSample.add(initrt);
        }
        this.nCategories = this.nCategories;
        this.processTree = processTree;
        this.isGS4Clock = isGS4Clock;
        this.sampleTreeEveryNIter = sampleTreeEveryNIter;
        this.previousLogLLEstimate = new double[this.nCSMC];
        for (i = 0; i < this.nCSMC; ++i) {
            this.previousLogLLEstimate[i] = Double.NEGATIVE_INFINITY;
        }
    }

    public InteractingParticleGibbs4K2P(Dataset dataset0, LazyParticleFilter.ParticleFilterOptions options, TreeDistancesProcessor tdp, RootedTree initrt, TreeTopologyProcessor trTopo) {
        this.dataset = dataset0;
        this.options = options;
        this.tdp = tdp;
        this.currentSample = new ArrayList<RootedTree>(this.nCSMC);
        for (int i = 0; i < this.nCSMC; ++i) {
            this.currentSample.add(initrt);
        }
        this.trTopo = trTopo;
    }

    public void setSaveTreesFromPMCMC(boolean saveTreesFromPMCMC) {
        this.saveTreesFromPMCMC = saveTreesFromPMCMC;
    }

    public void setNameOfAllTrees(String nameOfAllTrees) {
        this.nameOfAllTrees = nameOfAllTrees;
    }

    public String getNameOfAllTrees() {
        return this.nameOfAllTrees;
    }

    public boolean getSaveTreesFromPMCMC() {
        return this.saveTreesFromPMCMC;
    }

    public void setProcessTree(boolean processTree) {
        this.processTree = processTree;
    }

    public List<RootedTree> getRootedTree() {
        return this.currentSample;
    }

    public void next(Random rand) {
        ++this.iter;
        List<RootedTree> previousSample = this.currentSample;
        if (this.sampleTrans2tranv) {
            this.MHTrans2tranv(this.trans2tranv, rand);
        }
        ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
        if (this.iter % this.sampleTreeEveryNIter == 0) {
            int k;
            this.currentSample = this.iPmcmcStep(rand, previousSample);
            if (this.processTree) {
                for (k = 0; k < this.nCSMC; ++k) {
                    this.tdp.process(this.currentSample.get(k));
                }
            }
            if (this.useTopologyProcessor) {
                for (k = 0; k < this.nCSMC; ++k) {
                    this.trTopo.process(this.currentSample.get(k));
                }
            }
            ++this.treeCount;
            if (this.saveTreesFromPMCMC) {
                String stringOfTree = RootedTree.Util.toNewick((RootedTree)this.currentSample.get(0));
                String cmdStr = "echo -n 'TREE gsTree_" + this.treeCount + "=' >>" + this.nameOfAllTrees;
                IO.call((String)"bash -s", (String)cmdStr, (File)this.output);
                cmdStr = "echo '" + stringOfTree + "' | sed 's/internal_[0-9]*_[0-9]*//g' | sed 's/leaf_//g' >> " + this.nameOfAllTrees;
                IO.call((String)"bash -s", (String)cmdStr, (File)this.output);
            }
        }
        int tSize = this.currentSample.get(0).topology().nLeaves();
        outMan.write("PGS", new Object[]{"Iter", this.iter, "treeSize", tSize, "trans2tranv", this.trans2tranv, "LogLikelihood", this.LogLL});
    }

    private Pair<PartialCoalescentState, Double> oneiPmcmcStep(Random rand, boolean isconditional, RootedTree onesample) {
        ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
        CTMC.SimpleCTMC currentctmc = CTMC.SimpleCTMC.dnaCTMC((int)this.dataset.nSites(), (double)this.trans2tranv);
        PartialCoalescentState init = PartialCoalescentState.initFastState((Dataset)this.dataset, (CTMC)currentctmc, (boolean)this.isGS4Clock);
        PriorPriorKernel kernel = this.isGS4Clock ? new PriorPriorKernel(init) : new NCPriorPriorKernel(init);
        ParticleFilter pf = new ParticleFilter();
        pf.N = this.options.nParticles;
        pf.rand = rand;
        pf.resamplingStrategy = ParticleFilter.ResamplingStrategy.ALWAYS;
        if (isconditional) {
            List<Pair<PartialCoalescentState, Double>> restorePCS = InteractingParticleGibbs4K2P.restoreSequence((ParticleKernel<PartialCoalescentState>)kernel, onesample, this.isGS4Clock);
            ArrayList path = CollUtils.list();
            double[] weights = new double[restorePCS.size()];
            for (int i = 0; i < restorePCS.size(); ++i) {
                path.add(restorePCS.get(i).getFirst());
                weights[i] = (Double)restorePCS.get(i).getSecond();
            }
            pf.setConditional((List)path, weights);
        }
        pf.sample((ParticleKernel)kernel, (ParticleFilter.ParticleProcessor)pro);
        double logMarginalLike = pf.estimateNormalizer();
        PartialCoalescentState sampled = (PartialCoalescentState)pro.sample(rand);
        return Pair.makePair((Object)sampled, (Object)logMarginalLike);
    }

    private List<RootedTree> iPmcmcStep(Random rand, List<RootedTree> conditionedSamples) {
        Pair<PartialCoalescentState, Double> re;
        int i;
        ArrayList<RootedTree> result = new ArrayList<RootedTree>();
        ArrayList<Object> resultUnCondiPlusOne = new ArrayList<Object>(this.nUCSMC + 1);
        double[] logMargLikeVec = new double[this.nUCSMC + 1];
        for (i = 0; i < this.nUCSMC; ++i) {
            re = this.oneiPmcmcStep(rand, false, null);
            resultUnCondiPlusOne.add(re.getFirst());
            logMargLikeVec[i] = (Double)re.getSecond();
        }
        for (i = 0; i < this.nCSMC; ++i) {
            re = this.oneiPmcmcStep(rand, true, conditionedSamples.get(i));
            logMargLikeVec[this.nUCSMC] = (Double)re.getSecond();
            resultUnCondiPlusOne.add(this.nUCSMC, re.getFirst());
            double[] normalizedWeights0 = (double[])logMargLikeVec.clone();
            NumUtils.expNormalize((double[])normalizedWeights0);
            ArrayList<Double> w = new ArrayList<Double>(this.nUCSMC + 1);
            for (int k = 0; k < this.nUCSMC + 1; ++k) {
                w.add(k, normalizedWeights0[k]);
            }
            int idx = Sampling.sample((Random)rand, w);
            result.add(((PartialCoalescentState)resultUnCondiPlusOne.get(idx)).getFullCoalescentState());
        }
        return result;
    }

    private void updateLogLikelihood() {
        CTMC.SimpleCTMC ctmc = CTMC.SimpleCTMC.dnaCTMC((int)this.dataset.nSites(), (double)this.trans2tranv);
        for (int i = 0; i < this.nCSMC; ++i) {
            UnrootedTreeState previousncs = UnrootedTreeState.initFastState((UnrootedTree)this.currentSample.get(i).getUnrooted(), (Dataset)this.dataset, (CTMC)ctmc);
            this.previousLogLLEstimate[i] = previousncs.logLikelihood();
        }
    }

    private double sumLogLikelihood(CTMC ctmc) {
        double sumLogLikelihood = 0.0;
        for (int i = 0; i < this.nCSMC; ++i) {
            sumLogLikelihood += UnrootedTreeState.initFastState((UnrootedTree)this.currentSample.get(i).getUnrooted(), (Dataset)this.dataset, (CTMC)ctmc).logLikelihood();
        }
        return sumLogLikelihood;
    }

    private double sumPreviousLogLLEstimate() {
        double sum = 0.0;
        for (int i = 0; i < this.previousLogLLEstimate.length; ++i) {
            sum += this.previousLogLLEstimate[i];
        }
        return sum;
    }

    private double MHTrans2tranv(double currentTrans2tranv, Random rand) {
        double previouslogLL;
        Trans2tranvProposal kappaProposal = new Trans2tranvProposal(this.a, rand);
        Pair<Double, Double> proposed = kappaProposal.propose(currentTrans2tranv);
        double proposedTrans2tranv = (Double)proposed.getFirst();
        RootedTree sampledTree = this.currentSample.get(rand.nextInt(this.nCSMC));
        CTMC.SimpleCTMC currentctmc = CTMC.SimpleCTMC.dnaCTMC((int)this.dataset.nSites(), (double)currentTrans2tranv);
        this.LogLL = previouslogLL = UnrootedTreeState.initFastState((UnrootedTree)sampledTree.getUnrooted(), (Dataset)this.dataset, (CTMC)currentctmc).logLikelihood();
        CTMC.SimpleCTMC ctmc = CTMC.SimpleCTMC.dnaCTMC((int)this.dataset.nSites(), (double)proposedTrans2tranv);
        double Loglikelihood = UnrootedTreeState.initFastState((UnrootedTree)sampledTree.getUnrooted(), (Dataset)this.dataset, (CTMC)ctmc).logLikelihood();
        double logratio = Loglikelihood - previouslogLL + (Double)proposed.getSecond();
        double acceptPr = Math.min(1.0, Math.exp(logratio));
        boolean accept = Sampling.sampleBern((double)acceptPr, (Random)rand);
        if (accept) {
            this.trans2tranv = proposedTrans2tranv;
        }
        return acceptPr;
    }

    public static double height(Map<Taxon, Double> branchLengths, Arbre<Taxon> arbre) {
        double sum = 0.0;
        do {
            arbre = (Arbre)arbre.getChildren().get(0);
            sum += branchLengths.get(arbre.getContents()).doubleValue();
        } while (!arbre.isLeaf());
        return sum;
    }

    public static Map sortByValue(Map map) {
        LinkedList list = new LinkedList(map.entrySet());
        Collections.sort(list, new Comparator(){

            public int compare(Object o1, Object o2) {
                return ((Comparable)((Map.Entry)o1).getValue()).compareTo(((Map.Entry)o2).getValue());
            }
        });
        LinkedHashMap result = new LinkedHashMap();
        for (Map.Entry entry : list) {
            result.put(entry.getKey(), entry.getValue());
        }
        return result;
    }

    public static List<Pair<PartialCoalescentState, Double>> restoreSequence4NonClockTree(PartialCoalescentState current, RootedTree rt) {
        ArrayList result = CollUtils.list();
        List childrenList = rt.topology().nodes();
        class ArbreComparator
        implements Comparator<Arbre<Taxon>> {
            ArbreComparator() {
            }

            @Override
            public int compare(Arbre<Taxon> arbre1, Arbre<Taxon> arbre2) {
                return ((Taxon)arbre1.getContents()).toString().compareTo(((Taxon)arbre2.getContents()).toString());
            }
        }
        Collections.sort(childrenList, new ArbreComparator());
        Map branchLengths = rt.branchLengths();
        for (int i = 0; i < childrenList.size(); ++i) {
            if (((Arbre)childrenList.get(i)).isLeaf()) continue;
            Arbre currentArbre = (Arbre)childrenList.get(i);
            Taxon first = (Taxon)((Arbre)currentArbre.getChildren().get(0)).getContents();
            Taxon second = (Taxon)((Arbre)currentArbre.getChildren().get(1)).getContents();
            PartialCoalescentState coalesceResult = current.coalesce(current.indexOf(first), current.indexOf(second), 0.0, ((Double)branchLengths.get(first)).doubleValue(), ((Double)branchLengths.get(second)).doubleValue(), (Taxon)currentArbre.getContents());
            double logWeight = coalesceResult.logLikelihoodRatio();
            current = coalesceResult;
            result.add(Pair.makePair((Object)coalesceResult, (Object)logWeight));
        }
        return result;
    }

    public static List<Pair<PartialCoalescentState, Double>> restoreSequence(ParticleKernel<PartialCoalescentState> kernel, RootedTree rt, boolean isClock) {
        if (!isClock) {
            return InteractingParticleGibbs4K2P.restoreSequence4NonClockTree((PartialCoalescentState)kernel.getInitial(), rt);
        }
        ArrayList newNodeNames = CollUtils.list();
        ArrayList result = CollUtils.list();
        PartialCoalescentState current = (PartialCoalescentState)kernel.getInitial();
        List childrenList = rt.topology().nodes();
        Map branchLengths = rt.branchLengths();
        HashMap heightMap = CollUtils.map();
        for (int i = 0; i < childrenList.size(); ++i) {
            if (((Arbre)childrenList.get(i)).isLeaf()) continue;
            heightMap.put(childrenList.get(i), InteractingParticleGibbs4K2P.height(branchLengths, (Arbre<Taxon>)((Arbre)childrenList.get(i))));
        }
        Map sortHeightMap = InteractingParticleGibbs4K2P.sortByValue(heightMap);
        Set arbreSet = sortHeightMap.keySet();
        for (Arbre currentArbre : arbreSet) {
            newNodeNames.add(currentArbre.toString());
        }
        Iterator arbreIterator = arbreSet.iterator();
        double previousHeight = 0.0;
        while (arbreIterator.hasNext()) {
            Arbre currentArbre = (Arbre)arbreIterator.next();
            Taxon first = (Taxon)((Arbre)currentArbre.getChildren().get(0)).getContents();
            Taxon second = (Taxon)((Arbre)currentArbre.getChildren().get(1)).getContents();
            double currentHeight = (Double)sortHeightMap.get(currentArbre);
            double currentDelta = currentHeight - previousHeight;
            previousHeight = currentHeight;
            PartialCoalescentState coalesceResult = current.coalesce(current.indexOf(first), current.indexOf(second), currentDelta, 0.0, 0.0, (Taxon)currentArbre.getContents());
            double logWeight = coalesceResult.logLikelihood() - current.logLikelihood();
            current = coalesceResult;
            result.add(Pair.makePair((Object)coalesceResult, (Object)logWeight));
        }
        return result;
    }

    public static class Trans2tranvProposal {
        private final double a;
        private Random rand;

        public Trans2tranvProposal(double a, Random rand) {
            if (a <= 1.0) {
                throw new RuntimeException();
            }
            this.a = a;
            this.rand = rand;
        }

        public Pair<Double, Double> propose(double currentTrans2tranv) {
            double lambda = 2.0 * Math.log(this.a);
            double rvUnif = Sampling.nextDouble((Random)this.rand, (double)0.0, (double)1.0);
            double m = Math.exp(lambda * (rvUnif - 0.5));
            double newTrans2tranv = m * currentTrans2tranv;
            return Pair.makePair((Object)newTrans2tranv, (Object)Math.log(m));
        }
    }
}

