#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include "structN.h"
#include "memoryN.h"

/*#define BUFCHECK*/
extern long int totalY[5], totalU[5], totalV[5], totalmap[5], totalMV[5]; /*initiated in initN.c*/
extern float avgpsnr[3][16], avgvar[3][16];
extern long int *unconnectedL;

void read_frame(YUVimage_ptr frame, videoinfo info,  char *inname, int index, enum FORMAT format);

/****************************************************************************/
/*                               check_buffer()                             */
/****************************************************************************/
void check_buffer(int goal, int channel, videoinfo_ptr info)
{
  FILE *fpstat;

  /* diff is accumulative difference between target rate and generated bits */
  /* buff fullness is the current state of encoder buffer */
  /* if buff overflow,  then the net generated bits decrease */
  /* if buff underflow, then the net generated bits increase */

  fpstat = fopen(info->statname, "at+");

  info->diff += goal - info->generated; 
  info->buff_full += info->generated - channel;

  if(info->buff_full < 0){                     /* underflow */
    if(info->verbose){
      printf("<buffer> state: underflow %d Bmax %d\n", 
	     info->buff_full, info->buff_max);
    }
    fprintf(fpstat, "\n <buffer> state: underflow %d Bmax %d\n", 
	     info->buff_full, info->buff_max);

#ifdef BUFCHECK    
    info->diff += info->buff_full; /* net decrease */
    info->buff_full = 0;
#endif
  }
  else if(info->buff_full > info->buff_max){   /* overflow */
    if(info->verbose){
      printf("<buffer> state: overflow %d Bmax %d\n", 
	     info->buff_full, info->buff_max);
    }
    fprintf(fpstat, "\n <buffer> state: overflow %d Bmax %d\n", 
	    info->buff_full, info->buff_max);

#ifdef BUFCHECK    
    info->diff += info->buff_full-info->buff_max; /* net increase */
    info->buff_full = info->buff_max;
#endif
  }
  else{                                        /* normal */
    if(info->verbose){
      printf("<buffer> state: normal %d Bmax %d\n", 
	     info->buff_full, info->buff_max);
    }
    fprintf(fpstat, "\n <buffer> state: normal %d Bmax %d\n", 
	    info->buff_full, info->buff_max);
  }

  if(info->verbose){
    printf("diff = %d Rt = %d Rg = %d B = %d Rg = %d Rc = %d\n", 
	   info->diff, goal, info->generated, info->buff_full, info->generated,
	   channel);
  }
  fprintf(fpstat, "diff = %d Rt = %d Rg = %d\n", 
	  info->diff, goal, info->generated);
  fprintf(fpstat, "B = %d Rg = %d Rc = %d\n", 
          info->buff_full, info->generated, channel);

  fclose(fpstat);
}


int pstat(Rate_ptr alloc, videoinfo_ptr info, int numfr)
{
  FILE   *fp;
  int    i, j, *ybit, *ubit, *vbit, *map, *mv, sum, Tsum;
  float  *yvar, *uvar, *vvar;
  int    half_len, goal, channel;


  /******************   MMMMMMMMM  ********************/
  ybit  = alloc->ybit;  ubit  = alloc->ubit;  vbit  = alloc->vbit;
  yvar  = alloc->yvar;  uvar  = alloc->uvar;  vvar  = alloc->vvar;
  map = alloc->map; mv = alloc->mv;

  fp = fopen(info->statname, "at+");
  fprintf(fp, "\n <coding>\n");




  half_len = info->GOPsz / 2;
  for(i = 0; i <info->tPyrLev; i++){

    for(j = 0; j < half_len; j++){
      totalY[i] += ybit[half_len+j];
      totalU[i] += ubit[half_len+j];
      totalV[i] += vbit[half_len+j];
      totalmap[i] += map[half_len+j];
      totalMV[i] += mv[half_len+j];
    }
    half_len /= 2;
  }

  totalY[info->tPyrLev] += ybit[0];
  totalU[info->tPyrLev] += ubit[0];
  totalV[info->tPyrLev] += vbit[0];



  for(i=0, Tsum = 0; i < numfr ; i++){

    sum = ybit[i] + ubit[i] + vbit[i] + map[i] + mv[i];
    Tsum += sum;

    fprintf(fp, "FR%d ybit = %d\t", i, ybit[i]);


/*    fprintf(fp, "FR%d ybit = %d\t ubit = %d\t vbit = %d ",
            i, ybit[i],ubit[i],vbit[i]);
*/


    fprintf(fp, "map = %d\t mv = %d\t sum = %d\t", map[i], mv[i], sum);

   /* fprintf(fp, "yvar = %.1f\n", yvar[i]);*/

    fprintf(fp, "yvar = %.1f\t uvar = %.1f\t vvar = %.1f\n",
             yvar[i], uvar[i], vvar[i]);

  }

  fclose(fp);
  /* Buffer */
//  goal = info->GOPbit;  channel = numfr * info->prate + 0.5;
  info->generated = Tsum;
  check_buffer(goal, channel, info);

  return(0);
}




