#pragma once

#include "encoder_base.h"
#include "util.h"
#include "../enum/te.h"
#include "../protobuf/util.h"
#include "../protobuf/validate.h"

class TrackDensityEncoder : public ENCODER {
public:
  TrackDensityEncoder() {
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(1)},
      {BAR, TOKEN_DOMAIN(1)},
      {BAR_END, TOKEN_DOMAIN(1)},
      {TRACK, TOKEN_DOMAIN({
        STANDARD_TRACK,
        STANDARD_DRUM_TRACK,   
      })},
      {TRACK_END, TOKEN_DOMAIN(1)},
      {INSTRUMENT, TOKEN_DOMAIN(128)},
      {NOTE_OFFSET, TOKEN_DOMAIN(128)},
      {NOTE_ONSET, TOKEN_DOMAIN(128)},
      {TIME_DELTA, TOKEN_DOMAIN(48)},
      {DENSITY_LEVEL, TOKEN_DOMAIN(10)}
    });
    config = get_encoder_config();
  }
  ~TrackDensityEncoder() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->te = false;
    e->force_instrument = true;
    e->mark_density = true;
    e->min_tracks = 1; // not sure this is used anywhere
    e->resolution = 12;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    update_note_density(p);
    return to_performance_w_tracks_dev(p, rep, config);
  }
  void decode(vector<int> &tokens, midi::Piece *p) {
    return decode_track_dev(tokens, p, rep, config);
  }
};


class TrackDensityEncoderV2 : public ENCODER {
public:
  TrackDensityEncoderV2() {    
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(2)}, // 1 is bar infill
      {BAR, TOKEN_DOMAIN(1)},
      {BAR_END, TOKEN_DOMAIN(1)},
      {TRACK, TOKEN_DOMAIN({
        STANDARD_TRACK,
        STANDARD_DRUM_TRACK,   
      })},
      {TRACK_END, TOKEN_DOMAIN(1)},
      {INSTRUMENT, TOKEN_DOMAIN(128)},
      {NOTE_OFFSET, TOKEN_DOMAIN(128)},
      {NOTE_ONSET, TOKEN_DOMAIN(128)},
      {TIME_DELTA, TOKEN_DOMAIN(48)},
      {FILL_IN, TOKEN_DOMAIN(3)},
      {DENSITY_LEVEL, TOKEN_DOMAIN(10)}
      });
    config = get_encoder_config();
  }
  ~TrackDensityEncoderV2() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->both_in_one = true;
    e->te = false;
    e->force_instrument = true;
    e->mark_density = true;
    e->min_tracks = 1; // not sure this is used anywhere
    e->resolution = 12;
    e->use_drum_offsets = false;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    // add validate function here ...
    update_note_density(p);
    return to_performance_w_tracks_dev(p, rep, config);
  }
  void decode(vector<int> &tokens, midi::Piece *p) {
    if (config->do_multi_fill == true) {
      tokens = resolve_bar_infill_tokens(tokens, rep);
    }
    return decode_track_dev(tokens, p, rep, config);
  }
};

class TrackBarFillDensityEncoder : public ENCODER {
public:
  TrackBarFillDensityEncoder() {
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(1)},
      {BAR, TOKEN_DOMAIN(1)},
      {BAR_END, TOKEN_DOMAIN(1)},
      {TRACK, TOKEN_DOMAIN({
        STANDARD_TRACK,
        STANDARD_DRUM_TRACK,   
      })},
      {TRACK_END, TOKEN_DOMAIN(1)},
      {INSTRUMENT, TOKEN_DOMAIN(128)},
      {NOTE_OFFSET, TOKEN_DOMAIN(128)},
      {NOTE_ONSET, TOKEN_DOMAIN(128)},
      {TIME_DELTA, TOKEN_DOMAIN(48)},
      {FILL_IN, TOKEN_DOMAIN(3)},
      {DENSITY_LEVEL, TOKEN_DOMAIN(10)}
    });
    config = get_encoder_config();
  }
  ~TrackBarFillDensityEncoder() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->do_multi_fill = true;
    e->force_instrument = true;
    e->mark_density = true;
    e->min_tracks = 1; // not sure this is used anywhere
    e->resolution = 12;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    update_note_density(p);
    return to_performance_w_tracks_dev(p, rep, config);
  }
  void decode(vector<int> &raw_tokens, midi::Piece *p) {
    // before decoding we insert fills into sequence
    // this solution works with any number of fills
    int fill_pholder = rep->encode(FILL_IN, 0);
    int fill_start = rep->encode(FILL_IN, 1);
    int fill_end = rep->encode(FILL_IN, 2);
 
    vector<int> tokens;

    auto start_pholder = raw_tokens.begin();
    auto start_fill = raw_tokens.begin();
    auto end_fill = raw_tokens.begin();

    while (start_pholder != raw_tokens.end()) {
      start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS
      auto last_start_pholder = start_pholder;
      start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder);
      if (start_pholder != raw_tokens.end()) {
        start_fill = find(next(start_fill), raw_tokens.end(), fill_start);
        end_fill = find(next(end_fill), raw_tokens.end(), fill_end);

        // insert from last_start_pholder --> start_pholder
        tokens.insert(tokens.end(), last_start_pholder, start_pholder);
        tokens.insert(tokens.end(), next(start_fill), end_fill);
      }
      else {
        // insert from last_start_pholder --> end of sequence (excluding fill)
        start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start);
        tokens.insert(tokens.end(), last_start_pholder, start_fill);
      }
    }
    return decode_track_dev(tokens, p, rep, config);
  }
};

