/*
 * Decompiled with CFR 0.152.
 */
package gep;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.StrUtils;
import fig.exec.Execution;
import gep.comparisons.EMForCTMC;
import gep.comparisons.FlatReconstructionMethod;
import gep.comparisons.HeldoutReconstructionMethod;
import gep.comparisons.IncrementalReconstructionMethod;
import gep.data.DataSource;
import gep.data.GeneratedData;
import gep.data.LoadFlatDataFormat;
import gep.data.MSData;
import gep.data.RnaDataset;
import gep.pmcmc.PMCMC;
import gep.timeseries.Measurements;
import gep.timeseries.Series;
import gep.util.OutputManager;
import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import nuts.io.CSV;
import nuts.io.IO;
import nuts.lang.StringUtils;
import nuts.util.CollUtils;
import nuts.util.Indexer;
import pty.smc.ParticleFilter;

public class GEPMain
implements Runnable {
    @Option
    public int maxNSeries = 100;
    @Option
    public DataType dataType = DataType.GENERATED;
    @Option
    public double heldOutPr = 0.1;
    @Option
    public Random heldoutRandom = new Random(1L);
    @Option
    public Random samplingRand = new Random(1L);
    @Option
    public int maxMCMCIters = 2500;
    @Option
    public boolean saveReconstruction = false;
    @Option
    public Method method = Method.SHGEP;
    @Option
    public String pathToHeldoutSpec = "";
    public static OutputManager outputManager = new OutputManager();
    public static Indexer<String> observedIndexer = null;
    public static GeneratedData generatedData = null;
    public static RnaDataset rnaData = null;
    private static MSData msData;
    private static LoadFlatDataFormat loadPlainTxtData;
    private static PMCMC pmcmc;
    private static EMForCTMC em;
    private List<Measurements> allAnswers = CollUtils.list();
    private List<Measurements> rawObservations = CollUtils.list();
    private IncrementalReconstructionMethod system;
    private List<Set<Double>> loadedHeldoutTimes = null;

    public static int nChars() {
        return observedIndexer.size();
    }

    public static void main(String[] args) {
        pmcmc = new PMCMC();
        em = new EMForCTMC();
        PMCMC.pf = new ParticleFilter();
        generatedData = new GeneratedData();
        rnaData = new RnaDataset();
        msData = new MSData();
        loadPlainTxtData = new LoadFlatDataFormat();
        IO.run(args, new GEPMain(), "pmcmc", pmcmc, "em", em, "pf", PMCMC.pf, "gen", generatedData, "rna", rnaData, "msdata", msData, "paintxt", loadPlainTxtData);
    }

    @Override
    public void run() {
        if (this.method == Method.SHGEP) {
            this.system = pmcmc;
        } else if (this.method == Method.EM) {
            this.system = em;
        } else {
            throw new RuntimeException();
        }
        File reconstructionsDir = new File(Execution.getFile("reconstructions"));
        reconstructionsDir.mkdir();
        ArrayList<String> observations = CollUtils.list(this.dataSource().possibleObservations());
        Collections.sort(observations);
        observedIndexer = new Indexer<String>(observations);
        this.initData();
        PrintWriter out = null;
        if (this.saveReconstruction) {
            out = IOUtils.openOutHard(new File(reconstructionsDir, "FLAT_MODEL"));
        }
        double baseLineError = GEPMain.reconstructionError(out, FlatReconstructionMethod.mleFlatRecon(this.rawObservations), this.allAnswers);
        if (this.saveReconstruction) {
            out.close();
        }
        outputManager.printWrite("errors", "Method", "FLAT_MODEL", "Iter", 0, "Error", baseLineError);
        this.system.init(this.samplingRand);
        for (int mcmcIter = 0; mcmcIter < this.maxMCMCIters; ++mcmcIter) {
            LogInfo.track((Object)("Iteration " + mcmcIter + "/" + this.maxMCMCIters), true);
            this.system.iterate(this.samplingRand, mcmcIter);
            if (this.saveReconstruction) {
                out = IOUtils.openOutHard(new File(reconstructionsDir, "" + (Object)((Object)this.method) + "-" + mcmcIter));
            }
            double curError = GEPMain.reconstructionError(out, this.system, this.allAnswers);
            if (this.saveReconstruction) {
                out.close();
            }
            outputManager.printWrite("errors", new Object[]{"Method", this.method, "Iter", mcmcIter, "Error", curError});
            LogInfo.end_track();
        }
        outputManager.close();
    }

    public static double reconstructionError(PrintWriter out, HeldoutReconstructionMethod method, List<Measurements> allAnswers) {
        if (out != null) {
            out.println(CSV.header("seriesIndex", "queryPointIndex", "truthIndex", "reconstructedIndex"));
        }
        double nEval = 0.0;
        double nErrors = 0.0;
        for (int sIdx = 0; sIdx < allAnswers.size(); ++sIdx) {
            Measurements curAns = allAnswers.get(sIdx);
            nEval += (double)curAns.size();
            for (int qIdx = 0; qIdx < curAns.size(); ++qIdx) {
                int guess;
                int answer = curAns.getValue(qIdx);
                if (answer != (guess = method.reconstruct(sIdx, qIdx))) {
                    nErrors += 1.0;
                }
                if (out == null) continue;
                out.println(CSV.body(sIdx, qIdx, answer, guess));
            }
        }
        return nErrors / nEval;
    }

    private void initData() {
        Measurements currentMeasurement;
        this.loadHelout();
        PrintWriter output = this.openOut();
        PrintWriter heldoutOutput = this.heldoutOpenOut();
        int nSeq = 0;
        int nDat = 0;
        int nHeld = 0;
        ArrayList<Series> allObservations = CollUtils.list();
        for (int i = 0; i < this.maxNSeries && (currentMeasurement = this.dataSource().next(observedIndexer)) != null; ++i) {
            ArrayList<Double> times = CollUtils.list();
            ArrayList<Integer> values = CollUtils.list();
            ArrayList<Double> queryTimes = CollUtils.list();
            ArrayList<Integer> answers = CollUtils.list();
            this.holdout(currentMeasurement, times, values, queryTimes, answers, i);
            this.saveData(output, heldoutOutput, i, currentMeasurement, queryTimes);
            ++nSeq;
            nDat += currentMeasurement.size();
            nHeld += queryTimes.size();
            Measurements currentO = new Measurements(times, values);
            this.rawObservations.add(currentO);
            allObservations.add(new Series(currentO, queryTimes));
            Measurements currentQ = new Measurements(queryTimes, answers);
            this.allAnswers.add(currentQ);
        }
        this.system.loadData(allObservations, GEPMain.nChars());
        LogInfo.logsForce("nSequencesLoaded: " + nSeq);
        LogInfo.logsForce("nDatapointsLoaded: " + nDat);
        LogInfo.logsForce("nPointsHeldout: " + nHeld);
        output.close();
        heldoutOutput.close();
    }

    public boolean usingLoadedHeldoutSpec() {
        return this.pathToHeldoutSpec != null && !this.pathToHeldoutSpec.equals("");
    }

    private void loadHelout() {
        if (this.usingLoadedHeldoutSpec()) {
            LogInfo.logs("Holding out using loaded file:" + this.pathToHeldoutSpec);
            this.loadedHeldoutTimes = CollUtils.list();
            int index = 0;
            for (String line : IO.i(this.pathToHeldoutSpec)) {
                if (LoadFlatDataFormat.IGNORE_LINE_MATCH.matcher(line).matches()) continue;
                String[] mainSplit = line.split("[:]");
                if (mainSplit.length != 1 && mainSplit.length != 2) {
                    throw new RuntimeException();
                }
                int specifiedSeqIdx = Integer.parseInt(StringUtils.selectFirstRegex("Sequence\\s+([0-9]+)$", mainSplit[0]));
                if (specifiedSeqIdx != index) {
                    throw new RuntimeException();
                }
                String[] subSplit = mainSplit.length > 1 ? mainSplit[1].split("\\s+") : new String[]{};
                HashSet currentSet = CollUtils.set();
                for (String field : subSplit) {
                    currentSet.add(Double.parseDouble(field));
                }
                this.loadedHeldoutTimes.add(currentSet);
                ++index;
            }
        } else {
            LogInfo.logs("Holding out with probability " + this.heldOutPr + " and using random seed " + this.heldoutRandom.nextLong());
        }
    }

    private void saveData(PrintWriter output, PrintWriter heldoutOutput, int i, Measurements currentMeasurement, List<Double> queryTimes) {
        output.println("# Sequence " + i);
        output.println(StrUtils.join(currentMeasurement.getTimes(), " "));
        for (int value : currentMeasurement.getValues()) {
            output.append(observedIndexer.i2o(value) + " ");
        }
        output.println();
        heldoutOutput.println("Sequence " + i + ":" + StrUtils.join(queryTimes, " "));
    }

    private PrintWriter openOut() {
        PrintWriter out = IOUtils.openOutHard(Execution.getFile("data.txt"));
        out.println("# Time series represented by space-separated line doublets:");
        out.println("#   Line 1: The times of the measurements");
        out.println("#   Line 2: The states at these times (any string that does not contain space/tabs)");
        out.println("# Line starting by # or empty lines are ignored");
        out.println();
        return out;
    }

    private PrintWriter heldoutOpenOut() {
        PrintWriter out = IOUtils.openOutHard(Execution.getFile("heldout.txt"));
        out.println("# Heldout points for each series:");
        out.println("#   Format for each line: Sequence [n]: a space separated list of times, exactly as shown in the series file (NOT indices) heldout for sequence [n]");
        out.println("# Line starting by # or empty lines are ignored");
        out.println("");
        return out;
    }

    private DataSource dataSource() {
        if (this.dataType == DataType.GENERATED) {
            return generatedData;
        }
        if (this.dataType == DataType.RNA) {
            return rnaData;
        }
        if (this.dataType == DataType.MS) {
            return msData;
        }
        if (this.dataType == DataType.LOAD_PLAIN_TXT) {
            return loadPlainTxtData;
        }
        throw new RuntimeException();
    }

    private void holdout(Measurements allMeasurements, List<Double> times, List<Integer> values, List<Double> queryTimes, List<Integer> answers, int seqIdx) {
        HashSet<Double> allTimes = CollUtils.set(allMeasurements.getTimes());
        if (this.usingLoadedHeldoutSpec() && !allTimes.containsAll((Collection)this.loadedHeldoutTimes.get(seqIdx))) {
            throw new RuntimeException();
        }
        for (int currentSqIdx = 0; currentSqIdx < allMeasurements.size(); ++currentSqIdx) {
            int currentObsState = allMeasurements.getValue(currentSqIdx);
            double currentTime = allMeasurements.getTime(currentSqIdx);
            if (this.shouldHoldOut(currentSqIdx, allMeasurements.size(), currentTime, seqIdx)) {
                queryTimes.add(currentTime);
                answers.add(currentObsState);
                continue;
            }
            times.add(currentTime);
            values.add(currentObsState);
        }
    }

    private boolean shouldHoldOut(int currentSqIdx, int nMeas, double currentTime, int seqIdx) {
        boolean isLast;
        boolean bl = isLast = currentSqIdx == nMeas - 1;
        if (this.usingLoadedHeldoutSpec()) {
            boolean isInSet = this.loadedHeldoutTimes.get(seqIdx).contains(currentTime);
            if (isLast && isInSet) {
                LogInfo.warning("Currently, holding out the last item in the sequence is not supported for implementation reasons.  Ignoring held out time " + currentTime + " in sequence index " + seqIdx);
                return false;
            }
            return isInSet;
        }
        return this.heldoutRandom.nextDouble() < this.heldOutPr && !isLast;
    }

    public static enum DataType {
        GENERATED,
        RNA,
        MS,
        LOAD_PLAIN_TXT;

    }

    public static enum Method {
        SHGEP,
        EM;

    }
}

