#pragma once

#include <iostream>
#include <string>
#include <vector>
#include <array>
#include <set>

using namespace std;

#include "encoder/encoder_all.h"
#include "enum/gm.h"
#include "enum/model_type.h"
#include "enum/encoder_types.h"

#include "protobuf/validate.h"
#include "protobuf/util.h"

static const int NUM_LAYERS = 6;

class Control {
public:
  Control() {
    initialize(NULL);
  }
  Control(ENCODER *e) {
    initialize(e);
  }
  void initialize(ENCODER *e) {
    if (e) {
      encoder = e;
      int vocab_size = encoder->rep->max_token();
      null_trigger = std::vector<int>(vocab_size,0);
      null_mask = std::vector<int>(vocab_size,1);
    }
  }
  std::vector<int> encode(TOKEN_TYPE tt, vector<int> values) {
    if (tt == NONE) {
      return null_mask;
    }
    return encoder->rep->encode_to_one_hot(tt, values);
  }
  void add(TOKEN_TYPE trigger_tt, vector<int> trigger_values, TOKEN_TYPE mask_tt, vector<int> mask_values) {
    controls.push_back( std::make_pair(
      encode(trigger_tt, trigger_values), encode(mask_tt, mask_values)) );
  }
  void start(int n) {
    pos = std::vector<int>(n,0);
  }
  void show_mask(vector<int> mask) {
    for (const auto v : mask) {
      if (v==0) { cout << "-"; }
      else { cout << "x"; }
    }
    cout << endl;
  }
  void show() {
    for (const auto control : controls) {
      cout << "TRIGGER : ";
      show_mask(get<0>(control));
      cout << "MASK : ";
      show_mask(get<1>(control));
    }
  }
  bool is_finished(int index) {
    if (index >= pos.size()) {
      throw runtime_error("CONTROL INDEX OUT OF RANGE");
    }
    return pos[index] >= controls.size();
  }
  bool all_finished() {
    for (int i=0; i<pos.size(); i++) {
      if (!is_finished(i)) {
        return false;
      }
    }
    return true;
  }
  bool check_trigger(int index, int value) {
    if (is_finished(index)) {
      return false; // no more triggers
    }
    return (get<0>(controls[pos[index]])[value] == 1);
  }
  std::vector<int> get_mask(int index) {
    return get<1>(controls[pos[index]]);
  }
  void increment(int index) {
    pos[index]++;
  }
  std::vector<std::tuple<std::vector<int>,std::vector<int>>> controls;
  ENCODER *encoder;
  vector<int> pos;
  vector<int> null_trigger; // all zeros
  vector<int> null_mask; // all ones
};

// sample config for the common preferences
class SampleConfig {
public:
  SampleConfig() {
    batch_size = 1;
    temperature = 1.;
    verbose = false;
    max_steps = 0; // no limit
    num_bars = 4;
    encoder_type = NO_ENCODER;
    encoder = NULL;
    ckpt_path = "";
    model_type = TRACK_MODEL;
  }
  int batch_size;
  float temperature;
  bool verbose;
  int max_steps;
  int num_bars;
  MODEL_TYPE model_type;
  ENCODER_TYPE encoder_type;
  ENCODER *encoder;
  string ckpt_path;
  vector<int> order;
  vector<int> prompt;
  Control control;
};

// helper for generating interleaved
void prepare_generate_interleaved(std::vector<tuple<int,string,int>> &tracks, SampleConfig *sc) {

  if (sc->verbose) {
    cout << "GENERATING " << sc->num_bars << " BARS (INTERLEAVED)" << endl;
  }

  sc->control.initialize(sc->encoder);

  sc->control.add(PIECE_START, {0}, HEADER, {0});
  int track_num = 0;
  for (const auto track : tracks) {
    vector<int> insts = GM[get<1>(track)];
    if (track_num==0) {
      sc->control.add(HEADER, {0}, INSTRUMENT, insts);
    }
    else {
      sc->control.add(INSTRUMENT, {-1}, INSTRUMENT, insts);
    }
    track_num++;
  }
  sc->control.add(INSTRUMENT, {-1}, HEADER, {1});
  sc->control.add(HEADER, {1}, BAR, {0}); // start with bar_start

  // make sure we generate the right number of bars
  for (int i=0; i<sc->num_bars; i++) {
    sc->control.add(BAR, {0}, NONE, {});
  }
  sc->control.add(BAR_END, {0}, NONE, {});

  if (sc->verbose) {
    sc->control.show();
  }

  sc->prompt.push_back( sc->encoder->rep->encode(PIECE_START,0) );

}


