#pragma once

#include <vector>
#include <map>
#include <tuple>
#include <array>
#include <utility>

#include "midi.pb.h"
#include "parser.h"
#include "token_types.h"
#include "constants.h"
#include "encoder_config.h"

using namespace std;

class TOKEN {
public:
  TOKEN(vector<pair<TOKEN_TYPE,int>> spec) {
    cprod.push_back( 1 );
    for (size_t i=0; i<spec.size(); i++) {
      tt_2_index[get<0>(spec[i])] = i;
      domain.push_back( get<1>(spec[i]) );
      cprod.push_back( cprod.back() * get<1>(spec[i]) );
    }
  }
  int encode(map<TOKEN_TYPE,int> values) {
    // notifies if you try to add an unknown token
    int value;
    TOKEN_TYPE tt;
    int token = 0;
    for (const auto kv : values) {
      tt = kv.first;
      value = kv.second;
      auto it = tt_2_index.find(tt);
      if (it == tt_2_index.end()) {
        //cout << "WARNING : TRYING TO ADD UNKNOWN TOKEN TYPE" << endl;
      }
      else {
        if ((value < 0) || (value >= domain[it->second])) {
          cout << "ERROR : TRYING TO ENCODE VALUE (" << value << ") OUTSIDE OF DOMAIN (" << toString(tt) << ")" << endl;
          throw(2);
        }
        token += cprod[it->second] * value;
      }
    }
    return token;
  }
  int decode(int token, TOKEN_TYPE tt) {
    auto it = tt_2_index.find(tt);
    assert(it != tt_2_index.end());
    return (token / cprod[it->second]) % domain[it->second];
  }
  int shift(int token, TOKEN_TYPE tt, int shift) {
    auto it = tt_2_index.find(tt);
    assert(it != tt_2_index.end());
    int shifted = token + (shift * cprod[it->second]);
    int dec = decode(token, tt) + shift;
    if ((dec<0)||(dec>=domain[it->second])||(shifted<0)||(shifted>=cprod.back())) {
      throw 2;
    }
    return shifted;
  }
  int max_token() {
    return cprod.back();
  }
  int get_domain(TOKEN_TYPE tt) {
    auto it = tt_2_index.find(tt);
    assert(it != tt_2_index.end());
    return domain[it->second];
  }
  vector<int> domain;
  vector<int> cprod;
  map<TOKEN_TYPE,int> tt_2_index;
};

class REPRESENTATION {
public:
  REPRESENTATION(vector<vector<pair<TOKEN_TYPE,int>>> spec, const char *vname) {
    int count = 0;
    for (size_t i=0; i<spec.size(); i++) {
      for (size_t j=0; j<spec[i].size(); j++) {
        tt_2_index[get<0>(spec[i][j])] = i;
      }
      TOKEN tok(spec[i]);
      toks.push_back(tok);
      starts.push_back(count);
      count += tok.max_token();
      ends.push_back(count);
    }
    velocity_map_name = vname;
  }
  int decode(int token, TOKEN_TYPE tt) {
    auto it = tt_2_index.find(tt);
    assert(it != tt_2_index.end());
    if ((token >= starts[it->second]) && (token < ends[it->second])) {
      int value = toks[it->second].decode(token - starts[it->second], tt);
      if (tt == VELOCITY_LEVEL) {
        //cout << value << endl;
        return velocity_rev_maps[velocity_map_name][value];
      }
      else if (tt == VELOCITY) {
        return velocity_rev_maps["no_velocity"][value];
      }
      return value;
    }
    return -1;
  }
  bool is_token_type(int token, TOKEN_TYPE tt) {
    auto it = tt_2_index.find(tt);
    if ((it != tt_2_index.end()) && (token >= starts[it->second]) && (token < ends[it->second])) {
      return true;
    }
    return false;
  }
  int encode(map<TOKEN_TYPE,int> values) {
    // make sure that tts all belong to same token
    auto it = tt_2_index.find(values.begin()->first);
    assert(it != tt_2_index.end());
    // what is going on here ???
    //auto vit = values.find(VELOCITY);
    //if (vit != values.end()) {
    //  values[VELOCITY] = velocity_maps[velocity_map_name][vit->second];
    //}
    return toks[it->second].encode(values) + starts[it->second];
  }
  vector<int> encode_to_one_hot(map<TOKEN_TYPE,vector<int>> values) {
    auto it = tt_2_index.find(values.begin()->first);
    assert(it != tt_2_index.end());
    vector<int> oh(max_token(),0);
    if (values.begin()->second[0] == -1) {
      for (int i = starts[it->second]; i<ends[it->second]; i++) {
        oh[i] = 1;
      }
    }
    else {
      for (const auto x : values.begin()->second) {
        oh[toks[it->second].encode({{values.begin()->first,x}}) + starts[it->second]] = 1;
      }
    }
    return oh;
  }
  int shift(int token, TOKEN_TYPE tt, int shift) {
    auto it = tt_2_index.find(tt);
    assert(it != tt_2_index.end());
    if ((token >= starts[it->second]) && (token < ends[it->second])) {
      int shifted = toks[it->second].shift(token-starts[it->second], tt, shift);
      if (shifted == -1) { return -1; } // if it fails make sure to return -1
      return shifted + starts[it->second];
    }
    return token; // don't modify
  }
  int get_domain(TOKEN_TYPE tt) {
    auto it = tt_2_index.find(tt);
    assert(it != tt_2_index.end());
    return toks[it->second].get_domain(tt);
  }
  void show(vector<int> tokens) {
    int value;
    for (const auto token : tokens) {
      for (const auto kv : tt_2_index) {
        value = decode(token, kv.first);
        if (value >= 0) {
          cout << toString(kv.first) << "=" << value << " ";
        }
      }
      cout << endl;
    }
  }
  vector<string> pretty(vector<int> tokens) {
    int value;
    vector<string> output;
    for (const auto token : tokens) {
      string ts;
      for (const auto kv : tt_2_index) {
        value = decode(token, kv.first);
        if (value >= 0) {
          string tmp(toString(kv.first));
          ts += tmp + "=" + to_string(value) + " ";
        }
      }
      output.push_back(ts);
    }
    return output;
  }
  vector<int> where(vector<int> tokens, TOKEN_TYPE tt) {
    vector<int> indices;
    for (int i=0; i<tokens.size(); i++) {
      if (is_token_type(tokens[i], tt)) {
        indices.push_back(i);
      }
    }
    return indices;
  }
  vector<int> where_values(vector<int> tokens, TOKEN_TYPE tt) {
    vector<int> values;
    for (const auto i : where(tokens, tt)) {
      values.push_back(decode(tokens[i], tt));
    }
    return values;
  }
  int max_token() {
    return ends.back();
  }
  vector<TOKEN> toks;
  vector<int> starts;
  vector<int> ends;
  map<TOKEN_TYPE,int> tt_2_index;
  string velocity_map_name;
};

// ================================================================
// ================================================================
// ================================================================
// ENCODING HELPERS

