#pragma once

#include <google/protobuf/util/json_util.h>

#include <vector>
#include <map>
#include <tuple>
#include <array>
#include <utility>

#include "../protobuf/midi.pb.h"
#include "../enum/token_types.h"
#include "../enum/constants.h"
#include "../enum/encoder_config.h"

// development versions

vector<int> to_performance_dev(midi::Bar *bar, midi::Piece *p, REPRESENTATION *rep, int transpose, bool is_drum, EncoderConfig *ec) {
  vector<int> tokens;
  int current_step = 0;
  int current_velocity = -1;
  int current_instrument = -1;
  int N_TIME_TOKENS = rep->get_domain_size(TIME_DELTA);
  bool added_instrument = false;
  for (const auto i : bar->events()) {
    midi::Event event = p->events(i);
    if ((!is_drum) || (event.velocity()>0) || (ec->use_drum_offsets)) {
      //int qvel = velocity_maps[rep->velocity_map_name][event.velocity()];
      int qvel = event.velocity() > 0;

      if (event.time() > current_step) {
        while (event.time() > current_step + N_TIME_TOKENS) {
          tokens.push_back( rep->encode(TIME_DELTA, N_TIME_TOKENS-1) );
          current_step += N_TIME_TOKENS;
        }
        if (event.time() > current_step) {
          tokens.push_back( rep->encode(
            TIME_DELTA, event.time()-current_step-1) );
        }
        current_step = event.time();
      }
      else if (event.time() < current_step) {
        throw std::runtime_error("Events are not sorted!");
      }
      // if the rep contains velocity levels
      if (ec->use_velocity_levels) {
        if ((qvel > 0) && (qvel != current_velocity)) {
          tokens.push_back( rep->encode(VELOCITY_LEVEL, qvel) );
          current_velocity = qvel;
        }
        qvel = min(1,qvel); // flatten down to binary for note
      }
      if (qvel==0) {
        tokens.push_back( rep->encode(NOTE_OFFSET, event.pitch() + transpose) );
      }
      else {
        tokens.push_back( rep->encode(NOTE_ONSET, event.pitch() + transpose) );
      }
    }
  }
  return tokens;
}

vector<int> to_interleaved_performance_inner(vector<midi::Event> &events, REPRESENTATION *rep, EncoderConfig *ec) {
  vector<int> tokens;
  int current_step = 0;
  int current_velocity = -1;
  int current_instrument = -1;
  int N_TIME_TOKENS = rep->get_domain_size(TIME_DELTA);
  bool added_instrument = false;
  for (const auto event : events) {
    bool is_drum = is_drum_track(event.track_type());
    int current_transpose = ec->transpose * (!is_drum);
    if ((!is_drum) || (event.velocity()>0) || (ec->use_drum_offsets)) {
      int qvel = event.velocity() > 0;
      if (event.time() > current_step) {
        while (event.time() > current_step + N_TIME_TOKENS) {
          tokens.push_back( rep->encode(TIME_DELTA, N_TIME_TOKENS-1) );
          current_step += N_TIME_TOKENS;
        }
        if (event.time() > current_step) {
          tokens.push_back( rep->encode(
            TIME_DELTA, event.time()-current_step-1) );
        }
        current_step = event.time();
      }
      else if (event.time() < current_step) {
        throw std::runtime_error("Events are not sorted!");
      }
      if (event.instrument() != current_instrument) {
        tokens.push_back( rep->encode(INSTRUMENT, event.instrument()) );
        current_instrument = event.instrument();
      }
      if (ec->use_velocity_levels) {
        if ((qvel > 0) && (qvel != current_velocity)) {
          tokens.push_back( rep->encode(VELOCITY_LEVEL, qvel) );
          current_velocity = qvel;
        }
        qvel = min(1,qvel); // flatten down to binary for note
      }
      if (qvel==0) {
        tokens.push_back( 
          rep->encode(NOTE_OFFSET, event.pitch() + current_transpose) );
      }
      else {
        tokens.push_back( 
          rep->encode(NOTE_ONSET, event.pitch() + current_transpose) );
      }
    }
  }
  return tokens;
}

bool sort_events_winst(const midi::Event a, const midi::Event b) { 
  if (a.time() != b.time()) {
    return a.time() < b.time();
  }
  if (min(a.velocity(),1) != min(b.velocity(),1)) {
    return min(a.velocity(),1) < min(b.velocity(),1);
  }
  if (a.instrument() != b.instrument()) {
    return a.instrument() < b.instrument();
  }
  return a.pitch() < b.pitch();
}

