/*
 * Decompiled with CFR 0.152.
 */
package ev.poi;

import ev.hmm.SimpleAligner;
import ev.poi.MSAMarginalLikelihoodCalculator;
import fig.basic.Option;
import fig.basic.StrUtils;
import goblin.DerivationTree;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;
import ma.BalibaseCorpus;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.MultiAlignment;
import ma.RateMatrixLoader;
import nuts.lang.ArrayUtils;
import nuts.math.Sampling;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import pty.RootedTree;

public class LargeStepHomologySampler {
    @Option
    public static int nThreads = 1;

    public static boolean largeHomologySamplingStep(Set<Taxon> group1, MSAPoset msa, MSAMarginalLikelihoodCalculator calculator, Random rand, boolean maximize) {
        boolean accept;
        if (group1.size() == 0 || group1.equals(msa.sequences().keySet())) {
            return true;
        }
        double oldLL = calculator.marginalLogLikelihood(msa);
        ArrayList<Map<Taxon, Integer>> topLinearizedTaxon2Pos = CollUtils.list();
        ArrayList<Map<Taxon, Integer>> botLinearizedTaxon2Pos = CollUtils.list();
        DerivationTree.Derivation oldDeriv = LargeStepHomologySampler.originalDerivation(msa, group1, topLinearizedTaxon2Pos, botLinearizedTaxon2Pos);
        SimpleAligner aligner = LargeStepHomologySampler.createAligner(calculator, msa.sequences(), topLinearizedTaxon2Pos, botLinearizedTaxon2Pos);
        DerivationTree.Derivation newDeriv = maximize ? aligner.viterbi() : aligner.sample(rand);
        Collection<GreedyDecoder.Edge> removed = LargeStepHomologySampler.removeConnections(msa, group1);
        LargeStepHomologySampler.addConnections(msa, newDeriv, topLinearizedTaxon2Pos, botLinearizedTaxon2Pos);
        double newLL = calculator.marginalLogLikelihood(msa);
        if (maximize) {
            accept = newLL > oldLL;
        } else {
            double logRatio = newLL + aligner.pathLogProbability(oldDeriv) - oldLL - aligner.pathLogProbability(newDeriv);
            double acceptPr = Sampling.min1exp(logRatio);
            boolean bl = accept = rand.nextDouble() < acceptPr;
        }
        if (!accept) {
            LargeStepHomologySampler.removeConnections(msa, group1);
            for (GreedyDecoder.Edge e : removed) {
                if (msa.tryAdding(e)) continue;
                throw new RuntimeException();
            }
        }
        return accept;
    }

    public static boolean largeHomologySamplingStep(Taxon lang, MSAPoset msa, MSAMarginalLikelihoodCalculator calculator, Random rand, boolean maximize) {
        Set<Taxon> gp1 = Collections.singleton(lang);
        return LargeStepHomologySampler.largeHomologySamplingStep(gp1, msa, calculator, rand, maximize);
    }

    public static SimpleAligner createAligner(MSAMarginalLikelihoodCalculator calculator, Map<Taxon, String> sequences, List<Map<Taxon, Integer>> topLinearizedTaxon2Pos, List<Map<Taxon, Integer>> botLinearizedTaxon2Pos) {
        double[] topUnalignedLogWeights = LargeStepHomologySampler.unalignedLogWeights(sequences, topLinearizedTaxon2Pos, calculator);
        double[] botUnalignedLogWeights = LargeStepHomologySampler.unalignedLogWeights(sequences, botLinearizedTaxon2Pos, calculator);
        double[][] alignedLogWeights = LargeStepHomologySampler.alignedLogWeights(sequences, topLinearizedTaxon2Pos, botLinearizedTaxon2Pos, calculator);
        return new SimpleAligner(botUnalignedLogWeights, topUnalignedLogWeights, alignedLogWeights);
    }

    public static Collection<GreedyDecoder.Edge> removeConnections(MSAPoset msa, Set<Taxon> group1) {
        ArrayList<GreedyDecoder.Edge> edges = CollUtils.list();
        ArrayList<MSAPoset.Column> columns = CollUtils.list(msa.columns());
        for (MSAPoset.Column c : columns) {
            HashSet<Taxon> intersected = CollUtils.set(group1);
            intersected.retainAll(c.getPoints().keySet());
            if (!MSAPoset.isValidSplit(c, intersected)) continue;
            Taxon inGp = null;
            Taxon notInGp = null;
            for (Taxon l : c.getPoints().keySet()) {
                if (group1.contains(l)) {
                    inGp = l;
                    continue;
                }
                notInGp = l;
            }
            if (inGp == null || notInGp == null) {
                throw new RuntimeException();
            }
            edges.add(new GreedyDecoder.Edge(c.getPoints().get(inGp), c.getPoints().get(notInGp), inGp, notInGp));
            msa.split(c, intersected);
        }
        return edges;
    }