void prepare_generate_bars(std::vector<std::tuple<int,int>> &bars, midi::Piece *p, SampleConfig *sc) {

  std::vector<int> prompt;
  if (p) {
    std::set<std::tuple<int,int>> barset;
    std::copy(bars.begin(), bars.end(), std::inserter(barset, barset.end()));
    sc->encoder->config->do_multi_fill = true;
    sc->encoder->config->multi_fill = barset;
    sc->encoder->config->num_bars = sc->num_bars;
    prompt = sc->encoder->encode(p);

    // remove everything after fill_start token
    int fill_start = sc->encoder->rep->encode(FILL_IN,1);
    for (int index=0; index<prompt.size(); index++) {
      if (prompt[index] == fill_start) {
        prompt.resize(index+1);
        break;
      }
    }
  }
  else {
    throw runtime_error("MUST PROVIDE midi::Piece FOR BAR INFILL MODE");
    prompt.push_back( sc->encoder->rep->encode(PIECE_START,0) );
  }

  sc->control.initialize(sc->encoder);
  for (int i=0; i<bars.size(); i++) {
    sc->control.add(FILL_IN, {2}, NONE, {});
  }
  if (sc->verbose) {
    sc->control.show();
  }

  // copy prompt and control into sample config
  sc->prompt = prompt;
}

void prepare_generate_tracks(std::vector<tuple<int,string,int>> &tracks, midi::Piece *p, SampleConfig *sc) {
    
  sc->control.initialize(sc->encoder);

  int track_num = 0;
  bool piece_is_empty = p->tracks_size()==0;
  for (const auto track : tracks) {
    vector<int> insts = GM[get<1>(track)];
    int track_type = get<0>(track);
    int density = get<2>(track);

    if (track_type == STANDARD_BOTH) {
      track_type = -1;
    }

    if (sc->verbose) {
      cout << "GENERATING : " << track_type << " with density " << density << endl;
    }
    
    for (int i=0; i<insts.size(); i++) {
      insts[i] %= 128;
    }
    if ((track_num == 0) && (piece_is_empty)) {
      sc->control.add(PIECE_START, {0}, TRACK, {track_type});
    }
    else {
      sc->control.add(TRACK_END, {0}, TRACK, {track_type});
    }
    sc->control.add(TRACK, {-1}, INSTRUMENT, insts);

    // density control
    if (density >= 0) {
      sc->control.add(INSTRUMENT, {-1}, DENSITY_LEVEL, {density});
    }
    
    track_num++;
  }
  sc->control.add(TRACK_END, {0}, NONE, {});

  if (sc->verbose) {
    sc->control.show();
  }

  vector<int> prompt;
  if (!piece_is_empty) {
    prompt = sc->encoder->encode(p);
  }
  else {
    prompt.push_back( sc->encoder->rep->encode(PIECE_START,0) );
  }

  // copy prompt and control into sample config
  sc->prompt = prompt;
}