//==========================================================================


class TrackInterleavedEncoder : public ENCODER {
public:
  TrackInterleavedEncoder() {    
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(1)},
      {BAR, TOKEN_DOMAIN(1)},
      {BAR_END, TOKEN_DOMAIN(1)},
      {INSTRUMENT, TOKEN_DOMAIN(256)},
      {NOTE_OFFSET, TOKEN_DOMAIN(128)},
      {NOTE_ONSET, TOKEN_DOMAIN(128)},
      {TIME_DELTA, TOKEN_DOMAIN(48)},
      });
    config = get_encoder_config();
  }
  ~TrackInterleavedEncoder() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->te = false;
    e->min_tracks = 1; // not sure this is used anywhere
    e->resolution = 12;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    return to_interleaved_performance(p, rep, config);
  }
  void decode(vector<int> &tokens, midi::Piece *p) {
    return decode_track_dev(tokens, p, rep, config);
  }
};

class TrackInterleavedWHeaderEncoder : public ENCODER {
public:
  TrackInterleavedWHeaderEncoder() {    
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(1)},
      {BAR, TOKEN_DOMAIN(1)},
      {BAR_END, TOKEN_DOMAIN(1)},
      {INSTRUMENT, TOKEN_DOMAIN(256)},
      {NOTE_OFFSET, TOKEN_DOMAIN(128)},
      {NOTE_ONSET, TOKEN_DOMAIN(128)},
      {TIME_DELTA, TOKEN_DOMAIN(48)},
      {HEADER, TOKEN_DOMAIN(2)}
      });
    config = get_encoder_config();
  }
  ~TrackInterleavedWHeaderEncoder() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->te = false;
    e->min_tracks = 1; // not sure this is used anywhere
    e->resolution = 12;
    e->interleaved = true;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    return to_interleaved_performance(p, rep, config);
  }
  void decode(vector<int> &tokens, midi::Piece *p) {
    tokens = strip_header(tokens, rep);
    return decode_track_dev(tokens, p, rep, config);
  }
};


class TrackEncoder : public ENCODER {
public:
  TrackEncoder() {    
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(1)},
      {BAR, TOKEN_DOMAIN(1)},
      {BAR_END, TOKEN_DOMAIN(1)},
      {TRACK, TOKEN_DOMAIN({
        STANDARD_TRACK,
        STANDARD_DRUM_TRACK,   
      })},
      {TRACK_END, TOKEN_DOMAIN(1)},
      {INSTRUMENT, TOKEN_DOMAIN(128)},
      {NOTE_OFFSET, TOKEN_DOMAIN(128)},
      {NOTE_ONSET, TOKEN_DOMAIN(128)},
      {TIME_DELTA, TOKEN_DOMAIN(48)},
      });
    config = get_encoder_config();
  }
  ~TrackEncoder() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->force_instrument = true;
    e->min_tracks = 1; // not sure this is used anywhere
    e->resolution = 12;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    return to_performance_w_tracks_dev(p, rep, config);
  }
  void decode(vector<int> &tokens, midi::Piece *p) {
    return decode_track_dev(tokens, p, rep, config);
  }
};

class TrackUnquantizedEncoder : public ENCODER {
public:
  TrackUnquantizedEncoder() {    
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(1)},
    });
    config = get_encoder_config();
  }
  ~TrackUnquantizedEncoder() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->resolution = 12;
    e->unquantized = true;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    throw runtime_error("can't call this function!");
  }
  void decode(vector<int> &tokens, midi::Piece *p) {
    throw runtime_error("can't call this function!");
  }
};

//==========================================================================
// roll models into one using piece start
// and allow for unlimited n-bar segments to be encoded

class TrackSegmentEncoder : public ENCODER {
public:
  TrackSegmentEncoder() {    
    rep = new REPRESENTATION({
      {PIECE_START, TOKEN_DOMAIN(1)},
      {BAR, TOKEN_DOMAIN(1)},
      {BAR_END, TOKEN_DOMAIN(1)},
      {TRACK, TOKEN_DOMAIN({
        STANDARD_TRACK,
        STANDARD_DRUM_TRACK,   
      })},
      {TRACK_END, TOKEN_DOMAIN(1)},
      {INSTRUMENT, TOKEN_DOMAIN(128)},
      {NOTE_OFFSET, TOKEN_DOMAIN(128)},
      {NOTE_ONSET, TOKEN_DOMAIN(128)},
      {TIME_DELTA, TOKEN_DOMAIN(48)},
      {SEGMENT, TOKEN_DOMAIN(1)},
      {SEGMENT_END, TOKEN_DOMAIN(1)},
      });
    config = get_encoder_config();
  }
  ~TrackSegmentEncoder() {
    delete rep;
    delete config;
  }
  EncoderConfig *get_encoder_config() {
    EncoderConfig *e = new EncoderConfig();
    e->force_instrument = true;
    e->min_tracks = 1; // not sure this is used anywhere
    e->resolution = 12;
    e->multi_segment = true;
    return e;
  }
  vector<int> encode(midi::Piece *p) {
    return to_performance_w_tracks_dev(p, rep, config);
  }
  void decode(vector<int> &tokens, midi::Piece *p) {
    return decode_track_dev(tokens, p, rep, config);
  }
};