#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "basic.h"
#include "structN.h"
#include "bme_pruneN.h"

#define opt 2


/*****************************************************************************/
/*                                write_mv_pr()                              */
/*****************************************************************************/
void write_mv_pr(mvnode_ptr mvtop, vector_ptr fmv, videoinfo info, char *mvname, int index, int large, int subpel)
{
  char ctmp[80];
    /* mv_prune() is put immidiatly after block_matching() in mctfN.c. */ /*Feb 5 Chen*/
 /* mv_prune(mvtop, fmv, info, large, subpel); */     /* pruning lam=lam0 or lam*/
  sprintf(ctmp, "%s%s", mvname, "PR");
  write_mv(mvtop, fmv, info, ctmp, index);      /* after */
}

/*****************************************************************************/
/*                                read_mv_pr()                               */
/*****************************************************************************/
void read_mv_pr(mvnode_ptr mvtop, vector_ptr fmv, videoinfo info, char *mvname, int index, int large, int subpel)

{
  char ctmp[80];

  sprintf(ctmp, "%s%s", mvname, "PR");
  read_mv(mvtop, fmv, info, ctmp, index);
/*  mv_prune(mvtop, fmv, info, large, subpel);*/
}


/****************************************************************************/
/*                             fill_mv_child()                              */
/****************************************************************************/
void fill_mv_child(mvnode_ptr mvtop, vector_ptr fmv_root, vector_ptr fmv, videoinfo info, int large, int subpel)
{
  int    dL, dMAP, dMV, dD;
  float  slope;
  mvnode mvcurr;

  if(fmv->child){

    fill_mv_child(mvtop, fmv_root, fmv->child0, info, large, subpel);
    fill_mv_child(mvtop, fmv_root, fmv->child1, info, large, subpel);
    fill_mv_child(mvtop, fmv_root, fmv->child2, info, large, subpel);
    fill_mv_child(mvtop, fmv_root, fmv->child3, info, large, subpel); 



    /* calculate without this child node */
    fmv->child=0;
    est_mv_bit(&mvcurr, fmv_root, info, large, subpel);
    fmv->child=1;

    /* save the deltaH, deltaD and slope for each nod */
    dL   = mvtop->leaf   - mvcurr.leaf;  dMAP = mvtop->mapbit - mvcurr.mapbit;
    dMV  = mvtop->mvbit  - mvcurr.mvbit; dD   = mvcurr.mvD    - mvtop->mvD;
    slope  = (dMAP+dMV)?  (float)dD/(dMAP+dMV) : HUGE_VAL;
 
	if(slope < -10000) printf("small\n");
    fmv->dL = dL;   fmv->dMAP = dMAP; 
    fmv->dMV = dMV; fmv->dD = dD;  fmv->slope  = slope;
    fmv->mslope = get_min(fmv->slope, fmv->child0->mslope,fmv->child1->mslope, 
                  fmv->child2->mslope, fmv->child3->mslope);
	/*if(fmv->child0->child==0 && fmv->child1->child==0 && fmv->child2->child==0 && fmv->child3->child==0){
	  fmv->mslope = fmv->slope;
	}
	else{
	  fmv->mslope = get_min(fmv->child0->mslope,fmv->child1->mslope, 
		fmv->child2->mslope, fmv->child3->mslope);

	}*/



	/*printf("fmv->mslope %f slope %f, 0 %f, 1 %f, 2 %f, 3 %f\n", fmv->mslope, fmv->slope, 
		fmv->child0->mslope,fmv->child1->mslope, fmv->child2->mslope, fmv->child3->mslope);
	getchar();*/
  }
  else{  /* leaf node */
    fmv->dL  = fmv->dMAP = fmv->dMV = fmv->dD = 0;
    fmv->slope = fmv->mslope = HUGE_VAL;
  }
}
/*****************************************************************************/
/*                              fill_mv_node()                               */
/*****************************************************************************/
void fill_mv_node(mvnode_ptr mvtop, vector_ptr fmv, videoinfo info, int large, int subpel)
{
  int    X, Y, pos;

  /* fill the deltaD, deltaH and slope */
  /*  for every node of the quad-tree structured motion vectors */

  /* initial H and D */
  est_mv_bit(mvtop, fmv, info, large, subpel);

  /* fill each node recursively */
  for(Y=0 ; Y<info.ynum ; Y++){
    for(X=0 ; X<info.xnum ; X++){
      pos=Y*(info.xnum)+X; /*printf("fill mv node %d\n", pos);*/
      fill_mv_child(mvtop, fmv, &fmv[pos], info, large, subpel);
    }
  }
}

