#pragma once

#include <iostream>
#include <vector>
#include <tuple>
#include <map>
#include <set>

#include <iostream>
#include <fstream>
#include <sstream>
#include "../../midifile/include/Binasc.h"

#include "../../midifile/include/MidiFile.h"
#include "midi.pb.h"
#include "constants.h"
#include "encoder_config.h"
#include "density.h"

vector<midi::Event> get_sorted_events(midi::Piece *p) {
  vector<midi::Event> events;
  for (auto event : p->events()) {
    events.push_back( event );
  }
  sort(events.begin(), events.end(), [](const midi::Event a, const midi::Event b) { 
    if (a.time() != b.time()) {
      return a.time() < b.time();
    }
    if (min(a.velocity(),1) != min(b.velocity(),1)) {
      return min(a.velocity(),1) < min(b.velocity(),1);
    }
    if (a.pitch() != b.pitch()) {
      return a.pitch() < b.pitch();
    }
    return a.track() < b.track();
  });
  return events;
}

void update_pitch_limits(midi::Piece *p) {
  int min_pitch = INT_MAX;
  int max_pitch = 0;
  int track_num = 0;
  for (const auto track : p->tracks()) {
    int track_min_pitch = INT_MAX;
    int track_max_pitch = 0;
    for (const auto bar : track.bars()) {
      for (const auto event_id : bar.events()) {
        int pitch = p->events(event_id).pitch();
        track_min_pitch = min(pitch, track_min_pitch);
        track_max_pitch = max(pitch, track_max_pitch);
        min_pitch = min(pitch, min_pitch);
        max_pitch = max(pitch, max_pitch);
      }
    }
    p->mutable_tracks(track_num)->set_min_pitch( track_min_pitch );
    p->mutable_tracks(track_num)->set_max_pitch( track_max_pitch );
    track_num++;
  }
  p->set_min_pitch( min_pitch );
  p->set_max_pitch( max_pitch );
}

void update_valid_segments(midi::Piece *p, EncoderConfig *ec) {
  p->clear_valid_segments();
  p->clear_valid_tracks();

  // these should be able to passed in somewhere
  int seglen = ec->num_bars;
  int min_non_empty_bars = round(seglen * .75);
  int min_tracks = ec->min_tracks;

  if (p->tracks_size() < min_tracks) { return; } // no valid tracks

  // seglen + 1 crashes all the time ???
  for (int i=0; i<p->tracks(0).bars_size()-seglen; i++) {
    
    // check that all time sigs are supported
    bool supported_ts = true;
    bool is_four_four = true;
    for (int k=0; k<seglen; k++) {
      int beat_length = p->tracks(0).bars(i+k).beat_length();
      supported_ts &= (time_sig_map.find(beat_length) != time_sig_map.end());
      //is_four_four &= (bool)p->tracks(0).bars(i+k).is_four_four();
      is_four_four &= (beat_length == 4);
    }

    // check which tracks are valid
    uint32_t vtracks = 0;
    for (int j=0; j<p->tracks_size(); j++) {
      int non_empty_bars = 0;
      for (int k=0; k<seglen; k++) {
        if (p->tracks(j).bars(i+k).has_notes()) {
          non_empty_bars++;
        }
      }
      if (non_empty_bars >= min_non_empty_bars) {
        vtracks |= ((uint32_t)1 << j);
      }
    }

    // add to list if it meets the requirements
    if (!ec->mark_time_sigs) {
      supported_ts = is_four_four;
    }

    if ((__builtin_popcount(vtracks) >= min_tracks) && (supported_ts)) {
      p->add_valid_tracks(vtracks);
      p->add_valid_segments(i);
    }
  }
}

// =======================================================================
// average polyphony calculation