/****************************************************************************/
/*                                 snr_frame()                              */
/****************************************************************************/
void snr_frame(float *ysnr, float *usnr, float *vsnr, YUVimage_ptr codeframe, YUVimage_ptr inframe, videoinfo info)
{
  int    i, ypix, cpix;
  double sum, diff, peak;

  ypix = info.ywidth * info.yheight;
  cpix = info.cwidth * info.cheight;

  switch(info.format){
  case YUV:
  case RAS:
	  peak=255.;
	  break;
  case DPX:
	  peak=1023.;
	  break;
  /* IVB 2004/4/7 ---------------- */
  case FLT:
	  peak = 255.0 + 255.0; // or 255.0 ???
	  break;
  /* IVB 2004/4/7 ---- end ------- */
  default:
	  printf("image format error format = %d(pstatN.c)\n", info.format);
	  exit(1);
  }

  sum = 0.;
  for (i = 0; i < ypix; i++) {
    diff = inframe->Y[i] - codeframe->Y[i];
    sum += diff * diff;
  }
  *ysnr = (ypix)? 20.*log10(peak/sqrt(sum/ypix)) : 0.;

  sum = 0.;
  for (i = 0; i < cpix; i++) {
    diff = inframe->U[i] - codeframe->U[i];
    sum += diff * diff;
  }
  *usnr = (cpix)? 20.*log10(peak/sqrt(sum/cpix)) : 0.;

  sum = 0.;
  for (i = 0; i < cpix; i++) {
    diff = inframe->V[i] - codeframe->V[i];
    sum += diff * diff;
  }
  *vsnr = (cpix)? 20.*log10(peak/sqrt(sum/cpix)) : 0.;
}


/****************************************************************************/
/*                                 snr_crop()                               */
/****************************************************************************/
void snr_crop(float *ysnr, float *usnr, float *vsnr, YUVimage_ptr codeframe, YUVimage_ptr inframe, videoinfo info)
{
  int    i, j, start_x, start_y, width, height, pos, ypix, cpix;
  double sum, diff, peak;

  start_x =65;
  start_y =0;
  width = 59;
  height = 240;
  ypix = width * height;
  cpix = ypix/2;

  switch(info.format){
  case YUV:
  case RAS:
	  peak=255.;
	  break;
  case DPX:
	  peak=1023.;
	  break;
  /* IVB 2004/4/7 ---------------- */
  case FLT:
	  peak = 255.0 + 255.0; // or 255.0 ???
	  break;
  /* IVB 2004/4/7 ---- end ------- */
  default:
	  printf("image format error format = %d(pstatN.c)\n", info.format);
	  exit(1);
  }

  sum = 0.;
  for (i = 0; i < height; i++) {
	  for(j=0; j < width; j++){
		  pos=i*info.ywidth+ start_x + j;
          diff = inframe->Y[pos] - codeframe->Y[pos];    
		  sum += diff * diff;
	  }
  }

  *ysnr = 20.*log10(peak/sqrt(sum/ypix));
  //*ysnr = sum;

  sum = 0.;
  for (i = 0; i < height/2; i++) {
	  for(j=0; j < width/2; j++){
		  pos=i*info.cwidth+ start_x/2 + j;
		  diff = inframe->U[pos] - codeframe->U[pos];    
		  sum += diff * diff;
	  }
  }
  *usnr = 20.*log10(peak/sqrt(sum/cpix));
  //*usnr = sum;

  sum = 0.;
  for (i = 0; i < height/2; i++) {
	  for(j=0; j < width/2; j++){
		  pos=i*info.cwidth+ start_x/2 + j;
		  diff = inframe->V[pos] - codeframe->V[pos];    
		  sum += diff * diff;
	  }
  }
  *vsnr = 20.*log10(peak/sqrt(sum/cpix));
  //*vsnr = sum;
}



