/*
 * Decompiled with CFR 0.152.
 */
package conifer.pip;

import conifer.pip.LinearizedAlignment;
import conifer.pip.PIPLikelihoodCalculator;
import conifer.pip.simple.PIPProcess;
import ev.poi.PoissonParameters;
import fig.basic.LogInfo;
import fig.basic.Pair;
import fig.basic.UnorderedPair;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.RateMatrixLoader;
import nuts.math.Sampling;
import nuts.util.Counter;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import org.junit.Assert;
import org.junit.Test;
import pty.RandomRootedTrees;
import pty.RandomUnrootedTrees;
import pty.RootedTree;
import pty.UnrootedTree;

public class PIPLikelihoodTests {
    private static final int MAX_BOOTSTRAP_REJ = 10000;

    public static void main(String[] args) {
        PIPLikelihoodTests tests = new PIPLikelihoodTests();
        tests.testReversibility();
        tests.testMonteCarlo();
        PIPLikelihoodTests.testSpeed(10000, 100, new Random(1L));
    }

    @Test
    public void testReversibility() {
        LogInfo.track("Test: reversibility (invariance under rerooting)");
        int nTaxa = 10;
        Random rand = new Random(1L);
        Indexer<Character> indexer = RateMatrixLoader.rnaIndexer();
        for (int i = 0; i < 10; ++i) {
            LogInfo.track("\t" + (i + 1) + "/10");
            UnrootedTree ut = RandomUnrootedTrees.sampleExponentialBranchesUniformTopology(rand, nTaxa, 10.0);
            ArrayList<Taxon> taxa = new ArrayList<Taxon>(ut.leavesSet());
            PoissonParameters pip = new PoissonParameters(RateMatrixLoader.rnaIndexer(), RateMatrixLoader.k2p(), rand.nextDouble() * 10.0, rand.nextDouble() * 10.0);
            MSAPoset msa = PIPLikelihoodTests.randomMSA(taxa, indexer, rand, 5, 10);
            LinearizedAlignment lia = new LinearizedAlignment(msa);
            List<UnorderedPair<Taxon, Taxon>> edges = ut.edges();
            boolean initialized = false;
            double reference = Double.POSITIVE_INFINITY;
            for (int k = 0; k < 10; ++k) {
                UnorderedPair<Taxon, Taxon> randomEdge = edges.get(rand.nextInt(edges.size()));
                RootedTree.RootingInfo ri = new RootedTree.RootingInfo(randomEdge.getFirst(), randomEdge.getSecond(), new Taxon("___ROOT___"), rand.nextDouble());
                RootedTree rt = ut.reRoot(ri);
                PIPLikelihoodCalculator mf = new PIPLikelihoodCalculator(pip, lia, rt);
                double curEstimate = mf.computeDataLogProbabilityGivenTree();
                if (initialized) {
                    LogInfo.logs("assertClose(" + curEstimate + "," + reference + ")");
                    Assert.assertEquals((double)curEstimate, (double)reference, (double)MathUtils.threshold);
                    continue;
                }
                reference = curEstimate;
                initialized = true;
            }
            if (!initialized) {
                throw new RuntimeException();
            }
            LogInfo.end_track();
        }
        LogInfo.end_track();
    }