vector<midi::Note> track_events_to_notes(midi::Piece *p, int track_num) {
  midi::Event e;
  map<int,midi::Event> onsets;
  vector<midi::Note> notes;
  for (auto bar : p->tracks(track_num).bars()) {
    for (auto event_id : bar.events()) {
      e = p->events(event_id);
      if (e.velocity() > 0) {
        onsets[e.pitch()] = e;
      }
      else {
        auto it = onsets.find(e.pitch());
        if (it != onsets.end()) {
          midi::Event onset = it->second;
          midi::Note note;
          note.set_start( onset.time() );
          note.set_qstart( onset.qtime() );
          note.set_end( e.time() );
          note.set_qend( e.qtime() );
          note.set_velocity( onset.velocity() );
          note.set_pitch( onset.pitch() );
          note.set_instrument( onset.instrument() );
          note.set_track( onset.track() );
          note.set_bar( onset.bar() );
          note.set_is_drum( onset.is_drum() );
          notes.push_back(note);
          
          onsets.erase(it); // remove note
        }
      }
    }
  }
  return notes;
}

// for each tick measure amount of overlap
float average_polyphony(vector<midi::Note> &notes) {
  int max_tick = 0;
  for (auto note : notes) {
    max_tick = max(max_tick, note.end());
  }
  max_tick = min(max_tick, 1000000); // don't be too large
  vector<int> counts(max_tick,0);
  for (auto note : notes) {
    int s = max(0,note.start());
    int e = min(max_tick,note.end());
    for (int i=s; i<e; i++) {
      counts[i]++;
    }
  }
  float sum = 0;
  int total = 0;
  for (auto count : counts) {
    total += (int)(count > 0);
    sum += count;
  }
  return sum / total;
}

void update_polyphony(midi::Piece *p) {
  for (int i=0; i<p->tracks_size(); i++) {
    vector<midi::Note> notes = track_events_to_notes(p, i);
    p->mutable_tracks(i)->set_av_polyphony( average_polyphony(notes) );
  }
}

/*
bool notes_overlap(midi::Note *a, midi::Note *b) {
  return (a->start() >= b->start()) && (a->start() < b->end());
}

int max_polyphony(vector<midi::Note> &notes) {
  int max_poly = 0;
  vector<int> overlap_counts(notes.size(), 0);
  for (int i=0; i<notes.size(); i++) {
    for (int j=0; j<notes.size(); j++) {
      if ((i!=j) && notes_overlap(&notes[i],&notes[j])) {
        overlap_counts[i]++;
        max_poly = max(max_poly, overlap_counts[i]);
      }
    }
  }
  return max_poly;
}
*/

// =======================================================================

vector<int> MONO_DENSITY_QNT = {7, 10, 12, 14, 16, 18, 21, 26, 32, INT_MAX};
vector<int> POLY_DENSITY_QNT = {11, 16, 21, 28, 33, 41, 51, 64, 85, INT_MAX};

void update_note_density(midi::Piece *src) {
  for (int i=0; i<src->tracks_size(); i++) {
    // count the number of notes in a track
    int num_notes = 0;
    for (auto bar : src->tracks(i).bars()) {
      for (auto event_id : bar.events()) {
        if (src->events(event_id).velocity()) {
          num_notes++;
        }
      }
    }
    // map num notes to a density bin
    int index = 0;
    if (src->tracks(i).av_polyphony() < 1.1) {
      while (num_notes > MONO_DENSITY_QNT[index]) { index++; }
    }
    else {
      while (num_notes > POLY_DENSITY_QNT[index]) { index++; }
    }
    // update the protobuf message
    src->mutable_tracks(i)->set_note_density(index);

    // =======================================================
    // update note_density v2

    const midi::Track track = src->tracks(i);

    // calculate average notes per bar
    num_notes = 0;
    int bar_num = 0;
    set<int> valid_bars;
    for (auto bar : track.bars()) {
      for (auto event_id : bar.events()) {
        if (src->events(event_id).velocity()) {
          valid_bars.insert(bar_num);
          num_notes++;
        }
      }
      bar_num++;
    }
    int av_notes = round((double)num_notes / valid_bars.size());

    // calculate the density bin
    int qindex = track.instrument();
    if (track.is_drum()) {
      qindex = 128;
    }
    int bin = 0;
    while (av_notes > DENSITY_QUANTILES[qindex][bin]) { 
      bin++;
    }

    // update protobuf
    src->mutable_tracks(i)->set_note_density_v2(bin);
  }
}

// =============================================================
// =============================================================
// =============================================================