/****************************************************************************/
/*                                calsnr()                                  */
/****************************************************************************/
void calsnr(int start, int last, videoinfo info)
{
  int i, num;
  float ysnr, usnr, vsnr, mean1, mean2, mean3;
  YUVimage fr0, fr1;
  FILE *fpstat; 


  info.coding_domain = LOG; // calculate PSNR in LOG domain

  frame_alloc(&fr0, info);
  frame_alloc(&fr1, info);
  fpstat = fopen(info.statname, "at+");
  fprintf(fpstat, "\n <psnr>\n");
  fprintf(fpstat, " ysnr      usnr    vsnr\n"); 
  printf("calculate PSNR frame %d ~ frame %d\n", start, last);

  mean1=0.; mean2=0.; mean3=0.;
  for(i=start ; i<=last ; i++){
    read_frame(&fr1, info, info.inname, i, info.format);     /* original frame */
    read_frame(&fr0, info, info.decname, i, info.format);    /* coded frame */
    snr_frame(&ysnr, &usnr, &vsnr, &fr0, &fr1, info);
    mean1 += ysnr; mean2 += usnr; mean3 += vsnr;  /* New */
    fprintf(fpstat, "%.2f\t  %.2f\t  %.2f\n", ysnr, usnr, vsnr);
  }

  fprintf(fpstat, "=================================================\n");
  num = last-start+1;
  fprintf(fpstat, "%6s(%03d) ysnr = %.2f usnr = %.2f vsnr = %.2f\n","avg",num,mean1/num,mean2/num,mean3/num);
  fclose(fpstat);

  free_frame_interior(fr0);
  free_frame_interior(fr1);
}



/*
 * calsnr_seq()
 * calculate psnr of the whole sequence
 */
void calsnr_seq(YUVimage_ptr fr0, YUVimage_ptr fr1, int start, int last, videoinfo info)
{
  int i, num;
  float ysnr, usnr, vsnr, mean1, mean2, mean3;
  FILE *fpstat; 

  mean1=0.; mean2=0.; mean3=0.;
  fpstat = fopen(info.statname, "at+"); /* this is the difference from calsnr*/
  fprintf(fpstat, "\n <psnr>\n");
  fprintf(fpstat, " ysnr      usnr    vsnr\n"); 
  printf("start %d last %d\n", start, last);

  for(i=start ; i<=last ; i++){
    read_frame(fr1, info, info.inname, i, info.format);     /* original frame */
    read_frame(fr0, info, info.decname, i, info.format);    /* coded frame */
    snr_frame(&ysnr, &usnr, &vsnr, fr0, fr1, info);
    mean1 += ysnr; mean2 += usnr; mean3 += vsnr;

    fprintf(fpstat, "%.2f\t  %.2f\t  %.2f\n", ysnr, usnr, vsnr);
  }

  fprintf(fpstat, "=================================================\n");

  num = last-start+1;

  fprintf(fpstat, "%6s(%03d) ysnr = %.2f usnr = %.2f vsnr = %.2f\n","avg",num,mean1/num,mean2/num,mean3/num);
  fclose(fpstat);

}
/*
 * calsnr_frame()
 * calculate psnr of a frame
 */

