/*
 * Decompiled with CFR 0.152.
 */
package conifer.ssm;

import conifer.ssm.Edit;
import conifer.ssm.ForwardSimulator;
import conifer.ssm.StringMutationModel;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.math.ProposalRandom;
import nuts.util.Counter;
import pepper.editmodel.BayesRiskMinimizer;

public class InformedProposal2 {
    private StringMutationModel stringModel;
    @Option
    public double proposalGreed = 50.0;
    @Option
    public double proposalStopProbability = 0.9;
    @Option
    public int maxSteps = 50;
    @Option
    public PropMethod propMethod = PropMethod.SIMPLE_GREEDY;
    @Option
    public double alpha = 3.0;
    @Option
    public double beta = 1.0;
    @Option
    public boolean useRandomFlip = true;

    public static void main(String[] args) {
        InformedProposal2.test2();
    }

    static String printWeights(Counter<String> samples) {
        StringBuilder result = new StringBuilder();
        for (String key : samples) {
            result.append("[" + samples.getCount(key) + " : pid-" + ("" + key.hashCode()).replaceAll("-", "") + "] ");
        }
        return result.toString();
    }

    private static void test2() {
        StringMutationModel model = new StringMutationModel();
        Random rand = new Random(3667L);
        String start = ForwardSimulator.approxStationarySampling(model, rand);
        int K = 5;
        String current = start;
        System.out.println("K=5");
        System.out.println();
        System.out.println(current);
        for (int k = 0; k < 5; ++k) {
            Pair<Edit, Double> sampled = ForwardSimulator.next(current, model, rand);
            current = sampled.getFirst().newSeq;
            System.out.println(current);
        }
        System.out.println();
        String end = current;
        InformedProposal2 p = new InformedProposal2();
        p.setModel(model);
        p.proposalGreed = 3.0;
        while (p.proposalGreed < 15.0) {
            ProposalRandom pRand = new ProposalRandom(rand);
            System.out.println("Proposal greed:" + p.proposalGreed);
            List<Edit> edits = p.proposeEdits(pRand, start, end);
            if (edits != null) {
                System.out.println(edits.get((int)0).oldSeq);
                for (Edit e : edits) {
                    System.out.println(e.newSeq);
                }
                System.out.println("nSteps=" + edits.size());
                System.out.println("logProp=" + pRand.getLogProbability());
            } else {
                System.out.println("[BAILED OUT]");
            }
            System.out.println();
            p.proposalGreed += 1.0;
        }
    }

    public void setModel(StringMutationModel stringModel) {
        this.stringModel = stringModel;
    }

    private List<Edit> proposeEdits(ProposalRandom pRand, String start, String target) {
        ArrayList<Edit> result = new ArrayList<Edit>();
        String currentStr = start;
        for (int i = 0; i < this.maxSteps; ++i) {
            Edit currentEdit = this.proposeEdit(pRand, currentStr, target);
            if (currentEdit == null) {
                return result;
            }
            result.add(currentEdit);
            currentStr = currentEdit.newSeq;
        }
        return null;
    }

    public Pair<String, Double> proposeAndGetLogProb(Random rand, String end1, double edge1, String end2, double edge2) {
        ProposalRandom pRand = new ProposalRandom(rand);
        Pair<Pair<String, Double>, Pair<String, Double>> ordered = this.order(pRand, end1, edge1, end2, edge2);
        end1 = ordered.getFirst().getFirst();
        edge1 = ordered.getFirst().getSecond();
        end2 = ordered.getSecond().getFirst();
        edge2 = ordered.getSecond().getSecond();
        List<Edit> edits = this.proposeEdits(pRand, end1, end2);
        if (edits == null) {
            return Pair.makePair(null, Double.NEGATIVE_INFINITY);
        }
        List<Double> times = pRand.samplePartitions(edits.size(), edge1 + edge2);
        List<Pair<String, Double>> series = this.series(end2, times, edits);
        ArrayList<Pair<String, Double>> subSeries1 = new ArrayList<Pair<String, Double>>();
        ArrayList<Pair<String, Double>> subSeries2 = new ArrayList<Pair<String, Double>>();
        this.processTimes(series, edge1, subSeries1, subSeries2);
        String topStr = (String)((Pair)subSeries1.get(0)).getFirst();
        if (!topStr.equals(((Pair)subSeries2.get(0)).getFirst()) || subSeries1.size() + subSeries2.size() != 1 + series.size()) {
            throw new RuntimeException();
        }
        String botStr1 = (String)((Pair)subSeries1.get(subSeries1.size() - 1)).getFirst();
        String botStr2 = (String)((Pair)subSeries2.get(subSeries2.size() - 1)).getFirst();
        double logPriorRatio = this.stringModel.logPrior(topStr) - this.stringModel.logPrior(botStr1) - this.stringModel.logPrior(botStr2);
        double logLL = this.stringModel.logLikelihood(subSeries1) + this.stringModel.logLikelihood(subSeries2);
        double w = logPriorRatio + logLL - pRand.getLogProbability();
        return Pair.makePair(topStr, w);
    }