    @Test
    public void testMonteCarlo() {
        for (int trial = 0; trial < 2; ++trial) {
            boolean testTest;
            boolean bl = testTest = trial == 1;
            if (testTest) {
                LogInfo.track("Test: making sure test is sensitive to errors");
            } else {
                LogInfo.track("Test: data likelihoods should be close to the MC averages of repeated generation from prior");
            }
            Random rand = new Random(1L);
            double lambda = 1.0;
            double mu = 0.5;
            double bl2 = 0.5;
            Taxon ta = new Taxon("A");
            Taxon tb = new Taxon("B");
            double error = testTest ? 0.1 : 0.0;
            PoissonParameters pip = PoissonParameters.simplePIP(lambda, mu + error);
            PIPProcess process = new PIPProcess(lambda, mu);
            HashMap<List<Map<Taxon, Integer>>, MSAPoset> backPtrs = new HashMap<List<Map<Taxon, Integer>>, MSAPoset>();
            Counter<List> mcDistribution = new Counter<List>();
            for (int i = 0; i < 10000; ++i) {
                MSAPoset align = PIPProcess.keepOnlyEndPts(process.sample(rand, bl2), ta, tb);
                List<Map<Taxon, Integer>> key = PIPLikelihoodTests.extractListOfColMaps(align);
                backPtrs.put(key, align);
                mcDistribution.incrementCount(key, 1.0);
            }
            String s = "(A:" + bl2 / 2.0 + ",B:" + bl2 / 2.0 + ");";
            RootedTree cur = RootedTree.Util.fromNewickString(s);
            Counter<List> analyticValues = new Counter<List>();
            LogInfo.track("Estimates");
            LogInfo.logsForce("MC\tExact");
            for (List key : mcDistribution) {
                double mc = mcDistribution.getCount(key) / mcDistribution.totalCount();
                LinearizedAlignment lia = new LinearizedAlignment((MSAPoset)backPtrs.get(key));
                PIPLikelihoodCalculator mf = new PIPLikelihoodCalculator(pip, lia, cur);
                double mfPR = Math.exp(mf.computeDataLogProbabilityGivenTree());
                analyticValues.setCount(key, mfPR);
                LogInfo.logs("" + mc + "\t" + mfPR);
            }
            LogInfo.end_track();
            double unseenTailMass = 1.0 - analyticValues.totalCount();
            LogInfo.logsForce("Tail mass = " + unseenTailMass);
            if (unseenTailMass < 0.0 || unseenTailMass > 0.05) {
                throw new RuntimeException("Run for more iteration");
            }
            analyticValues.setCount(null, unseenTailMass);
            HashSet<List> toKeep = new HashSet<List>();
            for (List key : analyticValues) {
                if (!(analyticValues.getCount(key) > 0.05)) continue;
                toKeep.add(key);
            }
            if (toKeep.size() < 4) {
                throw new RuntimeException();
            }
            analyticValues.normalize();
            Counter<List> tested = mcDistribution;
            if (testTest) {
                LogInfo.logsForce("In this phase, we introduced a small error intentionally to make sure we reject the null in this case. [bootstrapping...]");
            } else {
                LogInfo.logsForce("In this phase, low p-value would be of concern (null hyp H0 = analytic and MC estimates coincide) [bootstrapping...]");
            }
            double pValue = PIPLikelihoodTests.bootstrapPValue(analyticValues, tested, rand, 10000, toKeep);
            LogInfo.logsForce("p-value = " + pValue);
            if (testTest) {
                if (pValue > 0.05) {
                    throw new RuntimeException("Test seems broken");
                }
            } else if (pValue < 0.05) {
                throw new RuntimeException("There might be a problem here, investigate---but try re-run with different seed first if value is not tiny");
            }
            LogInfo.end_track();
        }
    }

    public static <S> double totalVariationDistance(Counter<S> c1, Counter<S> c2, Set<S> sensitiveRegion) {
        MathUtils.checkClose(1.0, c1.totalCount());
        MathUtils.checkClose(1.0, c2.totalCount());
        double sum = 0.0;
        for (S key : sensitiveRegion) {
            sum += Math.abs(c1.getCount(key) - c2.getCount(key));
        }
        return 0.5 * sum;
    }

    public static <S> Pair<List<S>, double[]> convert(Counter<S> dist) {
        ArrayList<S> keys = new ArrayList<S>(dist.keySet());
        double[] prs = new double[keys.size()];
        for (int i = 0; i < keys.size(); ++i) {
            Object key = keys.get(i);
            prs[i] = dist.getCount(key);
        }
        return Pair.makePair(keys, prs);
    }

    public static <S> Counter<S> bootstrap(List<S> keys, double[] prs, Random rand, int nObservations, Set<Integer> core) {
        Counter<Integer> samples = null;
        boolean success = false;
        for (int i = 0; i < 10000; ++i) {
            samples = Sampling.efficientMultinomialSampling(rand, prs, nObservations);
            if (!samples.keySet().containsAll(core)) continue;
            success = true;
            break;
        }
        if (!success) {
            throw new RuntimeException();
        }
        samples.normalize();
        Counter<S> transformed = new Counter<S>();
        for (int i = 0; i < keys.size(); ++i) {
            S key = keys.get(i);
            transformed.setCount(key, samples.getCount(i));
        }
        return transformed;
    }