vector<int> to_performance(midi::Bar *bar, midi::Piece *p, REPRESENTATION *rep, int transpose, EncoderConfig *ec) {
  vector<int> tokens;
  int current_step = 0;
  int current_velocity = -1;
  int current_instrument = -1;
  int N_TIME_TOKENS = rep->get_domain(TIME_DELTA);
  bool added_instrument = false;
  for (const auto i : bar->events()) {
    midi::Event event = p->events(i);
    int qvel = velocity_maps[rep->velocity_map_name][event.velocity()];

    // CHANGE ......
    // instrument before time shift tokens only if not in force instrument mode
    if (!ec->force_instrument) {
      if (event.instrument() != current_instrument) {
        tokens.push_back( rep->encode({{INSTRUMENT, event.instrument()}}) );
        current_instrument = event.instrument();
        added_instrument = true;
      }
    }
    // CHANGE ......
    if (event.qtime() > current_step) {
      while (event.qtime() > current_step + N_TIME_TOKENS) {
        tokens.push_back( rep->encode({{TIME_DELTA,N_TIME_TOKENS-1}}) );
        current_step += N_TIME_TOKENS;
      }
      if (event.qtime() > current_step) {
        tokens.push_back( rep->encode(
          {{TIME_DELTA,event.qtime()-current_step-1}}) );
      }
      current_step = event.qtime();
    }
    else if (event.qtime() < current_step) {
      cout << "ERROR : events are not sorted." << endl;
      throw(2);
    }
    // if the rep contains velocity levels
    if (ec->use_velocity_levels) {
      if ((qvel > 0) && (qvel != current_velocity)) {
        tokens.push_back( rep->encode({{VELOCITY_LEVEL,qvel}}) );
        current_velocity = qvel;
      }
      qvel = min(1,qvel); // flatten down to binary for note
    }
    tokens.push_back( rep->encode(
      {{PITCH,event.pitch() + transpose},{VELOCITY,qvel}}) ); 
  }
  return tokens;
}

// need to make sure that the right source of randomness will be used
struct RNG {
    int operator() (int n) {
        return std::rand() / (1.0 + RAND_MAX) * n;
    }
};

int select_section(midi::Piece *p, vector<int> &vtracks, EncoderConfig *e) {
  // select the section and the tracks

  if (e->seed >= 0) {
    srand(e->seed);
  }

  if (p->valid_segments_size() == 0) {
    cout << "WARNING : no valid segments" << endl;
    return -1; // return empty sequence
  }

  int index = e->segment_idx;
  if (index == -1) {
    index = rand();
  }
  index = index % p->valid_segments_size();
  
  for (int i=0; i<32; i++) {
    if (p->valid_tracks(index) & (1<<i)) {
      vtracks.push_back(i);
    }
  }

  if (e->do_track_shuffle) {
    random_shuffle(vtracks.begin(), vtracks.end(), RNG());
  }
  int ntracks = min((int)vtracks.size(), e->max_tracks); // limit tracks
  vtracks.resize(ntracks); // remove extra tracks

  return index;
}

int get_track_type(midi::Piece *p, EncoderConfig *e, int i) {
  if (p->tracks(i).is_drum()) { return 1; }
  else if ((e->mark_polyphony) && (p->tracks(i).av_polyphony() >= 1.1)) { return 2; }
  return 0;
}

// make code for bar-major
// and track-major representations
// improved encoding for representations
// =======================================================================

void get_random_segment(midi::Piece *p, vector<int> &track_nums, vector<int> &bar_nums, EncoderConfig *e) {
  // select the section and the tracks

  if (e->seed >= 0) {
    srand(e->seed);
  }

  if (p->valid_segments_size() == 0) {
    cout << "WARNING : no valid segments" << endl;
    return; // track_nums and bar_nums are empty so nothing will happen
  }

  int index = e->segment_idx;
  if (index == -1) {
    index = rand();
  }
  index = index % p->valid_segments_size();
  
  for (int i=0; i<32; i++) {
    if (p->valid_tracks(index) & (1<<i)) {
      track_nums.push_back(i);
    }
  }

  int start_bar = p->valid_segments(index);
  for (int i=start_bar; i<start_bar+e->num_bars; i++) {
    bar_nums.push_back(i);
  }

  if (e->do_track_shuffle) {
    random_shuffle(track_nums.begin(), track_nums.end(), RNG());
  }
  int ntracks = min((int)track_nums.size(), e->max_tracks); // limit tracks
  track_nums.resize(ntracks); // remove extra tracks
}

void get_fill(midi::Piece *p, vector<int> &track_nums, vector<int> &bar_nums, EncoderConfig *e) {
  if ((e->do_multi_fill) && (e->multi_fill.size() == 0)) {
    assert(e->fill_percentage > 0);

    // pick a random percentange up to fill percentage and fill those
    int n_fill = rand() % (int)round(
      track_nums.size() * bar_nums.size() * e->fill_percentage);
    vector<tuple<int,int>> choices;
    for (const auto track_num : track_nums) {
      for (const auto bar_num : bar_nums) {
        choices.push_back(make_pair(track_num,bar_num));
      }
    }
    random_shuffle(choices.begin(), choices.end(), RNG());
    for (int i=0; i<n_fill; i++) {
      e->multi_fill.insert(choices[i]);
    }
  }
}

