/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.prob.SampleUtils;
import goblin.CognateId;
import goblin.CognateSet;
import goblin.DataLoaderInterface;
import goblin.DataPrepUtils;
import goblin.DerivationTree;
import goblin.HLParams;
import goblin.HLParamsLoader;
import goblin.ObservationsTracker;
import goblin.PrepareEMData;
import goblin.Taxon;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.MultiAlignment;
import nuts.io.IO;
import nuts.util.Arbre;
import nuts.util.Tree;
import pepper.Corpus;
import pepper.Encodings;
import pepper.editmodel.Utils;

public class DataLoader
implements DataLoaderInterface {
    @Option(gloss="Specification of a tree topology in lisp format, as a convenience for large tree specifications, if the string has the form \"!xyz\" then it reads from the file xyz")
    public String topoSpec;
    @Option(gloss="Source(s) for the word forms, see Corpus.java, or !generate to generate the data")
    public String wordsPath;
    public static final String GENERATE = "!generate".toLowerCase();
    @Option(gloss="Max number of words in training + test (must be specified when generating data)")
    public int maxNumberOfWords = Integer.MAX_VALUE;
    @Option(gloss="Language where word forms will be heldout or empty-string if no heldout desired")
    public String heldoutLang = "";
    @Option(gloss="Proportion in [0,1] of the words that should be held out in the specified held out lang if any")
    public double heldoutProp = 1.0;
    @Option(gloss="List of words to heldout---in which case heldout size is ignored!")
    public String heldoutListPath = "";
    @Option(gloss="Seed for generating data if applicable")
    public Random dataGenerationRand = new Random(1L);
    @Option(gloss="Seed for a shuffle of data performed before picking the forms to heldout if applicable")
    public Random dataSplitRand = new Random(1L);
    @Option(gloss="If we are generating data, should we forget the contents of nodes that neither modern nor heldout")
    public boolean removeInternalGeneratedNodes = true;
    @Option(gloss="Trees with strictly less than that much observed stuff will be discarted")
    public int minObservedLanguagePerTree = 1;
    @Option
    public String langRestrFile = "";
    @Option
    public int langRestrN = Integer.MAX_VALUE;
    @Option
    public boolean forceDNA = false;
    @Option
    public boolean forceRNA = false;
    public HLParamsLoader generationParamLoader = new HLParamsLoader();
    private Corpus rawCorpus;
    private Encodings encodings;
    private HLParams params;
    private int heldoutsize;
    private boolean loaded = false;
    private Tree<String> rawTopology;
    private List<HeldoutEntry> heldout = new ArrayList<HeldoutEntry>();
    private CognateSet cognateSet = new CognateSet();
    private Set<String> heldoutList = null;
    private Set<String> langsToRestrict;
    private CognateSet generatingCognateSet = null;

    @Override
    public List<HeldoutEntry> getHeldout() {
        if (!this.loaded) {
            this.load();
        }
        return new ArrayList<HeldoutEntry>(this.heldout);
    }

    @Override
    public List<Random> randomness() {
        ArrayList<Random> result = new ArrayList<Random>();
        result.add(this.dataGenerationRand);
        result.add(this.generationParamLoader.paramGenerationRand);
        result.add(this.dataSplitRand);
        return result;
    }

    public Tree<String> getTopology() {
        if (!this.loaded) {
            this.load();
        }
        return this.rawTopology.deepCopy();
    }

    @Override
    public CognateSet getCognateSet() {
        if (!this.loaded) {
            this.load();
        }
        return this.cognateSet;
    }

    @Override
    public HLParams getGeneratingParams() {
        if (!this.generated()) {
            throw new RuntimeException("generated() not defined");
        }
        if (!this.loaded) {
            this.load();
        }
        return this.params;
    }

    public Taxon root() {
        return new Taxon(this.rawTopology.getLabel());
    }

    @Override
    public void load() {
        if (this.heldoutLang.equals("") && this.heldoutProp != 1.0) {
            throw new RuntimeException("Bad heldout specs");
        }
        if (!this.heldoutListPath.equals("") && this.heldoutProp != 1.0) {
            throw new RuntimeException("Bad heldout specs");
        }
        this.rawTopology = DataLoader.createRawTopology(this.topoSpec);
        this.langsToRestrict = DataLoader.readLangsToRestrict(this.heldoutLang, this.langRestrFile, this.langRestrN, this.rawTopology);
        this.heldoutList = this.createHeldoutList();
        LogInfo.logs("Number of languages: " + this.rawTopology.getPreOrderTraversal().size());
        if (this.generated()) {
            this.params = this.loadParams();
        }
        this.encodings = this.loadEncodings();
        this.rawCorpus = this.createRawCorpus();
        LogInfo.logs("Number of languages with observed word forms: " + this.rawCorpus.getNLangs());
        this.heldoutsize = this.computeHeldoutSize();
        int currentHeldout = 0;
        int currentWord = 0;
        int numberOfWordsObserved = 0;
        LogInfo.track("Preparing cognate set");
        for (int row : SampleUtils.samplePermutation(this.dataSplitRand, this.rawCorpus.getNWords())) {
            LogInfo.logs("Row " + currentWord + "/" + Math.min(this.maxNumberOfWords, this.rawCorpus.getNWords()));
            CognateId id = this.cognateId(row);
            Arbre<DerivationTree.DerivationNode> current = DataPrepUtils.tree2arbre(this.rawTopology, this.rawCorpus.getWords(row));
            current = DataPrepUtils.trim(current);
            if (current.nLeaves() < this.minObservedLanguagePerTree) continue;
            if (this.hasHeldout() && DataPrepUtils.isValidForHeldout(this.rawCorpus, row, this.heldoutLang) && this.shouldHeldout(currentHeldout, this.rawCorpus.getWord(row, this.heldoutLang))) {
                this.holdout(id, current);
                ++currentHeldout;
            }
            ObservationsTracker obs = DataPrepUtils.observations(current);
            numberOfWordsObserved += obs.observedLanguages().size();
            this.cognateSet.addCognate(id, current, obs);
            if (++currentWord > this.maxNumberOfWords) break;
        }
        LogInfo.end_track();
        this.loaded = true;
        LogInfo.logss("Actual number of cognates prepared: " + this.cognateSet.size());
        LogInfo.logss("Number of words observed by sampler: " + numberOfWordsObserved);
        if (this.hasHeldout()) {
            LogInfo.logss("Actual number of cognates heldout: " + this.heldout.size());
        }
    }

    public static Set<String> readLangsToRestrict(String heldoutLang, String langRestrFile, int langRestrN, Tree<String> t) {
        HashSet<String> langsInFullTree = new HashSet<String>();
        for (Tree<String> st : t.getPostOrderTraversal()) {
            langsInFullTree.add(st.getLabel());
        }
        if (langRestrFile.equals("")) {
            return null;
        }
        HashSet<String> result = new HashSet<String>();
        if (!heldoutLang.equals("")) {
            result.add(heldoutLang);
        }
        int i = 0;
        for (String line : IO.i(langRestrFile)) {
            if (line.equals("")) continue;
            if (i >= langRestrN) break;
            String cLang = line.split("\t")[0];
            if (!langsInFullTree.contains(cLang)) continue;
            result.add(cLang);
            ++i;
        }
        if (i < langRestrN) {
            throw new RuntimeException("WARNING: less than the requested number of languages could be extracted");
        }
        return result;
    }

    private boolean shouldHeldout(int currentHeldout, String word) {
        if (this.heldoutList == null) {
            return currentHeldout < this.heldoutsize;
        }
        return this.heldoutList.contains(word);
    }

    private Set<String> createHeldoutList() {
        if (this.heldoutListPath.equals("")) {
            return null;
        }
        HashSet<String> result = new HashSet<String>();
        for (String line : IO.i(this.heldoutListPath)) {
            if (line.equals("")) continue;
            result.add(line);
        }
        return result;
    }

    private void holdout(CognateId id, Arbre<DerivationTree.DerivationNode> current) {
        Taxon lang = new Taxon(this.heldoutLang);
        Arbre<DerivationTree.DerivationNode> toheldOut = DerivationTree.findNodeByLangName(current, lang);
        String trueReconstruction = toheldOut.getContents().getWord();
        toheldOut.setContents(DataPrepUtils.forgetWord(toheldOut.getContents()));
        this.heldout.add(new HeldoutEntry(id, lang, trueReconstruction));
    }

    private CognateId cognateId(int rowIndex) {
        return this.rawCorpus.getCognateId(rowIndex);
    }

    private int computeHeldoutSize() {
        if (this.heldoutList != null) {
            return this.heldoutList.size();
        }
        if (!this.hasHeldout()) {
            return 0;
        }
        double nValidHeldOut = this.hasHeldout() ? (double)DataPrepUtils.nValidHeldoutRows(this.rawCorpus, this.heldoutLang, this.encodings.allChars()) : Double.NaN;
        double nValid = DataPrepUtils.nValidRows(this.rawCorpus, this.encodings.allChars());
        double nUsed = Math.min(nValid, (double)this.maxNumberOfWords);
        return (int)(this.hasHeldout() ? Math.min(this.heldoutProp * nUsed, nValidHeldOut) : 0.0);
    }

    private Encodings loadEncodings() {
        if (this.forceRNA && this.forceDNA) {
            throw new RuntimeException();
        }
        Encodings result = null;
        if (this.generated()) {
            result = this.params.enc;
        } else {
            result = Encodings.getGlobalEncodings();
            if (result == null) {
                try {
                    if (this.forceDNA) {
                        LogInfo.logs("DNA encodings loaded");
                        result = Encodings.dnaEncodings();
                    } else if (this.forceRNA) {
                        LogInfo.logs("RNA encodings loaded");
                        result = Encodings.rnaEncodings();
                    } else {
                        LogInfo.logs("Real encodings loaded from a library of size: " + Encodings.realEncoding().getNumberOfPhonemes());
                        result = Encodings.realEncoding(this.loadRawCorpus().allChars());
                    }
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        }
        Encodings.registerEncodings(result);
        LogInfo.logs("Encodings size: " + result.getNumberOfPhonemes());
        return result;
    }

    @Override
    public boolean generated() {
        return this.wordsPath.toLowerCase().equals(GENERATE);
    }

    private Corpus createRawCorpus() {
        if (this.generated()) {
            return this.generateCorpus();
        }
        try {
            Corpus result = this.loadRawCorpus();
            int initialSize = result.allWords().size();
            Set<Character> allChars = this.encodings.allChars();
            allChars.remove(Character.valueOf(this.encodings.boundChar()));
            result = Corpus.restrictToKnownCharacters(result, allChars);
            result = Corpus.restrictToEntriesWithAtLeastNKnownEntries(result, 2);
            LogInfo.logs("Number of words before and after rstrctToKnwnChrs(), ToLstNKnwn(): " + initialSize + " -> " + result.allWords().size());
            return result;
        }
        catch (IOException ioe) {
            throw new RuntimeException(ioe);
        }
    }

    public void saveCorpusToExec() {
        PrintWriter out = IOUtils.openOutHard(Utils.safeGetExecFilePath("corpus.txt"));
        out.append(this.rawCorpus.toString());
        out.close();
    }

    @Override
    public boolean hasHeldout() {
        return this.heldoutLang != null && !this.heldoutLang.equals("") && this.heldoutProp > 0.0;
    }

    private Taxon heldoutLang() {
        if (!this.hasHeldout()) {
            throw new RuntimeException();
        }
        return new Taxon(this.heldoutLang);
    }

    public static <T> List<T> preorder(Tree<T> topo) {
        List<Tree<T>> langTreeNames = topo.getPreOrderTraversal();
        ArrayList<T> langNames = new ArrayList<T>();
        for (Tree<T> tree : langTreeNames) {
            langNames.add(tree.getLabel());
        }
        return langNames;
    }

    public static List<String> preorder(Arbre<DerivationTree.DerivationNode> topo) {
        ArrayList<String> result = new ArrayList<String>();
        for (Arbre<DerivationTree.DerivationNode> subtree : topo.root().nodes()) {
            result.add(subtree.getContents().getWord());
        }
        return result;
    }

    @Override
    public CognateSet getGeneratingCognateSet() {
        return this.generatingCognateSet.copy();
    }

    private Corpus generateCorpus() {
        this.generatingCognateSet = new CognateSet();
        if (this.maxNumberOfWords == Integer.MAX_VALUE) {
            throw new RuntimeException("Should set maxNumberOfWords when generating");
        }
        ArrayList<List<String>> generatedWords = new ArrayList<List<String>>();
        for (int i = 0; i < this.maxNumberOfWords; ++i) {
            Arbre<DerivationTree.DerivationNode> current = this.params.generate(this.rawTopology, this.dataGenerationRand);
            generatedWords.add(DataLoader.preorder(current));
            this.generatingCognateSet.addCognate(new CognateId("Generated-" + i), current, this.removeInternalGeneratedNodes ? ObservationsTracker.modernObservationsTracker(current) : ObservationsTracker.allObservationsTracker(current));
        }
        Corpus result = new Corpus(DataLoader.preorder(this.rawTopology), generatedWords);
        if (this.removeInternalGeneratedNodes) {
            HashSet<String> set = new HashSet<String>();
            set.addAll(this.rawTopology.getYield());
            if (this.hasHeldout()) {
                set.add(this.heldoutLang);
            }
            result = Corpus.restrict(result, set);
        }
        return result;
    }

    private HLParams loadParams() {
        this.generationParamLoader.setLanguages(this.allLanguages());
        return this.generationParamLoader.getParams();
    }

    public Set<Taxon> allLanguages() {
        return DataLoader.allLanguages(this.rawTopology);
    }

    private static Set<Taxon> allLanguages(Tree<String> topo) {
        HashSet<Taxon> result = new HashSet<Taxon>();
        result.add(new Taxon(topo.getLabel()));
        for (Tree<String> child : topo.getChildren()) {
            result.addAll(DataLoader.allLanguages(child));
        }
        return result;
    }

    private Corpus loadRawCorpus() throws IOException {
        Corpus result = Corpus.parse(this.wordsPath);
        result = Corpus.restrict(result, new HashSet<String>(PrepareEMData.nodes(this.rawTopology)));
        if (this.langsToRestrict != null) {
            result = Corpus.restrict(result, this.langsToRestrict);
        }
        return result;
    }

    public static Tree<String> createRawTopology(String topoSpec) {
        String topologyString = DataPrepUtils.optionallyLoad(topoSpec);
        return DataPrepUtils.lisp2tree(topologyString);
    }

    @Override
    public boolean hasReferenceAlignments() {
        return false;
    }

    @Override
    public Map<CognateId, MultiAlignment> referenceAlignments() {
        throw new RuntimeException();
    }

    public static class HeldoutEntry
    implements Serializable {
        private static final long serialVersionUID = 1L;
        public final CognateId id;
        public final Taxon node;
        public final String trueReconstruction;

        public HeldoutEntry(CognateId id, Taxon node, String trueReconstruction) {
            this.id = id;
            this.trueReconstruction = trueReconstruction;
            this.node = node;
        }
    }
}

