#pragma once

#include <google/protobuf/util/json_util.h>

#include <vector>
#include "midi.pb.h"
#include "../enum/density.h"
#include "../enum/constants.h"
#include "../enum/te.h"
#include "../enum/gm.h"
#include "../enum/encoder_config.h"
#include "../random.h"

template<typename T>
vector<T> arange(T start, T stop, T step = 1) {
  vector<T> values;
  for (T value = start; value < stop; value += step)
    values.push_back(value);
  return values;
}

template<typename T>
vector<T> arange(T stop) {
  return arange(0, stop, 1);
}

struct RNG {
  int operator() (int n) {
    return std::rand() / (1.0 + RAND_MAX) * n;
  }
};

// select tracks for te rep

// ========================================================================
// MAX POLYPHONY

vector<midi::Note> track_events_to_notes(midi::Piece *p, int track_num, int *max_tick=NULL, bool no_drum_offsets=false) {
  midi::Event e;
  map<int,int> onsets;
  vector<midi::Note> notes;
  int bar_start = 0;
  for (auto bar : p->tracks(track_num).bars()) {
    for (auto event_id : bar.events()) {
      e = p->events(event_id);
      if (e.velocity() > 0) {
        if (is_drum_track(p->tracks(track_num).type()) && no_drum_offsets) {
          midi::Note note;
          note.set_start( bar_start + e.time() );
          note.set_end( bar_start + e.time() + 1 );
          note.set_pitch( e.pitch() );
          notes.push_back(note);
        }
        else {
          onsets[e.pitch()] = bar_start + e.time();
        }
      }
      else {
        auto it = onsets.find(e.pitch());
        if (it != onsets.end()) {
          midi::Note note;
          note.set_start( it->second );
          note.set_end( bar_start + e.time() );
          note.set_pitch( it->first );
          notes.push_back(note);
          
          onsets.erase(it); // remove note
        }
      }

      if (max_tick) {
        *max_tick = max(*max_tick, bar_start + e.time());
      }
    }
    // move forward a bar
    bar_start += p->resolution() * bar.beat_length(); 
  }
  return notes;
}

bool notes_overlap(midi::Note *a, midi::Note *b) {
  return (a->start() >= b->start()) && (a->start() < b->end());
}

int max_polyphony(vector<midi::Note> &notes, int max_tick) {
  if (max_tick > 100000) {
    throw runtime_error("MAX TICK TO LARGE!");
  }
  int max_polyphony = 0;
  vector<int> flat_roll(max_tick,0);
  for (const auto note : notes) {
    //cout << note.start() << " " << note.end() << " " << note.pitch() << endl;
    for (int t=note.start(); t<note.end(); t++) {
      flat_roll[t]++;
      max_polyphony = max(flat_roll[t],max_polyphony);
    }
  }
  return max_polyphony;
}

void update_max_polyphony(midi::Piece *p) {
  for (int i=0; i<p->tracks_size(); i++) {
    int max_tick = 0;
    vector<midi::Note> notes = track_events_to_notes(p, i, &max_tick);
    p->mutable_tracks(i)->set_max_polyphony( max_polyphony(notes, max_tick) );
  }
}

// ========================================================================
// FORCE MONOPHONIC

vector<midi::Event> force_monophonic(vector<midi::Event> &events) {
  bool note_sounding = false;
  midi::Event last_onset;
  vector<midi::Event> mono_events;

  for (const auto event : events) {
    if (event.velocity() > 0) {
      if ((note_sounding) && (event.time() > last_onset.time())) {
        // add note if it has nonzero length
        mono_events.push_back( last_onset );
        midi::Event offset;
        offset.CopyFrom(last_onset);
        offset.set_time( event.time() );
        offset.set_velocity( 0 );
        mono_events.push_back( offset );
        note_sounding = false;
      }

      // update last onset
      last_onset.CopyFrom( event );
      note_sounding = true;
    }
    else if ((note_sounding) && (last_onset.pitch() == event.pitch())) {
      // if we reach offset before next onset note remains unchanged
      mono_events.push_back( last_onset );
      mono_events.push_back( event );
      note_sounding = false;
    }
  }

  return mono_events;
}

// ========================================================================
// NOTE DENSITY

vector<tuple<int,int,int>> calculate_note_density(midi::Piece *x) {
  vector<tuple<int,int,int>> density;
  int num_notes;
  for (const auto track : x->tracks()) {
    for (const auto bar : track.bars()) {
      num_notes = 0;
      for (const auto event_index : bar.events()) {
        if (x->events(event_index).velocity()) {
          num_notes++;
        }
      }
      density.push_back(
        make_tuple(track.type(), track.instrument(), num_notes));
    }
  }
  return density;
}