    private static double[] unalignedLogWeights(Map<Taxon, String> sequences, List<Map<Taxon, Integer>> linearizedTaxon2Pos, MSAMarginalLikelihoodCalculator calculator) {
        double[] result = new double[linearizedTaxon2Pos.size()];
        for (int i = 0; i < result.length; ++i) {
            result[i] = calculator.columnLogLikelihood(calculator.convertToIndices(sequences, linearizedTaxon2Pos.get(i)));
        }
        return result;
    }

    private static double[] columnWithoutResampledLanguageLogWeights(List<Map<Taxon, Integer>> convertedColumns, MSAMarginalLikelihoodCalculator calculator) {
        double[] result = new double[convertedColumns.size()];
        for (int i = 0; i < convertedColumns.size(); ++i) {
            result[i] = calculator.columnLogLikelihood(convertedColumns.get(i));
        }
        return result;
    }

    public static DerivationTree.Derivation originalDerivation(MSAPoset msa, Set<Taxon> group1, List<Map<Taxon, Integer>> topLinearizedTaxon2Pos, List<Map<Taxon, Integer>> botLinearizedTaxon2Pos) {
        ArrayList<Integer> ancestorsList = CollUtils.list();
        int currentReindexedTopPos = 0;
        int currentReindexedBotPos = 0;
        for (MSAPoset.Column c : msa.linearizedColumns()) {
            Map<Taxon, Integer> top = LargeStepHomologySampler.extract(c.getPoints(), group1, true);
            Map<Taxon, Integer> bot = LargeStepHomologySampler.extract(c.getPoints(), group1, false);
            if (!top.isEmpty() && !bot.isEmpty()) {
                ancestorsList.add(currentReindexedTopPos);
            }
            if (top.isEmpty() && !bot.isEmpty()) {
                ancestorsList.add(-1);
            }
            if (!top.isEmpty()) {
                topLinearizedTaxon2Pos.add(top);
                ++currentReindexedTopPos;
            }
            if (bot.isEmpty()) continue;
            botLinearizedTaxon2Pos.add(bot);
            ++currentReindexedBotPos;
        }
        return new DerivationTree.Derivation(ArrayUtils.integerCollection2Array(ancestorsList), StrUtils.repeat("*", topLinearizedTaxon2Pos.size()), StrUtils.repeat("*", botLinearizedTaxon2Pos.size()));
    }

    public static void addConnections(MSAPoset msa, DerivationTree.Derivation newDeriv, List<Map<Taxon, Integer>> topLinearizedTaxon2Pos, List<Map<Taxon, Integer>> botLinearizedTaxon2Pos) {
        for (int botReindexedPos = 0; botReindexedPos < newDeriv.getCurrentWord().length(); ++botReindexedPos) {
            if (!newDeriv.hasAncestor(botReindexedPos)) continue;
            int topReindexedPos = newDeriv.ancestor(botReindexedPos);
            Taxon topLang = CollUtils.pick(topLinearizedTaxon2Pos.get(topReindexedPos).keySet());
            Taxon botLang = CollUtils.pick(botLinearizedTaxon2Pos.get(botReindexedPos).keySet());
            GreedyDecoder.Edge current = new GreedyDecoder.Edge(topLinearizedTaxon2Pos.get(topReindexedPos).get(topLang), botLinearizedTaxon2Pos.get(botReindexedPos).get(botLang), topLang, botLang);
            if (msa.tryAdding(current)) continue;
            throw new RuntimeException("Inconsistent state: could not add edge:" + current);
        }
    }

    private static Map<Taxon, Integer> extract(Map<Taxon, Integer> allPoints, Set<Taxon> topGp, boolean isTop) {
        HashMap<Taxon, Integer> result = CollUtils.map(allPoints);
        if (isTop) {
            result.keySet().retainAll(topGp);
        } else {
            result.keySet().removeAll(topGp);
        }
        return result;
    }

    private static double[][] alignedLogWeights(Map<Taxon, String> sequences, List<Map<Taxon, Integer>> topLinearizedTaxon2Pos, List<Map<Taxon, Integer>> botLinearizedTaxon2Pos, MSAMarginalLikelihoodCalculator calculator) {
        int T = topLinearizedTaxon2Pos.size();
        int B = botLinearizedTaxon2Pos.size();
        double[][] result = new double[T][B];
        for (int topPos = 0; topPos < T; ++topPos) {
            for (int botPos = 0; botPos < B; ++botPos) {
                HashMap<Taxon, Integer> merged = CollUtils.map();
                merged.putAll(calculator.convertToIndices(sequences, topLinearizedTaxon2Pos.get(topPos)));
                merged.putAll(calculator.convertToIndices(sequences, botLinearizedTaxon2Pos.get(botPos)));
                result[topPos][botPos] = calculator.columnLogLikelihood(merged);
            }
        }
        return result;
    }