void piece_header(midi::Piece *p, REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {

  tokens.push_back( rep->encode({{PIECE_START,0}}) );

  // GENRE HEADER
  // NOTE : data must be formated to always have two fields
  if (e->genre_header) {
    if (e->genre_tags.size()) {
      for (const auto tag : e->genre_tags) {
        tokens.push_back( rep->encode({{GENRE,genre_maps["msd_cd2"][tag]}}) );
      }
    }
    else {
      for (const auto tag : p->msd_cd1()) {
        tokens.push_back( rep->encode({{GENRE,genre_maps["msd_cd2"][tag]}}) );
      }
    }
  }

  // INSTRUMENT HEADER
  /*
  if (e->instrument_header) {
    tokens.push_back( rep->encode({{HEADER,0}}) );
    for (int i=0; i<(int)vtracks.size(); i++) {
      int track = get_track_type(p,e,vtracks[i]);
      tokens.push_back( rep->encode({{TRACK,track}}) );
      int inst = p->tracks(vtracks[i]).instrument();
      tokens.push_back( rep->encode({{INSTRUMENT,inst}}) );
    }
    tokens.push_back( rep->encode({{HEADER,1}}) );
  }
  */
}

void segment_header(REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  tokens.push_back( rep->encode({{SEGMENT,0}}) );
}

void segment_footer(REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  tokens.push_back( rep->encode({{SEGMENT_END,0}}) );
}

void track_header(const midi::Track *track, REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  // determine track type
  int track_type = 0;
  if (track->is_drum()) { 
    track_type = 1;
  }
  else if (e->mark_polyphony) {
    if (track->av_polyphony() < 1.1) { 
      track_type = 0; 
    }
    else {
      track_type = 2; 
    }
  }
  tokens.push_back( rep->encode({{TRACK,track_type}}) );
  // specify instrument for track
  if (e->force_instrument) {
    int inst = track->instrument();
    tokens.push_back( rep->encode({{INSTRUMENT,inst}}) );
  }
  // specify average density for track
  if (e->mark_density) {
    int density_level = track->note_density_v2();
    tokens.push_back( rep->encode({{DENSITY_LEVEL,density_level}}) );
  }
}

void track_footer(const midi::Track *track, REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  tokens.push_back( rep->encode({{TRACK_END,0}}) );
}

void bar_header(const midi::Bar *bar, REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  tokens.push_back( rep->encode({{BAR,0}}) );
  if (e->mark_time_sigs) {
    int ts = time_sig_map[round(bar->beat_length())];
    tokens.push_back( rep->encode({{TIME_SIGNATURE,ts}}) );
  }
}

void bar_footer(const midi::Bar *bar, REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  tokens.push_back( rep->encode({{BAR_END,0}}) );
}

void encode_bar(midi::Piece *p, int track_num, int bar_num, REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  /*
  if ((e->do_fill) && (fill_track==i) && (fill_bar==j)) {
    tokens.push_back( rep->encode({{FILL_IN,0}}) );
  }
  */
  if ((e->do_multi_fill) && (e->multi_fill.find(make_pair(track_num,bar_num)) != e->multi_fill.end())) {
    tokens.push_back( rep->encode({{FILL_IN,0}}) );
  }
  else {
    midi::Bar bar = p->tracks(track_num).bars(bar_num);
    int cur_transpose = e->transpose;
    if (p->tracks(track_num).is_drum()) {
      cur_transpose = 0;
    }
    vector<int> bar_tokens = to_performance(&bar, p, rep, cur_transpose, e);
    tokens.insert(tokens.end(), bar_tokens.begin(), bar_tokens.end());
  }
}

void fill_footer(midi::Piece *p, REPRESENTATION *rep, EncoderConfig *e, vector<int> &tokens) {
  if (e->do_multi_fill) {
    for (const auto track_bar : e->multi_fill) {
      int fill_track = get<0>(track_bar);
      int fill_bar = get<1>(track_bar);

      midi::Bar bar = p->tracks(fill_track).bars(fill_bar);
      int cur_transpose = e->transpose;
      if (p->tracks(fill_track).is_drum()) {
        cur_transpose = 0;
      }
      tokens.push_back( rep->encode({{FILL_IN,1}}) ); // begin fill-in
      vector<int> bar_tokens = to_performance(
        &bar, p, rep, cur_transpose, e);
      tokens.insert(tokens.end(), bar_tokens.begin(), bar_tokens.end());
      tokens.push_back( rep->encode({{FILL_IN,2}}) ); // end fill-in
    }
  }
}

vector<int> main_enc(midi::Piece *p, REPRESENTATION *rep, EncoderConfig *e) {

  vector<int> tokens;
  
  // determine the track_nums and bar_nums
  vector<int> track_nums;
  vector<int> bar_nums;
  get_random_segment(p, track_nums, bar_nums, e);
  
  vector<vector<int>> split_bar_nums(4,vector<int>());
  int local_idx, context_tracks;
  if (e->segment_mode) {
    local_idx = rand() % 4;
    context_tracks = track_nums.size() - (rand() % (track_nums.size()/2));
    for (int i=0; i<bar_nums.size(); i++) {
      split_bar_nums[i/4].push_back(bar_nums[i]);
    }
    get_fill(p, track_nums, split_bar_nums[local_idx], e);
  }
  get_fill(p, track_nums, bar_nums, e);

  if (track_nums.size() && bar_nums.size()) {
    if (e->piece_header) {
      piece_header(p, rep, e, tokens); // in either case we need piece header
    }
    if (e->segment_mode) {
      bool multi_fill_status = e->do_multi_fill;
      e->do_multi_fill = false;
      for (int seg_num=0; seg_num<4; seg_num++) {
        segment_header(rep, e, tokens);
        if (seg_num != local_idx) {
          int track_count = 0;
          for (const int track_num : track_nums) {
            const midi::Track track = p->tracks(track_num);
            track_header(&track, rep, e, tokens);
            for (int bar_num : split_bar_nums[seg_num]) {
              const midi::Bar bar = track.bars(bar_num);
              bar_header(&bar, rep, e, tokens);
              encode_bar(p, track_num, bar_num, rep, e, tokens);
              bar_footer(&bar, rep, e, tokens);
            }
            track_footer(&track, rep, e, tokens);
            track_count++;
            if ((track_count) >= context_tracks) {
              break; // no more tracks
            }
          }
        }
        else {
          tokens.push_back( rep->encode({{SEGMENT_FILL_IN,0}}) );
        }
        segment_footer(rep, e, tokens);
      }
      e->do_multi_fill = multi_fill_status;

      // fill in remaining segment
      tokens.push_back( rep->encode({{SEGMENT_FILL_IN,1}}) );

      for (const int track_num : track_nums) {
        const midi::Track track = p->tracks(track_num);
        track_header(&track, rep, e, tokens);
        for (int bar_num : bar_nums) {
          const midi::Bar bar = track.bars(bar_num);
          bar_header(&bar, rep, e, tokens);
          encode_bar(p, track_num, bar_num, rep, e, tokens);
          bar_footer(&bar, rep, e, tokens);
        }
        track_footer(&track, rep, e, tokens);
      }
      fill_footer(p, rep, e, tokens); // for bar fill in ...
      tokens.push_back( rep->encode({{SEGMENT_FILL_IN,2}}) );
    }
    else if (!e->bar_major) {
      for (const int track_num : track_nums) {
        const midi::Track track = p->tracks(track_num);
        track_header(&track, rep, e, tokens);
        for (int bar_num : bar_nums) {
          const midi::Bar bar = track.bars(bar_num);
          bar_header(&bar, rep, e, tokens);
          encode_bar(p, track_num, bar_num, rep, e, tokens);
          bar_footer(&bar, rep, e, tokens);
        }
        track_footer(&track, rep, e, tokens);
      }
    }
    else {
      for (const int bar_num : bar_nums) {
        const midi::Bar bar = p->tracks(track_nums[0]).bars(bar_num);
        bar_header(&bar, rep, e, tokens);
        for (const int track_num : track_nums) {
          const midi::Track track = p->tracks(track_num);
          track_header(&track, rep, e, tokens);
          encode_bar(p, track_num, bar_num, rep, e, tokens);
          track_footer(&track, rep, e, tokens);
        }
        bar_footer(&bar, rep, e, tokens);
      }
    }
    //fill_footer(p, rep, e, tokens);
  }
  return tokens;
}

void main_dec(vector<int> &tokens, midi::Piece *p, REPRESENTATION *rep, EncoderConfig *ec) {
  p->set_tempo(ec->default_tempo);
  p->set_resolution(ec->resolution);

  midi::Event *e = NULL;
  midi::Track *t = NULL;
  midi::Bar *b = NULL;
  int curr_time, current_instrument, current_track, current_velocity, beat_length;
  int bar_time = 0;
  int track_count = 0;
  int bar_count = 0;

  // this is flexible enough to handle bar-major or track-major

  for (const auto token : tokens) {
    if (rep->is_token_type(token, TRACK)) {
      if (ec->bar_major) {
        curr_time = bar_time;
      }
      else {
        // track major
        bar_time = 0; // restart bar_time each track
        bar_count = 0; // reset bar_count each track
        curr_time = 0; // restart track_time
      }
      current_instrument = 0; // reset instrument once per track
      current_track = SAFE_TRACK_MAP[track_count];

      // add or get track
      if (track_count >= p->tracks_size()) {
        t = p->add_tracks();
      }
      else {
        t = p->mutable_tracks(track_count);
      }
      if (rep->decode(token,TRACK)==1) {
        current_track = 9; // its a drum channel
        t->set_is_drum(true); // by default its false
      }
      else {
        t->set_is_drum(false);
      }
    }
    else if (rep->is_token_type(token, TRACK_END)) {
      track_count++;
    }
    else if (rep->is_token_type(token, BAR)) {
      if (ec->bar_major) {
        track_count = 0; // reset track_count each bar
      }
      curr_time = bar_time; // can set this either way
      beat_length = 4; // default value optionally overidden with TIME_SIGNATURE
    }
    else if (rep->is_token_type(token, TIME_SIGNATURE)) {
      beat_length = rev_time_sig_map[rep->decode(token, TIME_SIGNATURE)];
    }
    else if (rep->is_token_type(token, BAR_END)) {
      bar_time += (ec->resolution * beat_length);
      bar_count++;
    }
    else if (rep->is_token_type(token, TIME_DELTA)) {
      curr_time += (rep->decode(token, TIME_DELTA) + 1);
    }
    else if (rep->is_token_type(token, INSTRUMENT)) {
      current_instrument = rep->decode(token, INSTRUMENT);
      t->set_instrument( current_instrument );
    }
    else if (rep->is_token_type(token, VELOCITY_LEVEL)) {
      current_velocity = rep->decode(token, VELOCITY_LEVEL);
    }
    else if (rep->is_token_type(token, PITCH)) {

      int current_note_index = p->events_size();
      e = p->add_events();
      e->set_pitch( rep->decode(token, PITCH) );
      if ((!ec->use_velocity_levels) || (rep->decode(token, VELOCITY)==0)) {
        e->set_velocity( rep->decode(token, VELOCITY) );
      }
      else {
        e->set_velocity( current_velocity );
      }
      e->set_time( curr_time );
      e->set_qtime( curr_time - bar_time );
      e->set_instrument( current_instrument );
      e->set_track( current_track );
      e->set_is_drum( current_track == 9 );

      // get or add bar
      if (bar_count >= t->bars_size()) {
        b = t->add_bars();
      }
      else {
        b = t->mutable_bars(bar_count);
      }
      b->set_is_four_four(true);
      b->set_has_notes(false); // no notes yet
      b->set_time( bar_time );


      b->add_events(current_note_index);
      b->set_has_notes(true);
    }
  }
  p->add_valid_segments(0);
  p->add_valid_tracks((1<<p->tracks_size())-1);
}

// ====================================================================


vector<int> to_performance_w_tracks(midi::Piece *p, REPRESENTATION *rep, EncoderConfig *e) {

  // PIECE_START
  // TRACK
  // BAR
  // <NOTES>
  // BAR_END
  // TRACK_END
  // for fill in just replace notes will FILL TOKEN

  // this double seed is an issue
  //srand(time(NULL));

  vector<int> tokens;
  vector<int> vtracks;
  int segment_idx = select_section(p, vtracks, e);
  if (segment_idx == -1) return tokens;
  int start = p->valid_segments(segment_idx);

  int fill_track = e->fill_track;
  int fill_bar = e->fill_bar;
  if (fill_track == -1) fill_track = rand() % vtracks.size();
  if (fill_bar == -1) fill_bar = rand() % e->num_bars;


  // if multi fill bars not selected ... select some
  if ((e->do_multi_fill) && (e->multi_fill.size() == 0)) {
    if (e->fill_percentage > 0) {
      // pick a random percentange up to fill percentage and fill those
      int n_fill = rand() % (int)round(vtracks.size() * e->num_bars * e->fill_percentage);
      vector<tuple<int,int>> choices;
      for (int i=0; i<vtracks.size(); i++) {
        for (int j=0; j<e->num_bars; j++) {
          choices.push_back(make_pair(i,j));
        }
      }
      random_shuffle(choices.begin(), choices.end(), RNG());
      for (int i=0; i<n_fill; i++) {
        e->multi_fill.insert(choices[i]);
      }
    }
    else {
      int track_num = rand() % vtracks.size();
      for (const auto bar_num : TRACK_MASKS[(rand() % 15) + 1]) {
        e->multi_fill.insert(make_pair(track_num, bar_num));
      }
    }
  }
  
  int cur_transpose;
  tokens.push_back( rep->encode({{PIECE_START,0}}) );


  // GENRE HEADER
  // data must be formated to always have two fields
  if (e->genre_header) {
    // must pass in two or fail
    if (e->genre_tags.size()) {
      for (const auto tag : e->genre_tags) {
        tokens.push_back( rep->encode({{GENRE,genre_maps["msd_cd2"][tag]}}) );
      }
    }
    else {
      for (const auto tag : p->msd_cd1()) {
        tokens.push_back( rep->encode({{GENRE,genre_maps["msd_cd2"][tag]}}) );
      }
    }
  }

  // INSTRUMENT HEADER
  if (e->instrument_header) {
    tokens.push_back( rep->encode({{HEADER,0}}) );
    for (int i=0; i<(int)vtracks.size(); i++) {
      int track = get_track_type(p,e,vtracks[i]);
      tokens.push_back( rep->encode({{TRACK,track}}) );
      int inst = p->tracks(vtracks[i]).instrument();
      tokens.push_back( rep->encode({{INSTRUMENT,inst}}) );
    }
    tokens.push_back( rep->encode({{HEADER,1}}) );
  }


  for (int i=0; i<(int)vtracks.size(); i++) {
    int track_type = 0;
    cur_transpose = e->transpose;

    midi::Track track = p->tracks(vtracks[i]);
    
    if (track.is_drum()) { 
      track_type = 1;
      cur_transpose = 0; // AVOID DRUM TRANSPOSTION
    }
    else if (e->mark_polyphony) {
      if (track.av_polyphony() < 1.1) { track_type = 0; }
      else { track_type = 2; }
    }

    tokens.push_back( rep->encode({{TRACK,track_type}}) );
    // add instrument at start of track ...
    if (e->force_instrument) {
      int inst = track.instrument();
      tokens.push_back( rep->encode({{INSTRUMENT,inst}}) );
    }
    if (e->mark_density) {
      int density_level = track.note_density_v2();
      tokens.push_back( rep->encode({{DENSITY_LEVEL,density_level}}) );
    }

    for (int j=0; j<e->num_bars; j++) {
      tokens.push_back( rep->encode({{BAR,0}}) );
      if (e->mark_time_sigs) {
        // assume that beat_length is in the map
        int ts = time_sig_map[round(track.bars(start+j).beat_length())];
        tokens.push_back( rep->encode({{TIME_SIGNATURE,ts}}) );
      }
      if ((e->do_fill) && (fill_track==i) && (fill_bar==j)) {
        tokens.push_back( rep->encode({{FILL_IN,0}}) );
      }
      else if ((e->do_multi_fill) && (e->multi_fill.find(make_pair(i,j)) != e->multi_fill.end())) {
        tokens.push_back( rep->encode({{FILL_IN,0}}) );
      }
      else {
        midi::Bar b = track.bars(start+j);
        vector<int> bar = to_performance(
          &b, p, rep, cur_transpose, e);
        tokens.insert(tokens.end(), bar.begin(), bar.end());
      }
      tokens.push_back( rep->encode({{BAR_END,0}}) );
    }
    tokens.push_back( rep->encode({{TRACK_END,0}}) );
  }

  // add the fill in bar at the end
  if (e->do_fill) {
    cur_transpose = e->transpose;
    if (p->tracks(vtracks[fill_track]).is_drum()) {
      cur_transpose = 0;
    }
    tokens.push_back( rep->encode({{FILL_IN,1}}) ); // begin fill-in
    midi::Bar b = p->tracks(vtracks[fill_track]).bars(start+fill_bar);
    vector<int> bar = to_performance(
      &b, p, rep, cur_transpose, e);
    tokens.insert(tokens.end(), bar.begin(), bar.end());
    tokens.push_back( rep->encode({{FILL_IN,2}}) ); // end fill-in
  }

  // do multiple fills at the end
  if (e->do_multi_fill) {
    for (const auto track_bar : e->multi_fill) {
      int fill_track = get<0>(track_bar);
      int fill_bar = get<1>(track_bar);

      cur_transpose = e->transpose;
      if (p->tracks(vtracks[fill_track]).is_drum()) {
        cur_transpose = 0;
      }
      tokens.push_back( rep->encode({{FILL_IN,1}}) ); // begin fill-in
      midi::Bar b = p->tracks(vtracks[fill_track]).bars(start+fill_bar);
      vector<int> bar = to_performance(
        &b, p, rep, cur_transpose, e);
      tokens.insert(tokens.end(), bar.begin(), bar.end());
      tokens.push_back( rep->encode({{FILL_IN,2}}) ); // end fill-in
    }
  }

  // check the number of track tokens
  // and the number of bar tokens
  /*
  int track_count = 0;
  int bar_count = 0;
  for (const auto token : tokens) {
    track_count += (int)(token==rep->encode({{TRACK,0}}));
    track_count += (int)(token==rep->encode({{TRACK,1}}));
    bar_count += (int)(token==rep->encode({{BAR,0}}));
  }
  cout << track_count << " " << bar_count << " " << tokens.size() << endl;
  */

  return tokens;
}

/*
vector<int> to_performance_w_tracks_and_bars(midi::Piece *p, REPRESENTATION *rep, int transpose, bool ordered, int n_to_fill) {

  // allow for random order of tracks
  // and random order of bars
  int n_bars = 4;
  //srand(time(NULL));

  vector<int> tokens;
  if (p->valid_segments_size() == 0) {
    cout << "WARNING : no valid segments" << endl;
    return tokens; // return empty sequence
  }

  int ii = rand() % p->valid_segments_size();
  int start = p->valid_segments(ii);

  // select subset of valid tracks
  vector<int> vtracks;
  for (int i=0; i<32; i++) {
    if (p->valid_tracks(ii) & (1<<i)) {
      vtracks.push_back(i);
    }
  }

  random_shuffle(vtracks.begin(), vtracks.end());
  int ntracks = min((int)vtracks.size(), 12); // no more than 12 tracks
  vtracks.resize(ntracks); // remove extra tracks

  // random order of tracks and bars ...
  vector<pair<int,int>> track_bar_order;
  for (int i=0; i<(int)vtracks.size(); i++) {
    for (int j=0; j<n_bars; j++) {
      track_bar_order.push_back( make_pair(i,j) );
    }
  }
  random_shuffle(track_bar_order.begin(), track_bar_order.end());

  int split = 0;
  if (ordered) {
    if (n_to_fill == 0) {
      split = rand() % track_bar_order.size();
    }
    else {
      split = (int)track_bar_order.size() - n_to_fill;
    }
    sort(track_bar_order.begin(), track_bar_order.begin() + split);
    sort(track_bar_order.begin() + split, track_bar_order.end());
  }

  int count = 0;
  int track, bar, cur_transpose;
  tokens.push_back( rep->encode({{PIECE_START,0}}) );
  for (const auto track_bar : track_bar_order) {
    track = get<0>(track_bar);
    bar = get<1>(track_bar);
    cur_transpose = transpose;

    if ((ordered) && (count == split)) {
      tokens.push_back( rep->encode({{FILL_IN,0}}) );
    }

    tokens.push_back( rep->encode({{TRACK,track}}) );
    if (p->tracks(vtracks[track]).is_drum()) { 
      tokens.push_back( rep->encode({{DRUM_TRACK,0}}) );
      cur_transpose = 0; // // AVOID DRUM TRANSPOSTION
    }
    tokens.push_back( rep->encode({{BAR,bar}}) );
    midi::Bar b = p->tracks(vtracks[track]).bars(start+bar);
    vector<int> bar = to_performance(&b, p, rep, cur_transpose);
    tokens.insert(tokens.end(), bar.begin(), bar.end());
    tokens.push_back( rep->encode({{BAR_END,0}}) );
    count++;
  }
  return tokens;
}
*/

// ================================================================
// ================================================================
// ================================================================
// DECODING HELPERS

// pass in encoder config so that tempo can be set
void decode_track(vector<int> &tokens, midi::Piece *p, REPRESENTATION *rep, EncoderConfig *ec) {
  p->set_tempo(ec->default_tempo);
  p->set_resolution(ec->resolution);

  midi::Event *e = NULL;
  midi::Track *t = NULL;
  midi::Bar *b = NULL;
  int current_time, current_instrument, bar_start_time, current_track, current_velocity, beat_length;
  int track_count = 0;
  for (const auto token : tokens) {
    if (rep->is_token_type(token, TRACK)) {
      current_time = 0; // restart the time
      current_instrument = 0; // reset instrument
      current_track = SAFE_TRACK_MAP[track_count];
      t = p->add_tracks();
      if (rep->decode(token,TRACK)==1) {
        current_track = 9; // its a drum channel
        t->set_is_drum(true); // by default its false
      }
      else {
        t->set_is_drum(false);
      }
    }
    else if (rep->is_token_type(token, TRACK_END)) {
      track_count++;
    }
    else if (rep->is_token_type(token, BAR)) {
      beat_length = 4; // default value optionally overidden with TIME_SIGNATURE
      bar_start_time = current_time;
      b = t->add_bars();
      b->set_is_four_four(true);
      b->set_has_notes(false); // no notes yet
      b->set_time( bar_start_time );
    }
    else if (rep->is_token_type(token, TIME_SIGNATURE)) {
      beat_length = rev_time_sig_map[rep->decode(token, TIME_SIGNATURE)];
      
    }
    else if (rep->is_token_type(token, BAR_END)) {
      current_time += (ec->resolution * beat_length) - (current_time - bar_start_time);
    }
    else if (rep->is_token_type(token, TIME_DELTA)) {
      current_time += (rep->decode(token, TIME_DELTA) + 1);
    }
    else if (rep->is_token_type(token, INSTRUMENT)) {
      current_instrument = rep->decode(token, INSTRUMENT);
      t->set_instrument( current_instrument );
    }
    else if (rep->is_token_type(token, VELOCITY_LEVEL)) {
      current_velocity = rep->decode(token, VELOCITY_LEVEL);
    }
    else if (rep->is_token_type(token, PITCH)) {
      int current_note_index = p->events_size();
      e = p->add_events();
      e->set_pitch( rep->decode(token, PITCH) );
      if ((!ec->use_velocity_levels) || (rep->decode(token, VELOCITY)==0)) {
        e->set_velocity( rep->decode(token, VELOCITY) );
      }
      else {
        e->set_velocity( current_velocity );
      }
      e->set_time( current_time );
      e->set_qtime( current_time - bar_start_time );
      e->set_instrument( current_instrument );
      e->set_track( current_track );
      e->set_is_drum( current_track == 9 );
      b->add_events(current_note_index);
      b->set_has_notes(true);
    }
  }
  p->add_valid_segments(0);
  p->add_valid_tracks((1<<p->tracks_size())-1);
  
  // update the meta-data
  update_pitch_limits(p);
  update_note_density(p);
  update_polyphony(p);
}

midi::Piece *decode_track_bar(REPRESENTATION *rep, vector<int> &tokens) {
  midi::Piece *p = new midi::Piece;
  p->set_tempo(104);
  p->set_resolution(12);

  midi::Event *e = NULL;
  int current_time, current_instrument, current_track, current_bar;
  vector<int> track_map = {0,1,2,3,4,5,6,7,8,10,11,12,13,14,15};

  for (const auto token : tokens) {
    if (rep->is_token_type(token, TRACK)) {
      current_time = 0; // restart the time
      current_instrument = 0; // reset instrument
      // make sure to avoid drum channel
      current_track = track_map[rep->decode(token, TRACK)];      
    }
    if (rep->is_token_type(token, DRUM_TRACK)) {
      current_track = 9; // switch to drum channel
    }
    else if (rep->is_token_type(token, BAR)) {
      current_bar = rep->decode(token, BAR);
      current_time = 48 * current_bar;
    }
    else if (rep->is_token_type(token, TIME_DELTA)) {
      current_time += (rep->decode(token, TIME_DELTA) + 1);
    }
    else if (rep->is_token_type(token, INSTRUMENT)) {
      current_instrument = rep->decode(token, INSTRUMENT);
    }
    else if (rep->is_token_type(token, PITCH)) {
      e = p->add_events();
      e->set_pitch( rep->decode(token, PITCH) );
      e->set_velocity( rep->decode(token, VELOCITY) );
      e->set_time( current_time );
      e->set_instrument( current_instrument );
      e->set_track( current_track );
      e->set_is_drum( current_track == 9 );
    }
  }
  return p;
}


// BASE CLASS FOR AN ECODER
// SHOULD HAVE a method that converts midi::Piece to tokens
// AND a mtehod that converts back to midi::Piece
// velocity map
// instrument map

class ENCODER {
public:

  virtual vector<int> encode(midi::Piece *p, EncoderConfig *e) {}
  virtual void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {}

  string midi_to_json(string &filepath, EncoderConfig *e) {
    string json_string;
    midi::Piece p;
    parse_new(filepath, &p, e);
    google::protobuf::util::MessageToJsonString(p, &json_string);
    return json_string;
  }

  vector<int> midi_to_tokens(string &filepath, EncoderConfig *e) {
    midi::Piece p;
    parse_new(filepath, &p, e);
    return encode(&p, e);
  }

  void json_to_midi(string &json_string, string &filepath, EncoderConfig *e) {
    midi::Piece p;
    google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p);
    write_midi(&p, filepath);
  }

  vector<int> json_to_tokens(string &json_string, EncoderConfig *e) {
    midi::Piece p;
    google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p);
    return encode(&p, e); 
  }

  string tokens_to_json(vector<int> &tokens, EncoderConfig *e) {
    midi::Piece p;
    decode(tokens, &p, e);
    string json_string;
    google::protobuf::util::MessageToJsonString(p, &json_string);
    return json_string;
  }

  void tokens_to_midi(vector<int> &tokens, string &filepath, EncoderConfig *e) {
    midi::Piece p;
    decode(tokens, &p, e);
    write_midi(&p, filepath);
  }

  REPRESENTATION *rep;
};

