/*
 * Decompiled with CFR 0.152.
 */
package times;

import fig.basic.Pair;
import file.FileUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.StringTokenizer;
import nuts.util.Arbre;

public class TimeState {
    private Arbre<String> ar;
    private HashMap<Set<String>, String> cladesMap;
    private HashMap<String, Arbre<String>> arbreMap;
    private Map<Arbre<String>, Set<String>> nadMap;
    private Map<Arbre<String>, Set<String>> descMap;
    private Map<Arbre<String>, Set<String>> leavesMap;
    private List<String> order;
    private HashMap<String, Double> timesLBMap;
    private HashMap<String, Double> timesUBMap;
    private HashMap<String, PriorType> timesFlagMap;
    private HashMap<String, Double> timesMap;
    private double lambda;

    public List<String> getOrder() {
        return this.order;
    }

    public HashMap<String, Double> getTimes() {
        return this.timesMap;
    }

    public HashMap<String, PriorType> getTimesFlag() {
        return this.timesFlagMap;
    }

    public double logCalibratedPrior(double t, String s) {
        if (this.timesFlagMap.get(s) == PriorType.DELTA) {
            return t == this.timesLBMap.get(s) ? 0.0 : Double.NEGATIVE_INFINITY;
        }
        if (this.timesFlagMap.get(s) == PriorType.BOUNDED) {
            return this.logBoundedPrior(t, this.timesLBMap.get(s), this.timesUBMap.get(s));
        }
        return 0.0;
    }

    public double logBoundedPrior(double x, double lb, double ub) {
        double eps = 0.05;
        if (!Double.isInfinite(lb) && !Double.isInfinite(ub)) {
            double alpha = lb / (ub - lb) * ((1.0 - eps) / (0.5 * eps));
            double lambda = 1.0 / (ub - lb) * ((1.0 - eps) / (0.5 * eps));
            if (x < lb) {
                return Math.log(0.5 * eps) + Math.log(alpha) - Math.log(lb) + (alpha - 1.0) * (Math.log(x) - Math.log(lb));
            }
            if (x > ub) {
                return Math.log(0.5 * eps) + Math.log(lambda) - lambda * (x - ub);
            }
            return Math.log(1.0 - eps) - Math.log(ub - lb);
        }
        System.err.println("Bounds must be finite");
        System.exit(1);
        return 0.0;
    }

    public double sampleBoundedPrior(double lb, double ub, Random rand) {
        double eps = 0.05;
        double alpha = lb / (ub - lb) * ((1.0 - eps) / (0.5 * eps));
        double lambda = 1.0 / (ub - lb) * ((1.0 - eps) / (0.5 * eps));
        double u = rand.nextDouble();
        if (u < 0.5 * eps) {
            u = rand.nextDouble();
            double result = Math.log(lb) + 1.0 / alpha * Math.log(u / (0.5 * eps));
            result = Math.exp(result);
            return result;
        }
        if (u < 1.0 - 0.5 * eps) {
            u = rand.nextDouble();
            return lb * (1.0 - u) + ub * u;
        }
        u = rand.nextDouble();
        double result = -Math.log(1.0 - u / (0.5 * eps)) / lambda + ub;
        return result;
    }

    public double logRootPrior(double x) {
        Arbre<String> a = this.ar.root();
        String s = a.getContents();
        double ub = this.timesUBMap.get(s);
        double lb = this.timesLBMap.get(s);
        return this.logBoundedPrior(x, lb, ub);
    }

    public double sampleRootPrior(double lb, Random rand) {
        Arbre<String> a = this.ar.root();
        String s = a.getContents();
        lb = Math.max(lb, this.timesLBMap.get(s));
        if (!this.timesUBMap.containsKey(s)) {
            throw new RuntimeException("Need an upper bound on the time of the root");
        }
        double ub = this.timesUBMap.get(s);
        if (lb > ub) {
            throw new RuntimeException("Cannot sample time at root");
        }
        double result = rand.nextDouble() * (ub - lb) + lb;
        return result;
    }

