/*
 * Decompiled with CFR 0.152.
 */
package conifer.ml.data;

import conifer.ml.data.HeldoutData;
import conifer.ml.tests.TestRealData;
import ev.io.CreatePairwiseData;
import fig.basic.LogInfo;
import fig.basic.Option;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.MSAParser;
import ma.MSAPoset;
import ma.RateMatrixLoader;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.Indexer;
import pty.RootedTree;
import pty.io.Dataset;

public class PhylogeneticHeldoutDataset {
    public RootedTree rootedTree;
    public Dataset obs;
    public Indexer<Character> indexer;
    public HeldoutData heldOut;

    public static PhylogeneticHeldoutDataset loadData(PhylogeneticHeldoutDatasetOptions options) {
        return PhylogeneticHeldoutDataset.loadData(options, false);
    }

    public static PhylogeneticHeldoutDataset loadData(PhylogeneticHeldoutDatasetOptions options, boolean skipHoldout) {
        if (options.holdOutFre == 0.0) {
            skipHoldout = true;
        }
        PhylogeneticHeldoutDataset result = new PhylogeneticHeldoutDataset();
        result.indexer = RateMatrixLoader.rnaIndexer();
        result.rootedTree = RootedTree.Util.incrementSmallBranches(RootedTree.Util.load(new File(options.treeFile)), 0.001);
        Map data = null;
        try {
            CreatePairwiseData.LightMSA lmsa = CreatePairwiseData.LightMSA.parseStrictFASTA(new File(options.alignmentFile));
            data = lmsa.sequences;
        }
        catch (Exception e) {
            MSAPoset align = MSAParser.parseMSA(new File(options.alignmentFile));
            data = align.gapPaddedSequences();
        }
        Set<Integer> indices = PhylogeneticHeldoutDataset.minFractionIndices(data, options.minFractionObserved);
        indices = PhylogeneticHeldoutDataset.restrictToMaxSize(indices, options.maxNSites, options.rand);
        data = PhylogeneticHeldoutDataset.restrictSites(data, indices);
        if (skipHoldout) {
            result.obs = PhylogeneticHeldoutDataset.buildRNAObservations(data);
        } else {
            result.heldOut = new HeldoutData(result.indexer);
            Random rand = new Random(1L);
            result.obs = PhylogeneticHeldoutDataset.buildRNAObservations(data);
            result.heldOut.holdout(result.obs.observations(), options.holdOutFre, rand);
        }
        return result;
    }

    public static Dataset buildRNAObservations(Map<Taxon, ? extends CharSequence> data) {
        HashMap<Taxon, double[][]> result = CollUtils.map();
        LogInfo.track("Building RNA tips");
        int i = 1;
        for (Taxon t : data.keySet()) {
            LogInfo.logs("" + i++ + "/" + data.keySet().size());
            result.put(t, RateMatrixLoader.buildRNATip(data.get(t)));
        }
        LogInfo.end_track();
        return new TestRealData.SimpleObservations(result);
    }

    public static Set<Integer> minFractionIndices(Map<Taxon, ? extends CharSequence> data, double minFraction) {
        Counter<Integer> counts = new Counter<Integer>();
        for (CharSequence charSequence : data.values()) {
            for (int i = 0; i < charSequence.length(); ++i) {
                if (charSequence.charAt(i) == '-') continue;
                counts.incrementCount(i, 1.0);
            }
        }
        double nSeq = data.size();
        HashSet<Integer> result = CollUtils.set();
        for (int i = 0; i < CollUtils.pick(data.values()).length(); ++i) {
            if (!(counts.getCount(i) / nSeq >= minFraction)) continue;
            result.add(i);
        }
        return result;
    }

    public static Set<Integer> restrictToMaxSize(Set<Integer> original, int maxSize, Random rand) {
        if (maxSize >= original.size()) {
            return original;
        }
        ArrayList<Integer> items = CollUtils.list(original);
        Collections.sort(items);
        Collections.shuffle(items, rand);
        return CollUtils.set(items.subList(0, maxSize));
    }

    public static Map restrictSites(Map<Taxon, ? extends CharSequence> data, Set<Integer> indices) {
        HashMap<Taxon, StringBuilder> result = new HashMap<Taxon, StringBuilder>();
        for (Taxon taxon : data.keySet()) {
            CharSequence originalStr = data.get(taxon);
            StringBuilder newStr = new StringBuilder();
            for (int i = 0; i < originalStr.length(); ++i) {
                if (!indices.contains(i)) continue;
                newStr.append(originalStr.charAt(i));
            }
            result.put(taxon, newStr);
        }
        LogInfo.logsForce("Restricted nSites: " + CollUtils.pick(data.values()).length() + " -> " + ((StringBuilder)CollUtils.pick(result.values())).length());
        return result;
    }

    public static class PhylogeneticHeldoutDatasetOptions {
        @Option(gloss="Path to fasta or msf file")
        public String alignmentFile = "/Users/bouchard/Documents/data/utcs/16S.B.ALL/R0/cleaned.alignment.fasta";
        @Option(gloss="Path to newick tree")
        public String treeFile = "/Users/bouchard/Documents/data/utcs/16S.B.ALL.raxml.nwk";
        @Option(gloss="Fraction of points of the alignment to hold out(after all other filtering steps are done)")
        public double holdOutFre = 0.0;
        @Option(gloss="Remove sites that do not have at least that fractionobserved")
        public double minFractionObserved = 0.0;
        @Option(gloss="Maximum number of sites to load (pick at random without replacement if more are in the dataset (applied after removing sites with not enough data in them)")
        public int maxNSites = Integer.MAX_VALUE;
        @Option(gloss="Seed to pick random subset")
        public Random rand = new Random(1L);
    }
}

