/*
 * Decompiled with CFR 0.152.
 */
package gep.data;

import fig.basic.LogInfo;
import fig.basic.Option;
import gep.data.DataSource;
import gep.timeseries.Measurements;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import ma.MSAParser;
import ma.MSAPoset;
import ma.RateMatrixLoader;
import ma.SequenceType;
import nuts.math.GMFct;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Indexer;
import pepper.Encodings;
import pty.RootedTree;
import pty.io.Dataset;
import pty.learn.DiscreteBP;
import pty.smc.models.CTMC;

public class RnaDataset
implements DataSource {
    @Option
    public String alignmentFile;
    @Option
    public String treeFile;
    @Option
    public int minNTrans = 2;
    private List<Measurements> _data = null;
    private int pointer = 0;

    private void ensureLoaded() {
        if (this._data != null) {
            return;
        }
        this._data = CollUtils.list();
        RootedTree tree = RootedTree.Util.load(new File(this.treeFile));
        MSAPoset msa = MSAParser.parseMSA(new File(this.alignmentFile));
        Dataset observations = Dataset.DatasetUtils.fromAlignment(new File(this.alignmentFile), SequenceType.RNA);
        int nSites = msa.columns().size();
        CTMC.SimpleCTMC ctmc = new CTMC.SimpleCTMC(RateMatrixLoader.k2p(), nSites);
        LogInfo.track("Loading RNA data");
        for (int s = 0; s < nSites; ++s) {
            LogInfo.logs("Site " + (s + 1) + "/" + nSites);
            GMFct<Taxon> post = DiscreteBP.posteriorMarginalTransitions(tree, ctmc, observations, s);
            for (Arbre<Taxon> leaf : tree.topology().leaves()) {
                Measurements current = this.createTimeSeries(leaf, post, tree);
                if (this.nTrans(current) < this.minNTrans) continue;
                this._data.add(current);
            }
        }
        LogInfo.end_track();
    }

    private int nTrans(Measurements current) {
        int result = 0;
        for (int i = 0; i < current.size() - 1; ++i) {
            if (current.getValue(i) == current.getValue(i + 1)) continue;
            ++result;
        }
        return result;
    }

    private Measurements createTimeSeries(Arbre<Taxon> node, GMFct<Taxon> post, RootedTree tree) {
        ArrayList<Double> times = CollUtils.list();
        ArrayList<Integer> states = CollUtils.list();
        double curTime = 0.0;
        do {
            int argmax = -1;
            double max = Double.NEGATIVE_INFINITY;
            for (int state = 0; state < post.nStates(node.getContents()); ++state) {
                double cur = post.get(node.getContents(), state);
                if (!(cur > max)) continue;
                max = cur;
                argmax = state;
            }
            times.add(curTime);
            states.add(argmax);
            curTime += tree.branchLengths().get(node.getContents()).doubleValue();
        } while (!(node = node.getParent()).isRoot());
        Measurements result = new Measurements(times, states);
        return result;
    }

    @Override
    public Set<String> possibleObservations() {
        Indexer<Character> charIndex = Encodings.rnaEncodings().nonGapCharactersIndexer();
        HashSet<String> strIndex = CollUtils.set();
        for (int i = 0; i < charIndex.size(); ++i) {
            strIndex.add("" + charIndex.i2o(i));
        }
        return strIndex;
    }

    @Override
    public Measurements next(Indexer<String> indexer) {
        this.ensureLoaded();
        if (this.pointer++ < this._data.size()) {
            return this._data.get(this.pointer - 1);
        }
        return null;
    }
}