    public double getLogDensity() {
        double logmarginal = 0.0;
        double logjoint = 0.0;
        double logcalibrated = 0.0;
        double maxtime = 0.0;
        int count = 0;
        double lb = 0.0;
        for (String s : this.order) {
            Arbre<String> node = this.arbreMap.get(s);
            double t = this.timesMap.get(s);
            if (t < maxtime) {
                return Double.NEGATIVE_INFINITY;
            }
            maxtime = t;
            if (node.isRoot()) {
                if (count <= 0) continue;
                logmarginal += (double)count * (-this.lambda * lb + Math.log(1.0 - Math.exp(-this.lambda * (t - lb))) - Math.log(this.lambda));
                continue;
            }
            Arbre<String> parent = node.getParent();
            if (this.timesMap.get(parent.getContents()) > t) {
                return Double.NEGATIVE_INFINITY;
            }
            if (!this.timesFlagMap.containsKey(s)) {
                logjoint += -this.lambda * t;
                ++count;
                continue;
            }
            logjoint += -this.lambda * t;
            logcalibrated += this.logCalibratedPrior(t, s);
            if (count == 0) {
                logmarginal += -this.lambda * t;
                continue;
            }
            logmarginal += (double)count * (-this.lambda * lb + Math.log(1.0 - Math.exp(-this.lambda * (t - lb))) - Math.log(this.lambda));
            lb = t;
            count = 0;
        }
        double logd = logjoint - logmarginal + logcalibrated;
        return logd;
    }

    public TimeState copy() {
        TimeState ts = new TimeState(this.ar, this.cladesMap, this.arbreMap, this.timesMap);
        ts.setOrder(this.order);
        return ts;
    }

    public TimeState(Arbre<String> ar, HashMap<Set<String>, String> cladesMap, HashMap<String, Arbre<String>> arbreMap, HashMap<String, Double> timesMap) {
        this.ar = ar;
        this.cladesMap = cladesMap;
        this.arbreMap = arbreMap;
        this.lambda = 1.0;
        this.timesMap = timesMap;
        this.nadMap = Arbre.nadMap(ar);
        this.descMap = Arbre.descMap(ar);
        this.leavesMap = Arbre.leavesMap(ar);
    }

    public TimeState(Arbre<String> ar, HashMap<Set<String>, String> cladesMap, HashMap<String, Arbre<String>> arbreMap, String timesFile) {
        this.ar = ar;
        this.cladesMap = cladesMap;
        this.arbreMap = arbreMap;
        this.lambda = 1.0;
        this.readTimes(timesFile);
        this.nadMap = Arbre.nadMap(ar);
        this.descMap = Arbre.descMap(ar);
        this.leavesMap = Arbre.leavesMap(ar);
    }

    public void setLambda(double lambda) {
        this.lambda = lambda;
    }

    public void setOrder(List<String> order) {
        this.order = order;
    }

    public void setTimes(HashMap<String, Double> timesMap) {
        this.timesMap = timesMap;
    }

    public double getTime(String s) {
        return this.timesMap.get(s);
    }

    public void readTimes(String timesFile) {
        this.timesFlagMap = new HashMap();
        this.timesLBMap = new HashMap();
        this.timesUBMap = new HashMap();
        this.timesMap = new HashMap();
        try {
            HashMap<String, String> tmpMap = FileUtils.readMap(timesFile);
            for (String s : tmpMap.keySet()) {
                StringTokenizer tok = new StringTokenizer(s, ",");
                HashSet<String> keys = new HashSet<String>();
                while (tok.hasMoreTokens()) {
                    keys.add(tok.nextToken());
                }
                String internalLabel = this.cladesMap.get(keys);
                tok = new StringTokenizer(tmpMap.get(s), ",");
                if (tok.countTokens() == 1) {
                    this.timesFlagMap.put(internalLabel, PriorType.DELTA);
                    double val = Double.parseDouble(tok.nextToken());
                    this.timesUBMap.put(internalLabel, val);
                    this.timesLBMap.put(internalLabel, val);
                    continue;
                }
                if (tok.countTokens() != 2) continue;
                this.timesFlagMap.put(internalLabel, PriorType.BOUNDED);
                double val1 = Double.parseDouble(tok.nextToken());
                double val2 = Double.parseDouble(tok.nextToken());
                this.timesUBMap.put(internalLabel, val2);
                this.timesLBMap.put(internalLabel, val1);
            }
            if (this.timesFlagMap.get(this.ar.root().getContents()) != PriorType.BOUNDED) {
                System.err.println("Provide bounds on the age of the root");
                throw new RuntimeException("Provide bounds on the age of the root");
            }
        }
        catch (Exception e) {
            System.err.println("Could not read from file " + timesFile);
            e.printStackTrace();
        }
    }