// ================================================================
// ================================================================
// ================================================================
// ENCODING CLASSES

class TrackEncoder : public ENCODER  {
public:
  TrackEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}}},
      "no_velocity");
  }
  ~TrackEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true; // WATCH OUT
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class TrackBarMajorEncoder : public ENCODER  {
public:
  TrackBarMajorEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}}},
      "no_velocity");
  }
  ~TrackBarMajorEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    e->bar_major = true;
    e->num_bars = 16;
    e->piece_header = false;
    return main_enc(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return main_dec(tokens, p, rep, e);
  }
};

class SegmentEncoder : public ENCODER  {
public:
  SegmentEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{FILL_IN,3}},
      {{SEGMENT,1}},
      {{SEGMENT_END,1}},
      {{SEGMENT_FILL_IN,3}}
      },
      "no_velocity");
  }
  ~SegmentEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    e->segment_mode = true;
    e->num_bars = 16;
    e->min_tracks = 2;
    e->piece_header = false;
    e->do_multi_fill = true;
    e->fill_percentage = .5;
    update_valid_segments(p, e);
    if (p->valid_segments_size() == 0) {
      throw(1); // need to start over!
    }
    return main_enc(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    //throw(1); // its not implemented yet ...
    // split at each bar and create [segment][track][bar] of vectors
    int seg_num = 0;
    int track_num = 0;
    int bar_num = 0;
    int fill_num = 0;

    // also need to capture track headers
    vector<tuple<int,int,int>> fill_holder;
    vector<vector<vector<vector<int>>>> holder;
    bool bar_started = false;
    set<int> track_nums;

    for (const auto token : tokens) {
      if (rep->is_token_type(token, BAR)) {
        bar_started = true;
        holder[seg_num][track_num].push_back( vector<int>() );  
      }
      else if (rep->is_token_type(token, TRACK)) {
        holder[seg_num].push_back( vector<vector<int>>() );
        track_nums.insert( track_num );
      }
      else if (rep->is_token_type(token, SEGMENT)) {
        holder.push_back( vector<vector<vector<int>>>() );
      }
      else if (rep->is_token_type(token, BAR_END)) {
        holder[seg_num][track_num][bar_num].push_back(token);
        bar_started = false;
        bar_num++;
      }
      else if (rep->is_token_type(token, TRACK_END)) { 
        track_num++;
        bar_num = 0;
      }
      else if (rep->is_token_type(token, SEGMENT_END)) {
        seg_num++;
        track_num = 0;
        bar_num = 0;
      }
      else if (rep->is_token_type(token, FILL_IN)) {
        int value = rep->decode(token, FILL_IN);
        if (value == 0) {
          fill_holder.push_back( make_tuple(seg_num,track_num,bar_num) );
        }
        if (value == 1) {
          // reset seg_num, track_num and bar_num when fill segment starts
          tuple<int,int,int> loc = fill_holder[fill_num];
          seg_num = get<0>(loc);
          track_num = get<1>(loc);
          bar_num = get<2>(loc);
          bar_started = true;
        }
        else if (value == 2) {
          // want to push back BAR_END instead
          holder[seg_num][track_num][bar_num].push_back(token);
          bar_started = false;
          fill_num++;
        }
      }

      if (bar_started) {
        holder[seg_num][track_num][bar_num].push_back(token);
      }
    }

    // iterate over the bars are put in normal order
    vector<int> ftokens;
    for (const auto track_num : track_nums) {
      // add track header
      for (int seg_num=0; seg_num<4; seg_num++) {
        for (int bar_num=0; bar_num<4; bar_num++) {
          if ((holder[seg_num].size() > track_num) && (holder[seg_num][track_num].size() > bar_num)) {
            for (const auto token : holder[seg_num][track_num][bar_num]) {
              ftokens.push_back( token );
            }
          }
        }
      }
      // add track footer

    }
    
    //return main_dec(tokens, p, rep, e);
  }
};