int check_type_inner(std::istream& input) {

  if (input.peek() != 'M') {
		std::stringstream binarydata;
		smf::Binasc binasc;
		binasc.writeToBinary(binarydata, input);
		binarydata.seekg(0, std::ios_base::beg);
		if (binarydata.peek() != 'M') {
			return -1;
		} else {
			return check_type_inner(binarydata);
		}
	}

	int    character;
	unsigned long  longdata;
	ushort shortdata;

	// Read the MIDI header (4 bytes of ID, 4 byte data size,
	// anticipated 6 bytes of data.

	character = input.get();
	if (character == EOF) {
		return -1;
	} else if (character != 'M') {
		return -1;
	}

	character = input.get();
	if (character == EOF) {
		return -1;
	} else if (character != 'T') {
		return -1;
	}

	character = input.get();
	if (character == EOF) {
		return -1;
	} else if (character != 'h') {
		return -1;
	}

	character = input.get();
	if (character == EOF) {
		return -1;
	} else if (character != 'd') {
		return -1;
	}

	// read header size (allow larger header size?)
	longdata = smf::MidiFile::readLittleEndian4Bytes(input);
	if (longdata != 6) {
		return -1;
	}

	// Header parameter #1: format type
	return smf::MidiFile::readLittleEndian2Bytes(input);
}

int check_type(const std::string& filename) {
  std::fstream input;
	input.open(filename.c_str(), std::ios::binary | std::ios::in);
  if (!input.is_open()) { return false; }
  return check_type_inner(input);
}


// =============================================================
// =============================================================
// =============================================================
static const int DRUM_CHANNEL = 9;

#define QUIET_CALL(noisy) { \
    cout.setstate(ios_base::failbit);\
    cerr.setstate(ios_base::failbit);\
    (noisy);\
    cout.clear();\
    cerr.clear();\
}

using Note = tuple<int,int,int,int,int>; // (ONSET,PITCH,DURATION,VELOCITY,INST)
using Event = tuple<int,int,int,int>; // (TIME,VELOCITY,PITCH,INSTRUMENT) 

using VOICE_TUPLE = tuple<int,int,int>; // CHANGE

int quantize_beat(double x, double TPQ, double SPQ, double cut=.5) {
  return (int)((x / TPQ * SPQ) + (1.-cut)) * (TPQ / SPQ);
}

int quantize_second(double x, double spq, double ticks, double steps_per_second, double cut=.5) {
  return (int)((x / ticks * spq * steps_per_second) + (1.-cut));
}

bool sort_events(const midi::Event *a, const midi::Event *b) { 
  if (a->time() != b->time()) {
    return a->time() < b->time();
  }
  if (min(a->velocity(),1) != min(b->velocity(),1)) {
    return min(a->velocity(),1) < min(b->velocity(),1);
  }
  if (a->pitch() != b->pitch()) {
    return a->pitch() < b->pitch();
  }
  return a->track() < b->track();
}