    public void init(Random rand) {
        List<String> newOrder = this.randomOrder(rand);
        this.setOrder(newOrder);
        this.resampleAllTimes(rand);
    }

    public List<String> randomOrder(Random rand) {
        ArrayList<String> result = new ArrayList<String>();
        ArrayList tmp = new ArrayList();
        tmp.add(this.ar.root());
        while (tmp.size() > 0) {
            int choice = rand.nextInt(tmp.size());
            Arbre chosen = (Arbre)tmp.get(choice);
            result.add(0, (String)chosen.getContents());
            tmp.remove(choice);
            for (Arbre a : chosen.getChildren()) {
                if (a.isLeaf()) continue;
                tmp.add(a);
            }
        }
        return result;
    }

    private List<String> swapOrder(int i, int j) {
        ArrayList<String> result = new ArrayList<String>();
        ArrayList<String> result1 = new ArrayList<String>();
        result.addAll(this.order.subList(0, i));
        Arbre<String> aj = this.arbreMap.get(this.order.get(j));
        for (int k = i; k < j; ++k) {
            String s = this.order.get(k);
            Arbre<String> a = this.arbreMap.get(s);
            if (Arbre.isAncestorOf(aj, a)) {
                result.add(s);
                continue;
            }
            result1.add(s);
        }
        result.add(this.order.get(j));
        result.addAll(result1);
        result.add(this.order.get(i));
        result.addAll(this.order.subList(j, this.order.size()));
        return result;
    }

    private Pair<Integer, Integer> pickRandomNadPair(Random rand) {
        int i = rand.nextInt(this.order.size());
        Arbre<String> a = this.arbreMap.get(this.order.get(i));
        Set<String> nads = this.nadMap.get(a);
        String[] nadsArray = new String[nads.size()];
        nads.toArray(nadsArray);
        String select = nadsArray[rand.nextInt(nadsArray.length)];
        int j = 0;
        for (String s : this.order) {
            if (select.equals(s)) break;
            ++j;
        }
        return Pair.makePair(i, j);
    }

    public List<String> resampleOrder(Random rand) {
        Pair<Integer, Integer> p = this.pickRandomNadPair(rand);
        List<String> newOrder = this.swapOrder(p.getFirst(), p.getSecond());
        return newOrder;
    }

    public static double logTransitionProbability(TimeState oldState, TimeState newState) {
        List<String> order1 = oldState.getOrder();
        List<String> order2 = newState.getOrder();
        HashMap<String, Double> time1 = oldState.getTimes();
        HashMap<String, Double> time2 = newState.getTimes();
        boolean flag = true;
        if (order1.size() == order2.size()) {
            for (int i = 0; i < order1.size(); ++i) {
                if (order1.get(i).equals(order2.get(i))) continue;
                flag = false;
                break;
            }
        } else {
            flag = false;
        }
        double logdensity = 0.0;
        if (flag) {
            int ind = -1;
            for (int i = 0; i < order1.size(); ++i) {
                double t2;
                double t1 = time1.get(order1.get(i));
                if (!(Math.abs(t1 - (t2 = time2.get(order2.get(i)).doubleValue())) > Double.MIN_VALUE)) continue;
                if (ind == -1) {
                    ind = i;
                    continue;
                }
                throw new RuntimeException("Found a change in more than one time");
            }
            if (ind == -1) {
                throw new RuntimeException("Times are identical");
            }
            String s = order1.get(ind);
            logdensity = newState.logTransitionProbability(s, time2.get(s));
        } else {
            logdensity = newState.logTransitionProbability(time2);
        }
        return logdensity;
    }

