/*
 * Decompiled with CFR 0.152.
 */
package conifer.ssm;

import conifer.pip.simple.PIPString;
import conifer.ssm.BirthDeathExperiment;
import conifer.ssm.ImportanceSampler;
import fig.basic.Option;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.math.ProposalRandom;
import nuts.math.RateMtxUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.junit.Assert;
import org.junit.Test;

public class SimplePotentialExperiments {
    private static final int MAX_SAMPLE_TRIES = 100000;
    public static final Potential<Integer> INTEGER_POTENTIAL = new Potential<Integer>(){

        @Override
        public double get(Integer proposed, Integer target) {
            return Math.abs(proposed - target);
        }
    };
    public static final String SPECIAL_SYMBOL = "___TARGET___";
    private static final int MAX_PROPOSE_ATTEMPTS = 1000000;

    public static void main(String[] args) {
        BirthDeathExperiment.SimpleBD bd = new BirthDeathExperiment.SimpleBD(2.0, 3.0, 100);
        ProposalRandom pRand = new ProposalRandom(new Random(1L));
        double t = 1.0;
        double[][] Q = SimplePotentialExperiments.getRateMatrix(bd, bd.indexer());
        double[][] exp = MatrixFunctions.expm((DoubleMatrix)new DoubleMatrix(Q).mul(1.0)).toArray2();
        System.out.println();
        Random rand = new Random(4234234L);
        for (int i = 0; i < 100; ++i) {
            Path<Integer> p = SimplePotentialExperiments.forwardSample(rand, bd, rand.nextInt(10), 1.0);
            BirthDeathExperiment.Datum<Integer> datum = p.toEndPointDatum();
            PotProposal<Integer> proposal = new PotProposal<Integer>(bd, INTEGER_POTENTIAL, new PotPropOptions());
            ImportanceSampler<Integer> is = new ImportanceSampler<Integer>(proposal, bd);
            System.out.println("exact = " + exp[(Integer)datum.startPoint][(Integer)datum.endPoint]);
            for (boolean automatic : new boolean[]{true, false}) {
                System.out.println("Automatic = " + automatic);
                ((PotProposal)proposal).options.automatic = automatic;
                for (int nPart = 10; nPart < 100001; nPart *= 10) {
                    is.nParticles = nPart;
                    Counter samples = is.sample((Integer)datum.startPoint, (Integer)datum.endPoint, datum.length);
                    System.out.print("approx(" + nPart + ") = " + is.estimateZ(samples));
                    System.out.println(" [" + samples.keySet().contains(p.jumps()) + "]");
                }
            }
            System.out.println("truePath: " + p.jumps());
            System.out.println();
        }
    }

    public static <S> double[][] getRateMatrix(SparseProcess<S> p, Indexer<S> index) {
        int size = index.size();
        double[][] result = new double[size][size];
        for (int i = 0; i < size; ++i) {
            S x = index.i2o(i);
            Counter<S> rates = p.rates(x);
            for (S y : rates.keySet()) {
                if (x.equals(y)) {
                    throw new RuntimeException();
                }
                result[i][index.o2i(y)] = rates.getCount(y);
            }
        }
        RateMtxUtils.fillRateMatrixDiagonalEntries(result);
        return result;
    }

    public static <S> Path<S> forwardSample(Random rand, BirthDeathExperiment.Process<S> p, S startPoint, double deltaT) {
        ArrayList<Pair<S, Double>> result = new ArrayList<Pair<S, Double>>();
        double s = 0.0;
        boolean success = false;
        S current = startPoint;
        for (int i = 0; i < 100000; ++i) {
            Pair<S, Double> sample = BirthDeathExperiment.sampleJumpWaitingTime(rand, current, p);
            s += sample.getSecond().doubleValue();
            double curT = sample.getSecond();
            if (s > deltaT) {
                curT -= s - deltaT;
            }
            result.add(Pair.makePair(current, curT));
            if (s > deltaT) {
                success = true;
                break;
            }
            current = sample.getFirst();
        }
        if (!success) {
            throw new RuntimeException();
        }
        return new Path(result);
    }

