#include <string>
#include <vector>
#include <array>
#include <set>

#include <torch/script.h> // One-stop header.
#include <torch/nn/functional/activation.h>

using namespace std;

const int NUM_LAYERS = 6;

class ControlWrap {
public:
  ControlWrap(vector<tuple<vector<int>,vector<int>>> x) {
    controls = x;
  }
  void start(int n) {
    pos = std::vector<int>(n,0);
  }
  void show_mask(vector<int> mask) {
    for (const auto v : mask) {
      if (v==0) { cout << "-"; }
      else { cout << "x"; }
    }
    cout << endl;
  }
  void show() {
    for (const auto control : controls) {
      cout << "TRIGGER : ";
      show_mask(get<0>(control));
      cout << "MASK : ";
      show_mask(get<1>(control));
    }
  }
  bool is_finished(int index) {
    if (index >= pos.size()) {
      throw runtime_error("CONTROL INDEX OUT OF RANGE");
    }
    return pos[index] >= controls.size();
  }
  bool all_finished() {
    for (int i=0; i<pos.size(); i++) {
      if (!is_finished(i)) {
        return false;
      }
    }
    return true;
  }
  bool check_trigger(int index, int value) {
    if (is_finished(index)) {
      return false; // no more triggers
    }
    return (get<0>(controls[pos[index]])[value] == 1);
  }
  std::vector<int> get_mask(int index) {
    return get<1>(controls[pos[index]]);
  }
  void increment(int index) {
    pos[index]++;
  }
  vector<tuple<vector<int>,vector<int>>> controls;
  vector<int> pos;
};


void sample_inner(ControlWrap *ctrl, vector<vector<int>> &seqs, float temperature, torch::jit::Module *model, std::vector<torch::jit::IValue> &inputs) {
  auto outputs = model->forward(inputs).toTuple();
  auto logits = outputs->elements()[0].toTensor().index({torch::indexing::Slice(),-1,torch::indexing::Slice()});
  auto past_key_values = outputs->elements()[1];

  // override values in next_tokens if necessary
  // set avoided values to a very small value 

  for (int i=0; i<seqs.size(); i++) {
    if ((seqs[i].size() > 0) && (ctrl->check_trigger(i, seqs[i].back()))) {
      vector<int> mask = ctrl->get_mask(i);
      for (int j=0; j<mask.size(); j++) {
        if (mask[j] == 0) {
          // set this to a very small possibility
          logits[i][j] = -1 * std::numeric_limits<float>::max();
        }
      }
      ctrl->increment(i);
    }
  }

  namespace F = torch::nn::functional;
  auto probs = (logits / temperature).softmax(1);
  auto next_tokens = probs.multinomial(1);

  inputs.clear();
  inputs.push_back( next_tokens );
  inputs.push_back( past_key_values );

  // add next token to the sequences
  for (int i=0; i<seqs.size(); i++) {
    if (!ctrl->is_finished(i)) {
      seqs[i].push_back( next_tokens[i][0].item<int64_t>() );
    }
  }
}

void sample_from_model(string &ckpt_path, vector<int> &prompt, vector<tuple<vector<int>,vector<int>>> control, int batch_size, float temperature) {

  ControlWrap ctrl(control);

  torch::jit::Module model;
  try {
    model = torch::jit::load(ckpt_path);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return;
  }

  // prepare inputs
  std::vector<torch::jit::IValue> inputs;
  auto opts = torch::TensorOptions().dtype(torch::kInt64);
  torch::Tensor x = torch::zeros({batch_size, (int)prompt.size()}, opts);
  for (int k=0; k<batch_size; k++) {
    for (int i=0; i<prompt.size(); i++) {
      x[k][i] = prompt[i];
    }
  }
  inputs.push_back( x );

  // create empty state
  std::vector<torch::jit::IValue> state;
  for (int i=0; i<NUM_LAYERS; i++) {
    state.push_back(torch::zeros({2, batch_size, 8, 0, 64}));
  }
  inputs.push_back( torch::ivalue::Tuple::create(state) );

  // create empty sequences
  vector<vector<int>> seqs;
  for (int k=0; k<batch_size; k++) {
    seqs.push_back( prompt );
  }

  ctrl.start(batch_size); // initialize the control
  int num_steps = 0;
  while (!ctrl.all_finished()) {
    sample_inner(&ctrl, seqs, temperature, &model, inputs);
    num_steps++;
  }
}