// we use CD2 genre as it has 14138 matches
class TrackGenreEncoder : public ENCODER  {
public:
  TrackGenreEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{GENRE,16}}},
      "no_velocity");
  }
  ~TrackGenreEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    e->genre_header = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class TrackVelocityLevelEncoder : public ENCODER {
public:
  TrackVelocityLevelEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{VELOCITY_LEVEL,32}},
      {{TIME_DELTA,48}}},
      "magenta");
  }
  ~TrackVelocityLevelEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    e->use_velocity_levels = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class TrackVelocityEncoder : public ENCODER {
public:
  TrackVelocityEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,32}},
      {{TIME_DELTA,48}}},
      "magenta");
  }
  ~TrackVelocityEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class TrackInstHeaderEncoder : public ENCODER  {
public:
  TrackInstHeaderEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{HEADER,2}}},
      "no_velocity");
  }
  ~TrackInstHeaderEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    e->instrument_header = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class TrackOneBarFillEncoder : public ENCODER  {
public:
  TrackOneBarFillEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{FILL_IN,3}}},
      "no_velocity");
  }
  ~TrackOneBarFillEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->do_fill = true;
    e->force_instrument = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &raw_tokens, midi::Piece *p, EncoderConfig *e) {
    // before decoding we just need to insert the last section
    // this needs to be fixed for fill in end token
    int fill_placeholder = rep->encode({{FILL_IN,0}});
    int fill_start = rep->encode({{FILL_IN,1}});

    raw_tokens.pop_back(); // might not need this at all .. 
    vector<int> tokens;

    auto src = find(raw_tokens.begin(), raw_tokens.end(), fill_start);
    tokens.insert(tokens.begin(), raw_tokens.begin(), src);
    auto dst = find(tokens.begin(), tokens.end(), fill_placeholder);
    tokens.insert(dst, next(src), raw_tokens.end());
    tokens.erase(find(tokens.begin(), tokens.end(), fill_placeholder));

    return decode_track(tokens, p, rep, e);
  }
};

