/*---------------------------------------------------------------------------*/
// Baseline Wavelet Transform Coder Construction Kit
//
// Geoff Davis
// gdavis@cs.dartmouth.edu
// http://www.cs.dartmouth.edu/~gdavis
//
// Copyright 1996 Geoff Davis 9/11/96
//
// Permission is granted to use this software for research purposes as
// long as this notice stays attached to this software.
//
/*---------------------------------------------------------------------------*/
#include <stdio.h>
#include <iostream.h>
#include <math.h>
#include "global.h"
#include "entropy.h"
/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/
EntropyCoder::EntropyCoder (int histoCapacity) :
                            histoCapacity (histoCapacity)
{
}

/*---------------------------------------------------------------------------*/
MonoLayerCoder::MonoLayerCoder (int histoCapacity) :
                                EntropyCoder (histoCapacity)
{
}

/*---------------------------------------------------------------------------*/
MultiLayerCoder::MultiLayerCoder (int histoCapacity) :
                                  EntropyCoder (histoCapacity)
{
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/
EscapeCoder::EscapeCoder (int histoCapacity) : 
                          MonoLayerCoder (histoCapacity)
{
  freq = uniform = NULL;
  seen = NULL;
}

/*---------------------------------------------------------------------------*/

EscapeCoder::EscapeCoder (int histoCapacity, int nSym) : 
                          MonoLayerCoder (histoCapacity)
{
  seen = NULL;
  setNSym (nSym);
}

/*---------------------------------------------------------------------------*/
EscapeCoder::~EscapeCoder ()
{
  if (seen != NULL) {
    delete [] seen;
    delete freq;
    delete uniform;
  }
}

/*---------------------------------------------------------------------------*/

void EscapeCoder::setNSym (int newNSym)
{
  if (seen != NULL) {
    delete [] seen;
    delete freq;
    delete uniform;
  }

  nSym = newNSym;

  freq = new iHistogram (nSym+1, histoCapacity);
  uniform = new iHistogram (nSym, histoCapacity);

  seen = new char [nSym];
  for (int i = 0; i < nSym; i++)
    seen[i] = FALSE;

  reset ();
}

/*---------------------------------------------------------------------------*/
// write symbol to encoder, update histogram if update flag set, return cost
Real EscapeCoder::write (Encoder *encoder, int symbol, char update, 
			   int context1, int context2)
{
  const int Escape = nSym;
  Real bits;

  context1 = context2 = -1;  // prevents warning from compiler

  if (seen[symbol]) {
    bits = freq->Entropy(symbol);
    if (encoder != NULL)
      encoder->writeSymbol (symbol, freq);
  } else {
    if (encoder != NULL) {
      encoder->writeSymbol (Escape, freq);
      encoder->writeSymbol (symbol, uniform);
    }
    bits = freq->Entropy(Escape) + uniform->Entropy(symbol);

    if (update)
      seen[symbol] = TRUE;
  }
  if (update)
    freq->IncCount(symbol, FALSE);

  return bits;
}

/*---------------------------------------------------------------------------*/
// read symbol from decoder, update histogram if update flag set
int EscapeCoder::read (Decoder *decoder, char update, 
		       int context1, int context2, char missing)
{
  const int Escape = nSym;
  int symbol;

  context1 = context2 = -1;  // prevents warning from compiler

  symbol = decoder->readSymbol (freq);
  assert (symbol >= 0 && symbol <= nSym);
  if (symbol == Escape) {
    symbol = decoder->readSymbol (uniform);
    assert (symbol >= 0 && symbol < nSym);
  }

  if (update)
    freq->IncCount(symbol, FALSE);

  return symbol;
}

/*---------------------------------------------------------------------------*/
Real EscapeCoder::cost (int symbol, char update, int context1, int context2)
{
  const int Escape = nSym;
  Real bits;

  context1 = context2 = -1;  // prevents warning from compiler

  if (seen[symbol]) {
    bits = freq->Entropy(symbol);

  } else {
    bits = freq->Entropy(Escape) + uniform->Entropy(symbol);

    if (update)
      seen[symbol] = TRUE;
  }
  if (update)
    freq->IncCount(symbol, FALSE);

  return bits;
}

/*---------------------------------------------------------------------------*/
void EscapeCoder::reset (int lowestResetLayer)
{
  int *temp = new int [nSym+1];
  int i;

  for (i = 0; i < nSym; i++) // no symbols observed -- set all counts to 0
    temp[i] = 0;
  temp[nSym] = 1;                // except for escape symbol
  freq->InitCounts (temp);

  for (i = 0; i < nSym; i++) // uniform histogram -- set all counts to 1
    temp[i] = 1;
  uniform->InitCounts (temp);

  delete [] temp;
}

/*---------------------------------------------------------------------------*/

// ---------------------------- Added -------------------------------------
void EscapeCoder::writeHead (Encoder *encoder, int precision)
{
/*  int i, j, k, n;
  
  encoder->writeInt(signedSym);
  encoder->writeInt(nLayers);
  for (i = 0; i < nLayers; i++)
  {
    encoder->writeInt(nFreq[i]);
    for (j = 0; j < nFreq[i]; j++)
    {
      encoder->writeInt(freq[i][j]->TotalCount());
      encoder->writeInt(freq[i][j]->readMaxCt());
      n = freq[i][j]->readNsyms();
      encoder->writeInt(n);
      for (k = 0; k < n; k++)
      {
        encoder->writeInt(freq[i][j]->readCount(k));
        encoder->writeInt(freq[i][j]->readTreeCount(k));
        encoder->writeInt(freq[i][j]->readSymToPos(k));
        encoder->writeInt(freq[i][j]->readPosToSym(k));
      }
    }   
  }
  */
}

/*---------------------------------------------------------------------------*/

void EscapeCoder::readHead (Decoder *decoder, int &precision, int subband)
{
/*  int i, j, k, n;
  
  signedSym = decoder->readInt ();
  nLayers = decoder->readInt ();
  for (i = 0; i < nLayers; i++)
  {
    nFreq[i] = decoder->readInt ();
    for (j = 0; j < nFreq[i]; j++)
    {
      freq[i][j]->writeTotalCount(decoder->readInt ());
      freq[i][j]->writeMaxCt(decoder->readInt ());
      n = decoder->readInt ();
      freq[i][j]->writeNsyms(n);
      for (k = 0; k < n; k++)
      {
        freq[i][j]->writeCount(decoder->readInt (), k);
        freq[i][j]->writeTreeCount(decoder->readInt (), k);
        freq[i][j]->writeSymToPos(decoder->readInt (), k);
        freq[i][j]->writePosToSym(decoder->readInt (), k);
      }
    }  
  }
  */
}


void EscapeCoder::reinitialize (int subband)
{
}
// ------------------ End of added --------------------------------------

/*---------------------------------------------------------------------------*/
LayerCoder::LayerCoder (int nLayers, int signedSym, int capacity) :
  MultiLayerCoder (capacity), signedSym(signedSym), nLayers(nLayers)

{
  int i, j;
  nFreq = new int [nLayers];
  freq = new iHistogram** [nLayers];

  if (signedSym)
    nFreq[0] = 3;
  else
    nFreq[0] = 2;
  freq[0] = new iHistogram* [nFreq[0]];

  for (i = 1; i < nLayers; i++) {
    if (signedSym)
      nFreq[i] = 2 * nFreq[i-1] + 1;
    else 
      nFreq[i] = 2 * nFreq[i-1];
    freq[i] = new iHistogram* [nFreq[i]];
  }

  for (i = 0; i < nLayers; i++) {
    for (j = 0; j < nFreq[i]; j++) {
      freq[i][j] = new iHistogram (3, capacity);
    }
  }

  reset (0);
  // ----------------------- Added --------------------------------------
  previousSymbol = 1; // 1 is default value, zero-symbol ??
  // ------------------ End of added ------------------------------------
}

/*---------------------------------------------------------------------------*/
LayerCoder::~LayerCoder ()
{
  for (int i = 0; i < nLayers; i++) {
    for (int j = 0; j < nFreq[i]; j++) {
      delete freq[i][j];
    }
    delete [] freq[i];
  }
  delete [] freq;
  delete [] nFreq;
}

/*---------------------------------------------------------------------------*/
// for signed symbol coder, symbols will be
//   -1, 0, 1  for context = 0
//    0, 1     for context > 0
//    -1, 0    for context < 0
//   1 will be added to all incoming symbols
// for unsigned coder, symbols will be
//    0, 1  -- 1 will be added to all incoming symbols

void LayerCoder::reset (int lowestResetLayer)
{
  int plusCounts[3] =  {0, 1, 1};
  int zeroCounts[3] =  {1, 1, 1};
  int i, j;

  if (signedSym) {
    for (i = lowestResetLayer; i < nLayers; i++) {
      for (j = -nFreq[i]/2; j < 0; j++)
	freq[i][j+nFreq[i]/2]->InitCounts (plusCounts);
	//	freq[i][j+nFreq[i]/2]->InitCounts (minus_counts);
      freq[i][nFreq[i]/2]->InitCounts (zeroCounts);
      for (j = 1; j <= nFreq[i]/2; j++)
	freq[i][j+nFreq[i]/2]->InitCounts (plusCounts);
    }
  } else {
    for (i = lowestResetLayer; i < nLayers; i++) {
      for (j = 0; j < nFreq[i]; j++) {
	freq[i][j]->InitCounts (plusCounts);
      }
    }
  }
}

// --------------------------- Added ---------------------------------------
void LayerCoder::reinitialize (int subband)
{
  int i, j, k;


  if (subband == 0) 
  {
    for (i = 2; i < nLayers; i++) 
      for (j = 0; j < nFreq[i]; j++)
        for (k = 0; k < 3; k++)
	{
	  freq[i][j]->writeCount (cnt[j%4][k], k);
	  freq[i][j]->writeSymToPos (stp[j%4][k], k);
	  freq[i][j]->writePosToSym (pts[j%4][k], k);
	}
    
  } 
  else
  { 
    for (i = 0; i < nLayers; i++) 
      for (j = 0; j < nFreq[i]; j++) 
        for (k = 0; k < 3; k++)
	{
	  if (j%2)
	  {
	    freq[i][j]->writeCount (cnt[1][k], k);
	    freq[i][j]->writeSymToPos (stp[1][k], k);
	    freq[i][j]->writePosToSym (pts[1][k], k);
	  }
	  else
	  {
	    freq[i][j]->writeCount (cnt[j%4][k], k);
	    freq[i][j]->writeSymToPos (stp[j%4][k], k);
	    freq[i][j]->writePosToSym (pts[j%4][k], k);
	  }
        }
    
  }
}
  

// ------------------------- End of added ----------------------------------

/*---------------------------------------------------------------------------*/
// write symbol to encoder, update histogram if update flag set, return cost
Real LayerCoder::write (Encoder *encoder, int symbol, char update, 
			  int layer, int context)
{
  symbol++;
  if (signedSym) {
    context += nFreq[layer]/2;
  }
  if (encoder != NULL)
    encoder->writeSymbol (symbol, freq[layer][context]);
  Real bits = freq[layer][context]->Entropy(symbol);
  
  if (update)
    freq[layer][context]->IncCount(symbol, FALSE);
  return bits;
}

/*---------------------------------------------------------------------------*/
// read symbol from decoder, update histogram if update flag set
int LayerCoder::read  (Decoder *decoder, char update, int layer, 
                       int context, char missing)
{
  int symbol;

  if (signedSym) {
    context += nFreq[layer]/2;
  }

  if (!missing)
    { 
      symbol = decoder->readSymbol (freq[layer][context]); 
      //printf("%d ", symbol); 
    }
  else 
    { 
      //symbol = 1;
      symbol = previousSymbol;     // can be 0, 1 or 2
      //printf("%d ", symbol);
    }
      
  if (update)
    freq[layer][context]->IncCount(symbol, missing);

  previousSymbol = symbol;
  symbol--;
  return symbol;
}

/*---------------------------------------------------------------------------*/
Real LayerCoder::cost (int symbol, char update, int layer, int context)
{
  symbol++;
  if (signedSym) {
    context += nFreq[layer]/2;
  }
  Real bits = freq[layer][context]->Entropy(symbol);
  
  if (update)
    freq[layer][context]->IncCount(symbol, FALSE);
  return bits;
}
/*---------------------------------------------------------------------------*/


// ---------------------------- Added -------------------------------------
void LayerCoder::writeHead (Encoder *encoder, int precision)
{
  int i, j, k, n;
  
  //encoder->writeInt(signedSym);
  //printf("%d ", signedSym);
  //encoder->writeInt(nLayers);
  //printf("%d ", nLayers);
  //printf("\n");
  for (i = 0; i < precision; i++)
  {
    //encoder->writeInt(nFreq[i]);
    //printf("%d ", nFreq[i]);
    //printf("\n");
    for (j = 0; j < nFreq[i]; j++)
    {
      //encoder->writeInt(freq[i][j]->TotalCount());
      //if (nFreq[i] < 10)
      //   printf("%d ", freq[i][j]->readMaxCt());
      //encoder->writeInt(freq[i][j]->readMaxCt());
      //n = freq[i][j]->readNsyms();
      n = 3;
      //if (nFreq[i] < 3) printf("%d ", n);
      //encoder->writeInt(n);
      //printf("\n");
      
                   
        for (k = 0; k < n; k++)
        {
          encoder->writeNonneg(freq[i][j]->readCount(k));
          //encoder->writeNBits(9,freq[i][j]->readCount(k)); // 2^9=512=maxCt
                                                           // not good
          //encoder->writeInt(freq[i][j]->readTreeCount(k));
        }

/*  Do not write SymToPos and PosToSym
        // write first two elements of symToPos to header
        for (k = 0; k < n-1; k++)
        {  
          //encoder->writeNonneg(freq[i][j]->readSymToPos(k));
          encoder->writeNBits(2,freq[i][j]->readSymToPos(k));
        }
      
        // write posToSym to header
        if (freq[i][j]->readSymToPos(0) == freq[i][j]->readPosToSym(0) &&
            freq[i][j]->readSymToPos(1) == freq[i][j]->readPosToSym(1) &&
            freq[i][j]->readSymToPos(2) == freq[i][j]->readPosToSym(2) )
               encoder->writeNBits(2,3); // if same as symToPos, write '11'
        else for (k = 0; k < n-1; k++)   // else write first two
             { 
               //encoder->writeNonneg(freq[i][j]->readPosToSym(k));
               encoder->writeNBits(2,freq[i][j]->readPosToSym(k));
        
             }
*/             
    }
  }
}

/*---------------------------------------------------------------------------*/

void LayerCoder::readHead (Decoder *decoder, int &precision, int subband)
{
  int i, j, k, n, temp, tempv[3];
  
  //signedSym = decoder->readInt ();
  //signedSym = (subband != 0);
  //nLayers = decoder->readInt ();
  //nLayers = 10; //??
  for (i = 0; i < precision; i++)
  {
    //nFreq[i] = decoder->readInt ();
    for (j = 0; j < nFreq[i]; j++)
    {
      
      //freq[i][j]->writeTotalCount(decoder->readInt ());
      //freq[i][j]->writeMaxCt(decoder->readInt ());
      //freq[i][j]->writeMaxCt(512); //??
      //n = decoder->readInt ();
      n = 3;
      //freq[i][j]->writeNsyms(n);
      
      // read count
      for (k = 0; k < n; k++)
      {
        freq[i][j]->writeCount(decoder->readNonneg (), k);
        //freq[i][j]->writeCount(decoder->readNBits (9), k);
        
        //freq[i][j]->writeTreeCount(decoder->readInt (), k);
      }

/* Do not read SymToPos and PosToSym      
      tempv[0] = 0; tempv[1] = 0; tempv[2] = 0;
      // read first two elements of symToPos
      for (k = 0; k < n-1; k++)
      { 
        //freq[i][j]->writeSymToPos(decoder->readNonneg (), k);
        temp = decoder->readNBits (2);
        freq[i][j]->writeSymToPos(temp, k);
        tempv[temp] = 1;
      }
      // the last element of symToPos
      for (k = 0; k < n; k++)
        if (tempv[k] == 0) freq[i][j]->writeSymToPos(k, 2);
      
      // now posToSym
      temp = decoder->readNBits (2);
      if (temp == 3)    // same as symToPos
        for (k = 0; k < n; k++)
          freq[i][j]->writePosToSym(freq[i][j]->readSymToPos(k), k); 
      else              // temp was the first elem of posToSym
      {
        freq[i][j]->writePosToSym(temp, 0);
        tempv[0] = 0; tempv[1] = 0; tempv[2] = 0;
        tempv[temp] = 1;
        for (k = 1; k < n-1; k++)
        {
          //freq[i][j]->writePosToSym(decoder->readNonneg (), k);
          temp = decoder->readNBits (2);
          freq[i][j]->writePosToSym(temp, k);
          tempv[temp] = 1;
        }
        // the last element of posToSym
        for (k = 0; k < n; k++)
          if (tempv[k] == 0) freq[i][j]->writePosToSym(k, 2);
      }  
*/        
      //if (subband < 3)         // LL, HL and LH at lowest level of decomp.
        freq[i][j]->Initialize();
    }  
  }
}

int LayerCoder::getCnt(int lyr, int cntxt, int num)
{ 
  return freq[lyr][cntxt]->readCount(num); 
}

int LayerCoder::getStp(int lyr, int cntxt, int num)
{ 
  return freq[lyr][cntxt]->readSymToPos(num); 
}

int LayerCoder::getPts(int lyr, int cntxt, int num)
{ 
  return freq[lyr][cntxt]->readPosToSym(num); 
}

void LayerCoder::setCnt(int cntxt, int pos, int value)
{ 
  cnt[cntxt][pos] = value; 
}

void LayerCoder::setStp(int cntxt, int pos, int value)
{ 
  stp[cntxt][pos] = value; 
}

void LayerCoder::setPts(int cntxt, int pos, int value)
{ 
  pts[cntxt][pos] = value; 
}




// ------------------ End of added --------------------------------------

/*---------------------------------------------------------------------------*/
