#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
namespace py = pybind11;
using namespace std;

#include "main.h"
#include "version.h"

PYBIND11_MODULE(dataset_builder_2,m) {
  m.def("encode", &encode);

  m.def("version", &version);

  m.def("getEncoderSize", &getEncoderSize);
  m.def("getEncoderType", &getEncoderType);
  m.def("getEncoder", &getEncoder);

  m.def("sgbf", &sgbf);

  m.def("check_type", &check_type);

  m.def("prune_tracks", &prune_tracks);
  m.def("prune_tracks_and_bars", &prune_tracks_and_bars);
  m.def("remove_track", &remove_track);
  m.def("remove_tracks", &remove_tracks);
  m.def("clear_track", &clear_track);
  m.def("add_track", &add_track);
  m.def("select_segment", &select_segment);
  m.def("update_segments", &update_segments);

  m.def("partial_copy_json", &partial_copy_json);

  py::class_<Jagged>(m, "Jagged")
    .def(py::init<string &>())
    .def("set_seed", &Jagged::set_seed)
    .def("set_num_bars", &Jagged::set_num_bars)
    .def("set_max_tracks", &Jagged::set_max_tracks)
    .def("set_max_seq_len", &Jagged::set_max_seq_len)
    .def("enable_write", &Jagged::enable_write)
    .def("enable_read", &Jagged::enable_read)
    .def("append", &Jagged::append)
    .def("read", &Jagged::read)
    .def("read_bytes", &Jagged::read_bytes)
    .def("read_json", &Jagged::read_json)
    .def("read_batch", &Jagged::read_batch)
    .def("read_batch_w_continue", &Jagged::read_batch_w_continue)
    .def("close", &Jagged::close)
    .def("get_size", &Jagged::get_size)
    .def("get_split_size", &Jagged::get_split_size);

  py::class_<TOKEN>(m, "TOKEN")
    .def(py::init<vector<pair<TOKEN_TYPE,int>>>())
    .def("encode", &TOKEN::encode)
    .def("decode", &TOKEN::decode)
    .def("shift", &TOKEN::shift)
    .def("max_token", &TOKEN::max_token)
    .def("get_domain", &TOKEN::get_domain);

  py::class_<REPRESENTATION>(m, "REPRESENTATION")
    .def(py::init<vector<vector<pair<TOKEN_TYPE,int>>>,const char*>())
    .def("decode", &REPRESENTATION::decode)
    .def("is_token_type", &REPRESENTATION::is_token_type)
    .def("encode", &REPRESENTATION::encode)
    .def("encode_to_one_hot", &REPRESENTATION::encode_to_one_hot)
    .def("shift", &REPRESENTATION::shift)
    .def("get_domain", &REPRESENTATION::get_domain)
    .def("show", &REPRESENTATION::show)
    .def("pretty", &REPRESENTATION::pretty)
    .def("where", &REPRESENTATION::where)
    .def("where_values", &REPRESENTATION::where_values)
    .def("max_token", &REPRESENTATION::max_token);

py::class_<EncoderConfig>(m, "EncoderConfig")
  .def(py::init<>())
  .def_readwrite("do_fill", &EncoderConfig::do_fill)
  .def_readwrite("do_multi_fill", &EncoderConfig::do_multi_fill)
  .def_readwrite("do_track_shuffle", &EncoderConfig::do_track_shuffle)
  .def_readwrite("force_instrument", &EncoderConfig::force_instrument)
  .def_readwrite("mark_polyphony", &EncoderConfig::mark_polyphony)
  .def_readwrite("mark_density", &EncoderConfig::mark_density)
  .def_readwrite("mark_time_sigs", &EncoderConfig::mark_time_sigs)
  .def_readwrite("instrument_header", &EncoderConfig::instrument_header)
  .def_readwrite("use_velocity_levels", &EncoderConfig::use_velocity_levels)
  .def_readwrite("genre_header", &EncoderConfig::genre_header)
  .def_readwrite("piece_header", &EncoderConfig::piece_header)
  .def_readwrite("bar_major", &EncoderConfig::bar_major)
  .def_readwrite("force_four_four", &EncoderConfig::force_four_four)
  .def_readwrite("segment_mode", &EncoderConfig::segment_mode)
  .def_readwrite("force_valid", &EncoderConfig::force_valid)
  .def_readwrite("transpose", &EncoderConfig::transpose)
  .def_readwrite("seed", &EncoderConfig::seed)
  .def_readwrite("segment_idx", &EncoderConfig::segment_idx)
  .def_readwrite("fill_track", &EncoderConfig::fill_track)
  .def_readwrite("fill_bar", &EncoderConfig::fill_bar)
  .def_readwrite("max_tracks", &EncoderConfig::max_tracks)
  .def_readwrite("resolution", &EncoderConfig::resolution)
  .def_readwrite("default_tempo", &EncoderConfig::default_tempo)
  .def_readwrite("num_bars", &EncoderConfig::num_bars)
  .def_readwrite("min_tracks", &EncoderConfig::min_tracks)
  .def_readwrite("fill_percentage", &EncoderConfig::fill_percentage)
  .def_readwrite("multi_fill", &EncoderConfig::multi_fill)
  .def_readwrite("genre_tags", &EncoderConfig::genre_tags);

py::enum_<TOKEN_TYPE>(m, "TOKEN_TYPE", py::arithmetic())
  .value("PIECE_START", TOKEN_TYPE::PIECE_START)
  .value("NOTE_ONSET", TOKEN_TYPE::NOTE_ONSET)
  .value("NOTE_OFFSET", TOKEN_TYPE::NOTE_OFFSET)
  .value("PITCH", TOKEN_TYPE::PITCH)
  .value("NON_PITCH", TOKEN_TYPE::NON_PITCH)
  .value("VELOCITY", TOKEN_TYPE::VELOCITY)
  .value("TIME_DELTA", TOKEN_TYPE::TIME_DELTA)
  .value("INSTRUMENT", TOKEN_TYPE::INSTRUMENT)
  .value("BAR", TOKEN_TYPE::BAR)
  .value("BAR_END", TOKEN_TYPE::BAR_END)
  .value("TRACK", TOKEN_TYPE::TRACK)
  .value("TRACK_END", TOKEN_TYPE::TRACK_END)
  .value("DRUM_TRACK", TOKEN_TYPE::DRUM_TRACK)
  .value("FILL_IN", TOKEN_TYPE::FILL_IN)
  .value("HEADER", TOKEN_TYPE::HEADER)
  .value("VELOCITY_LEVEL", TOKEN_TYPE::VELOCITY_LEVEL)
  .value("GENRE", TOKEN_TYPE::GENRE)
  .value("DENSITY_LEVEL", TOKEN_TYPE::DENSITY_LEVEL)
  .value("TIME_SIGNATURE", TOKEN_TYPE::TIME_SIGNATURE)
  .value("SEGMENT", TOKEN_TYPE::SEGMENT)
  .value("SEGMENT_END", TOKEN_TYPE::SEGMENT_END)
  .value("SEGMENT_FILL_IN", TOKEN_TYPE::SEGMENT_FILL_IN)
  .export_values();

py::enum_<ENCODER_TYPE>(m, "ENCODER_TYPE", py::arithmetic())
  .value("TRACK_ENCODER", ENCODER_TYPE::TRACK_ENCODER)
  .value("TRACK_BAR_MAJOR_ENCODER", ENCODER_TYPE::TRACK_BAR_MAJOR_ENCODER)
  .value("TRACK_GENRE_ENCODER", ENCODER_TYPE::TRACK_GENRE_ENCODER)
  .value("TRACK_VELOCITY_ENCODER", ENCODER_TYPE::TRACK_VELOCITY_ENCODER)
  .value("TRACK_VELOCITY_LEVEL_ENCODER", ENCODER_TYPE::TRACK_VELOCITY_LEVEL_ENCODER)
  .value("TRACK_ONE_BAR_FILL_ENCODER", ENCODER_TYPE::TRACK_ONE_BAR_FILL_ENCODER)
  .value("TRACK_MONO_POLY_ENCODER", ENCODER_TYPE::TRACK_MONO_POLY_ENCODER)
  .value("TRACK_INST_HEADER_ENCODER", ENCODER_TYPE::TRACK_INST_HEADER_ENCODER)
  .value("TRACK_ONE_TWO_THREE_BAR_FILL_ENCODER", ENCODER_TYPE::TRACK_ONE_TWO_THREE_BAR_FILL_ENCODER)
  .value("TRACK_BAR_FILL_ENCODER", ENCODER_TYPE::TRACK_BAR_FILL_ENCODER)
  .value("TRACK_MONO_POLY_DENSITY_ENCODER", ENCODER_TYPE::TRACK_MONO_POLY_DENSITY_ENCODER)
  .value("SEGMENT_ENCODER", ENCODER_TYPE::SEGMENT_ENCODER)
  .value("TRACK_BAR_FILL_SIXTEEN_ENCODER", ENCODER_TYPE::TRACK_BAR_FILL_SIXTEEN_ENCODER)
  .value("TRACK_DENSITY_ENCODER", ENCODER_TYPE::TRACK_DENSITY_ENCODER)
  .value("TRACK_DENSITY_VELOCITY_ENCODER", ENCODER_TYPE::TRACK_DENSITY_VELOCITY_ENCODER)
  .value("TRACK_BAR_FILL_DENSITY_ENCODER", ENCODER_TYPE::TRACK_BAR_FILL_DENSITY_ENCODER)
  .value("TRACK_BAR_FILL_DENSITY_VELOCITY_ENCODER", ENCODER_TYPE::TRACK_BAR_FILL_DENSITY_VELOCITY_ENCODER)
  .value("NO_ENCODER", ENCODER_TYPE::NO_ENCODER)
  .export_values();

py::class_<TrackEncoder>(m, "TrackEncoder")
  .def(py::init<>())
  .def("encode", &TrackEncoder::encode)
  .def("decode", &TrackEncoder::decode)
  .def("midi_to_json", &TrackEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackEncoder::json_to_midi)
  .def("json_to_tokens", &TrackEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackEncoder::rep);

py::class_<TrackBarMajorEncoder>(m, "TrackBarMajorEncoder")
  .def(py::init<>())
  .def("encode", &TrackBarMajorEncoder::encode)
  .def("decode", &TrackBarMajorEncoder::decode)
  .def("midi_to_json", &TrackBarMajorEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackBarMajorEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackBarMajorEncoder::json_to_midi)
  .def("json_to_tokens", &TrackBarMajorEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackBarMajorEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackBarMajorEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackBarMajorEncoder::rep);

py::class_<TrackGenreEncoder>(m, "TrackGenreEncoder")
  .def(py::init<>())
  .def("encode", &TrackGenreEncoder::encode)
  .def("decode", &TrackGenreEncoder::decode)
  .def("midi_to_json", &TrackGenreEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackGenreEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackGenreEncoder::json_to_midi)
  .def("json_to_tokens", &TrackGenreEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackGenreEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackGenreEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackGenreEncoder::rep);

py::class_<TrackVelocityEncoder>(m, "TrackVelocityEncoder")
  .def(py::init<>())
  .def("encode", &TrackVelocityEncoder::encode)
  .def("decode", &TrackVelocityEncoder::decode)
  .def("midi_to_json", &TrackVelocityEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackVelocityEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackVelocityEncoder::json_to_midi)
  .def("json_to_tokens", &TrackVelocityEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackVelocityEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackVelocityEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackVelocityEncoder::rep);

py::class_<TrackVelocityLevelEncoder>(m, "TrackVelocityLevelEncoder")
  .def(py::init<>())
  .def("encode", &TrackVelocityLevelEncoder::encode)
  .def("decode", &TrackVelocityLevelEncoder::decode)
  .def("midi_to_json", &TrackVelocityLevelEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackVelocityLevelEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackVelocityLevelEncoder::json_to_midi)
  .def("json_to_tokens", &TrackVelocityLevelEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackVelocityLevelEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackVelocityLevelEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackVelocityLevelEncoder::rep);

py::class_<TrackOneBarFillEncoder>(m, "TrackOneBarFillEncoder")
  .def(py::init<>())
  .def("encode", &TrackOneBarFillEncoder::encode)
  .def("decode", &TrackOneBarFillEncoder::decode)
  .def("midi_to_json", &TrackOneBarFillEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackOneBarFillEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackOneBarFillEncoder::json_to_midi)
  .def("json_to_tokens", &TrackOneBarFillEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackOneBarFillEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackOneBarFillEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackOneBarFillEncoder::rep);

py::class_<TrackMonoPolyEncoder>(m, "TrackMonoPolyEncoder")
  .def(py::init<>())
  .def("encode", &TrackMonoPolyEncoder::encode)
  .def("decode", &TrackMonoPolyEncoder::decode)
  .def("midi_to_json", &TrackMonoPolyEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackMonoPolyEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackMonoPolyEncoder::json_to_midi)
  .def("json_to_tokens", &TrackMonoPolyEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackMonoPolyEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackMonoPolyEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackMonoPolyEncoder::rep);

py::class_<TrackInstHeaderEncoder>(m, "TrackInstHeaderEncoder")
  .def(py::init<>())
  .def("encode", &TrackInstHeaderEncoder::encode)
  .def("decode", &TrackInstHeaderEncoder::decode)
  .def("midi_to_json", &TrackInstHeaderEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackInstHeaderEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackInstHeaderEncoder::json_to_midi)
  .def("json_to_tokens", &TrackInstHeaderEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackInstHeaderEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackInstHeaderEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackInstHeaderEncoder::rep);

py::class_<TrackOneTwoThreeBarFillEncoder>(m, "TrackOneTwoThreeBarFillEncoder")
  .def(py::init<>())
  .def("encode", &TrackOneTwoThreeBarFillEncoder::encode)
  .def("decode", &TrackOneTwoThreeBarFillEncoder::decode)
  .def("midi_to_json", &TrackOneTwoThreeBarFillEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackOneTwoThreeBarFillEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackOneTwoThreeBarFillEncoder::json_to_midi)
  .def("json_to_tokens", &TrackOneTwoThreeBarFillEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackOneTwoThreeBarFillEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackOneTwoThreeBarFillEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackOneTwoThreeBarFillEncoder::rep);

py::class_<TrackBarFillEncoder>(m, "TrackBarFillEncoder")
  .def(py::init<>())
  .def("encode", &TrackBarFillEncoder::encode)
  .def("decode", &TrackBarFillEncoder::decode)
  .def("midi_to_json", &TrackBarFillEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackBarFillEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackBarFillEncoder::json_to_midi)
  .def("json_to_tokens", &TrackBarFillEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackBarFillEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackBarFillEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackBarFillEncoder::rep);

py::class_<TrackMonoPolyDensityEncoder>(m, "TrackMonoPolyDensityEncoder")
  .def(py::init<>())
  .def("encode", &TrackMonoPolyDensityEncoder::encode)
  .def("decode", &TrackMonoPolyDensityEncoder::decode)
  .def("midi_to_json", &TrackMonoPolyDensityEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackMonoPolyDensityEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackMonoPolyDensityEncoder::json_to_midi)
  .def("json_to_tokens", &TrackMonoPolyDensityEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackMonoPolyDensityEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackMonoPolyDensityEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackMonoPolyDensityEncoder::rep);

py::class_<SegmentEncoder>(m, "SegmentEncoder")
  .def(py::init<>())
  .def("encode", &SegmentEncoder::encode)
  .def("decode", &SegmentEncoder::decode)
  .def("midi_to_json", &SegmentEncoder::midi_to_json)
  .def("midi_to_tokens", &SegmentEncoder::midi_to_tokens)
  .def("json_to_midi", &SegmentEncoder::json_to_midi)
  .def("json_to_tokens", &SegmentEncoder::json_to_tokens)
  .def("tokens_to_json", &SegmentEncoder::tokens_to_json)
  .def("tokens_to_midi", &SegmentEncoder::tokens_to_midi)
  .def_readwrite("rep", &SegmentEncoder::rep);

py::class_<TrackBarFillSixteenEncoder>(m, "TrackBarFillSixteenEncoder")
  .def(py::init<>())
  .def("encode", &TrackBarFillSixteenEncoder::encode)
  .def("decode", &TrackBarFillSixteenEncoder::decode)
  .def("midi_to_json", &TrackBarFillSixteenEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackBarFillSixteenEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackBarFillSixteenEncoder::json_to_midi)
  .def("json_to_tokens", &TrackBarFillSixteenEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackBarFillSixteenEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackBarFillSixteenEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackBarFillSixteenEncoder::rep);

py::class_<TrackDensityEncoder>(m, "TrackDensityEncoder")
  .def(py::init<>())
  .def("encode", &TrackDensityEncoder::encode)
  .def("decode", &TrackDensityEncoder::decode)
  .def("midi_to_json", &TrackDensityEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackDensityEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackDensityEncoder::json_to_midi)
  .def("json_to_tokens", &TrackDensityEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackDensityEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackDensityEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackDensityEncoder::rep);

py::class_<TrackDensityVelocityEncoder>(m, "TrackDensityVelocityEncoder")
  .def(py::init<>())
  .def("encode", &TrackDensityVelocityEncoder::encode)
  .def("decode", &TrackDensityVelocityEncoder::decode)
  .def("midi_to_json", &TrackDensityVelocityEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackDensityVelocityEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackDensityVelocityEncoder::json_to_midi)
  .def("json_to_tokens", &TrackDensityVelocityEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackDensityVelocityEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackDensityVelocityEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackDensityVelocityEncoder::rep);

py::class_<TrackBarFillDensityEncoder>(m, "TrackBarFillDensityEncoder")
  .def(py::init<>())
  .def("encode", &TrackBarFillDensityEncoder::encode)
  .def("decode", &TrackBarFillDensityEncoder::decode)
  .def("midi_to_json", &TrackBarFillDensityEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackBarFillDensityEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackBarFillDensityEncoder::json_to_midi)
  .def("json_to_tokens", &TrackBarFillDensityEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackBarFillDensityEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackBarFillDensityEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackBarFillDensityEncoder::rep);

py::class_<TrackBarFillDensityVelocityEncoder>(m, "TrackBarFillDensityVelocityEncoder")
  .def(py::init<>())
  .def("encode", &TrackBarFillDensityVelocityEncoder::encode)
  .def("decode", &TrackBarFillDensityVelocityEncoder::decode)
  .def("midi_to_json", &TrackBarFillDensityVelocityEncoder::midi_to_json)
  .def("midi_to_tokens", &TrackBarFillDensityVelocityEncoder::midi_to_tokens)
  .def("json_to_midi", &TrackBarFillDensityVelocityEncoder::json_to_midi)
  .def("json_to_tokens", &TrackBarFillDensityVelocityEncoder::json_to_tokens)
  .def("tokens_to_json", &TrackBarFillDensityVelocityEncoder::tokens_to_json)
  .def("tokens_to_midi", &TrackBarFillDensityVelocityEncoder::tokens_to_midi)
  .def_readwrite("rep", &TrackBarFillDensityVelocityEncoder::rep);

}