    public double logTransitionProbability(String s, double time) {
        for (int i = 0; i < this.order.size(); ++i) {
            String t = this.order.get(i);
            if (!t.equals(s)) continue;
            if (this.timesFlagMap.get(t) == PriorType.DELTA) {
                if (Math.abs(this.timesMap.get(t) - time) < Double.MIN_VALUE) {
                    return 0.0;
                }
                return Double.NEGATIVE_INFINITY;
            }
            if (i == this.order.size() - 1) {
                return this.logRootPrior(time);
            }
            double ub = this.timesMap.get(this.order.get(i + 1));
            double lb = 0.0;
            if (i > 0) {
                lb = this.timesMap.get(this.order.get(i - 1));
            }
            return Math.log(1.0 / ub - lb);
        }
        throw new RuntimeException();
    }

    public double logTransitionProbability(HashMap<String, Double> newTimesMap) {
        double ub = Double.MAX_VALUE;
        double logdensity = 0.0;
        for (int i = this.order.size() - 1; i >= 0; --i) {
            String s = this.order.get(i);
            if (this.timesFlagMap.get(s) == PriorType.DELTA && Math.abs(newTimesMap.get(s) - this.timesLBMap.get(s)) > Double.MIN_VALUE) {
                logdensity = Double.NEGATIVE_INFINITY;
                break;
            }
            if (i == this.order.size() - 1) {
                logdensity += this.logRootPrior(newTimesMap.get(s));
                ub = newTimesMap.get(s);
                continue;
            }
            double lb = 0.0;
            if (i > 0) {
                lb = newTimesMap.get(this.order.get(i - 1));
            }
            logdensity += Math.log(1.0 / (ub - lb));
            ub = newTimesMap.get(s);
        }
        return logdensity;
    }

    public HashMap<String, Double> resampleAllTimes(Random rand) {
        HashMap<String, Double> result = new HashMap<String, Double>();
        Set<String> leaves = this.leavesMap.get(this.ar.root());
        for (String l : leaves) {
            result.put(l, 0.0);
        }
        double ub = Double.MAX_VALUE;
        for (int i = this.order.size() - 1; i >= 0; --i) {
            double tmp;
            String s = this.order.get(i);
            Set<String> desc = this.descMap.get(this.arbreMap.get(s));
            double min = 0.0;
            for (String t : desc) {
                if (this.timesFlagMap.get(t) != PriorType.DELTA) continue;
                min = Math.max(this.timesLBMap.get(t), min);
            }
            if (this.timesFlagMap.get(s) == PriorType.DELTA) {
                tmp = this.timesLBMap.get(s);
                if (tmp < min) {
                    throw new RuntimeException("Inconsistent specification of times");
                }
                result.put(s, tmp);
                continue;
            }
            if (this.arbreMap.get(s).isRoot()) {
                tmp = this.sampleRootPrior(min, rand);
                result.put(s, tmp);
                ub = tmp;
                continue;
            }
            if (min > ub) {
                throw new RuntimeException("Cannot sample times for node " + s);
            }
            tmp = rand.nextDouble() * (ub - min) + min;
            result.put(s, tmp);
            ub = tmp;
        }
        return result;
    }

    public HashMap<String, Double> resampleTime(String s, Random rand) {
        HashMap<String, Double> result = new HashMap<String, Double>();
        for (String key : this.timesMap.keySet()) {
            result.put(key, this.timesMap.get(key));
        }
        if (this.timesFlagMap.get(s) == PriorType.DELTA) {
            return result;
        }
        double lb = 0.0;
        String root = this.ar.root().getContents();
        if (!this.timesUBMap.containsKey(root)) {
            throw new RuntimeException("Need an upper bound on the time of the root");
        }
        double ub = this.timesUBMap.get(root);
        int index = -1;
        for (int i = 0; i < this.order.size(); ++i) {
            if (this.order.get(i) != s) continue;
            index = i;
            break;
        }
        if (index == -1) {
            throw new RuntimeException("String " + s + "not found in tree");
        }
        if (index > 0) {
            lb = this.timesMap.get(this.order.get(index - 1));
        }
        if (index < this.order.size() - 1) {
            ub = this.timesMap.get(this.order.get(index + 1));
        }
        double newTime = lb + (ub - lb) * rand.nextDouble();
        result.put(s, newTime);
        return result;
    }

    public String toString() {
        String result = "";
        for (String s : this.timesMap.keySet()) {
            result = result + s + "\t" + this.timesMap.get(s);
        }
        return result;
    }

    public static enum PriorType {
        FLAT,
        DELTA,
        BOUNDED;

    }
}