void update_note_density(midi::Piece *x) {

  int track_num = 0;
  int num_notes, bar_num;
  for (const auto track : x->tracks()) {

    // calculate average notes per bar
    num_notes = 0;
    int bar_num = 0;
    set<int> valid_bars;
    for (const auto bar : track.bars()) {
      for (const auto event_index : bar.events()) {
        if (x->events(event_index).velocity()) {
          valid_bars.insert(bar_num);
          num_notes++;
        }
      }
      bar_num++;
    }
    int av_notes = round((double)num_notes / valid_bars.size());

    // calculate the density bin
    int qindex = track.instrument();
    if (track.is_drum()) {
      qindex = 128;
    }
    int bin = 0;
    while (av_notes > DENSITY_QUANTILES[qindex][bin]) { 
      bin++;
    }

    // update protobuf
    x->mutable_tracks(track_num)->set_note_density_v2(bin);
    track_num++;
  }
}

// ========================================================================
// EMPTY BARS

void update_has_notes(midi::Piece *x) {
  //cout << "entering update_has_notes" << endl;
  int track_num = 0;
  for (const auto track : x->tracks()) {
    int bar_num = 0;
    for (const auto bar : track.bars()) {
      bool has_notes = false;
      for (const auto event_index : bar.events()) {
        if (x->events(event_index).velocity()>0) {
          has_notes = true;
          break;
        }
      }
      x->mutable_tracks(track_num)->mutable_bars(bar_num)->set_has_notes(has_notes);
      bar_num++;
    }
    track_num++;
  }
}

int get_num_bars(midi::Piece *x) {
  if (x->tracks_size() == 0) {
    return 0;
  }
  set<int> lengths;
  for (const auto track : x->tracks()) {
    lengths.insert( track.bars_size() );
  }
  if (lengths.size() > 1) {
    throw std::runtime_error("Each track must have the same number of bars!");
  }
  return *lengths.begin();
}

void reorder_tracks(midi::Piece *x, vector<int> track_order) {
  int num_tracks = x->tracks_size();
  if (num_tracks != track_order.size()) {
    cout << num_tracks << " " << track_order.size() << endl;
    throw runtime_error("Track order does not match midi::Piece.");
  }
  for (int track_num=0; track_num<num_tracks; track_num++) {
    x->mutable_tracks(track_num)->set_order(track_order[track_num]);
  }
  sort(
    x->mutable_tracks()->begin(), 
    x->mutable_tracks()->end(), 
    [](const midi::Track &a, const midi::Track &b){ 
      return a.order() < b.order();
    }
  );
}

void prune_tracks_dev2(midi::Piece *x, vector<int> tracks, vector<int> bars) {

  if (x->tracks_size() == 0) {
    return;
  }

  midi::Piece tmp(*x);

  int num_bars = get_num_bars(x);
  bool remove_bars = bars.size() > 0;
  x->clear_tracks();
  x->clear_events();

  vector<int> tracks_to_keep;
  for (const auto track_num : tracks) {
    if ((track_num >= 0) && (track_num < tmp.tracks_size())) {
      tracks_to_keep.push_back(track_num);
    }
  }

  vector<int> bars_to_keep;
  for (const auto bar_num : bars) {
    if ((bar_num >= 0) && (bar_num < num_bars)) {
      bars_to_keep.push_back(bar_num);
    }
  }

  for (const auto track_num : tracks_to_keep) {
    const midi::Track track = tmp.tracks(track_num);
    midi::Track *t = x->add_tracks();
    t->CopyFrom( track );
    if (remove_bars) {
      t->clear_bars();
      for (const auto bar_num : bars_to_keep) {
        const midi::Bar bar = track.bars(bar_num);
        midi::Bar *b  = t->add_bars();
        b->CopyFrom( bar );
        b->clear_events();
        for (const auto event_index : bar.events()) {
          b->add_events( x->events_size() );
          midi::Event *e = x->add_events();
          e->CopyFrom( tmp.events(event_index) );
        }
      }
    }
  }

  //if (x->events_size() == 0) {
  //  throw std::runtime_error("NO EVENTS COPIED");
  //}
}

bool track_has_notes(midi::Piece *x, const midi::Track &track, vector<int> &bars_to_keep) {
  for (const auto bar_num : bars_to_keep) {
    for (const auto event_index : track.bars(bar_num).events()) {
      if (x->events(event_index).velocity()>0) {
        return true;
      }
    }
  }
  return false;
}