/*****************************************************************************/
/*                              update_parent()                              */
/*****************************************************************************/
void update_parent(mvnode mvdel, vector_ptr fmv)
{
  int  dMAPMV, dD;
  vector_ptr tmv;

  if(fmv->parent){

	if(fmv->parent->merge != NO){

      /* update the parent node */
      fmv->parent->dL   -= mvdel.leaf;   
      fmv->parent->dMAP -= mvdel.mapbit; 
      fmv->parent->dMV  -= mvdel.mvbit;
      fmv->parent->dD   -= mvdel.mvD;

	  if(fmv->parent->dMV < 0){
	    fmv->parent->dMV = 0;
	   /* printf("fmv->parent->dMV < 0 \n"); APR24*/
	  }

      dMAPMV = (fmv->parent->dMAP) + (fmv->parent->dMV);
      dD     = fmv->parent->dD;
      fmv->parent->slope  = (dMAPMV)?  (float)dD/dMAPMV : HUGE_VAL;
	}


	tmv = fmv->parent;
    fmv->parent->mslope = get_min(tmv->slope, tmv->child0->mslope, tmv->child1->mslope, tmv->child2->mslope, tmv->child3->mslope);
    
    update_parent(mvdel, fmv->parent);
  }
  else{
  }
}



/*****************************************************************************/
/*                               prune_child()                               */
/*****************************************************************************/
void prune_child(float mslope, mvnode_ptr mvdel, vector_ptr fmv, int *quit)
{
  if(*quit) return;

  if(fmv->child){
    if(fmv->slope == mslope){  /* delete this node */
        *quit=1;
        /* save  delta */
        mvdel->leaf  = fmv->dL;   mvdel->mapbit = fmv->dMAP;
        mvdel->mvbit = fmv->dMV;  mvdel->mvD    = fmv->dD;
        mvdel->slope = fmv->slope;

        /* pruned node is now a leaf node */
        fmv->dL  = fmv->dMAP = fmv->dMV = fmv->dD = 0;
        fmv->slope = fmv->mslope = HUGE_VAL; 

        /* update the parent node and free memory */
        update_parent(*mvdel, fmv);
        free_child(fmv);
	    fmv->child=0; /*reset after free_child not before*/

	}
    else{
      prune_child(mslope, mvdel, fmv->child0, quit);
      prune_child(mslope, mvdel, fmv->child1, quit);
      prune_child(mslope, mvdel, fmv->child2, quit);
      prune_child(mslope, mvdel, fmv->child3, quit);
    }
  }
  else{
  }
}


/*****************************************************************************/
/*                             prune_smallest()                              */
/*****************************************************************************/
void prune_smallest(float mslope, mvnode_ptr mvdel, vector_ptr fmv, videoinfo info)
{
  int   X, Y, pos, quit;

  /* prune the lead tree having the minimum slope */
  quit=0;
  for(Y=0 ; Y<info.ynum ; Y++){  /* coding loop */
    for(X=0 ; X<info.xnum ; X++){
      pos=Y*(info.xnum)+X;
	

	  /*if(pos ==1 ){		
		printf("mslop %f\n", mslope);
		printf("pos %d parent_mode %d  dMV %d dMAP %d dD %d slope %f \n",
			pos, fmv[pos].mode, fmv[pos].dMV, fmv[pos].dMAP, fmv[pos].dD, fmv[pos].slope);
	  }*/
	  
	  prune_child(mslope, mvdel, &fmv[pos], &quit);
    }
  }  

}


