#pragma once

#include <torch/script.h> // One-stop header.
#include <torch/nn/functional/activation.h>

#include <iostream>
#include <string>
#include <vector>
#include <array>
#include <set>

#include "encoder/encoder_all.h"
#include "enum/model_type.h"
#include "sample_fix.h"

using namespace std;

void load_model(string &ckpt_path, torch::jit::Module *m) {
  try {
    *m = torch::jit::load(ckpt_path);
  }
  catch (const c10::Error& e) {
    throw runtime_error("ERROR LOADING MODEL.");
  }
}

void sample_inner(Control *ctrl, vector<vector<int>> &seqs, float temperature, torch::jit::Module *model, vector<torch::jit::IValue> &inputs, ENCODER *encoder, int *num_bars_sampled) {
  auto outputs = model->forward(inputs).toTuple();
  auto logits = outputs->elements()[0].toTensor().index({torch::indexing::Slice(),-1,torch::indexing::Slice()});
  auto past_key_values = outputs->elements()[1];

  // override values in next_tokens if necessary
  // set avoided values to a very small value 

  for (int i=0; i<seqs.size(); i++) {
    if ((seqs[i].size() > 0) && (ctrl->check_trigger(i, seqs[i].back()))) {
      vector<int> mask = ctrl->get_mask(i);
      for (int j=0; j<mask.size(); j++) {
        if (mask[j] == 0) {
          // set this to a very small possibility
          logits[i][j] = -1 * std::numeric_limits<float>::max();
        }
      }
      ctrl->increment(i);
    }
  }

  namespace F = torch::nn::functional;
  auto probs = (logits / temperature).softmax(1);
  auto next_tokens = probs.multinomial(1);

  inputs.clear();
  inputs.push_back( next_tokens );
  inputs.push_back( past_key_values );

  // add next token to the sequences
  for (int i=0; i<seqs.size(); i++) {
    if (!ctrl->is_finished(i)) {
      seqs[i].push_back( next_tokens[i][0].item<int64_t>() );
    }
  }

  if (encoder->rep->is_token_type(seqs[0].back(),BAR_END)) {
    (*num_bars_sampled)++;
  }
}

vector<vector<int>> sample(SampleConfig *sc) {

  torch::jit::Module m;
  load_model(sc->ckpt_path, &m);

  int num_bars_sampled = 0;

  if (sc->verbose) {
    cout << "PROMPT BEGIN" << endl;
    cout << "PROMPT LENGTH : " << sc->prompt.size() << endl;
    for (const auto token : sc->prompt) {
      cout << sc->encoder->rep->pretty(token) << endl;
    }
    cout << "PROMPT END" << endl;
  }
  
  vector<torch::jit::IValue> inputs;

  // create the inputs by tileing prompt
  auto opts = torch::TensorOptions().dtype(torch::kInt64);
  torch::Tensor x = torch::zeros({sc->batch_size,(int)sc->prompt.size()},opts);
  for (int k=0; k<sc->batch_size; k++) {
    for (int i=0; i<sc->prompt.size(); i++) {
      x[k][i] = sc->prompt[i];
    }
  }
  inputs.push_back( x );

  // create empty state
  // TODO :: infer the rest of the state dimensions from the model
  std::vector<torch::jit::IValue> state;
  for (int i=0; i<NUM_LAYERS; i++) {
    state.push_back(torch::zeros({2, sc->batch_size, 8, 0, 64}));
  }
  inputs.push_back( torch::ivalue::Tuple::create(state) );

  // create empty sequnces
  vector<vector<int>> seqs;
  for (int k=0; k<sc->batch_size; k++) {
    seqs.push_back( sc->prompt );
  }

  sc->control.start(sc->batch_size); // initialize the control
  int num_steps = 0;
  while (!sc->control.all_finished()) {
    sample_inner(&sc->control, seqs, sc->temperature, &m, inputs, sc->encoder, &num_bars_sampled);
    num_steps++;
    if (sc->verbose) {
      cout << num_steps << " | "; 
      for (int i=0; i<seqs.size(); i++) {
        cout << sc->encoder->rep->pretty(seqs[i].back()) << " ";
      }
      cout << endl;
    }
    if ((sc->max_steps > 0) && (num_steps >= sc->max_steps)) {
      break;
    }
  }

  return seqs;
}

vector<midi::Piece> generate(midi::Status *status, midi::Piece *piece, float temp, int batch_size, bool verbose, map<tuple<int,MODEL_TYPE>,tuple<ENCODER_TYPE,string>> &ckpt_map, int max_steps=0) {

  // returns prompt control, encoder_type, ckpt_path, order
  SampleConfig sc;
  sc.temperature = temp;
  sc.batch_size = batch_size;
  sc.verbose = verbose;
  prepare_generate(status, piece, &sc, ckpt_map);
  
  torch::jit::Module m;
  load_model(sc.ckpt_path, &m);

  int num_bars_sampled = 0;

  if (verbose) {
    cout << "PROMPT BEGIN" << endl;
    cout << "PROMPT LENGTH : " << sc.prompt.size() << endl;
    for (const auto token : sc.prompt) {
      cout << sc.encoder->rep->pretty(token) << endl;
    }
    cout << "PROMPT END" << endl;
  }
  
  vector<torch::jit::IValue> inputs;

  // create the inputs by tileing prompt
  auto opts = torch::TensorOptions().dtype(torch::kInt64);
  torch::Tensor x = torch::zeros({batch_size, (int)sc.prompt.size()}, opts);
  for (int k=0; k<batch_size; k++) {
    for (int i=0; i<sc.prompt.size(); i++) {
      x[k][i] = sc.prompt[i];
    }
  }
  inputs.push_back( x );

  // create empty state
  // TODO :: infer the rest of the state dimensions from the model
  std::vector<torch::jit::IValue> state;
  for (int i=0; i<NUM_LAYERS; i++) {
    state.push_back(torch::zeros({2, batch_size, 8, 0, 64}));
  }
  inputs.push_back( torch::ivalue::Tuple::create(state) );

  // create empty sequnces
  vector<vector<int>> seqs;
  for (int k=0; k<batch_size; k++) {
    seqs.push_back( sc.prompt );
  }

  sc.control.start(batch_size); // initialize the control
  int num_steps = 0;
  while (!sc.control.all_finished()) {
    sample_inner(&sc.control, seqs, temp, &m, inputs, sc.encoder, &num_bars_sampled);
    num_steps++;
    if (verbose) {
      cout << num_steps << " | "; 
      for (int i=0; i<seqs.size(); i++) {
        cout << sc.encoder->rep->pretty(seqs[i].back()) << " ";
      }
      cout << endl;
    }
    if ((max_steps > 0) && (num_steps >= max_steps)) {
      break;
    }
  }

  // convert back to piece
  vector<midi::Piece> output(batch_size);
  sc.encoder->tokens_to_json_array(seqs, output);
  return output;
}