void calsnr_frame(YUVimage_ptr fr0, YUVimage_ptr fr1, int curr, videoinfo info)
{
  int i, num;
  float ysnr, usnr, vsnr;
  static float mean1=0., mean2=0., mean3=0.;
  FILE *fpstat; 

  //mean1=0.; mean2=0.; mean3=0.;
  fpstat = fopen(info.statname, "at+"); /* this is the difference from calsnr*/
  if(curr == info.start){
    fprintf(fpstat, "\n <psnr>\n");
    fprintf(fpstat, " ysnr      usnr    vsnr\n"); 
  }
   
  snr_frame(&ysnr, &usnr, &vsnr, fr0, fr1, info);
   
  mean1 += ysnr; mean2 += usnr; mean3 += vsnr;

  fprintf(fpstat, "%.2f\t  %.2f\t  %.2f\n", ysnr, usnr, vsnr);

  if(curr == info.last){
    fprintf(fpstat, "=================================================\n");
    num = info.last-info.start+1;
    fprintf(fpstat, "%6s(%03d) ysnr = %.2f usnr = %.2f vsnr = %.2f\n","avg",num,mean1/num,mean2/num,mean3/num);
  }
  fclose(fpstat);
}


void print_mvbits(videoinfo info, int num_of_GOP, int *mvbits)
{
  int i;
  FILE   *fp_mv;
  
  if(!(fp_mv = fopen(info.mvstatname, "wb"))){
	  printf("can not open %s\n", info.mvstatname);
	  exit(1);
  }
  
  for(i=0; i<num_of_GOP; i++)
    fprintf(fp_mv, "%d\n", mvbits[i]>>3); /* GOP bytes for mv coding*/
  fclose(fp_mv);

  return;
}


void print_stat(videoinfo info)
{
	int i, j;

		    FILE *fpstat;

            fpstat = fopen(info.statname, "at+");
            /* the psnr is between the original frame and the lowband without coding*/
            fprintf(fpstat, "average psnr of low temporal band without coding:\n");
            for(j=1; j<info.GOPsz; j++){
              for(i=0; i<3; i++){
                avgpsnr[i][j] /= (info.last-info.start+1)/info.GOPsz;
			  }
              fprintf(fpstat, " %d\t Y %.2f\t  U %.2f\t  V %.2f\n", j, avgpsnr[0][j], avgpsnr[1][j], avgpsnr[2][j]);
			}
            fprintf(fpstat, "average var of low temporal band without coding:\n");
            for(j=0; j<info.GOPsz; j++){
              for(i=0; i<3; i++){
                 avgvar[i][j] /= (info.last-info.start+1)/info.GOPsz;
			  }
              fprintf(fpstat, " %d\t Y %.2f\t  U %.2f\t  V %.2f\n", j, avgvar[0][j], avgvar[1][j], avgvar[2][j]);
			}
/*
            fprintf(fpstat, "averaged psnr of low temporal band with coding:\n");
			for(j=1; j<info.GOPsz; j++){
				for(i=0; i<3; i++){
                  avgpsnr_cod[i][j] /= (info.last-info.start+1)/info.GOPsz;
				}
                fprintf(fpstat, " %d\t Y %.2f\t  U %.2f\t  V %.2f\n", j, avgpsnr_cod[0][j], avgpsnr_cod[1][j], avgpsnr_cod[2][j]);
			}
*/
           for(i=0; i<info.tPyrLev; i++){
              fprintf(fpstat, "number of unconnected pixels in levle %d is %d", i+1, unconnectedL[i]);
              fprintf(fpstat, "its percentage in the whole number of pixels in that level %0.4f\n", unconnectedL[i]/(info.ywidth*info.yheight*(info.last-info.start+1)/(float)pow(2, i+1)) );
		   }
          for(i=0; i<info.tPyrLev; i++){
            fprintf(fpstat, "level %d\n", i+1);
            fprintf(fpstat, "totalY = %ld\n", totalY[i]);
            fprintf(fpstat, "totalU = %ld\n", totalU[i]);
            fprintf(fpstat, "totalV = %ld\n", totalV[i]);
            fprintf(fpstat, "totalMV = %ld\n", totalMV[i]);
            fprintf(fpstat, "totalmap = %ld\n", totalmap[i]);
		  }
         fprintf(fpstat, "level %d\n", i+1);
         fprintf(fpstat, "totalY = %ld\n", totalY[4]);
         fprintf(fpstat, "totalU = %ld\n", totalU[4]);
         fprintf(fpstat, "totalV = %ld\n", totalV[4]);
         fclose(fpstat);
}