    public static void main(String[] args) {
        Random rand = new Random(1L);
        nThreads = 20;
        BalibaseCorpus.BalibaseCorpusOptions baliopt = new BalibaseCorpus.BalibaseCorpusOptions();
        baliopt.referenceAlignmentsPath.clear();
        for (String arg : args) {
            baliopt.referenceAlignmentsPath.add(arg);
        }
        BalibaseCorpus bc = new BalibaseCorpus(baliopt);
        double[][] subRates = RateMatrixLoader.hky85();
        Indexer<Character> indexer = RateMatrixLoader.rnaIndexer();
        double _insertRate = 1.0;
        double _delRate = 1.0;
        double nIterFactor = 10.0;
        double nBurnFactor = 2.0;
        try {
            Counter<GreedyDecoder.Edge> edgeCounter = new Counter<GreedyDecoder.Edge>();
            MultiAlignment _ma = MultiAlignment.parse("data/gutell/processed/vertebrata/align.msf");
            ArrayList<Taxon> restr = CollUtils.list();
            int n = 0;
            for (Taxon lang : _ma.getSequences().keySet()) {
                if (n++ >= 10) continue;
                restr.add(lang);
            }
            _ma = _ma.restrict(restr);
            System.out.println("Parsed align");
            MSAPoset msa = MSAPoset.fromMultiAlignmentObject(_ma);
            RootedTree rt = RootedTree.Util.incrementSmallBranches(RootedTree.Util.load(new File("data/gutell/processed/vertebrata-mltree/tree.newick")), 0.005);
            rt = RootedTree.Util.restrict(rt, CollUtils.set(restr));
            MSAMarginalLikelihoodCalculator calc = null;
            System.out.println(_ma.sumOfPairsScore(msa.toMultiAlignmentObject()));
            System.out.println(msa);
            System.out.println(StrUtils.repeat("-", 50));
            double init = calc.marginalLogLikelihood(msa);
            ArrayList<Taxon> langs = CollUtils.list(msa.sequences().keySet());
            int i = 0;
            while ((double)i < 10.0) {
                Collections.shuffle(langs, rand);
                for (Taxon aLang : langs) {
                    Set<Taxon> proposed = Collections.singleton(aLang);
                    System.out.println("Current proposal set:" + proposed);
                    LargeStepHomologySampler.largeHomologySamplingStep(proposed, msa, calc, rand, false);
                    System.out.println(_ma.sumOfPairsScore(msa.toMultiAlignmentObject()) + "\t" + (calc.marginalLogLikelihood(msa) - init));
                    System.out.println(msa);
                    if (!((double)i > 2.0)) continue;
                    edgeCounter.incrementAll(msa.edges(), 1.0);
                    MSAPoset decoded = new MSAPoset(msa.sequences());
                    for (GreedyDecoder.Edge e : edgeCounter) {
                        decoded.tryAdding(e);
                    }
                    System.out.println("MBR decoded:" + _ma.sumOfPairsScore(decoded.toMultiAlignmentObject()));
                    System.out.println(decoded);
                }
                ++i;
            }
        }
        catch (NoSuchElementException noSuchElementException) {
            // empty catch block
        }
    }

    public static boolean __old_implementation_largeHomologySamplingStep(Taxon resampledTaxon, MSAPoset msa, MSAMarginalLikelihoodCalculator calculator, Random rand) {
        throw new RuntimeException();
    }

    public static Integer[] indexString(String str, Indexer<Character> index) {
        Integer[] result = new Integer[str.length()];
        for (int i = 0; i < str.length(); ++i) {
            char cur = str.charAt(i);
            result[i] = index.containsObject(Character.valueOf(cur)) ? Integer.valueOf(index.o2i(Character.valueOf(cur))) : null;
        }
        return result;
    }

    private static DerivationTree.Derivation originalDerivation(List<MSAPoset.Column> linearizedColumns, int length, Taxon resampledTaxon) {
        int c;
        int[] ancestors = new int[linearizedColumns.size()];
        for (c = 0; c < ancestors.length; ++c) {
            ancestors[c] = -1;
        }
        for (c = 0; c < ancestors.length; ++c) {
            if (!linearizedColumns.get(c).getPoints().containsKey(resampledTaxon)) continue;
            ancestors[c] = linearizedColumns.get(c).getPoints().get(resampledTaxon);
        }
        String top = StrUtils.repeat("*", length);
        String bot = StrUtils.repeat("*", ancestors.length);
        return new DerivationTree.Derivation(ancestors, top, bot);
    }