//tuple<vector<int>,Control,ENCODER_TYPE,string,vector<int>> 
void prepare_generate(midi::Status *status, midi::Piece *piece, SampleConfig *sc, map<tuple<int,MODEL_TYPE>,tuple<ENCODER_TYPE,string>> &ckpt_map) {

  validate_status(status, piece);
  update_has_notes(piece);

  vector<tuple<int,string,int>> tracks;
  vector<tuple<int,int>> bars;
  int num_cond_tracks = 0;
  int num_resample_tracks = 0;
  int num_infill_tracks = 0;
  vector<STATUS_TRACK_TYPE> track_types;
  vector<int> order;
  vector<int> cond_tracks;

  int track_num = 0;
  for (const auto track : status->tracks()) {
    STATUS_TRACK_TYPE tt = infer_track_type(track);
    switch( tt ) {
      case CONDITION :
        order.push_back( num_cond_tracks );
        cond_tracks.push_back( track.track_id() );
        num_cond_tracks++;
        break;
      case RESAMPLE : 
        order.push_back( num_resample_tracks );
        tracks.push_back(make_tuple(
          track.track_type(), track.instrument(), track.density()));
        num_resample_tracks++;
        break;
      case INFILL :     
        num_infill_tracks++;
        break;
    }
    track_types.push_back( tt );
    int bar_num = 0;
    for (const auto selected : track.selected_bars()) {
      if (selected) {
        bars.push_back( make_pair(track_num, bar_num) );
      }
      bar_num++;
    }
    track_num++;
  }

  // provide overview of tracks for sampling
  if (sc->verbose) {
    int track_num = 0;
    for (const auto track_type : track_types) {
      cout << "TRACK " << track_num << " -> " << track_type << endl;
      track_num++;
    }
  }

  int num_bars = status->tracks(0).selected_bars_size();

  tuple<int,MODEL_TYPE> model_key = make_tuple(num_bars,TRACK_MODEL);
  if (num_infill_tracks > 0) {
    model_key = make_tuple(num_bars,BAR_INFILL_MODEL);
  }
  tuple<ENCODER_TYPE,string> ckpt_info = ckpt_map[model_key];

  //SampleConfig sc;
  //sc.temperature = temp;
  //sc.batch_size = batch_size;
  //sc.verbose = verbose;

  sc->num_bars = num_bars;
  sc->encoder_type = get<0>(ckpt_info);
  sc->encoder = getEncoder(sc->encoder_type);
  sc->ckpt_path = get<1>(ckpt_info);

  if (sc->encoder_type == TRACK_INTERLEAVED_W_HEADER_ENCODER) {
    prepare_generate_interleaved(tracks, sc);
  }
  else if (num_infill_tracks > 0) {

    if (sc->verbose) {
      cout << "GENERATING " << bars.size() << " BARS" << endl;
    }

    sc->encoder->config->do_multi_fill = true;

    // remove excess bars if any
    prune_tracks_dev2(
      piece, arange(0,piece->tracks_size(),1), arange(0,sc->num_bars,1));

    return prepare_generate_bars(bars, piece, sc);

  }
  else {

    if (sc->verbose) {
      cout << "GENERATING " << num_resample_tracks << " TRACKS" << endl;
    }

    sc->encoder->config->do_multi_fill = false;

    // fix the order
    // order is the output position for each track
    for (track_num=0; track_num<status->tracks_size(); track_num++) {
      if (track_types[track_num] == RESAMPLE) {
        order[track_num] = order[track_num] + num_cond_tracks;
      }
    }
    vector<int> inverse_order(order.size(),0);
    for (int i=0; i<order.size(); i++) {
      inverse_order[order[i]] = i;
    }

    sc->order = inverse_order;

    // prune unneeded tracks
    prune_tracks_dev2(piece, cond_tracks, arange(0,sc->num_bars,1));

    // call generation
    return prepare_generate_tracks(tracks, piece, sc);

    // reorder the tracks
    //for (int i=0; i<output.size(); i++) {
    //  reorder_tracks(&(output[i]), inverse_order);
    //}

  }
}

tuple<vector<int>,Control,ENCODER_TYPE,string,vector<int>> prepare_generate_py(string &status_str, string &piece_str, float temp, int batch_size, bool verbose, map<tuple<int,MODEL_TYPE>,tuple<ENCODER_TYPE,string>> &ckpt_map) {
  midi::Piece piece;
  google::protobuf::util::JsonStringToMessage(piece_str.c_str(), &piece);
  midi::Status status;
  google::protobuf::util::JsonStringToMessage(status_str.c_str(), &status);
  // make a sample config
  SampleConfig sc;
  sc.temperature = temp;
  sc.batch_size = batch_size;
  sc.verbose = verbose;
  prepare_generate(&status, &piece, &sc, ckpt_map);
  return make_tuple(sc.prompt, sc.control, sc.encoder_type, sc.ckpt_path, sc.order);
}