void parse_new(string filepath, midi::Piece *p, EncoderConfig *ec, map<string,vector<string>> *genre_data=NULL) {
  smf::MidiFile midifile;
  try {
    QUIET_CALL(midifile.read(filepath));
    midifile.makeAbsoluteTicks();
    midifile.linkNotePairs();
  }
  catch(int e) {
    //cerr << "MIDI PARSING FAILED ..." << endl;
    return; // nothing to parse
  }


  bool is_offset;
  int pitch, time, velocity, channel, current_tick, bar, voice;
  int start, end, start_bar, end_bar, rel_start, rel_end;
  int track_count = midifile.getTrackCount();
  int TPQ = midifile.getTPQ();
  vector<int> instruments(16,0);

  smf::MidiEvent *mevent;
  midi::Event *e;
  
  // get time signature and track/channel info
  int max_tick = 0;
  map<VOICE_TUPLE,int> track_map;
  map<int,VOICE_TUPLE> rev_track_map;
  map<int,tuple<int,int,int>> timesigs;
  //timesigs[0] = make_tuple(4,4,TPQ * 4);

  // loop over each track and determine ...
  // 1) voice membership
  // 2) time signatures
  // 3) the maximum tick
  for (int track=0; track<track_count; track++) {
    fill(instruments.begin(), instruments.end(), 0); // zero instruments
    for (int event=0; event<midifile[track].size(); event++) { 
      mevent = &(midifile[track][event]);
      if (mevent->isTimeSignature()) {
        int barlength = (double)(TPQ * 4 * (*mevent)[3]) / (1<<(*mevent)[4]);
        timesigs[mevent->tick] = make_tuple(
          (*mevent)[3], 1<<(*mevent)[4], barlength);
      }
      if ((mevent->isNoteOn()) || (mevent->isNoteOff())) {
        channel = mevent->getChannelNibble();
        VOICE_TUPLE vtup = make_tuple(track,channel,instruments[channel]);
        if (track_map.find(vtup) == track_map.end()) {
          int current_size = track_map.size();
          track_map[vtup] = current_size;
          rev_track_map[current_size] = vtup;
        }
        max_tick = max(max_tick, mevent->tick);
      }
      if (mevent->isTempo()) {
        p->set_tempo(mevent->getTempoBPM());
      }
      if (mevent->isPatchChange()) {
        channel = mevent->getChannelNibble();
        instruments[channel] = (int)((*mevent)[1]);
      } 
    }
  }
  
  
  map<int,int> barlines;
  map<int,int> barlengths;

  // over-ride to 4/4
  if (ec->force_four_four) {
    timesigs.clear();
    timesigs[0] = make_tuple(4,4,TPQ * 4);
  }

  // we only consider midi files which have time signature information
  // unless an ignore flag is passed
  // here we want to determine ...
  // 1) the start position of each bar
  // 2) the length of each bar
  if (timesigs.size() > 0) {

    // add TPQ to max_tick to account for possible quantization later
    // only add if max tick is greater than last timesig
    if (max_tick + TPQ > timesigs.rbegin()->first) {
      timesigs[max_tick + TPQ] = make_tuple(0,0,0);
    }

    int barlength;
    int current_bar;
    auto it = timesigs.begin();
    while (it->first < timesigs.rbegin()->first) {
      barlength = get<2>(it->second);
      for (int i=it->first; i<next(it)->first; i=i+barlength) {
        current_bar = barlines.size();
        barlines[i] = current_bar;
        barlengths[current_bar] = barlength;
      }
      it++;
    }
    // add last barline
    current_bar = barlines.size();
    barlines[timesigs.rbegin()->first + barlength] = current_bar;
    barlengths[current_bar] = barlength;
  }
  else {
    //cerr << "NO TIME SIGNATURE DATA ..." << endl;
    return; // don't have accurate time-sigs
  }
  
  int MIN_TICK = timesigs.begin()->first;
  vector<midi::Event*> events;
  //vector<int> min_pitches(track_map.size() + 1,128);
  //vector<int> max_pitches(track_map.size() + 1,0);

  // here we create a list of events
  for (int track=0; track<track_count; track++) {
    fill(instruments.begin(), instruments.end(), 0); // zero instruments
    for (int event=0; event<midifile[track].size(); event++) {
      mevent = &(midifile[track][event]);

      // we ignore notes that do not have an end
      if ((mevent->isNoteOn()) && (mevent->isLinked())) {
        pitch = (*mevent)[1];
        velocity = (int)(*mevent)[2];
        channel = mevent->getChannelNibble();
        VOICE_TUPLE vtup = make_tuple(track,channel,instruments[channel]);
        voice = track_map.find(vtup)->second; // CAN THIS FAIL ?

        // NOTE : this keeps values in tick units but already quantized
        start = quantize_beat(mevent->tick, TPQ, ec->resolution);
        end = quantize_beat(mevent->getLinkedEvent()->tick, TPQ, ec->resolution);

        // a valid note must be on or after the first barline
        // and must have a positive duration
        if ((start >= MIN_TICK) && (end - start > 0)) {

          auto start_bar_it = barlines.lower_bound(start);
          auto end_bar_it = barlines.lower_bound(end);

          // onset should be >= start_bar tick
          if (start_bar_it->first > start) {
            start_bar_it--;
          }
          // offset should be > end_bar tick
          // this pushes offsets on the barline into the previous bar 
          if (end_bar_it->first >= end) {
            end_bar_it--;
          }

          start_bar = start_bar_it->second;
          end_bar = end_bar_it->second;

          // start and end relative to the barline
          rel_start = ((start - start_bar_it->first) * ec->resolution) / TPQ;
          rel_end = ((end - end_bar_it->first) * ec->resolution) / TPQ;

          if ((start_bar >= 0) && (end_bar >= 0)) {

            // add onset
            events.push_back( new midi::Event);
            e = events.back();
            e->set_time((start * ec->resolution) / TPQ);
            e->set_velocity(velocity);
            e->set_pitch(pitch);
            e->set_instrument(instruments[channel]);
            e->set_track(voice);
            e->set_qtime(rel_start);
            e->set_is_drum(channel == DRUM_CHANNEL);
            e->set_bar(start_bar);

            // add offset
            events.push_back( new midi::Event);
            e = events.back();
            e->set_time((end * ec->resolution) / TPQ);
            e->set_velocity(0);
            e->set_pitch(pitch);
            e->set_instrument(instruments[channel]);
            e->set_track(voice);
            e->set_qtime(rel_end);
            e->set_is_drum(channel == DRUM_CHANNEL);
            e->set_bar(end_bar);
          }
        }
      }

      if (midifile[track][event].isPatchChange()) {
        channel = midifile[track][event].getChannelNibble();
        instruments[channel] = (int)midifile[track][event][1];
      }      
    }
  }

  // sort the events
  sort(events.begin(), events.end(), sort_events);

  int ntracks = track_map.size();
  vector<vector<midi::Bar*>> bars(ntracks, vector<midi::Bar*>(barlines.size()));
  for (int i=0; i<ntracks; i++) {
    midi::Track *t = p->add_tracks();
    t->set_is_drum( (get<1>(rev_track_map[i]) == DRUM_CHANNEL) );
    t->set_instrument( get<2>(rev_track_map[i]) ); // inst in rev_track_map
    int j = 0;
    for (const auto kv : barlines) {
      midi::Bar *b = t->add_bars();
      b->set_is_four_four( barlengths[j] == (TPQ*4) );
      b->set_has_notes( false );
      b->set_time( kv.first * ec->resolution / TPQ );
      b->set_beat_length( (double)barlengths[j] / TPQ );
      bars[i][j] = b;
      j++;
    }
  }

  int n = 0;
  for (const auto event : events) {
    midi::Event *e = p->add_events();
    e->CopyFrom(*event);
    bars[event->track()][event->bar()]->add_events(n);
    if (e->velocity() > 0) {
      bars[event->track()][event->bar()]->set_has_notes(true);
    }
    n++;
    delete event; // otherwise memory leak
  }

  // update meta-data about the piece
  update_valid_segments(p, ec);
  update_pitch_limits(p);
  update_note_density(p);
  update_polyphony(p);
  
  p->set_segment_length(ec->num_bars);
  p->set_resolution(ec->resolution);

  // add genre information if available
  if (genre_data) {
    for (const auto genre : (*genre_data)["msd_tagtraum_cd1"]) {
      p->add_msd_cd1( genre );
    } 
    for (const auto genre : (*genre_data)["msd_tagtraum_cd2"]) {
      p->add_msd_cd2( genre );
    }
    for (const auto genre : (*genre_data)["msd_tagtraum_cd2c"]) {
      p->add_msd_cd2c( genre );
    }
  }
}

// turn a piece into midi
// should barlines be added somewhere ???
void write_midi(midi::Piece *p, string &path) {

  smf::MidiFile outputfile;
  outputfile.absoluteTicks();
  outputfile.setTicksPerQuarterNote(p->resolution());
  outputfile.addTempo(0, 0, p->tempo());
  outputfile.addTrack(16); // ensure drum channel
  vector<int> current_inst(16,0);

  for (const auto event : p->events()) {

    int channel = event.track();
    if (event.is_drum()) {
      channel = DRUM_CHANNEL;
    }

    if (current_inst[event.track()] != event.instrument()) {
      outputfile.addPatchChange(
        event.track(), event.time(), channel, event.instrument());
      current_inst[event.track()] = event.instrument();
    }

    outputfile.addNoteOn(
      event.track(), // track
      event.time(), // time
      channel, // channel  
      event.pitch(), // pitch
      event.velocity()); // velocity (need some sort of conversion)
  }
  outputfile.sortTracks();         // make sure data is in correct order
  outputfile.write(path.c_str()); // write Standard MIDI File twinkle.mid
}

// =============================================================
// =============================================================
// =============================================================
// =============================================================
// =============================================================
// =============================================================