/*****************************************************************************/
/*                                mv_prune()                                 */
/* pruning the quad-tree structured motion vector                            */
/*  using D + lambda*H                                                       */
/* D is the prediction error                                                 */
/* H is the bit for representing the motion vectors                          */
/* no prunning for the FSBM                                                  */
/*****************************************************************************/
/*void mv_prune(vector_ptr fmv, videoinfo info, int large, int subpel, int bi_flag, int curr, int noframe)
{
	mvnode_ptr mvtop;
	mvnode mvdel;

    mvtop = (mvnode_ptr)getarray(1, sizeof(mvnode), "mvtop");  
  fill_mv_node(mvtop, fmv, info, large, subpel);

  if(info.level==1) print_mvnode(mvtop); 
  if(info.level==1) return; 

  while(mvtop->mslope < info.lambda){
    prune_smallest(mvtop->mslope, &mvdel, fmv, info);
    mvtop->leaf   -= mvdel.leaf;    mvtop->mapbit -= mvdel.mapbit; 
    mvtop->mvbit  -= mvdel.mvbit;   mvtop->mvD    += mvdel.mvD;  
    mvtop->slope  =  mvdel.slope;   mvtop->mslope = get_mslope(fmv, info);
  }
  if(info.verbose) print_mvnode(mvtop);
  write_mv_pr(mvtop, fmv, info, info.mvname, curr+(2*noframe+1)*large, large, subpel);
  free(mvtop);
}
*/
void mv_prune(vector_ptr fmv, videoinfo info, int large, int subpel, int curr, int noframe)
{
  mvnode_ptr mvtop;
  mvnode mvdel;
  float mslope;

 
	if(opt == 1){
      mvtop = (mvnode_ptr)getarray(1, sizeof(mvnode), "mvtop");  
	  fill_mv_node(mvtop, fmv, info, large, subpel);
          /* save the min slope at the top level */
      mslope = mvtop->mslope = get_mslope(fmv, info);	  
        /*printf("mvtop->mslope %f \n", mvtop->mslope);*/
	}
	else{
	  fill_mv_node2(fmv, info, large, subpel);
	  mslope = get_mslope(fmv, info);
	}


  if(info.level==1) return; 

  /* loop for pruning, in other words sweeping the R-D curve */
  while(mslope < info.lambda){
    prune_smallest(mslope, &mvdel, fmv, info);
	mslope = get_mslope(fmv, info); 
  }
  if(opt == 1)  free(mvtop);
}







/****************************************************************************/
/*                              est_child_map()                             */
/*                           for a macroblock                               */
/****************************************************************************/
void est_child_map(vector_ptr fmv, int *mapbit, int x, int y, int xblk, int yblk, int hor, int ver, int small,enum FLAG adapt_flag)
{
  int cx, cy;

  if(fmv->child){

    (*mapbit)++;  /*for tree structure */

    cx=x; cy=y;
    est_child_map(fmv->child0, mapbit, cx, cy,xblk/2,yblk/2,hor,ver, small, adapt_flag);

    cx=x+xblk/2; cy=y;
    est_child_map(fmv->child1, mapbit, cx, cy,xblk/2,yblk/2,hor,ver, small, adapt_flag);

    cx=x; cy=y+yblk/2;
    est_child_map(fmv->child2, mapbit, cx, cy,xblk/2,yblk/2,hor,ver, small, adapt_flag);

    cx=x+xblk/2; cy=y+yblk/2;
    est_child_map(fmv->child3, mapbit, cx, cy,xblk/2,yblk/2,hor,ver, small, adapt_flag);
  }
  else{
    if(x>=hor || y>=ver){ 
      return;
    }
    else{
      if(xblk<=small){
        if(adapt_flag == YES){
		  if(fmv->mode == DEFAULT)
            (*mapbit)++; 
		  else
			(*mapbit) += 2;
        }
      }
      else{
        (*mapbit)++;  /*for tree structure */;

        if(adapt_flag == YES){
		  if(fmv->mode ==DEFAULT)
            (*mapbit)++;
		  else
			(*mapbit) += 2;
        }
      }
    } /* Feb 23 */

  }
}