class TrackOneTwoThreeBarFillEncoder : public ENCODER {
public:
  TrackOneTwoThreeBarFillEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{FILL_IN,3}}},
      "no_velocity");
  }
  ~TrackOneTwoThreeBarFillEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->do_multi_fill = true;
    e->force_instrument = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &raw_tokens, midi::Piece *p, EncoderConfig *e) {
    // before decoding we insert fills into sequence
    // this solution works with any number of fills
    int fill_pholder = rep->encode({{FILL_IN,0}});
    int fill_start = rep->encode({{FILL_IN,1}});
    int fill_end = rep->encode({{FILL_IN,2}});

    vector<int> tokens;

    auto start_pholder = raw_tokens.begin();
    auto start_fill = raw_tokens.begin();
    auto end_fill = raw_tokens.begin();

    while (start_pholder != raw_tokens.end()) {
      start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS
      auto last_start_pholder = start_pholder;
      start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder);
      if (start_pholder != raw_tokens.end()) {
        start_fill = find(next(start_fill), raw_tokens.end(), fill_start);
        end_fill = find(next(end_fill), raw_tokens.end(), fill_end);

        // insert from last_start_pholder --> start_pholder
        tokens.insert(tokens.end(), last_start_pholder, start_pholder);
        tokens.insert(tokens.end(), next(start_fill), end_fill);
      }
      else {
        // insert from last_start_pholder --> end of sequence (excluding fill)
        start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start);
        tokens.insert(tokens.end(), last_start_pholder, start_fill);
      }
    }
    return decode_track(tokens, p, rep, e);
  }
};

