/*
 * Decompiled with CFR 0.152.
 */
package conifer.fastmetrics;

import conifer.Phylogeny;
import conifer.ml.data.PhylogeneticHeldoutDataset;
import fenchel.algo.FactorGraphSumProduct;
import fenchel.factor.BinaryFactor;
import fenchel.factor.FactorGraph;
import fenchel.factor.IdentityFactor;
import fenchel.factor.UnaryFactor;
import fig.basic.UnorderedPair;
import goblin.Taxon;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import nuts.math.Graph;
import nuts.math.Graphs;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import pty.RootedTree;
import pty.UnrootedTree;

public class CladeMetrics {
    public static final BinaryFactor identityFactor = new BinaryFactor(){

        @Override
        public UnaryFactor marginalize(List<UnaryFactor> factorsOnFirstNode) {
            if (factorsOnFirstNode.size() != 1) {
                throw new RuntimeException();
            }
            return factorsOnFirstNode.get(0);
        }

        @Override
        public int maxNumberOfFactorsSupported() {
            return 1;
        }
    };

    public static Map<TreeMetric, Double> computeTreeMetrics(Phylogeny truth, Phylogeny guess) {
        Indexer<Taxon> tipsIndexer = CladeMetrics.tipIndexer(truth);
        CladeCalculator calc1 = CladeMetrics.cladeSet(truth.getUnrooted(), tipsIndexer);
        CladeCalculator calc2 = CladeMetrics.cladeSet(guess.getUnrooted(), tipsIndexer);
        HashMap<TreeMetric, Double> result = new HashMap<TreeMetric, Double>();
        double l0 = 0.0;
        double l1 = 0.0;
        double l2 = 0.0;
        Counter clades1 = calc1.bipartitions();
        Counter clades2 = calc2.bipartitions();
        for (UnorderedPair bipart : CollUtils.union(clades1.keySet(), clades2.keySet())) {
            double b1 = clades1.getCount(bipart);
            double b2 = clades2.getCount(bipart);
            l0 += CladeMetrics.l0(b1, b2);
            l1 += CladeMetrics.l1(b1, b2);
            l2 += CladeMetrics.l2(b1, b2);
        }
        result.put(TreeMetric.l0, l0);
        result.put(TreeMetric.l1, l1);
        result.put(TreeMetric.l2, Math.sqrt(l2));
        return result;
    }

    private static double l2(double b1, double b2) {
        double x = b1 - b2;
        return x * x;
    }

    private static double l1(double b1, double b2) {
        return Math.abs(b1 - b2);
    }

    private static double l0(double b1, double b2) {
        return Math.abs(CladeMetrics.isPositive(b1) - CladeMetrics.isPositive(b2));
    }

    private static double isPositive(double b2) {
        return b2 > 0.0 ? 1.0 : 0.0;
    }

    public static double bipartitionTopologySymmetricDifference(CladeCalculator calc1, CladeCalculator calc2) {
        Set clades1 = calc1.bipartitions().keySet();
        Set clades2 = calc2.bipartitions().keySet();
        double result = 0.0;
        for (UnorderedPair bipart : clades1) {
            if (clades2.contains(bipart)) continue;
            result += 1.0;
        }
        for (UnorderedPair bipart : clades2) {
            if (clades1.contains(bipart)) continue;
            result += 1.0;
        }
        return result;
    }

    public static Indexer<Taxon> tipIndexer(Phylogeny p) {
        Indexer<Taxon> tipIndexer = new Indexer<Taxon>();
        for (Taxon t : p.getUnrooted().leaves()) {
            tipIndexer.addToIndex((Taxon[])new Taxon[]{t});
        }
        return tipIndexer;
    }

    public static CladeCalculator cladeSet(UnrootedTree ut, Indexer<Taxon> tipIndexer) {
        CladeBuilderFactorGraph fg = new CladeBuilderFactorGraph(ut, tipIndexer);
        FactorGraphSumProduct<Taxon> sp = new FactorGraphSumProduct<Taxon>();
        sp.init(fg);
        return new CladeCalculator(sp, ut);
    }

    public static void main(String[] args) {
        PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions phyloOptions = new PhylogeneticHeldoutDataset.PhylogeneticHeldoutDatasetOptions();
        phyloOptions.maxNSites = 10;
        PhylogeneticHeldoutDataset phyloData = PhylogeneticHeldoutDataset.loadData(phyloOptions);
        Random rand = new Random(1L);
        RootedTree original = RootedTree.Util.random(rand, phyloData.rootedTree.getUnrooted().leaves());
        RootedTree truth = phyloData.rootedTree;
        System.out.println(CladeMetrics.computeTreeMetrics(truth, original));
        Indexer<Taxon> tipsIndexer = CladeMetrics.tipIndexer(truth);
        CladeCalculator calc1 = CladeMetrics.cladeSet(truth.getUnrooted(), tipsIndexer);
        CladeCalculator calc2 = CladeMetrics.cladeSet(original.getUnrooted(), tipsIndexer);
        System.out.println(CladeMetrics.bipartitionTopologySymmetricDifference(calc1, calc2));
    }