vector<int> to_interleaved_performance(midi::Piece *p, REPRESENTATION *rep, EncoderConfig *e) {

  vector<int> tokens;

  tokens.push_back( rep->encode(PIECE_START, 0) );

  // TODO : include a header here
  // to control the instruments in sampling
  if (rep->get_domain_size(HEADER) == 2) {
    tokens.push_back( rep->encode(HEADER,0) );
    set<int> insts;
    for (const auto track : p->tracks()) {
      insts.insert( track.instrument() + 128*is_drum_track(track.type()) );
    }
    for (const auto inst : insts) {
      tokens.push_back( rep->encode(INSTRUMENT,inst) );
    }
    tokens.push_back( rep->encode(HEADER,1) );
  }

  int num_bars = get_num_bars(p);
  for (int bar_num=0; bar_num<num_bars; bar_num++) {
    vector<midi::Event> events;
    for (int track_num=0; track_num<p->tracks_size(); track_num++) {
      int track_type = p->tracks(track_num).type();
      bool is_drum = is_drum_track(track_type); 
      for (const auto index : p->tracks(track_num).bars(bar_num).events()) {
        midi::Event e;
        e.CopyFrom( p->events(index) );
        e.set_instrument( p->tracks(track_num).instrument() + 128*is_drum );
        e.set_track_type( track_type );
        events.push_back( e );
      }
    }

    // sort the events by time, onset/offset, track_type, instrument, pitch
    sort(events.begin(), events.end(), sort_events_winst);

    // make bar
    tokens.push_back( rep->encode(BAR, 0) );
    vector<int> bar_tokens = to_interleaved_performance_inner(events, rep, e);
    tokens.insert(tokens.end(), bar_tokens.begin(), bar_tokens.end());
    tokens.push_back( rep->encode(BAR_END, 0) );
  }
  return tokens;
}

vector<int> to_performance_w_tracks_dev(midi::Piece *p, REPRESENTATION *rep, EncoderConfig *e) {

  vector<int> tokens;
  int cur_transpose;
  
  
  if (rep->get_domain_size(PIECE_START) == 2) {
    // here we are combining bar infill and track infill
    tokens.push_back( rep->encode(PIECE_START, (int)e->do_multi_fill) );
  }
  else {
    tokens.push_back( rep->encode(PIECE_START, 0) );
  }

  int total_bars = get_num_bars(p);
  vector<vector<int>> bar_segments;
  if (e->multi_segment) {
    int NB = 4; // TODO :: alow for control of this
    for (int SB=0; SB<(int)(total_bars/NB)*NB; SB=SB+NB) {
      bar_segments.push_back( arange(SB,SB+NB) );
    }
  }
  else {
    bar_segments.push_back( arange(0,total_bars) );
  }

  for (const auto bar_segment : bar_segments) {

    // start each segment with a segment token
    if (e->multi_segment) {
      tokens.push_back( rep->encode(SEGMENT, 0) );
    }

    for (int track_num=0; track_num<p->tracks_size(); track_num++) {

      midi::Track track = p->tracks(track_num);

      bool is_drum = is_drum_track( track.type() );
      cur_transpose = e->transpose;
      if (is_drum) {
        cur_transpose = 0;
      }

      tokens.push_back( rep->encode(TRACK, track.type()) );
      tokens.push_back( rep->encode(INSTRUMENT,track.instrument()) );
      
      if (e->mark_density) {
        tokens.push_back( rep->encode(DENSITY_LEVEL, track.note_density_v2()) );
      }

      for (const auto bar_num : bar_segment) {

        if (bar_num >= total_bars) {
          throw runtime_error("BAR NUMBER OUT OF RANGE!");
        }

        midi::Bar bar = track.bars(bar_num);
        tokens.push_back( rep->encode(BAR, 0) );
        if (e->mark_time_sigs) {
          // assume that beat_length is in the map
          int ts = time_sig_map[round(bar.beat_length())];
          tokens.push_back( rep->encode(TIME_SIGNATURE, ts) );
        }
        // for multi fill i should include segment_num here
        // to make masking easier ...
        if ((e->do_multi_fill) && (e->multi_fill.find(make_pair(track_num,bar_num)) != e->multi_fill.end())) {
          tokens.push_back( rep->encode(FILL_IN, 0) );
        }
        else {
          vector<int> bar_tokens = to_performance_dev(
            &bar, p, rep, cur_transpose, is_drum, e);
          tokens.insert(tokens.end(), bar_tokens.begin(), bar_tokens.end());
        }
        tokens.push_back( rep->encode(BAR_END, 0) );
      }
      tokens.push_back( rep->encode(TRACK_END, 0) );
    }

    // end each segment with a segment token
    if (e->multi_segment) {
      tokens.push_back( rep->encode(SEGMENT_END, 0) );
    }
  }

  if (e->do_multi_fill) {
    for (const auto track_bar : e->multi_fill) {
      int fill_track = get<0>(track_bar);
      int fill_bar = get<1>(track_bar);
      bool is_drum = is_drum_track( p->tracks(fill_track).type() );

      cur_transpose = e->transpose;
      if (is_drum) {
        cur_transpose = 0;
      }
      tokens.push_back( rep->encode(FILL_IN, 1) ); // begin fill-in
      midi::Bar bar = p->tracks(fill_track).bars(fill_bar);
      vector<int> bar_tokens = to_performance_dev(
        &bar, p, rep, cur_transpose, is_drum, e);
      tokens.insert(tokens.end(), bar_tokens.begin(), bar_tokens.end());
      tokens.push_back( rep->encode(FILL_IN, 2) ); // end fill-in
    }
  }

  return tokens;
}