    private void processTimes(List<Pair<String, Double>> series, double middle, List<Pair<String, Double>> subSeries1, List<Pair<String, Double>> subSeries2) {
        double currentL = 0.0;
        double prevL = 0.0;
        boolean foundMidPt = false;
        for (Pair<String, Double> interval : series) {
            if ((currentL += interval.getSecond().doubleValue()) > middle) {
                if (!foundMidPt) {
                    String curStr = interval.getFirst();
                    subSeries1.add(Pair.makePair(curStr, middle - prevL));
                    subSeries2.add(Pair.makePair(curStr, currentL - middle));
                } else {
                    subSeries2.add(interval);
                }
                foundMidPt = true;
            } else {
                subSeries1.add(interval);
            }
            prevL = currentL;
        }
        Collections.reverse(subSeries1);
    }

    private List<Pair<String, Double>> series(String endPt, List<Double> times, List<Edit> edits) {
        if (edits.size() + 1 != times.size()) {
            throw new RuntimeException();
        }
        ArrayList<Pair<String, Double>> result = new ArrayList<Pair<String, Double>>();
        for (int i = 0; i < times.size(); ++i) {
            double curLen = times.get(i);
            String str = i == times.size() - 1 ? endPt : edits.get((int)i).oldSeq;
            result.add(Pair.makePair(str, curLen));
        }
        return result;
    }

    private Pair<Pair<String, Double>, Pair<String, Double>> order(ProposalRandom pRand, String end1, double edge1, String end2, double edge2) {
        Pair<String, Double> a = Pair.makePair(end1, edge1);
        Pair<String, Double> b = Pair.makePair(end2, edge2);
        boolean flip = false;
        if (this.useRandomFlip) {
            flip = pRand.sampleBern(0.5);
        } else if (end1.length() < end2.length()) {
            flip = true;
        } else if (end1.length() > end2.length()) {
            flip = false;
        } else {
            boolean bl = flip = end1.compareTo(end2) < 0;
        }
        if (flip) {
            return Pair.makePair(b, a);
        }
        return Pair.makePair(a, b);
    }

    private Edit proposeEdit(ProposalRandom proposalRandom, String current, String target) {
        if (current.equals(target) && proposalRandom.sampleBern(this.proposalStopProbability)) {
            return null;
        }
        if (this.stringModel == null) {
            throw new RuntimeException("Need to call InformedProposal.setModel() first");
        }
        List<Edit> candidates = this.stringModel.rates(current).getFirst();
        double[] proposalProbabilities = this.proposalProbabilities(current, target, candidates);
        int index = proposalRandom.sampleMultinomial(proposalProbabilities);
        return candidates.get(index);
    }

    private double[] proposalProbabilities(String current, String target, List<Edit> candidates) {
        double[] result = new double[candidates.size()];
        double currentEditD = BayesRiskMinimizer.computeDist(current, target);
        for (int i = 0; i < candidates.size(); ++i) {
            double improvement;
            String proposed = candidates.get((int)i).newSeq;
            double updatedEditD = BayesRiskMinimizer.computeDist(proposed, target);
            result[i] = improvement = currentEditD - updatedEditD;
        }
        result = this.proposalWeights(result);
        return result;
    }

    private double[] proposalWeights(double[] improvement) {
        double[] result = new double[improvement.length];
        Counter<Integer> counts = null;
        Counter<Integer> probabs = null;
        if (this.propMethod == PropMethod.STRATIFIED) {
            counts = new Counter<Integer>();
            for (int i = 0; i < improvement.length; ++i) {
                counts.incrementCount((int)improvement[i], 1.0);
            }
            probabs = new Counter<Integer>();
            double cProb = 0.4;
            for (int i = 3; i > 0; --i) {
                double cCount = counts.getCount(i);
                if (cCount != 0.0) {
                    probabs.setCount(i, cProb / cCount);
                }
                cProb /= 2.0;
            }
        }
        for (int i = 0; i < improvement.length; ++i) {
            if (this.propMethod == PropMethod.SIMPLE_GREEDY) {
                result[i] = Math.exp(this.proposalGreed * improvement[i]);
                continue;
            }
            if (this.propMethod == PropMethod.TWO_PARAMETERS) {
                double sign;
                if (this.alpha <= 0.0 || this.beta <= 0.0) {
                    throw new RuntimeException();
                }
                double coef = 0.0;
                double abs = Math.abs(improvement[i]);
                double d = sign = improvement[i] == 0.0 ? 0.0 : improvement[i] / abs;
                if (abs == 0.0) {
                    coef = 0.0;
                } else if (abs == 1.0) {
                    coef = this.alpha * sign;
                } else if (abs > 1.0) {
                    coef = (this.alpha + this.beta) * sign;
                } else {
                    throw new RuntimeException();
                }
                result[i] = Math.exp(coef);
                continue;
            }
            if (this.propMethod == PropMethod.STRATIFIED) {
                result[i] = probabs.getCount((int)improvement[i]);
                continue;
            }
            throw new RuntimeException();
        }
        NumUtils.normalize(result);
        return result;
    }

    public static enum PropMethod {
        SIMPLE_GREEDY,
        TWO_PARAMETERS,
        STRATIFIED;

    }
}