    public static <S> double bootstrapPValue(Counter<S> truth, Counter<S> guess, Random rand, int n, Set<S> core) {
        guess = new Counter<S>(guess);
        int nObservations = MathUtils.safeIntValue(guess.totalCount());
        guess.normalize();
        double tObs = PIPLikelihoodTests.totalVariationDistance(truth, guess, core);
        Pair<List<S>, double[]> convertedTruth = PIPLikelihoodTests.convert(truth);
        HashSet<Integer> coreIndices = new HashSet<Integer>();
        for (int i = 0; i < convertedTruth.getFirst().size(); ++i) {
            S current = convertedTruth.getFirst().get(i);
            if (!core.contains(current)) continue;
            coreIndices.add(i);
        }
        double num = 1.0;
        double denom = 1.0;
        for (int mcSample = 0; mcSample < n; ++mcSample) {
            Counter<S> bootstrap = PIPLikelihoodTests.bootstrap(convertedTruth.getFirst(), convertedTruth.getSecond(), rand, nObservations, coreIndices);
            double cur = PIPLikelihoodTests.totalVariationDistance(truth, bootstrap, core);
            if (cur > tObs) {
                num += 1.0;
            }
            denom += 1.0;
        }
        return num / denom;
    }

    public static List<Map<Taxon, Integer>> extractListOfColMaps(MSAPoset msa) {
        ArrayList<Map<Taxon, Integer>> result = new ArrayList<Map<Taxon, Integer>>();
        for (MSAPoset.Column c : msa.linearizedColumns()) {
            result.add(c.getPoints());
        }
        return result;
    }

    private static String randomStr(Random rand, Indexer<Character> indexer, int min, int max) {
        int len = rand.nextInt(max - min) + min;
        String result = "";
        for (int i = 0; i < len; ++i) {
            result = result + indexer.i2o(rand.nextInt(indexer.size()));
        }
        return result;
    }

    public static void testSpeed(int nColumns, int nTax, Random rand) {
        int i;
        int nChars = 4;
        int nTests = 10000;
        Indexer<Character> indexer = RateMatrixLoader.rnaIndexer();
        RootedTree tree = RandomRootedTrees.sampleCoalescent(rand, nTax, 0.1);
        PoissonParameters pip = PoissonParameters.createFromAdditiveLengthIntensityParameterization(indexer, RateMatrixLoader.k2p(), nColumns, 10.0);
        HashMap<Taxon, String> seqns = new HashMap<Taxon, String>();
        List<Taxon> leaves = tree.topology().leaveContents();
        for (Taxon t : leaves) {
            StringBuilder current = new StringBuilder();
            for (int i2 = 0; i2 < nColumns; ++i2) {
                current.append(indexer.i2o(rand.nextInt(4)));
            }
            seqns.put(t, current.toString());
        }
        MSAPoset msa = new MSAPoset(seqns);
        for (i = 0; i < nColumns; ++i) {
            HashMap<Taxon, Integer> points = new HashMap<Taxon, Integer>();
            for (Taxon t : leaves) {
                points.put(t, i);
            }
            if (msa.tryAdding(points)) continue;
            throw new RuntimeException();
        }
        LogInfo.track("Speed test");
        for (i = 0; i < 10000; ++i) {
            LinearizedAlignment lia = new LinearizedAlignment(msa);
            PIPLikelihoodCalculator pipLC = new PIPLikelihoodCalculator(pip, lia, tree);
            pipLC.computeDataLogProbabilityGivenTree();
            LogInfo.logs("" + (i + 1) + "/" + 10000);
        }
        LogInfo.end_track();
    }

    public static MSAPoset randomMSA(List<Taxon> taxa, Indexer<Character> indexer, Random rand, int minStrLen, int maxStrLen) {
        HashMap<Taxon, String> sequences = new HashMap<Taxon, String>();
        for (Taxon t : taxa) {
            sequences.put(t, PIPLikelihoodTests.randomStr(rand, indexer, minStrLen, maxStrLen));
        }
        MSAPoset msa = new MSAPoset(sequences);
        for (int j = 0; j < 100; ++j) {
            List<Integer> taxaIdx = Sampling.sampleWithoutReplacement(rand, taxa.size(), 2);
            Taxon t1 = taxa.get(taxaIdx.get(0));
            Taxon t2 = taxa.get(taxaIdx.get(1));
            int s1 = rand.nextInt(((String)sequences.get(t1)).length());
            int s2 = rand.nextInt(((String)sequences.get(t2)).length());
            msa.tryAdding(new GreedyDecoder.Edge(s1, s2, t1, t2));
        }
        return msa;
    }
}