// fix the deconversion
void decode_track_dev(vector<int> &tokens, midi::Piece *p, REPRESENTATION *rep, EncoderConfig *ec) {
  p->set_tempo(ec->default_tempo);
  p->set_resolution(ec->resolution);

  map<int,int> inst_to_track;
  midi::Event *e = NULL;
  midi::Track *t = NULL;
  midi::Bar *b = NULL;
  int current_time, current_instrument, bar_start_time, current_track, current_velocity, beat_length;
  int track_count = 0;
  int bar_count = 0;
  for (const auto token : tokens) {

    //cout << "DECODING ... " << rep->pretty(token) << endl;

    if (rep->is_token_type(token, SEGMENT)) {
      track_count = 0; // reset track count
      t = NULL;
      b = NULL;
    }
    if (rep->is_token_type(token, TRACK)) {
      current_time = 0; // restart the time
      current_instrument = 0; // reset instrument
      if (track_count >= p->tracks_size()) {
        t = p->add_tracks();
      }
      else {
        t = p->mutable_tracks(track_count);
      }
      t->set_is_drum( is_drum_track(rep->decode(token)) );
      t->set_type( rep->decode(token) );
    }
    else if (rep->is_token_type(token, TRACK_END)) {
      track_count++;
      t = NULL;
    }
    else if (rep->is_token_type(token, BAR)) {
      current_time = 0; // restart the time
      beat_length = 4; // default value optionally overidden with TIME_SIGNATURE
      if (!ec->interleaved) {
        b = t->add_bars();
      }
      bar_count++;
    }
    else if (rep->is_token_type(token, TIME_SIGNATURE)) {
      beat_length = rev_time_sig_map[rep->decode(token)];
    }
    else if (rep->is_token_type(token, BAR_END)) {
      if (b) {
        b->set_beat_length(beat_length);
      }
      b = NULL;
    }
    else if (rep->is_token_type(token, TIME_DELTA)) {
      current_time += (rep->decode(token) + 1);
    }
    else if (rep->is_token_type(token, INSTRUMENT)) {
        
      // if we are in track interleaved mode
      // we need to retrive track from instrument
      if (ec->interleaved) {
        current_instrument = rep->decode(token);
        auto it = inst_to_track.find( current_instrument );
        if (it != inst_to_track.end()) {
          t = p->mutable_tracks(it->second);
        }
        else {
          inst_to_track[current_instrument] = track_count;
          t = p->add_tracks();
          t->set_instrument( current_instrument % 128 );
          if (current_instrument >= 128) {
            t->set_is_drum( true );
            t->set_type( STANDARD_DRUM_TRACK );
          }
          else {
            t->set_is_drum( false );
            t->set_type( DRUM_TRACK );
          }
          track_count++;
        }
        // add bars and get current bar
        int curr_bars = t->bars_size();
        for (int n=curr_bars; n<bar_count; n++) {
          b = t->add_bars();
          b->set_beat_length( 4 ); // set to default
        }
        b = t->mutable_bars(bar_count-1); // make sure to get right bar
      }
      else if (t) {
        current_instrument = rep->decode(token);
        t->set_instrument( current_instrument );
      }
    }
    else if (rep->is_token_type(token, VELOCITY_LEVEL)) {
      current_velocity = rep->decode(token);
    }
    else if (rep->is_token_type(token, NOTE_ONSET) || rep->is_token_type(token, NOTE_OFFSET)) {
      if (b && t) {
        int current_note_index = p->events_size();
        e = p->add_events();
        e->set_pitch( rep->decode(token) );
        e->set_velocity( 100 );
        if (rep->is_token_type(token, NOTE_OFFSET)) {
          e->set_velocity( 0 );
        }
        e->set_time( current_time );
        b->add_events( current_note_index );
        b->set_has_notes( true );
      }
    }

  }
  p->add_valid_segments(0);
  p->add_valid_tracks((1<<p->tracks_size())-1);

  // add extra bars if needed ...
  if (ec->interleaved) {
    for (int track_num=0; track_num<p->tracks_size(); track_num++) {
      t = p->mutable_tracks(track_num);
      int curr_bars = t->bars_size();
      for (int n=curr_bars; n<bar_count; n++) {
        t->add_bars();
      }
    }
  }

  // for debug
  /*
  int track_num = 0;
  for (const auto track : p->tracks()) {
    int bar_num = 0;
    for (const auto bar : track.bars()) {
      cout << "TRACK " << track_num << " BAR " << bar_num << " " << bar.events_size() << endl;
      bar_num++;
    }
    track_num++;
  }
  */
  
  // update note density
  update_note_density(p);
}


// ================================================================
// ================================================================
// ================================================================
// DECODING HELPERS

vector<int> strip_header(vector<int> &tokens, REPRESENTATION *rep) {
  vector<int> fixed;
  bool in_header = false;
  for (const auto token : tokens) {
    if (rep->is_token_type(token, HEADER)) {
      in_header = !in_header;
    }
    else if (!in_header) {
      fixed.push_back( token );
    }
  }
  cout << fixed.size() << " " << tokens.size() << endl;
  return fixed;
}