void prune_empty_tracks(midi::Piece *x, vector<int> &bars_to_keep) {
  vector<int> tracks_to_keep;
  int track_num = 0;
  for (const auto track : x->tracks()) {
    if (track_has_notes(x,track,bars_to_keep)) {
      tracks_to_keep.push_back(track_num);
    }
    track_num++;
  }
  prune_tracks_dev2(x, tracks_to_keep, bars_to_keep);
}

void shuffle_tracks_dev(midi::Piece *x, mt19937 *engine) {
  vector<int> tracks = arange(0,x->tracks_size(),1);
  shuffle(tracks.begin(), tracks.end(), *engine);
  prune_tracks_dev2(x, tracks, {});
}

// ========================================================================
// RANDOM SEGMENT SELECTION FOR TRAINING
// 
// 1. we select an index of a random segment


void update_valid_segments(midi::Piece *x, int seglen, int min_tracks, bool opz) {
  update_has_notes(x);
  x->clear_valid_segments();
  x->clear_valid_tracks();

  if (x->tracks_size() < min_tracks) { return; } // no valid tracks

  int min_non_empty_bars = round(seglen * .75);
  int num_bars = get_num_bars(x);
  
  for (int start=0; start<num_bars-seglen+1; start++) {
    
    // check that all time sigs are supported
    bool supported_ts = true;
    bool is_four_four = true;
    for (int k=0; k<seglen; k++) {
      int beat_length = x->tracks(0).bars(start+k).beat_length();
      supported_ts &= (time_sig_map.find(beat_length) != time_sig_map.end());
      is_four_four &= (beat_length == 4);
    }

    // check which tracks are valid
    midi::ValidTrack vtracks;
    map<int,int> used_track_types;
    for (int track_num=0; track_num<x->tracks_size(); track_num++) {
      int non_empty_bars = 0;
      for (int k=0; k<seglen; k++) {
        if (x->tracks(track_num).bars(start+k).has_notes()) {
          non_empty_bars++;
        }
      }
      if (non_empty_bars >= min_non_empty_bars) {
        vtracks.add_tracks( track_num );
        if (opz) {
          // product of train types should be different
          int combined_train_type = 1;
          for (const auto train_type : x->tracks(track_num).train_types()) {
            combined_train_type *= train_type;
          }
          used_track_types[combined_train_type]++;
        }
      }
    }

    // check if there are enough tracks
    bool enough_tracks = vtracks.tracks_size() >= min_tracks;
    if (opz) {
      // for OPZ we can't count repeated track types
      // as we train on only one track per track type
      bool opz_valid = used_track_types.size() >= min_tracks;
      
      // also valid if we have more than one multi possibility track
      auto it = used_track_types.find(OPZ_ARP_TRACK * OPZ_LEAD_TRACK);
      opz_valid |= ((it != used_track_types.end()) && (it->second > 1));

      it = used_track_types.find(
        OPZ_ARP_TRACK * OPZ_LEAD_TRACK * OPZ_CHORD_TRACK);
      opz_valid |= ((it != used_track_types.end()) && (it->second > 1));

      enough_tracks &= opz_valid;
    }

    //cout << enough_tracks << " " << is_four_four << endl;

    if (enough_tracks && is_four_four) {
      midi::ValidTrack *v = x->add_valid_tracks_v2();
      v->CopyFrom(vtracks);
      x->add_valid_segments(start);
    }
  }
}

void select_random_segment(midi::Piece *x, int num_bars, int min_tracks, int max_tracks, bool opz, std::mt19937 *engine) {
  update_valid_segments(x, num_bars, min_tracks, opz);
  
  if (x->valid_segments_size() == 0) {
    throw std::runtime_error("NO VALID SEGMENTS");
  }

  //int index = rand() % x->valid_segments_size();
  int index = random_on_range(x->valid_segments_size(), engine);
  int start = x->valid_segments(index);
  vector<int> valid_tracks;
  for (const auto track_num : x->valid_tracks_v2(index).tracks()) {
    valid_tracks.push_back(track_num);
  }
  shuffle(valid_tracks.begin(), valid_tracks.end(), *engine);
  vector<int> bars = arange(start,start+num_bars,1);
  
  if (opz) {
    // filter out duplicate OPZ tracks
    // convert train_track_types to type
    // randomly pick a track type for each track

    vector<int> pruned_tracks;
    vector<int> used(NUM_TRACK_TYPES,0);
    for (const auto track_num : valid_tracks) {
      vector<int> track_options;
      for (const auto track_type : x->tracks(track_num).train_types()) {
        track_options.push_back( track_type );
      }
      shuffle(track_options.begin(), track_options.end(), *engine);
      for (const auto track_type : track_options) {
        if ((track_type >= 0) && (track_type < NUM_TRACK_TYPES)) {
          if ((used[track_type] == 0) && (track_type <= OPZ_CHORD_TRACK)) {
            pruned_tracks.push_back( track_num );
            // set the track type to the one randomly selected
            x->mutable_tracks(track_num)->set_type( track_type );
            used[track_type] = 1;
            break;
          }
        }
      }
    }
    valid_tracks = pruned_tracks;
    
    // it is possible that we have less than min tracks here
    // throw an exception if this is the case
    if (valid_tracks.size() < min_tracks) {
      throw runtime_error("LESS THAN MIN TRACKS");
    }
  }
  else {
    // limit the tracks
    int ntracks = min((int)valid_tracks.size(), max_tracks);
    valid_tracks.resize(ntracks);
  }

  prune_tracks_dev2(x, valid_tracks, bars);
}