    private static Collection<GreedyDecoder.Edge> removeConnections(MSAPoset msa, Taxon resampledTaxon) {
        Set<Taxon> split = Collections.singleton(resampledTaxon);
        ArrayList<MSAPoset.Column> initialColumns = CollUtils.list(msa.columns());
        ArrayList<GreedyDecoder.Edge> edges = CollUtils.list();
        for (MSAPoset.Column c : initialColumns) {
            if (!c.getPoints().containsKey(resampledTaxon) || c.getPoints().size() <= 1) continue;
            int pos = c.getPoints().get(resampledTaxon);
            Taxon other = LargeStepHomologySampler.pickDifferentThan(c.getPoints().keySet(), resampledTaxon);
            int otherPos = c.getPoints().get(other);
            edges.add(new GreedyDecoder.Edge(pos, otherPos, resampledTaxon, other));
            msa.split(c, split);
        }
        return edges;
    }

    public static <T> T pickDifferentThan(Collection<T> collection, T elt) {
        for (T t : collection) {
            if (elt.equals(t)) continue;
            return t;
        }
        return null;
    }

    private static void addConnections(MSAPoset msa, DerivationTree.Derivation sampledIndices, List<Map<Taxon, Integer>> cols, Taxon resampledTaxon) {
        for (int i = 0; i < sampledIndices.getCurrentWord().length(); ++i) {
            int otherIndex;
            if (!sampledIndices.hasAncestor(i)) continue;
            Taxon anchorLang = CollUtils.pick(cols.get(i).keySet());
            int anchorIndex = cols.get(i).get(anchorLang);
            GreedyDecoder.Edge e = new GreedyDecoder.Edge(anchorIndex, otherIndex = sampledIndices.ancestor(i), anchorLang, resampledTaxon);
            if (msa.tryAdding(e)) continue;
            throw new RuntimeException();
        }
    }

    private static List<Map<Taxon, Integer>> removeLinksColumns(List<MSAPoset.Column> linearizedColumns, Taxon resampledTaxon) {
        ArrayList<Map<Taxon, Integer>> result = CollUtils.list();
        for (MSAPoset.Column c : linearizedColumns) {
            HashMap<Taxon, Integer> copy = CollUtils.map(c.getPoints());
            copy.keySet().remove(resampledTaxon);
            result.add(copy);
        }
        return result;
    }

    private static List<Map<Taxon, Integer>> convertLinksColumns(List<Map<Taxon, Integer>> removed, MSAPoset msa, Indexer<Character> indexer) {
        ArrayList<Map<Taxon, Integer>> result = CollUtils.list();
        for (Map<Taxon, Integer> cur : removed) {
            Map<Taxon, Integer> converted = MSAMarginalLikelihoodCalculator.convertToIndices(msa.sequences(), cur, indexer);
            result.add(converted);
        }
        return result;
    }

    private static List<MSAPoset.Column> linearizeAndRemoveEmptyProjectedColumns(MSAPoset msa, Taxon resampledTaxon) {
        throw new RuntimeException();
    }

    private static SimpleAligner getAligner(Taxon resampledTaxon, List<Map<Taxon, Integer>> convertedColumns, Integer[] convertedSequence, MSAMarginalLikelihoodCalculator calculator) {
        double[] columnWithoutResampledLanguageLogWeights = LargeStepHomologySampler.columnWithoutResampledLanguageLogWeights(convertedColumns, calculator);
        double[] columnSingletonLogWeights = LargeStepHomologySampler.columnSingletonLogWeights(resampledTaxon, convertedSequence, calculator);
        double[][] mergedLogWeights = LargeStepHomologySampler.mergedLogWeights(resampledTaxon, convertedColumns, convertedSequence, calculator);
        return new SimpleAligner(columnWithoutResampledLanguageLogWeights, columnSingletonLogWeights, mergedLogWeights);
    }

    private static double[][] mergedLogWeights(Taxon resampledTaxon, List<Map<Taxon, Integer>> cols, Integer[] seq, MSAMarginalLikelihoodCalculator calculator) {
        int S = seq.length;
        int C = cols.size();
        double[][] result = new double[S][C];
        for (int s = 0; s < S; ++s) {
            for (int c = 0; c < C; ++c) {
                Map<Taxon, Integer> current = cols.get(c);
                current.put(resampledTaxon, seq[s]);
                result[s][c] = calculator.columnLogLikelihood(current);
                current.remove(resampledTaxon);
            }
        }
        return result;
    }

    private static double[] columnSingletonLogWeights(Taxon resampledTaxon, Integer[] convertedSequence, MSAMarginalLikelihoodCalculator calculator) {
        double[] result = new double[convertedSequence.length];
        for (int i = 0; i < convertedSequence.length; ++i) {
            HashMap<Taxon, Integer> singleton = CollUtils.map();
            singleton.put(resampledTaxon, convertedSequence[i]);
            result[i] = calculator.columnLogLikelihood(singleton);
        }
        return result;
    }
}

