/*
 * Decompiled with CFR 0.152.
 */
package pty.smc;

import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Tree;
import pty.RootedTree;
import pty.UnrootedTree;
import pty.smc.NCPriorPriorKernel;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;
import pty.smc.models.LikelihoodModelCalculator;

public class CondSMCDissector {
    private final UnrootedTree originalTree;
    private final Taxon leaf1;
    private final Taxon leaf2;
    private List<Arbre<Taxon>> backbone;
    private int anchorIndex = -1;
    private Taxon anchor;
    private Set<Taxon> sampled = null;
    private Set<Taxon> conditioned = null;
    private Arbre<Taxon> rootedAtAnchor;

    public CondSMCDissector(UnrootedTree originalTree, Taxon leaf1, Taxon leaf2) {
        this.originalTree = originalTree;
        this.leaf1 = leaf1;
        this.leaf2 = leaf2;
        this.findBackbone();
    }

    private void findBackbone() {
        if (this.leaf1 == this.leaf2) {
            throw new RuntimeException();
        }
        Tree<Taxon> t = this.originalTree.toTree(this.leaf1);
        Arbre<Taxon> iter = Arbre.findFirstNodeWithContents(Arbre.tree2Arbre(t), this.leaf2);
        this.backbone = CollUtils.list();
        this.backbone.add(iter);
        while (!iter.isRoot()) {
            iter = iter.getParent();
            this.backbone.add(iter);
        }
    }

    private Arbre<Taxon> anchorNeighbor(int i) {
        if (i == 0 || i == this.backbone.size() - 1) {
            throw new RuntimeException();
        }
        for (Arbre<Taxon> child : this.backbone.get(i).getChildren()) {
            if (child == this.backbone.get(i - 1)) continue;
            return child;
        }
        throw new RuntimeException();
    }

    public boolean sampleAnchor(Random rand) {
        if (this.anchorIndex != -1) {
            throw new RuntimeException();
        }
        int len = this.backbone.size();
        double[] prs = new double[len];
        for (int i = 1; i < len - 1; ++i) {
            prs[i] = this.anchorNeighbor(i).nodes().size() - 1;
        }
        boolean success = NumUtils.normalize(prs);
        if (!success) {
            return false;
        }
        this.anchorIndex = SampleUtils.sampleMultinomial(rand, prs);
        this.anchor = this.backbone.get(this.anchorIndex).getContents();
        this.sampled = CollUtils.set(this.anchorNeighbor(this.anchorIndex).nodeContents());
        this.conditioned = CollUtils.set(this.backbone.get(this.backbone.size() - 1).nodeContents());
        this.conditioned.removeAll(this.sampled);
        this.sampled.add(this.anchor);
        this.sampled.add(Taxon.dummy);
        this.rootedAtAnchor = Arbre.tree2Arbre(this.originalTree.toTree(this.anchor));
        return true;
    }

    private Arbre<Taxon> sampleRooting(Random rand) {
        ArrayList candidates = CollUtils.list();
        for (UnorderedPair<Taxon, Taxon> edge : this.originalTree.edges()) {
            if (!this.sampled.contains(edge.getFirst()) || !this.sampled.contains(edge.getSecond())) continue;
            candidates.add(edge);
        }
        int index = rand.nextInt(candidates.size());
        RootedTree.RootingInfo rooting = new RootedTree.RootingInfo((Taxon)((UnorderedPair)candidates.get(index)).getFirst(), (Taxon)((UnorderedPair)candidates.get(index)).getSecond(), Taxon.dummy, 0.5);
        RootedTree rt = this.originalTree.reRoot(rooting);
        return rt.topology();
    }

    private LikelihoodModelCalculator fixedModelCalculator(PartialCoalescentState model) {
        GuidedCoalescentCreator gcc = new GuidedCoalescentCreator(this.conditioned, this.rootedAtAnchor, model, new Random(1L), this.originalTree, null, null);
        List list = gcc.compute();
        return ((PartialCoalescentState)list.get(list.size() - 1)).getLikelihoodModelCalculator(0);
    }

    public List<PartialCoalescentState> currentSampledPath(PartialCoalescentState model, Random rand) {
        LikelihoodModelCalculator addit = this.fixedModelCalculator(model);
        return new GuidedCoalescentCreator(this.sampled, this.sampleRooting(rand), model, rand, this.originalTree, this.anchor, addit).compute();
    }