class TrackBarFillEncoder : public ENCODER {
public:
  TrackBarFillEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{FILL_IN,3}}},
      "no_velocity");
  }
  ~TrackBarFillEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->do_multi_fill = true;
    e->fill_percentage = .5;
    e->force_instrument = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &raw_tokens, midi::Piece *p, EncoderConfig *e) {
    // before decoding we insert fills into sequence
    // this solution works with any number of fills
    int fill_pholder = rep->encode({{FILL_IN,0}});
    int fill_start = rep->encode({{FILL_IN,1}});
    int fill_end = rep->encode({{FILL_IN,2}});

    vector<int> tokens;

    auto start_pholder = raw_tokens.begin();
    auto start_fill = raw_tokens.begin();
    auto end_fill = raw_tokens.begin();

    while (start_pholder != raw_tokens.end()) {
      start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS
      auto last_start_pholder = start_pholder;
      start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder);
      if (start_pholder != raw_tokens.end()) {
        start_fill = find(next(start_fill), raw_tokens.end(), fill_start);
        end_fill = find(next(end_fill), raw_tokens.end(), fill_end);

        // insert from last_start_pholder --> start_pholder
        tokens.insert(tokens.end(), last_start_pholder, start_pholder);
        tokens.insert(tokens.end(), next(start_fill), end_fill);
      }
      else {
        // insert from last_start_pholder --> end of sequence (excluding fill)
        start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start);
        tokens.insert(tokens.end(), last_start_pholder, start_fill);
      }
    }
    return decode_track(tokens, p, rep, e);
  }
};

class TrackBarFillSixteenEncoder : public ENCODER {
public:
  TrackBarFillSixteenEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{FILL_IN,3}}},
      "no_velocity");
  }
  ~TrackBarFillSixteenEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    // FIX THIS ....
    //e->do_multi_fill = true;
    //e->fill_percentage = .5;
    e->force_instrument = true;
    e->num_bars = 16;
    e->min_tracks = 2;
    if (p->valid_segments_size() == 0) {
      throw(1); // need to start over!
    }
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &raw_tokens, midi::Piece *p, EncoderConfig *e) {
    // before decoding we insert fills into sequence
    // this solution works with any number of fills
    int fill_pholder = rep->encode({{FILL_IN,0}});
    int fill_start = rep->encode({{FILL_IN,1}});
    int fill_end = rep->encode({{FILL_IN,2}});

    vector<int> tokens;

    auto start_pholder = raw_tokens.begin();
    auto start_fill = raw_tokens.begin();
    auto end_fill = raw_tokens.begin();

    while (start_pholder != raw_tokens.end()) {
      start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS
      auto last_start_pholder = start_pholder;
      start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder);
      if (start_pholder != raw_tokens.end()) {
        start_fill = find(next(start_fill), raw_tokens.end(), fill_start);
        end_fill = find(next(end_fill), raw_tokens.end(), fill_end);

        // insert from last_start_pholder --> start_pholder
        tokens.insert(tokens.end(), last_start_pholder, start_pholder);
        tokens.insert(tokens.end(), next(start_fill), end_fill);
      }
      else {
        // insert from last_start_pholder --> end of sequence (excluding fill)
        start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start);
        tokens.insert(tokens.end(), last_start_pholder, start_fill);
      }
    }
    return decode_track(tokens, p, rep, e);
  }
};

// this version will allow for control of note density
// and monophonic or polyphonic
// also includes time signature update
class TrackMonoPolyDensityEncoder : public ENCODER  {
public:
  TrackMonoPolyDensityEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,3}}, // {mono, drums, polyphonic}
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{DENSITY_LEVEL,10}}, // 10 density levels
      {{TIME_SIGNATURE,5}}}, // 5 main time signatures
      "no_velocity");
  }
  ~TrackMonoPolyDensityEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->mark_polyphony = true;
    e->force_instrument = true;
    e->mark_density = true;
    e->mark_time_sigs = true;
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};