    public static class CladeFactor
    implements UnaryFactor {
        private final int[] sortedTips;
        private final Indexer<Taxon> indexer;

        public CladeFactor(int[] sortedTips, Indexer<Taxon> indexer) {
            this.indexer = indexer;
            this.sortedTips = sortedTips;
            if (sortedTips != null) {
                Arrays.sort(sortedTips);
                for (int i = 0; i < sortedTips.length - 1; ++i) {
                    if (sortedTips[i] != sortedTips[i + 1]) continue;
                    throw new RuntimeException();
                }
            }
        }

        public String toString() {
            if (this.sortedTips == null) {
                return "complement";
            }
            return this.taxa().toString();
        }

        public Set<Taxon> taxa() {
            HashSet<Taxon> result = new HashSet<Taxon>();
            for (int i : this.sortedTips) {
                result.add(this.indexer.i2o(i));
            }
            return result;
        }

        @Override
        public UnaryFactor multiply(List<UnaryFactor> otherFactors) {
            if (this.sortedTips == null) {
                return this;
            }
            int nItems = this.sortedTips.length;
            for (UnaryFactor f : otherFactors) {
                CladeFactor other = (CladeFactor)f;
                if (other.sortedTips == null) {
                    return other;
                }
                if (other.indexer != this.indexer) {
                    throw new RuntimeException();
                }
                nItems += other.sortedTips.length;
            }
            if (nItems > this.indexer.size() / 2) {
                return new CladeFactor(null, this.indexer);
            }
            int[] result = new int[nItems];
            int index = 0;
            for (UnaryFactor f : otherFactors) {
                for (int item : ((CladeFactor)f).sortedTips) {
                    result[index++] = item;
                }
            }
            for (Object item : (Object)this.sortedTips) {
                result[index++] = (int)item;
            }
            return new CladeFactor(result, this.indexer);
        }

        @Override
        public double logNorm() {
            return Double.NaN;
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + Arrays.hashCode(this.sortedTips);
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            CladeFactor other = (CladeFactor)obj;
            if (other.indexer != this.indexer) {
                throw new RuntimeException();
            }
            if (this.sortedTips == null && other.sortedTips == null) {
                return true;
            }
            return Arrays.equals(this.sortedTips, other.sortedTips);
        }
    }

    public static final class CladeBuilderFactorGraph
    implements FactorGraph<Taxon> {
        private final Graph<Taxon> topology;
        private final Map<Taxon, UnaryFactor> map;
        private final Indexer<Taxon> tipIndexer;

        public CladeBuilderFactorGraph(UnrootedTree t, Indexer<Taxon> indexer) {
            this.topology = t.getTopology();
            this.map = new HashMap<Taxon, UnaryFactor>();
            this.tipIndexer = indexer;
            this.init(t);
        }

        private void init(UnrootedTree t) {
            for (Taxon tip : t.leaves()) {
                int index = this.tipIndexer.o2i(tip);
                this.map.put(tip, new CladeFactor(new int[]{index}, this.tipIndexer));
            }
        }

        @Override
        public UnaryFactor getUnary(Taxon node) {
            UnaryFactor result = this.map.get(node);
            if (result == null) {
                return IdentityFactor.identity;
            }
            return result;
        }

        @Override
        public Graph<Taxon> getTopology() {
            return this.topology;
        }

        @Override
        public BinaryFactor getBinary(Taxon source, Taxon destination) {
            return identityFactor;
        }
    }

    public static class CladeCalculator {
        private final FactorGraphSumProduct<Taxon> rawClades;
        private final UnrootedTree ut;
        private Counter<UnorderedPair<CladeFactor, CladeFactor>> bipartitions = null;

        private CladeCalculator(FactorGraphSumProduct<Taxon> rawClades, UnrootedTree ut) {
            this.ut = ut;
            this.rawClades = rawClades;
        }

        private Counter<UnorderedPair<CladeFactor, CladeFactor>> bipartitions() {
            if (this.bipartitions != null) {
                return this.bipartitions;
            }
            this.bipartitions = new Counter();
            for (UnorderedPair<Taxon, Taxon> edge : Graphs.edgeSet(this.ut.getTopology())) {
                double branch = this.ut.branchLength(edge);
                CladeFactor c1 = (CladeFactor)this.rawClades.getMessage(edge.getFirst(), edge.getSecond());
                CladeFactor c2 = (CladeFactor)this.rawClades.getMessage(edge.getSecond(), edge.getFirst());
                if (c1.sortedTips == null && c2.sortedTips == null) {
                    throw new RuntimeException();
                }
                this.bipartitions.incrementCount(new UnorderedPair<CladeFactor, CladeFactor>(c1, c2), branch);
            }
            return this.bipartitions;
        }
    }

    public static enum TreeMetric {
        l0,
        l1,
        l2;

    }
}