    public UnrootedTree reconstitute(PartialCoalescentState _sampled) {
        UnrootedTree sampled = UnrootedTree.fromRooted(_sampled.getFullCoalescentState());
        HashSet inSampled = CollUtils.set();
        for (Taxon t : sampled.getTopology().vertexSet()) {
            if (sampled.getTopology().nbrs(t).size() == 1) continue;
            inSampled.add(t);
        }
        if (CollUtils.intersects(inSampled, this.conditioned)) {
            throw new RuntimeException();
        }
        Arbre<Taxon> reRooted = Arbre.tree2Arbre(sampled.toTree(this.anchor));
        if (reRooted.getChildren().size() != 1) {
            throw new RuntimeException();
        }
        Arbre<Taxon> stemmedRerooted = reRooted.getChildren().get(0);
        ArrayList newChildren = CollUtils.list();
        if (this.rootedAtAnchor.getChildren().size() != 3) {
            throw new RuntimeException();
        }
        for (Arbre<Taxon> oldChildren : this.rootedAtAnchor.getChildren()) {
            if (this.sampled.contains(oldChildren.getContents())) {
                newChildren.add(stemmedRerooted.copy());
                continue;
            }
            if (this.conditioned.contains(oldChildren.getContents())) {
                newChildren.add(oldChildren.copy());
                continue;
            }
            throw new RuntimeException();
        }
        Arbre<Taxon> newTree = Arbre.arbre(this.anchor, newChildren);
        HashMap<Taxon, Double> bls = CollUtils.map();
        for (Arbre<Taxon> subt : this.rootedAtAnchor.nodes()) {
            if (subt.isRoot() || !this.conditioned.contains(subt.getContents())) continue;
            bls.put(subt.getContents(), this.originalTree.branchLength(subt.getContents(), subt.getParent().getContents()));
        }
        for (Arbre<Taxon> subt : reRooted.nodes()) {
            if (subt.isRoot()) continue;
            bls.put(subt.getContents(), sampled.branchLength(subt.getContents(), subt.getParent().getContents()));
        }
        RootedTree resultRT = RootedTree.Util.create(newTree, bls);
        return UnrootedTree.fromRooted(resultRT);
    }

    public static void main(String[] args) {
        NCPriorPriorKernel.deltaProposalRate = 10.0;
        UnrootedTree ut = UnrootedTree.fromNewickRemovingBinaryRoot(new File("/Users/bouchard/w/legacy/state/remote/gp1187.seg1197.time1308586425516.exec/output/sim--1731337436.newick"));
        System.out.println(ut);
        Taxon l1 = new Taxon("leaf_1");
        Taxon l2 = new Taxon("leaf_2");
        CondSMCDissector cd = new CondSMCDissector(ut, l1, l2);
        Random rand = new Random(1L);
        cd.sampleAnchor(rand);
        PartialCoalescentState pcs = PartialCoalescentState.initState(ut.leaves(), false);
        List<PartialCoalescentState> list = cd.currentSampledPath(pcs, rand);
        ParticleFilter.StoreProcessor pro = new ParticleFilter.StoreProcessor();
        PartialCoalescentState init = list.get(0);
        List<PartialCoalescentState> path = list.subList(1, list.size());
        NCPriorPriorKernel kernel = new NCPriorPriorKernel(init);
        double[] weights = new double[path.size()];
        ParticleFilter<PartialCoalescentState> pf = new ParticleFilter<PartialCoalescentState>();
        pf.setConditional(path, weights);
        pf.sample(kernel, pro);
        PartialCoalescentState sampled = (PartialCoalescentState)pro.sample(rand);
        System.out.println(cd.reconstitute(sampled));
    }

    private static class GuidedCoalescentCreator {
        private final List<PartialCoalescentState> list = CollUtils.list();
        private final Set<Taxon> taxaToConsider;
        private final Arbre<Taxon> topology;
        private final PartialCoalescentState model;
        private final Random rand;
        private final UnrootedTree originalTree;
        private final Taxon additionalLeaf;
        private final LikelihoodModelCalculator additionalLLC;
        private PartialCoalescentState current = null;
        private double halfTopBranch = -1.0;