    @Test
    public void testPIPotential() {
        PIPPotential pot = new PIPPotential();
        PIPString current = new PIPString("0 -1 -1 0 -1 0");
        PIPString target = new PIPString("0       0  1 0");
        Assert.assertEquals((double)4.0, (double)pot.get(current, target), (double)0.0);
        target = new PIPString("0 0 1 0 0");
        Assert.assertTrue((boolean)Double.isInfinite(pot.get(current, target)));
        current = new PIPString("0 -1 -1 0 -1 0 -1");
        target = new PIPString("0       0  1 0 1 1");
        Assert.assertEquals((double)7.0, (double)pot.get(current, target), (double)0.0);
    }

    public static <S> List<S> propose(SparseProcess<S> process, Potential<S> pot, ProposalRandom pRand, S firstEndPoint, S lastEndPoint, Object specialSymbol, double greed, double stopPr) {
        ArrayList<S> result = new ArrayList<S>();
        S current = firstEndPoint;
        result.add(current);
        if (current.equals(lastEndPoint) && pRand.sampleBern(stopPr)) {
            return result;
        }
        boolean success = false;
        for (int i = 0; i < 1000000; ++i) {
            Counter<S> trPr = process.rates(current);
            trPr.normalize();
            SimplePotentialExperiments.buildPotProposal(pot, trPr, current, lastEndPoint, pRand.rand, specialSymbol, greed, stopPr);
            S next = pRand.sampleMultinomial(trPr);
            if (next.equals(specialSymbol)) {
                result.add(lastEndPoint);
                success = true;
                break;
            }
            result.add(next);
            current = next;
        }
        if (!success) {
            throw new RuntimeException();
        }
        return result;
    }

    public static <S> void buildPotProposal(Potential<S> pot, Counter<S> transitionPrs, S current, S target, Random rand, Object specialSymbol, double greed, double stopPr) {
        SimplePotentialExperiments.potPropDistort(transitionPrs, pot, current, target, greed);
        SimplePotentialExperiments.addTarget(transitionPrs, target, specialSymbol, stopPr);
    }

    public static <S> void addTarget(Counter transitionPrs, S target, Object specialSymbol, double stopPr) {
        if (!MathUtils.isProb(stopPr)) {
            throw new RuntimeException();
        }
        if (transitionPrs.keySet().contains(specialSymbol)) {
            throw new RuntimeException();
        }
        double curPr = transitionPrs.getCount(target);
        if (curPr == 0.0) {
            return;
        }
        transitionPrs.setCount(specialSymbol, curPr * stopPr);
        transitionPrs.setCount(target, curPr * (1.0 - stopPr));
    }

    public static <S> void potPropDistort(Counter<S> transitionPrs, Potential<S> pot, S current, S target, double greed) {
        MathUtils.checkClose(1.0, transitionPrs.totalCount());
        double initialPotential = pot.get(current, target);
        if (current == null || target == null) {
            throw new RuntimeException();
        }
        if (greed < 0.5 || greed > 1.0) {
            throw new RuntimeException();
        }
        double pGood = 0.0;
        double pBad = 0.0;
        for (S key : transitionPrs.keySet()) {
            double delta = pot.get(key, target) - initialPotential;
            double pr = transitionPrs.getCount(key);
            if (delta == -1.0) {
                pGood += pr;
            } else if (delta == 1.0 || delta == 0.0) {
                pBad += pr;
            } else if (!Double.isInfinite(delta)) {
                throw new RuntimeException("\nd(" + key + "," + target + ") = " + pot.get(key, target) + "\nd(" + current + "," + target + ") = " + pot.get(current, target));
            }
            if (!Double.isInfinite(delta)) continue;
            transitionPrs.setCount(key, 0.0);
        }
        if (pGood == 0.0 || pBad == 0.0) {
            transitionPrs.normalize();
            return;
        }
        double alpha = Math.max(greed, pGood);
        for (S key : transitionPrs.keySet()) {
            double delta = pot.get(key, target) - initialPotential;
            double pr = transitionPrs.getCount(key);
            double newValue = Double.isInfinite(delta) ? 0.0 : (delta == -1.0 ? alpha / pGood : (1.0 - alpha) / pBad) * pr;
            transitionPrs.setCount(key, newValue);
        }
        transitionPrs.normalize();
    }

