/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.LogInfo;
import fig.basic.Option;
import fig.prob.SampleUtils;
import goblin.CognateId;
import goblin.CognateSet;
import goblin.DataLoader;
import goblin.DataLoaderInterface;
import goblin.DataPrepUtils;
import goblin.DerivationTree;
import goblin.HLParams;
import goblin.ObservationsTracker;
import goblin.Taxon;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import ma.BalibaseCorpus;
import ma.MultiAlignment;
import ma.SequenceType;
import ma.YAM;
import nuts.util.Arbre;
import pepper.Encodings;

public class BioDataLoaderAdaptor
implements DataLoaderInterface {
    @Option
    public double branchDiscretLength = 0.2;
    @Option
    public Random permutationRandom = new Random(1L);
    @Option
    public int maxNumberOfMSA = Integer.MAX_VALUE;
    @Option
    public ArrayList<String> restrict = new ArrayList();
    @Option
    public boolean discretizeTree = true;
    @Option
    public boolean flattenTree = false;
    @Option(gloss="used only when holding seqns out")
    public boolean stemTree = true;
    @Option
    public boolean holdoutSequences = false;
    private BalibaseCorpus bc = null;
    private BalibaseCorpus.BalibaseCorpusOptions opt;
    private List<DataLoader.HeldoutEntry> heldout = new ArrayList<DataLoader.HeldoutEntry>();
    private CognateSet cognateSet = null;

    public SequenceType getSequenceType() {
        return this.bc.getType();
    }

    public void setBioCorpusOptions(BalibaseCorpus.BalibaseCorpusOptions opt) {
        if (this.bc != null) {
            throw new RuntimeException();
        }
        this.opt = opt;
    }

    public Map<Taxon, Double> getBranchLengths(CognateId id) {
        return this.bc.getBranchLengths(id);
    }

    @Override
    public boolean generated() {
        return false;
    }

    @Override
    public CognateSet getCognateSet() {
        if (this.cognateSet == null) {
            this.load();
        }
        return this.cognateSet;
    }

    private void loadEncodings() {
        Encodings result = this.opt.sequenceType.getEncodings();
        Encodings.registerEncodings(result);
        LogInfo.logs("Encodings size: " + result.getNumberOfPhonemes());
    }

    @Override
    public List<DataLoader.HeldoutEntry> getHeldout() {
        return Collections.unmodifiableList(this.heldout);
    }

    @Override
    public boolean hasHeldout() {
        return this.heldout.size() > 0;
    }

    @Override
    public void load() {
        this.loadEncodings();
        this.bc = new BalibaseCorpus(this.opt);
        this.cognateSet = new CognateSet();
        int currentWord = 0;
        ArrayList<CognateId> ids = new ArrayList<CognateId>(this.bc.intersectedIds());
        for (int i : SampleUtils.samplePermutation(this.permutationRandom, this.bc.intersectedIds().size())) {
            if (!this.isIncluded((CognateId)ids.get(i))) continue;
            if (++currentWord > this.maxNumberOfMSA) break;
            CognateId id = (CognateId)ids.get(i);
            MultiAlignment ma = this.bc.getMultiAlignment(id);
            Arbre<DerivationTree.DerivationNode> current = DataPrepUtils.tree2arbre2(this.bc.getTopology(id), ma.getSequences());
            if (this.discretizeTree) {
                current = YAM.hobGoblinTransform(current, this.branchDiscretLength, this.bc.getBranchLengths(id));
            }
            if (this.flattenTree) {
                current = this.flattenTree(current);
            }
            ObservationsTracker obs = ObservationsTracker.modernObservationsTracker(current);
            if (this.holdoutSequences) {
                Taxon toHeldout = this.languageToHoldout(id);
                HashSet<Taxon> langs = new HashSet<Taxon>();
                langs.addAll(obs.observedLanguages());
                langs.remove(toHeldout);
                obs = new ObservationsTracker(langs);
                String truth = DerivationTree.findNodeByLangName(current, toHeldout).getContents().getWord();
                DataPrepUtils.forgetUnobserved(current, obs);
                if (this.stemTree) {
                    DataPrepUtils.stem(current, toHeldout);
                }
                this.heldout.add(new DataLoader.HeldoutEntry(id, toHeldout, truth));
                this.bc.restricMultiAlignment(id, langs);
            }
            this.cognateSet.addCognate(id, current, obs);
        }
    }

    private Taxon languageToHoldout(CognateId id) {
        Taxon argmin = null;
        double minBranchLength = Double.POSITIVE_INFINITY;
        for (Taxon lang : this.bc.getBranchLengths(id).keySet()) {
            if (!this.bc.getMultiAlignment(id).getSequences().keySet().contains(lang) || !(this.bc.getBranchLengths(id).get(lang) < minBranchLength)) continue;
            argmin = lang;
            minBranchLength = this.bc.getBranchLengths(id).get(lang);
        }
        return argmin;
    }

    private Arbre<DerivationTree.DerivationNode> flattenTree(Arbre<DerivationTree.DerivationNode> current) {
        current = current.root();
        ArrayList children = new ArrayList();
        for (Arbre<DerivationTree.DerivationNode> subtree : current.nodes()) {
            if (!subtree.isLeaf()) continue;
            children.add(Arbre.arbre(new DerivationTree.DerivationNode(subtree.getContents().getLanguage(), subtree.getContents().getWord())));
        }
        return Arbre.arbre(current.getContents(), children);
    }

    private boolean isIncluded(CognateId cognateId) {
        if (this.restrict.size() == 0) {
            return true;
        }
        return this.restrict.contains(cognateId.toString());
    }

    @Override
    public List<Random> randomness() {
        return Collections.singletonList(this.permutationRandom);
    }

    @Override
    public boolean hasReferenceAlignments() {
        return true;
    }

    @Override
    public Map<CognateId, MultiAlignment> referenceAlignments() {
        return this.bc.getMultiAlignments();
    }

    @Override
    public CognateSet getGeneratingCognateSet() {
        throw new RuntimeException();
    }

    @Override
    public HLParams getGeneratingParams() {
        throw new RuntimeException();
    }

    public Map<CognateId, MultiAlignment> getClustalwAlignments() {
        return this.bc.getClustalwAlignments();
    }
}