        private GuidedCoalescentCreator(Set<Taxon> taxaToConsider, Arbre<Taxon> topology, PartialCoalescentState model, Random rand, UnrootedTree originalTree, Taxon additionalLeaf, LikelihoodModelCalculator additionalLLC) {
            this.taxaToConsider = taxaToConsider;
            this.topology = topology;
            this.model = model;
            this.rand = rand;
            this.originalTree = originalTree;
            this.additionalLeaf = additionalLeaf;
            this.additionalLLC = additionalLLC;
        }

        private List<PartialCoalescentState> compute() {
            this.init();
            while (!this.current.isFinalState()) {
                this.iterate();
            }
            return this.list;
        }

        private void init() {
            if (this.model.isClock()) {
                throw new RuntimeException();
            }
            if (this.topology.getChildren().size() == 2) {
                this.halfTopBranch = this.originalTree.branchLength(this.topology.getChildren().get(0).getContents(), this.topology.getChildren().get(1).getContents()) / 2.0;
            }
            ArrayList<Taxon> leavesNames = new ArrayList<Taxon>();
            ArrayList<LikelihoodModelCalculator> leaves = new ArrayList<LikelihoodModelCalculator>();
            for (Arbre<PartialCoalescentState.CoalescentNode> current : this.model.roots) {
                if (!current.isLeaf()) {
                    throw new RuntimeException();
                }
                Taxon t = current.getContents().nodeIdentifier;
                if (!this.taxaToConsider.contains(t)) continue;
                leavesNames.add(t);
                leaves.add(current.getContents().likelihoodModelCache);
            }
            if (this.additionalLeaf != null) {
                leavesNames.add(this.additionalLeaf);
                leaves.add(this.additionalLLC);
            }
            this.current = PartialCoalescentState.initialState(leaves, leavesNames, this.model.getObservations(), false);
            this.list.add(this.current);
        }

        private void iterate() {
            HashSet fringe = CollUtils.set();
            for (Arbre<PartialCoalescentState.CoalescentNode> node : this.current.roots) {
                fringe.add(node.getContents().nodeIdentifier);
            }
            double[] prs = new double[this.current.nRoots()];
            for (int i = 0; i < prs.length; ++i) {
                prs[i] = fringe.contains(this.findParentAndSibling(i).getSecond()) ? 1.0 : 0.0;
            }
            NumUtils.normalize(prs);
            int sampledIndex = SampleUtils.sampleMultinomial(this.rand, prs);
            Taxon sampledTaxon = this.current.roots.get((int)sampledIndex).getContents().nodeIdentifier;
            Pair<Taxon, Taxon> parentAndSibling = this.findParentAndSibling(sampledIndex);
            Taxon parent = parentAndSibling.getFirst();
            Taxon sibling = parentAndSibling.getSecond();
            int otherIndex = -1;
            for (int i = 0; i < prs.length; ++i) {
                if (!this.current.roots.get((int)i).getContents().nodeIdentifier.equals(sibling)) continue;
                otherIndex = i;
                break;
            }
            double bl1 = parent == Taxon.dummy ? this.halfTopBranch : this.originalTree.branchLength(sampledTaxon, parent);
            double bl2 = parent == Taxon.dummy ? this.halfTopBranch : this.originalTree.branchLength(sibling, parent);
            this.current = this.current.coalesce(sampledIndex, otherIndex, 0.0, bl1, bl2, parent);
            this.list.add(this.current);
        }

        private Pair<Taxon, Taxon> findParentAndSibling(int i) {
            Taxon pickedRoot = this.current.roots.get((int)i).getContents().nodeIdentifier;
            if (!this.taxaToConsider.contains(pickedRoot)) {
                throw new RuntimeException();
            }
            Arbre<Taxon> pickedNode = Arbre.findFirstNodeWithContents(this.topology, pickedRoot);
            Arbre<Taxon> parent = pickedNode.getParent();
            if (!this.taxaToConsider.contains(parent.getContents())) {
                throw new RuntimeException();
            }
            for (Arbre<Taxon> child : parent.getChildren()) {
                if (!this.taxaToConsider.contains(child.getContents()) || child.getContents().equals(pickedNode.getContents())) continue;
                return Pair.makePair(parent.getContents(), child.getContents());
            }
            throw new RuntimeException();
        }
    }
}