    public static class PotProposal<S>
    implements BirthDeathExperiment.Proposal<S> {
        private final SparseProcess<S> process;
        private final Potential<S> potential;
        private final PotPropOptions options;

        public PotProposal(SparseProcess<S> process, Potential<S> potential, PotPropOptions options) {
            this.process = process;
            this.potential = potential;
            this.options = options;
        }

        @Override
        public Pair<List<S>, Double> propose(Random rand, S x, S y, double t) {
            ProposalRandom pRand = new ProposalRandom(rand);
            double greed = this.options.greed;
            double stopPr = this.options.stopPr;
            if (this.options.automatic) {
                greed = this.options.randPr(rand);
                stopPr = this.options.randPr(rand);
            }
            List<S> proposed = SimplePotentialExperiments.propose(this.process, this.potential, pRand, x, y, this.options.specialSymbol, greed, stopPr);
            return Pair.makePair(proposed, Math.exp(pRand.getLogProbability()));
        }
    }

    public static class PIPPotential
    implements Potential<PIPString> {
        @Override
        public double get(PIPString proposed, PIPString target) {
            int targetNZero = target.zeroes();
            int[] targetPlusses = target.plusses();
            int curNPlus = 0;
            int interZeroIdx = 0;
            int pot = 0;
            for (int curChar : proposed.characters) {
                if (curChar == 0) {
                    pot += Math.abs(targetPlusses[interZeroIdx] - curNPlus);
                    curNPlus = 0;
                    ++interZeroIdx;
                    continue;
                }
                if (curChar == 1) {
                    ++curNPlus;
                    continue;
                }
                if (curChar == -1) {
                    ++pot;
                    continue;
                }
                throw new RuntimeException();
            }
            pot += Math.abs(targetPlusses[interZeroIdx] - curNPlus);
            if (interZeroIdx != targetNZero) {
                return Double.POSITIVE_INFINITY;
            }
            return pot;
        }
    }

    public static class PotPropOptions {
        @Option
        public boolean automatic = true;
        @Option
        public double stopPr = 0.95;
        @Option
        public double greed = 0.6666666666666666;
        @Option
        public String specialSymbol = "___TARGET___";

        public double randPr(Random rand) {
            double rInt = rand.nextInt(4) + 2;
            return 1.0 - Math.pow(0.5, rInt);
        }
    }

    public static interface Potential<S> {
        public double get(S var1, S var2);
    }

    public static class Path<S> {
        private final List<Pair<S, Double>> path;
        private double length;

        private Path(List<Pair<S, Double>> path) {
            this.path = path;
            for (Pair<S, Double> item : path) {
                this.length += item.getSecond().doubleValue();
            }
        }

        public BirthDeathExperiment.Datum<S> toEndPointDatum() {
            return new BirthDeathExperiment.Datum<S>(this.path.get(0).getFirst(), this.path.get(this.path.size() - 1).getFirst(), this.length);
        }

        public String toString() {
            return "Path [path=" + this.path + ", length=" + this.length + "]";
        }

        public List<S> jumps() {
            ArrayList<S> result = new ArrayList<S>();
            for (Pair<S, Double> pair : this.path) {
                result.add(pair.getFirst());
            }
            return result;
        }
    }

    public static interface SparseProcess<S> {
        public Counter<S> rates(S var1);
    }
}