// other helpers for training ...

tuple<int,int> get_pitch_extents(midi::Piece *x) {
  int min_pitch = INT_MAX;
  int max_pitch = 0;
  for (const auto track : x->tracks()) {
    if (!is_drum_track(track.type())) {
      for (const auto bar : track.bars()) {
        for (const auto event_index : bar.events()) {
          int pitch = x->events(event_index).pitch();
          min_pitch = min(pitch, min_pitch);
          max_pitch = max(pitch, max_pitch);
        }
      }
    }
  }
  return make_pair(min_pitch, max_pitch);
}

set<tuple<int,int>> make_bar_mask(midi::Piece *x, float proportion, std::mt19937 *engine) {
  int num_tracks = x->tracks_size();
  int num_bars = get_num_bars(x);
  int max_filled_bars = (int)round(num_tracks * num_bars * proportion);
  int n_fill = random_on_range(max_filled_bars, engine);
  //int n_fill = rand() % (int)round(num_tracks * num_bars * proportion);
  vector<tuple<int,int>> choices;
  for (int track_num=0; track_num<num_tracks; track_num++) {
    for (int bar_num=0; bar_num<num_bars; bar_num++) {
      choices.push_back(make_pair(track_num,bar_num));
    }
  }
  set<tuple<int,int>> mask;
  shuffle(choices.begin(), choices.end(), *engine);
  for (int i=0; i<n_fill; i++) {
    mask.insert(choices[i]);
  }
  return mask;
}

// conversion for gm
string gm_inst_to_string(int track_type, int instrument) {
  return GM_REV[is_drum_track(track_type) * 128 + instrument];
}

// wrap these for acccesibility in python
// ===================================================================
// ===================================================================
// ===================================================================
// ===================================================================
// ===================================================================

midi::Piece string_to_piece(string json_string) {
  midi::Piece x;
  google::protobuf::util::JsonStringToMessage(json_string.c_str(), &x);
  return x;
}

string piece_to_string(midi::Piece x) {
  string json_string;
  google::protobuf::util::MessageToJsonString(x, &json_string);
  return json_string;
}

string update_valid_segments_py(string json_string, int num_bars, int min_tracks, bool opz) {
  midi::Piece x = string_to_piece(json_string);
  update_valid_segments(&x, num_bars, min_tracks, opz);
  return piece_to_string(x);
}

string update_note_density_py(string json_string) {
  midi::Piece x = string_to_piece(json_string);
  update_note_density(&x);
  return piece_to_string(x);
}

string prune_empty_tracks_py(string json_string, vector<int> bars) {
  midi::Piece x = string_to_piece(json_string);
  prune_empty_tracks(&x, bars);
  return piece_to_string(x);
}

string prune_tracks_py(string json_string, vector<int> tracks, vector<int> bars) {
  midi::Piece x = string_to_piece(json_string);
  prune_tracks_dev2(&x, tracks, bars);
  return piece_to_string(x);
}

string select_random_segment_py(string json_string, int num_bars, int min_tracks, int max_tracks, bool opz, int seed) {
  std::mt19937 engine(seed);
  midi::Piece x = string_to_piece(json_string);
  select_random_segment(&x, num_bars, min_tracks, max_tracks, opz, &engine);
  return piece_to_string(x);
}

string reorder_tracks_py(string json_string, vector<int> &track_order) {
  midi::Piece x = string_to_piece(json_string);
  reorder_tracks(&x, track_order);
  return piece_to_string(x);
}