class TrackMonoPolyEncoder : public ENCODER {
public:
  TrackMonoPolyEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,3}}, // {mono, drums, polyphonic}
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}}},
      "no_velocity");
  }
  ~TrackMonoPolyEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->mark_polyphony = true;
    e->force_instrument = true;
    //update_polyphony(p); // done in dataset v4
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class PerformanceEncoder : public ENCODER {
public:
  PerformanceEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{PITCH,128},{VELOCITY,2}},
      {{NON_PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}}},
      "no_velocity");
  }
  ~PerformanceEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    // get all the events sorted
    vector<midi::Event> events = get_sorted_events(p);
    vector<int> tokens;
    int current_step = 0;
    int current_instrument = -1;
    int N_TIME_TOKENS = rep->get_domain(TIME_DELTA);
    array<int,128> VMAP = velocity_maps[rep->velocity_map_name];
    for (const auto event : events) {

      if (event.time() > current_step) {
        while (event.time() > current_step + N_TIME_TOKENS) {
          tokens.push_back( rep->encode({{TIME_DELTA,N_TIME_TOKENS-1}}) );
          current_step += N_TIME_TOKENS;
        }
        if (event.time() > current_step) {
          tokens.push_back( rep->encode(
            {{TIME_DELTA,event.time()-current_step-1}}) );
        }
        current_step = event.time();
      }
      else if (event.time() < current_step) {
        cout << "ERROR : events are not sorted." << endl;
        throw(2);
      }
      if (event.is_drum()) {
        tokens.push_back( rep->encode({
          {NON_PITCH,event.pitch()},
          {VELOCITY,VMAP[event.velocity()]}
        }));
      }
      else {
        tokens.push_back( rep->encode({
          {PITCH,event.pitch()},
          {VELOCITY,VMAP[event.velocity()]}
        }));
      }
    }
    return tokens;

  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    throw(1);
  }
};


// =====================================================================
// FINAL PAPER ENCODINGS
// track model
// track model w velocity
// barfill model
// barfill model w velocity

class TrackDensityEncoder : public ENCODER {
public:
  TrackDensityEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{DENSITY_LEVEL,10}}}, // 10 density levels
      "no_velocity");
  }
  ~TrackDensityEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    e->mark_density = true;
    //e->num_bars = 4;
    e->min_tracks = 1;
    update_note_density(p);
    if (!e->force_valid) {
      update_valid_segments(p, e);
      if (p->valid_segments_size() == 0) {
        cout << "NO VALID SEGMENTS" << endl;
        throw(1); // need to start over!
      }
    }
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class TrackDensityVelocityEncoder : public ENCODER {
public:
  TrackDensityVelocityEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{DENSITY_LEVEL,10}}, // 10 density levels
      {{VELOCITY_LEVEL,32}}}, 
      "magenta");
  }
  ~TrackDensityVelocityEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->force_instrument = true;
    e->mark_density = true;
    //e->num_bars = 4;
    e->min_tracks = 1;
    e->use_velocity_levels = true;
    update_note_density(p);
    if (!e->force_valid) {
      update_valid_segments(p, e);
      if (p->valid_segments_size() == 0) {
        cout << "NO VALID SEGMENTS" << endl;
        throw(1); // need to start over!
      }
    }
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &tokens, midi::Piece *p, EncoderConfig *e) {
    return decode_track(tokens, p, rep, e);
  }
};

class TrackBarFillDensityEncoder : public ENCODER {
public:
  TrackBarFillDensityEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{FILL_IN,3}},
      {{DENSITY_LEVEL,10}}}, // 10 density levels
      "no_velocity");
  }
  ~TrackBarFillDensityEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->do_multi_fill = true;
    e->fill_percentage = .5;
    e->force_instrument = true;
    e->mark_density = true;
    // multi fill line used to be here
    update_note_density(p);
    if (!e->force_valid) {
      // only need to do these things in train
      e->multi_fill.clear(); // so it chooses new bars to fill
      update_valid_segments(p, e);
      if (p->valid_segments_size() == 0) {
        cout << "NO VALID SEGMENTS" << endl;
        throw(1); // need to start over!
      }
    }
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &raw_tokens, midi::Piece *p, EncoderConfig *e) {
    // before decoding we insert fills into sequence
    // this solution works with any number of fills
    int fill_pholder = rep->encode({{FILL_IN,0}});
    int fill_start = rep->encode({{FILL_IN,1}});
    int fill_end = rep->encode({{FILL_IN,2}});

    vector<int> tokens;

    auto start_pholder = raw_tokens.begin();
    auto start_fill = raw_tokens.begin();
    auto end_fill = raw_tokens.begin();

    while (start_pholder != raw_tokens.end()) {
      start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS
      auto last_start_pholder = start_pholder;
      start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder);
      if (start_pholder != raw_tokens.end()) {
        start_fill = find(next(start_fill), raw_tokens.end(), fill_start);
        end_fill = find(next(end_fill), raw_tokens.end(), fill_end);

        // insert from last_start_pholder --> start_pholder
        tokens.insert(tokens.end(), last_start_pholder, start_pholder);
        tokens.insert(tokens.end(), next(start_fill), end_fill);
      }
      else {
        // insert from last_start_pholder --> end of sequence (excluding fill)
        start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start);
        tokens.insert(tokens.end(), last_start_pholder, start_fill);
      }
    }
    return decode_track(tokens, p, rep, e);
  }
};

class TrackBarFillDensityVelocityEncoder : public ENCODER {
public:
  TrackBarFillDensityVelocityEncoder() {
    rep = new REPRESENTATION({
      {{PIECE_START,1}},
      {{BAR,1}}, 
      {{BAR_END,1}}, 
      {{TRACK,2}},
      {{TRACK_END,1}},
      {{INSTRUMENT,128}},
      {{PITCH,128},{VELOCITY,2}},
      {{TIME_DELTA,48}},
      {{FILL_IN,3}},
      {{DENSITY_LEVEL,10}},
      {{VELOCITY_LEVEL,32}}}, // 10 density levels
      "magenta");
  }
  ~TrackBarFillDensityVelocityEncoder() {
    delete rep;
  }
  vector<int> encode(midi::Piece *p, EncoderConfig *e) {
    e->do_multi_fill = true;
    e->fill_percentage = .5;
    e->force_instrument = true;
    e->mark_density = true;
    e->use_velocity_levels = true;
    // multi fill line used to be here
    update_note_density(p);
    if (!e->force_valid) {
      e->multi_fill.clear(); // so it chooses new bars to fill
      update_valid_segments(p, e);
      if (p->valid_segments_size() == 0) {
        cout << "NO VALID SEGMENTS" << endl;
        throw(1); // need to start over!
      }
    }
    return to_performance_w_tracks(p, rep, e);
  }
  void decode(vector<int> &raw_tokens, midi::Piece *p, EncoderConfig *e) {
    // before decoding we insert fills into sequence
    // this solution works with any number of fills
    int fill_pholder = rep->encode({{FILL_IN,0}});
    int fill_start = rep->encode({{FILL_IN,1}});
    int fill_end = rep->encode({{FILL_IN,2}});

    vector<int> tokens;

    auto start_pholder = raw_tokens.begin();
    auto start_fill = raw_tokens.begin();
    auto end_fill = raw_tokens.begin();

    while (start_pholder != raw_tokens.end()) {
      start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS
      auto last_start_pholder = start_pholder;
      start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder);
      if (start_pholder != raw_tokens.end()) {
        start_fill = find(next(start_fill), raw_tokens.end(), fill_start);
        end_fill = find(next(end_fill), raw_tokens.end(), fill_end);

        // insert from last_start_pholder --> start_pholder
        tokens.insert(tokens.end(), last_start_pholder, start_pholder);
        tokens.insert(tokens.end(), next(start_fill), end_fill);
      }
      else {
        // insert from last_start_pholder --> end of sequence (excluding fill)
        start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start);
        tokens.insert(tokens.end(), last_start_pholder, start_fill);
      }
    }
    return decode_track(tokens, p, rep, e);
  }
};