/*
 * Decompiled with CFR 0.152.
 */
package phyloPMCMC;

import ev.poi.processors.TreeDistancesProcessor;
import ev.poi.processors.TreeTopologyProcessor;
import fig.basic.ListUtils;
import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.exec.Execution;
import fig.prob.Dirichlet;
import gep.util.OutputManager;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.io.IO;
import nuts.math.Sampling;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.io.Dataset;
import pty.mcmc.UnrootedTreeState;
import pty.smc.LazyParticleFilter;
import pty.smc.NCPriorPriorKernel;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.ParticleKernel;
import pty.smc.PriorPriorKernel;
import pty.smc.models.CTMC;

public class InteractingParticleGibbs4GTRIGamma {
    private final Dataset dataset;
    LazyParticleFilter.ParticleFilterOptions options = null;
    private final TreeDistancesProcessor tdp;
    private boolean useTopologyProcessor = false;
    private final TreeTopologyProcessor trTopo;
    private double[] previousLogLLEstimate;
    private List<RootedTree> currentSample = null;
    private double[] subsRates;
    private double[] statFreqs;
    private double alpha = 0.0;
    private double pInv = 0.0;
    private double a_alpha = 1.3;
    private double a_pInv = 0.4;
    private double a_statFreqs = 300.0;
    private double a_subsRates = 200.0;
    public static OutputManager outMan = new OutputManager();
    private int iter = 0;
    private int nCategories = 4;
    private int treeCount = 0;
    File output = new File(Execution.getFile((String)"results"));
    private String nameOfAllTrees = "allTrees.trees";
    private boolean saveTreesFromPMCMC = false;
    private int sampleTreeEveryNIter = 100;
    private boolean processTree = false;
    private boolean isGS4Clock = true;
    private int nCSMC = 4;
    private int nUCSMC = 4;

    public InteractingParticleGibbs4GTRIGamma(Dataset dataset0, LazyParticleFilter.ParticleFilterOptions options, TreeDistancesProcessor tdp, boolean useTopologyProcessor, TreeTopologyProcessor trTopo, RootedTree initrt, double[] subsRates, double[] statFreqs, double alpha, double pInv, double a_alpha, double a_pInv, double a_statFreqs, double a_subsRates, int nCategories, boolean processTree, boolean isGS4Clock, int sampleTreeEveryNIter, int nCSMC, int nUCSMC) {
        int i;
        this.nCSMC = nCSMC;
        this.nUCSMC = nUCSMC;
        this.dataset = dataset0;
        this.options = options;
        this.tdp = tdp;
        this.useTopologyProcessor = useTopologyProcessor;
        this.trTopo = useTopologyProcessor ? trTopo : null;
        this.subsRates = subsRates;
        this.statFreqs = statFreqs;
        this.alpha = alpha;
        this.pInv = pInv;
        this.currentSample = new ArrayList<RootedTree>(this.nCSMC);
        for (i = 0; i < this.nCSMC; ++i) {
            this.currentSample.add(initrt);
        }
        this.a_alpha = a_alpha;
        this.a_pInv = a_pInv;
        this.a_statFreqs = a_statFreqs;
        this.a_subsRates = a_subsRates;
        this.nCategories = nCategories;
        this.processTree = processTree;
        this.isGS4Clock = isGS4Clock;
        this.sampleTreeEveryNIter = sampleTreeEveryNIter;
        this.previousLogLLEstimate = new double[this.nCSMC];
        for (i = 0; i < this.nCSMC; ++i) {
            this.previousLogLLEstimate[i] = Double.NEGATIVE_INFINITY;
        }
    }

    public InteractingParticleGibbs4GTRIGamma(Dataset dataset0, LazyParticleFilter.ParticleFilterOptions options, TreeDistancesProcessor tdp, RootedTree initrt, double[] subsRates, double[] statFreqs, double alpha, double pInv, TreeTopologyProcessor trTopo) {
        this.dataset = dataset0;
        this.options = options;
        this.tdp = tdp;
        this.subsRates = subsRates;
        this.statFreqs = statFreqs;
        this.alpha = alpha;
        this.pInv = pInv;
        this.currentSample = new ArrayList<RootedTree>(this.nCSMC);
        for (int i = 0; i < this.nCSMC; ++i) {
            this.currentSample.add(initrt);
        }
        this.trTopo = trTopo;
    }

    public void setSaveTreesFromPMCMC(boolean saveTreesFromPMCMC) {
        this.saveTreesFromPMCMC = saveTreesFromPMCMC;
    }

