/*
 * Decompiled with CFR 0.152.
 */
package conifer.pip.simple;

import conifer.pip.simple.PIPString;
import conifer.ssm.BirthDeathExperiment;
import conifer.ssm.SimplePotentialExperiments;
import ev.poi.PoissonParameters;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import goblin.Taxon;
import java.util.HashMap;
import java.util.Random;
import ma.GreedyDecoder;
import ma.MSAPoset;
import nuts.math.Sampling;
import nuts.util.Counter;

public class PIPProcess
implements SimplePotentialExperiments.SparseProcess<PIPString>,
BirthDeathExperiment.Process<PIPString> {
    public final double lambda;
    public final double mu;
    public static final int MAX_FWD_SAMPLE_STEPS = 100000000;

    public PIPProcess(double lambda, double mu) {
        this.lambda = lambda;
        this.mu = mu;
    }

    @Override
    public double holdRate(PIPString x) {
        return this.rates(x).totalCount();
    }

    @Override
    public double transitionProbability(PIPString x, PIPString y) {
        Counter<PIPString> rates = this.rates(x);
        rates.normalize();
        double result = rates.getCount(y);
        if (result == 0.0) {
            System.out.println("WARNING!");
        }
        return result;
    }

    @Override
    public PIPString sample(Random rand, PIPString x) {
        Counter<PIPString> rates = this.rates(x);
        rates.normalize();
        return Sampling.sampleCounter(rates, rand);
    }

    public MSAPoset sampleStationary(Random rand) {
        return this.createInitMSA(PoissonParameters.repeat(PoissonParameters.star, (int)SampleUtils.samplePoisson(rand, this.lambda / this.mu)));
    }

    public MSAPoset createInitMSA(String str) {
        Taxon firstId = PIPProcess.indexedTaxon(0);
        HashMap<Taxon, String> seqns = new HashMap<Taxon, String>();
        seqns.put(firstId, str);
        return new MSAPoset(seqns);
    }

    public static Taxon indexedTaxon(int idx) {
        String idxStr = "" + idx;
        int initLen = idxStr.length();
        for (int i = 0; i < 4 - initLen; ++i) {
            idxStr = '0' + idxStr;
        }
        return new Taxon("seq-" + idxStr);
    }

    @Override
    public MSAPoset sample(Random rand, double branchLength) {
        MSAPoset init = this.sampleStationary(rand);
        return this.sample(rand, init, branchLength);
    }

    public MSAPoset sample(Random rand, MSAPoset init, double branchLength) {
        Pair<MSAPoset, Double> sampled;
        if (branchLength <= 0.0) {
            throw new RuntimeException();
        }
        MSAPoset curMSA = init;
        double lengthConsumed = 0.0;
        for (int i = 0; i < 100000000 && !((lengthConsumed += (sampled = this.sample(rand, curMSA)).getSecond().doubleValue()) > branchLength); ++i) {
            curMSA = sampled.getFirst();
        }
        return curMSA;
    }

    public static MSAPoset keepOnlyEndPts(MSAPoset msa, Taxon first, Taxon second) {
        Taxon tF = PIPProcess.indexedTaxon(0);
        Taxon tL = PIPProcess.indexedTaxon(msa.sequences().size() - 1);
        HashMap<Taxon, String> newSeqns = new HashMap<Taxon, String>();
        newSeqns.put(first, msa.sequences().get(tF));
        newSeqns.put(second, msa.sequences().get(tL));
        MSAPoset result = new MSAPoset(newSeqns);
        for (MSAPoset.Column c : msa.columns()) {
            if (!c.getPoints().containsKey(tF) || !c.getPoints().containsKey(tL) || result.tryAdding(new GreedyDecoder.Edge(c.getPoints().get(tF), c.getPoints().get(tL), first, second))) continue;
            throw new RuntimeException();
        }
        return result;
    }

    @Override
    public Pair<MSAPoset, Double> sample(Random rand, MSAPoset current) {
        int i;
        Taxon lastId = PIPProcess.indexedTaxon(current.sequences().size() - 1);
        String lastSeq = current.sequences().get(lastId);
        int len = lastSeq.length();
        double rate = this.lambda + this.mu * (double)len;
        double insPr = this.lambda / rate;
        double time = Sampling.sampleExponential(rand, 1.0 / rate);
        boolean isIns = Sampling.sampleBern(insPr, rand);
        HashMap<Taxon, String> seqns = new HashMap<Taxon, String>(current.sequences());
        String newSeq = PoissonParameters.repeat(PoissonParameters.star, len + (isIns ? 1 : -1));
        Taxon newId = PIPProcess.indexedTaxon(current.sequences().size());
        seqns.put(newId, newSeq);
        int pos = rand.nextInt(len + (isIns ? 1 : 0));
        MSAPoset newMSA = new MSAPoset(seqns);
        for (MSAPoset.Column c : current.columns()) {
            if (newMSA.tryAdding(c)) continue;
            throw new RuntimeException();
        }
        for (i = 0; i < pos; ++i) {
            if (newMSA.tryAdding(new GreedyDecoder.Edge(i, i, lastId, newId))) continue;
            throw new RuntimeException();
        }
        if (isIns) {
            for (i = pos; i < len; ++i) {
                if (newMSA.tryAdding(new GreedyDecoder.Edge(i, i + 1, lastId, newId))) continue;
                throw new RuntimeException();
            }
        } else {
            for (i = pos + 1; i < len; ++i) {
                if (newMSA.tryAdding(new GreedyDecoder.Edge(i, i - 1, lastId, newId))) continue;
                throw new RuntimeException();
            }
        }
        return Pair.makePair(newMSA, time);
    }

    @Override
    public Counter<PIPString> rates(PIPString point) {
        Counter<PIPString> rates = new Counter<PIPString>();
        double nInsPoints = point.characters.size() + 1;
        int i = 0;
        while ((double)i < nInsPoints) {
            rates.incrementCount(new PIPString(point.characters.subList(0, i), 1, point.characters.subList(i, point.characters.size())), this.lambda / nInsPoints);
            ++i;
        }
        for (i = 0; i < point.characters.size(); ++i) {
            rates.incrementCount(new PIPString(point.characters.subList(0, i), point.characters.subList(i + 1, point.characters.size())), this.mu);
        }
        return rates;
    }
}

