/*
 * Decompiled with CFR 0.152.
 */
package pty.eval;

import fig.basic.IOUtils;
import goblin.BayesRiskMinimizer;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import ma.newick.NewickParser;
import nuts.math.Fct;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Tree;
import pty.smc.MapLeaves;
import pty.smc.PartialCoalescentState;
import pty.smc.ParticleFilter;

public class SymmetricDiff<T>
implements BayesRiskMinimizer.LossFct<Set<T>> {
    public static final SymmetricDiff<Set<Taxon>> CLADE_SYMMETRIC_DIFFERENCE = new SymmetricDiff();

    public static <S> Set<S> filterLowCounts(Counter<S> counter, double threshold) {
        HashSet<S> result = new HashSet<S>();
        for (S key : counter.keySet()) {
            if (!(counter.getCount(key) > threshold)) continue;
            result.add(key);
        }
        return result;
    }

    public static <S> Set<Set<S>> cladesFromUnrooted(Arbre<S> unrootedTreeWithArbitraryRooting) {
        Set<Set<S>> clades = SymmetricDiff.clades(unrootedTreeWithArbitraryRooting);
        clades = SymmetricDiff.addComplements(clades);
        clades.remove(new HashSet());
        return clades;
    }

    private static <S> Set<Set<S>> addComplements(Set<Set<S>> original) {
        HashSet<Set<S>> result = new HashSet<Set<S>>(original);
        Set<S> all = SymmetricDiff.allLeaves(original);
        for (Set<S> clade : original) {
            HashSet<S> complement = new HashSet<S>(all);
            complement.removeAll(clade);
            result.add(complement);
        }
        return result;
    }

    public static <S> Set<S> allLeaves(Set<Set<S>> clades) {
        HashSet<S> result = new HashSet<S>();
        for (Set<S> clade : clades) {
            result.addAll(clade);
        }
        return result;
    }

    @Override
    public double loss(Set<T> s1, Set<T> s2) {
        return SymmetricDiff.symmetricDifferenceSize(s1, s2);
    }

    public static <S> int symmetricDifferenceSize(Set<S> s1, Set<S> s2) {
        int result = 0;
        for (S elt : s1) {
            if (s2.contains(elt)) continue;
            ++result;
        }
        for (S elt : s2) {
            if (s1.contains(elt)) continue;
            ++result;
        }
        return result;
    }

    public static int cladesSymmetricDifferenceSize(Arbre<Taxon> s1, Arbre<Taxon> s2, MapLeaves bijection) {
        return SymmetricDiff.cladesSymmetricDifferenceSize(SymmetricDiff.clades(s1), SymmetricDiff.clades(s2), bijection);
    }

    public static int cladesSymmetricDifferenceSize(Set<Set<Taxon>> s1, Set<Set<Taxon>> s2, MapLeaves bijection) {
        Set<Set<Taxon>> mapped = bijection.mapClades(s1);
        Set<Set<Taxon>> filtered = bijection.filterClades(s2);
        return SymmetricDiff.symmetricDifferenceSize(mapped, filtered);
    }

    public static ParticleFilter.ParticleMapperProcessor<PartialCoalescentState, Set<Set<Taxon>>> createCladeProcessor() {
        return new ParticleFilter.ParticleMapperProcessor<PartialCoalescentState, Set<Set<Taxon>>>(new Fct<PartialCoalescentState, Set<Set<Taxon>>>(){

            @Override
            public Set<Set<Taxon>> evalAt(PartialCoalescentState x) {
                return SymmetricDiff.clades(x.getUnlabeledArbre());
            }
        });
    }

    public static <T> boolean violates(Set<T> s1, Set<T> s2) {
        return CollUtils.intersects(s1, s2) && !s1.containsAll(s2) && !s2.containsAll(s1);
    }

    public static <T> boolean violatesOne(Set<T> s, Set<Set<T>> sets) {
        for (Set<T> s2 : sets) {
            if (!SymmetricDiff.violates(s, s2)) continue;
            return true;
        }
        return false;
    }

    public static int deltaSymmetricDiff(PartialCoalescentState state, int i, int j, Set<Set<Taxon>> refs) {
        return SymmetricDiff.deltaSymmetricDiff(state, i, j, refs, null);
    }

    public static int deltaSymmetricDiff(PartialCoalescentState state, Set<Set<Taxon>> allCladesInState, int i, int j, Set<Set<Taxon>> preMappedRefs, MapLeaves bijection) {
        int result = 0;
        if (state.nRoots() == 2) {
            return 0;
        }
        Set<Taxon> newClade = state.mergedClade(i, j);
        if (bijection != null) {
            newClade = bijection.restrict(newClade);
        }
        if (!newClade.isEmpty()) {
            result += preMappedRefs.contains(newClade) ? 0 : 1;
            block0: for (Set<Taxon> ref : preMappedRefs) {
                if (!SymmetricDiff.violates(ref, newClade)) continue;
                for (Set<Taxon> oldClade : allCladesInState) {
                    if (!SymmetricDiff.violates(ref, oldClade)) continue;
                    continue block0;
                }
                ++result;
            }
        }
        return result;
    }

    public static int deltaSymmetricDiff(PartialCoalescentState state, int i, int j, Set<Set<Taxon>> refs, MapLeaves bijection) {
        Set<Set<Taxon>> allCladesInState = state.allClades();
        if (bijection != null) {
            refs = bijection.mapClades(refs);
            allCladesInState = bijection.filterClades(allCladesInState);
        }
        return SymmetricDiff.deltaSymmetricDiff(state, allCladesInState, i, j, refs, bijection);
    }

    public static void main(String[] args) {
        try {
            NewickParser np = new NewickParser(IOUtils.openIn("data/generatedExperiment/tree.newick"));
            Tree<String> tree = np.parse();
            Arbre<String> ar = Arbre.tree2Arbre(tree);
            np = new NewickParser(IOUtils.openIn("data/generatedExperiment/contml50k.newick"));
            Tree<String> tree2 = np.parse();
            Arbre<String> ar2 = Arbre.tree2Arbre(tree2);
            System.out.println(SymmetricDiff.normalizedSymmetricCladeDiff(ar, ar2));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static <S> int symmetricCladeDiff(Arbre<S> t1, Arbre<S> t2) {
        Set<Set<S>> clades1 = SymmetricDiff.clades(t1);
        Set<Set<S>> clades2 = SymmetricDiff.clades(t2);
        return SymmetricDiff.symmetricDifferenceSize(clades1, clades2);
    }

    public static <S> int maxSymmetricCladeDiff(Arbre<S> t1, Arbre<S> t2) {
        return t1.nodes().size() + t2.nodes().size();
    }

    public static <S> double normalizedSymmetricCladeDiff(Arbre<S> t1, Arbre<S> t2) {
        return (double)SymmetricDiff.symmetricCladeDiff(t1, t2) / (double)SymmetricDiff.maxSymmetricCladeDiff(t1, t2);
    }

    public static <S> Set<Set<S>> clades(Arbre<S> t) {
        HashSet<Set<S>> clades = new HashSet<Set<S>>();
        SymmetricDiff.clades(SymmetricDiff.leafSetArbre(t), clades);
        return clades;
    }

    private static <S> void clades(Arbre<Set<S>> t1, Set<Set<S>> clades) {
        clades.add(t1.getContents());
        for (Arbre<Set<S>> child : t1.getChildren()) {
            SymmetricDiff.clades(child, clades);
        }
    }

    public static <S> Set<Set<S>> complete(Set<Set<S>> clades) {
        HashSet<Set<S>> result = new HashSet<Set<S>>(clades);
        HashSet<S> all = new HashSet<S>();
        for (Set<S> clade : clades) {
            all.addAll(clade);
        }
        result.add(all);
        for (Set<Object> s : all) {
            result.add(Collections.singleton(s));
        }
        return result;
    }

    public static <S> Arbre<S> clades2arbre(Set<Set<S>> clades) {
        Arbre<S> result = SymmetricDiff._clades2arbre(SymmetricDiff.complete(clades));
        return result;
    }

    private static <S> Arbre<S> _clades2arbre(Set<Set<S>> clades) {
        if (clades.size() == 1) {
            Set<S> singleton = clades.iterator().next();
            if (singleton.size() != 1) {
                throw new RuntimeException();
            }
            return Arbre.arbre(singleton.iterator().next());
        }
        Set<Set<S>> remainder = SymmetricDiff.removeLargest(clades);
        ArrayList children = new ArrayList();
        while (remainder.size() > 0) {
            Set<Set<S>> child = SymmetricDiff.extractChild(remainder);
            children.add(SymmetricDiff._clades2arbre(child));
        }
        return Arbre.arbre(null, children);
    }

    private static <S> Set<Set<S>> removeLargest(Set<Set<S>> rem) {
        HashSet<Set<S>> result = new HashSet<Set<S>>(rem);
        int max = Integer.MIN_VALUE;
        Set<S> argmax = null;
        for (Set<S> clade : rem) {
            if (clade.size() <= max) continue;
            max = clade.size();
            argmax = clade;
        }
        result.remove(argmax);
        return result;
    }

    private static <S> Set<Set<S>> extractChild(Set<Set<S>> rem) {
        int max = Integer.MIN_VALUE;
        Set<S> argmax = null;
        for (Set<S> clade : rem) {
            if (clade.size() <= max) continue;
            max = clade.size();
            argmax = clade;
        }
        HashSet<Set<S>> result = new HashSet<Set<S>>();
        Iterator<Set<S>> iter = rem.iterator();
        while (iter.hasNext()) {
            Set<S> current = iter.next();
            if (!argmax.containsAll(current)) continue;
            result.add(current);
            iter.remove();
        }
        return result;
    }

    public static <T> Arbre<Set<T>> leafSetArbre(Arbre<T> a) {
        return a.postOrderMap(new Arbre.ArbreMap<T, Set<T>>(){

            @Override
            public Set<T> map(Arbre<T> currentDomainNode) {
                HashSet result = new HashSet(currentDomainNode.nLeaves());
                if (currentDomainNode.isLeaf()) {
                    result.add(currentDomainNode.getContents());
                } else {
                    for (Set child : this.getChildImage()) {
                        result.addAll(child);
                    }
                }
                return result;
            }
        });
    }

    public static class ConsensusProcessor
    implements ParticleFilter.ParticleProcessor<PartialCoalescentState> {
        private Counter<Set<Taxon>> flatClades = new Counter();

        @Override
        public void process(PartialCoalescentState state, double weight) {
            this.flatClades.incrementAll(state.allClades(), weight);
        }

        public Set<Set<Taxon>> consensus(double threshold) {
            if (threshold < 0.5) {
                throw new RuntimeException();
            }
            return SymmetricDiff.filterLowCounts(this.flatClades, threshold);
        }
    }
}