    public void setNameOfAllTrees(String nameOfAllTrees) {
        this.nameOfAllTrees = nameOfAllTrees;
    }

    public String getNameOfAllTrees() {
        return this.nameOfAllTrees;
    }

    public boolean getSaveTreesFromPMCMC() {
        return this.saveTreesFromPMCMC;
    }

    public void setProcessTree(boolean processTree) {
        this.processTree = processTree;
    }

    public double[] getSubsRates() {
        return this.subsRates;
    }

    public double[] getStateFreqs() {
        return this.statFreqs;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public double getpInv() {
        return this.pInv;
    }

    public List<RootedTree> getRootedTree() {
        return this.currentSample;
    }

    public double[] proposeAlpha(Random rand, double low, double high) {
        double[] result = new double[2];
        double scale = 0.0;
        double proposedAlpha = Double.MAX_VALUE;
        while (proposedAlpha < low || proposedAlpha > high) {
            scale = Sampling.nextDouble((Random)rand, (double)(1.0 / this.a_alpha), (double)this.a_alpha);
            proposedAlpha = scale * this.alpha;
        }
        result[0] = scale;
        result[1] = proposedAlpha;
        return result;
    }

    public double proposePInv(Random rand, double low, double high) {
        double proposedPInv = Double.MAX_VALUE;
        while (proposedPInv < low || proposedPInv > high) {
            proposedPInv = Sampling.nextDouble((Random)rand, (double)Math.max(0.0, this.pInv - this.a_pInv), (double)Math.min(1.0, this.pInv + this.a_pInv));
        }
        return proposedPInv;
    }

    public double[] proposeFromDirichlet(Random rand, double a, double[] rates) {
        double[] alphas = new double[rates.length];
        for (int i = 0; i < rates.length; ++i) {
            alphas[i] = a * rates[i];
        }
        double[] result = Dirichlet.sample((Random)rand, (double[])alphas);
        return result;
    }

    public double logProposal(double scale, double[] rates) {
        double[] alphas = new double[rates.length];
        for (int i = 0; i < rates.length; ++i) {
            alphas[i] = scale * rates[i];
        }
        return Dirichlet.logProb((double[])alphas, (double)ListUtils.sum((double[])alphas), (double[])rates);
    }

    public void next(Random rand) {
        ++this.iter;
        List<RootedTree> previousSample = this.currentSample;
        double proposedPInv = 0.0;
        if (this.pInv > 0.0) {
            proposedPInv = this.proposePInv(rand, 0.0, 1.0);
        }
        double acceptpInv = this.MHpInv(proposedPInv, rand);
        double[] propAlpha = this.proposeAlpha(rand, 0.05, 50.0);
        double scale = propAlpha[0];
        double proposedAlpha = this.alpha;
        double acceptPralpha = this.MHalpha(proposedAlpha, scale, rand);
        double[] proposedsubsRates = this.subsRates;
        double acceptPrsubsRates = this.MHsubsRates(proposedsubsRates, this.a_subsRates, rand);
        double[] proposedstatFreqs = this.statFreqs;
        double acceptPrstatFreqs = this.MHstatFreqs(proposedstatFreqs, this.a_statFreqs, rand);
        ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
        if (this.iter % this.sampleTreeEveryNIter == 0) {
            int k;
            this.currentSample = this.iPmcmcStep(rand, previousSample);
            if (this.processTree) {
                for (k = 0; k < this.nCSMC; ++k) {
                    this.tdp.process(this.currentSample.get(k));
                }
            }
            if (this.useTopologyProcessor) {
                for (k = 0; k < this.nCSMC; ++k) {
                    this.trTopo.process(this.currentSample.get(k));
                }
            }
            ++this.treeCount;
            if (this.saveTreesFromPMCMC) {
                String stringOfTree = RootedTree.Util.toNewick((RootedTree)this.currentSample.get(0));
                String cmdStr = "echo -n 'TREE gsTree_" + this.treeCount + "=' >>" + this.nameOfAllTrees;
                IO.call((String)"bash -s", (String)cmdStr, (File)this.output);
                cmdStr = "echo '" + stringOfTree + "' | sed 's/internal_[0-9]*_[0-9]*//g' | sed 's/leaf_//g' >> " + this.nameOfAllTrees;
                IO.call((String)"bash -s", (String)cmdStr, (File)this.output);
            }
        }
        int tSize = this.currentSample.get(0).topology().nLeaves();
        outMan.write("PGS", new Object[]{"Iter", this.iter, "treeSize", tSize, "acceptPralpha", acceptPralpha, "acceptpInv", acceptpInv, "acceptPrstatFreqs", acceptPrstatFreqs, "acceptPrsubsRates", acceptPrsubsRates, "statFreqs1", this.statFreqs[0], "statFreqs2", this.statFreqs[1], "statFreqs3", this.statFreqs[2], "statFreqs4", this.statFreqs[3], "subsRates1", this.subsRates[0], "subsRates2", this.subsRates[1], "subsRates3", this.subsRates[2], "subsRates4", this.subsRates[3], "subsRates5", this.subsRates[4], "subsRates6", this.subsRates[5], "alpha", this.alpha, "pInv", this.pInv, "LogLikelihood", this.previousLogLLEstimate});
    }

    private Pair<PartialCoalescentState, Double> oneiPmcmcStep(Random rand, boolean isconditional, RootedTree onesample) {
        ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
        CTMC.GTRIGammaCTMC currentctmc = new CTMC.GTRIGammaCTMC(this.statFreqs, this.subsRates, 4, this.dataset.nSites(), this.alpha, this.nCategories, this.pInv);
        PartialCoalescentState init = PartialCoalescentState.initFastState((Dataset)this.dataset, (CTMC)currentctmc, (boolean)this.isGS4Clock);
        PriorPriorKernel kernel = this.isGS4Clock ? new PriorPriorKernel(init) : new NCPriorPriorKernel(init);
        ParticleFilter pf = new ParticleFilter();
        pf.N = this.options.nParticles;
        pf.rand = rand;
        pf.resamplingStrategy = ParticleFilter.ResamplingStrategy.ALWAYS;
        System.out.print(onesample);
        if (isconditional) {
            List<Pair<PartialCoalescentState, Double>> restorePCS = InteractingParticleGibbs4GTRIGamma.restoreSequence((ParticleKernel<PartialCoalescentState>)kernel, onesample, this.isGS4Clock);
            ArrayList path = CollUtils.list();
            double[] weights = new double[restorePCS.size()];
            for (int i = 0; i < restorePCS.size(); ++i) {
                path.add(restorePCS.get(i).getFirst());
                weights[i] = (Double)restorePCS.get(i).getSecond();
                System.out.print(weights[i] + "\t");
            }
            System.out.println();
            pf.setConditional((List)path, weights);
        }
        pf.sample((ParticleKernel)kernel, (ParticleFilter.ParticleProcessor)pro);
        double logMarginalLike = pf.estimateNormalizer();
        PartialCoalescentState sampled = (PartialCoalescentState)pro.sample(rand);
        return Pair.makePair((Object)sampled, (Object)logMarginalLike);
    }

    private List<RootedTree> iPmcmcStep(Random rand, List<RootedTree> conditionedSamples) {
        Pair<PartialCoalescentState, Double> re;
        int i;
        ArrayList<RootedTree> result = new ArrayList<RootedTree>();
        ArrayList<Object> resultUnCondiPlusOne = new ArrayList<Object>(this.nUCSMC + 1);
        double[] logMargLikeVec = new double[this.nUCSMC + 1];
        for (i = 0; i < this.nUCSMC; ++i) {
            re = this.oneiPmcmcStep(rand, false, null);
            resultUnCondiPlusOne.add(re.getFirst());
            logMargLikeVec[i] = (Double)re.getSecond();
        }
        for (i = 0; i < this.nCSMC; ++i) {
            System.out.println("Number " + i + "CSMC");
            re = this.oneiPmcmcStep(rand, true, conditionedSamples.get(i));
            logMargLikeVec[this.nUCSMC] = (Double)re.getSecond();
            resultUnCondiPlusOne.add(this.nUCSMC, re.getFirst());
            double[] normalizedWeights0 = (double[])logMargLikeVec.clone();
            for (int k = 0; k < this.nUCSMC + 1; ++k) {
                System.out.print(normalizedWeights0[k] + "\t\t");
            }
            System.out.println();
            NumUtils.expNormalize((double[])normalizedWeights0);
            ArrayList<Double> w = new ArrayList<Double>(this.nUCSMC + 1);
            for (int k = 0; k < this.nUCSMC + 1; ++k) {
                w.add(k, normalizedWeights0[k]);
                System.out.print(w.get(k) + "\t\t");
            }
            System.out.println();
            int idx = Sampling.sample((Random)rand, w);
            System.out.println("idx is " + idx);
            result.add(((PartialCoalescentState)resultUnCondiPlusOne.get(idx)).getFullCoalescentState());
        }
        return result;
    }

    private void updateLogLikelihood() {
        CTMC.GTRIGammaCTMC ctmc = new CTMC.GTRIGammaCTMC(this.statFreqs, this.subsRates, 4, this.dataset.nSites(), this.alpha, this.nCategories, this.pInv);
        for (int i = 0; i < this.nCSMC; ++i) {
            UnrootedTreeState previousncs = UnrootedTreeState.initFastState((UnrootedTree)this.currentSample.get(i).getUnrooted(), (Dataset)this.dataset, (CTMC)ctmc);
            this.previousLogLLEstimate[i] = previousncs.logLikelihood();
        }
    }

    private double sumLogLikelihood(CTMC ctmc) {
        double sumLogLikelihood = 0.0;
        for (int i = 0; i < this.nCSMC; ++i) {
            sumLogLikelihood += UnrootedTreeState.initFastState((UnrootedTree)this.currentSample.get(i).getUnrooted(), (Dataset)this.dataset, (CTMC)ctmc).logLikelihood();
        }
        return sumLogLikelihood;
    }

    private double sumPreviousLogLLEstimate() {
        double sum = 0.0;
        for (int i = 0; i < this.previousLogLLEstimate.length; ++i) {
            sum += this.previousLogLLEstimate[i];
        }
        return sum;
    }

    private double MHalpha(double proposedAlpha, double scale, Random rand) {
        CTMC.GTRIGammaCTMC ctmc = new CTMC.GTRIGammaCTMC(this.statFreqs, this.subsRates, 4, this.dataset.nSites(), proposedAlpha, this.nCategories, this.pInv);
        double logratio = this.sumLogLikelihood((CTMC)ctmc) - this.sumPreviousLogLLEstimate() + Math.log(scale);
        double acceptPr = Math.min(1.0, Math.exp(logratio));
        boolean accept = Sampling.sampleBern((double)acceptPr, (Random)rand);
        if (accept) {
            this.alpha = proposedAlpha;
            this.updateLogLikelihood();
        }
        return acceptPr;
    }

    private double MHpInv(double proposedpInv, Random rand) {
        CTMC.GTRIGammaCTMC ctmc = new CTMC.GTRIGammaCTMC(this.statFreqs, this.subsRates, 4, this.dataset.nSites(), this.alpha, this.nCategories, proposedpInv);
        double logratio = this.sumLogLikelihood((CTMC)ctmc) - this.sumPreviousLogLLEstimate();
        double acceptPr = Math.min(1.0, Math.exp(logratio));
        boolean accept = Sampling.sampleBern((double)acceptPr, (Random)rand);
        if (accept) {
            this.pInv = proposedpInv;
            this.updateLogLikelihood();
        }
        return acceptPr;
    }

    private double MHstatFreqs(double[] proposedstatFreqs, double a_statFreqs, Random rand) {
        CTMC.GTRIGammaCTMC ctmc = new CTMC.GTRIGammaCTMC(proposedstatFreqs, this.subsRates, 4, this.dataset.nSites(), this.alpha, this.nCategories, this.pInv);
        double logratio = this.sumLogLikelihood((CTMC)ctmc) - this.sumPreviousLogLLEstimate() + this.logProposal(a_statFreqs, proposedstatFreqs) - this.logProposal(a_statFreqs, this.statFreqs);
        double acceptPr = Math.min(1.0, Math.exp(logratio));
        boolean accept = Sampling.sampleBern((double)acceptPr, (Random)rand);
        if (accept) {
            this.statFreqs = proposedstatFreqs;
            this.updateLogLikelihood();
        }
        return acceptPr;
    }

    private double MHsubsRates(double[] proposedsubsRates, double a_subsRates, Random rand) {
        CTMC.GTRIGammaCTMC ctmc = new CTMC.GTRIGammaCTMC(this.statFreqs, proposedsubsRates, 4, this.dataset.nSites(), this.alpha, this.nCategories, this.pInv);
        double logratio = this.sumLogLikelihood((CTMC)ctmc) - this.sumPreviousLogLLEstimate() + this.logProposal(a_subsRates, proposedsubsRates) - this.logProposal(a_subsRates, this.subsRates);
        double acceptPr = Math.min(1.0, Math.exp(logratio));
        boolean accept = Sampling.sampleBern((double)acceptPr, (Random)rand);
        if (accept) {
            this.subsRates = proposedsubsRates;
            this.updateLogLikelihood();
        }
        return acceptPr;
    }

    public static double height(Map<Taxon, Double> branchLengths, Arbre<Taxon> arbre) {
        double sum = 0.0;
        do {
            arbre = (Arbre)arbre.getChildren().get(0);
            sum += branchLengths.get(arbre.getContents()).doubleValue();
        } while (!arbre.isLeaf());
        return sum;
    }

    public static Map sortByValue(Map map) {
        LinkedList list = new LinkedList(map.entrySet());
        Collections.sort(list, new Comparator(){

            public int compare(Object o1, Object o2) {
                return ((Comparable)((Map.Entry)o1).getValue()).compareTo(((Map.Entry)o2).getValue());
            }
        });
        LinkedHashMap result = new LinkedHashMap();
        for (Map.Entry entry : list) {
            result.put(entry.getKey(), entry.getValue());
        }
        return result;
    }

    public static List<Pair<PartialCoalescentState, Double>> restoreSequence4NonClockTree(PartialCoalescentState current, RootedTree rt) {
        ArrayList result = CollUtils.list();
        List childrenList = rt.topology().nodes();
        class ArbreComparator
        implements Comparator<Arbre<Taxon>> {
            ArbreComparator() {
            }

            @Override
            public int compare(Arbre<Taxon> arbre1, Arbre<Taxon> arbre2) {
                return ((Taxon)arbre1.getContents()).toString().compareTo(((Taxon)arbre2.getContents()).toString());
            }
        }
        Collections.sort(childrenList, new ArbreComparator());
        Map branchLengths = rt.branchLengths();
        for (int i = 0; i < childrenList.size(); ++i) {
            if (((Arbre)childrenList.get(i)).isLeaf()) continue;
            Arbre currentArbre = (Arbre)childrenList.get(i);
            Taxon first = (Taxon)((Arbre)currentArbre.getChildren().get(0)).getContents();
            Taxon second = (Taxon)((Arbre)currentArbre.getChildren().get(1)).getContents();
            PartialCoalescentState coalesceResult = current.coalesce(current.indexOf(first), current.indexOf(second), 0.0, ((Double)branchLengths.get(first)).doubleValue(), ((Double)branchLengths.get(second)).doubleValue(), (Taxon)currentArbre.getContents());
            double logWeight = coalesceResult.logLikelihoodRatio();
            current = coalesceResult;
            result.add(Pair.makePair((Object)coalesceResult, (Object)logWeight));
        }
        return result;
    }

    public static List<Pair<PartialCoalescentState, Double>> restoreSequence(ParticleKernel<PartialCoalescentState> kernel, RootedTree rt, boolean isClock) {
        if (!isClock) {
            return InteractingParticleGibbs4GTRIGamma.restoreSequence4NonClockTree((PartialCoalescentState)kernel.getInitial(), rt);
        }
        ArrayList newNodeNames = CollUtils.list();
        ArrayList result = CollUtils.list();
        PartialCoalescentState current = (PartialCoalescentState)kernel.getInitial();
        List childrenList = rt.topology().nodes();
        Map branchLengths = rt.branchLengths();
        HashMap heightMap = CollUtils.map();
        for (int i = 0; i < childrenList.size(); ++i) {
            if (((Arbre)childrenList.get(i)).isLeaf()) continue;
            heightMap.put(childrenList.get(i), InteractingParticleGibbs4GTRIGamma.height(branchLengths, (Arbre<Taxon>)((Arbre)childrenList.get(i))));
        }
        Map sortHeightMap = InteractingParticleGibbs4GTRIGamma.sortByValue(heightMap);
        Set arbreSet = sortHeightMap.keySet();
        for (Arbre currentArbre : arbreSet) {
            newNodeNames.add(currentArbre.toString());
        }
        Iterator arbreIterator = arbreSet.iterator();
        double previousHeight = 0.0;
        while (arbreIterator.hasNext()) {
            Arbre currentArbre = (Arbre)arbreIterator.next();
            Taxon first = (Taxon)((Arbre)currentArbre.getChildren().get(0)).getContents();
            Taxon second = (Taxon)((Arbre)currentArbre.getChildren().get(1)).getContents();
            double currentHeight = (Double)sortHeightMap.get(currentArbre);
            double currentDelta = currentHeight - previousHeight;
            previousHeight = currentHeight;
            PartialCoalescentState coalesceResult = current.coalesce(current.indexOf(first), current.indexOf(second), currentDelta, 0.0, 0.0, (Taxon)currentArbre.getContents());
            double logWeight = coalesceResult.logLikelihood() - current.logLikelihood();
            current = coalesceResult;
            result.add(Pair.makePair((Object)coalesceResult, (Object)logWeight));
        }
        return result;
    }
}