/****************************************************************************/
/*                             est_child_mv()                               */
/* estimate bits for motion vectors and distortion in a macroblock          */
/* the number of leaves (leaf) does not count intrablock                  */
/****************************************************************************/
void est_child_mv(vector_ptr fmv, float *pmvx, float *pmvy, int num_symbol, int subpel, int x, int y,
			 int xblk, int yblk, int hor, int ver, int *pmf, int *leaf, int *mvD)

{
  int   dmvx, dmvy, symbol, cx, cy, xblock, yblock;
  
  if(fmv->child){
    cx=x; cy=y;
    est_child_mv(fmv->child0, pmvx, pmvy, num_symbol, subpel, 
		 cx, cy, xblk/2, yblk/2, hor, ver, pmf, leaf, mvD);

    cx=x+xblk/2; cy=y;
    est_child_mv(fmv->child1, pmvx, pmvy, num_symbol, subpel,
		 cx, cy, xblk/2, yblk/2, hor, ver, pmf, leaf, mvD);

    cx=x; cy=y+yblk/2;
    est_child_mv(fmv->child2, pmvx, pmvy, num_symbol, subpel,
		 cx, cy, xblk/2, yblk/2, hor, ver, pmf, leaf, mvD);

    cx=x+xblk/2; cy=y+yblk/2;
    est_child_mv(fmv->child3, pmvx, pmvy, num_symbol, subpel,
		 cx, cy, xblk/2, yblk/2, hor, ver, pmf, leaf, mvD);
  }
  else{
    if(x>=hor || y>=ver) return;

    /* number of leaf and update the distortion */
    xblock = (x+xblk<=hor)? xblk : hor-x;
    yblock = (y+yblk<=ver)? yblk : ver-y;

    (*mvD) += nint((fmv->mad) * (xblock*yblock));


#ifdef INTRA_CODED_BLOCK
	if(fmv->mode == INTRABLOCK)  return;
#endif

    /* mv bit */
    (*leaf)++; // not include intrablocks

    dmvx =(int)((1<<subpel)*(fmv->mvx - *pmvx));
    dmvy =(int)((1<<subpel)*(fmv->mvy - *pmvy));

/*
    dmvx = fmv->mvx;
    dmvy = fmv->mvy;
    if(fmv->bi == 1){
      dmvx = -dmvx;
      dmvy = -dmvy;
    }
    dmvx =(int)((1<<subpel)*(dmvx - *pmvx));
    dmvy =(int)((1<<subpel)*(dmvy - *pmvy));
*/ /* QUEST */

    if     (dmvx >  num_symbol/2 ) dmvx -= num_symbol;
    else if(dmvx <-(num_symbol/2)) dmvx += num_symbol;
    if     (dmvy >  num_symbol/2 ) dmvy -= num_symbol;
    else if(dmvy <-(num_symbol/2)) dmvy += num_symbol;
	
    symbol = dmvx+num_symbol/2;  /* offset to consider - */
    if(symbol<0 || symbol >=num_symbol){
      printf("error in est_child_mv()\n");exit(1);
    }
    pmf[symbol]++;

    symbol = dmvy+num_symbol/2;  /* offset to consider - */
    if(symbol<0 || symbol >=num_symbol){
      printf("error in est_child_mv()\n");exit(1);
    }
    pmf[symbol]++;

    *pmvx = fmv->mvx;
    *pmvy = fmv->mvy;


/*
    if(fmv->bi == 1){
      *pmvx = -(*pmvx);
      *pmvy = -(*pmvy);
    }
*/ /* QUEST */

  }
}


/*****************************************************************************/
/*                              est_mv_bit()                                 */
/*                            for the whole motion field                     */
/*****************************************************************************/
/*est_mv_bit(mvtop, fmv, info, large, subpel, bi_flag)
     mvnode_ptr mvtop;
     vector_ptr fmv;
     videoinfo  info;
     int        large, subpel, bi_flag;*/
void est_mv_bit(mvnode_ptr mvtop, vector_ptr fmv, videoinfo info, int large, int subpel)
{
  int   i, x, y, X, Y, xnum, ynum, xblk, yblk, hor, ver, pos, small, itemp;
  int   num_symbol, *mvpmf, full;
  int   leaf, mapbit, mvD;
  float pmvx, pmvy, mvbpp, mvbit;

  /* estimation the bit for motion vector and quad-tree maps */
  /* calculate the overall distortion */

  /* miscellany */
  xnum=info.xnum; ynum=info.ynum;
  xblk=info.xblk; yblk=info.yblk;
  hor =info.ywidth; ver =info.yheight;
  small=xblk; itemp=info.level;
  while(itemp!=1){ small/=2; itemp--;}

  /* assign memory and initialize */
  full = (info.maxx!=1)? 1 : 0;
  num_symbol = get_numsymbol(info.level, info.maxx, large, subpel, full);
  mvpmf = (int *) getarray(num_symbol, sizeof(int), "mvpmf");
  for(i=0 ; i<num_symbol ; i++) mvpmf[i]=0;

  /* estimate the size of the map */
  mapbit=0;
  if(info.level!=1){
    for(y=0, Y=0 ; Y<ynum ; y+=yblk, Y++){  /* coding loop */
      for(x=0, X=0 ; X<xnum ; x+=xblk, X++){
        pos=Y*xnum+X;
        est_child_map(&fmv[pos], &mapbit, x, y, xblk, yblk, hor, ver, small, info.adapt_flag);

      }
    }
  }

  /* estimate the mv bit by the zero-order entropy and distortion */
  leaf=0; mvD=0.;
  for(y=0, Y=0 ; Y<ynum ; y+=yblk, Y++){  /* coding loop */
    pmvx=0.; pmvy=0.;

    for(x=0, X=0 ; X<xnum ; x+=xblk, X++){
      pos=Y*xnum+X;
      est_child_mv(&fmv[pos], &pmvx, &pmvy, num_symbol, subpel,
                   x, y, xblk, yblk, hor, ver, mvpmf, &leaf, &mvD);
    }
  }

  /* mv bit entropy */
  mvbpp = entropy(num_symbol, 2*leaf, mvpmf); /* mvx and mvy */
  mvbit = nint(mvbpp * (2*leaf));

  mvtop->leaf   = leaf;
  mvtop->mapbit = mapbit;
  mvtop->mvbit  = mvbit;
  mvtop->mvD    = mvD;
  mvtop->slope  = 0.;

  /* release the memory */
  free(mvpmf);
}

/****************************************************************************/
/*                                get_mslope()                              */
/****************************************************************************/
float get_mslope(vector_ptr fmv, videoinfo info)
{
  int   X, Y, pos;
  float mslope;

  mslope = HUGE_VAL;
  for(Y=0 ; Y<info.ynum ; Y++){ 
    for(X=0 ; X<info.xnum ; X++){
      pos=Y*(info.xnum)+X;
      if(fmv[pos].mslope < mslope) mslope = fmv[pos].mslope;
    }
  }  

  return mslope;
}

/*****************************************************************************/
/*                                  get_min()                                */
/*****************************************************************************/
float get_min(float num1, float num2, float num3, float num4, float num5)
{
  if     (num1<=num2 && num1<=num3 && num1<=num4 && num1<=num5) return num1;
  else if(num2<=num3 && num2<=num4 && num2<=num5)               return num2;
  else if(num3<=num4 && num3<=num5)                             return num3;
  else if(num4<=num5)                                           return num4;
  else                                                          return num5;
}
/*float get_min(float num1, float num2, float num3, float num4)
{
  if     (num1<=num2 && num1<=num3 && num1<=num4) return num1;
  else if(num2<=num3 && num2<=num4)               return num2;
  else if(num3<=num4)                             return num3;
  else                                            return num4;
}*/

/*****************************************************************************/
/*                                print_mvnode()                             */
/*****************************************************************************/
void print_mvnode(mvnode_ptr mvtop)
{
  printf("leaf = %d ",    mvtop->leaf);
  printf("map = %d ",     mvtop->mapbit);
  printf("mv = %d ",      mvtop->mvbit);
  printf("mvD = %d ",     mvtop->mvD);
  printf("slope = %.3f ", mvtop->slope);
  printf("mslope = %.3f\n", mvtop->mslope);
}
