#include "fft2d.h"

fft2d_mpi_plan fft2d_mpi_create_plan(MPI_Comm comm, int nx, int ny,
                                     fftw_direction dir, int flags){
fft2d_mpi_plan p;

    p = (fft2d_mpi_plan) fftw_malloc(sizeof(fft2d_mpi_plan_data));
    p->p_fft_x = 0;
    p->p_fft_y = 0;
    p->mpi_data = 0;
    
    p->p_fft_x = fftw_create_plan(nx, dir, flags | FFTW_IN_PLACE);   
    p->p_fft_y = fftw_create_plan(ny, dir, flags | FFTW_IN_PLACE);
    p->work = (fftw_complex *)malloc(3*ny*sizeof(fftw_complex));
    if (p->work == NULL) {
       fft2d_mpi_destroy_plan(p);
       return NULL;
    }
    p->mpi_data = get_mpi_data(nx,ny,comm);
    if (p->mpi_data == NULL) {
       fft2d_mpi_destroy_plan(p);
       return NULL;
    }
    return p;
}
    
fft_mpi_data get_mpi_data(int nx, int ny, MPI_Comm comm){
fft_mpi_data mpi_data;
int k;

    mpi_data = (fft_mpi_data) malloc(sizeof(fft_mpi_data_struct));
    mpi_data->comm = comm;
    MPI_Comm_rank(comm,&(mpi_data->myid));
    MPI_Comm_size(comm,&(mpi_data->numprocs));
    mpi_data->nx_proc=nx/mpi_data->numprocs;
    if (mpi_data->nx_proc * mpi_data->numprocs != nx) {
       fprintf(stderr,"fft2d_mpi_create_plan: nx must be a multiple of the"
                      " # of processors\n");
       return NULL;
    }
    mpi_data->ny_proc = ny/mpi_data->numprocs;
    if (mpi_data->ny_proc * mpi_data->numprocs != ny) {
       fprintf(stderr,"fft2d_mpi_create_plan: ny must be a multiple of the"
                      " # of processors\n");
       return NULL;
    }

/* define datatype of a block to be sent: ny_proc elements of type
   fftw_complex (assumed to be equivalent to 2 elements of MPI_DOUBLE each)
   stored contiguously in memory. */

    MPI_Type_contiguous(2*(mpi_data->ny_proc),MPI_DOUBLE,&mpi_data->block);
    MPI_Type_commit(&mpi_data->block);

    mpi_data->s_req
      = (MPI_Request *)malloc(2*(mpi_data->numprocs-1)*sizeof(MPI_Request));
    mpi_data->s_status
      = (MPI_Status *)malloc(2*(mpi_data->numprocs-1)*sizeof(MPI_Status));
    if (mpi_data->s_req == NULL || mpi_data->s_status == NULL) {
       fprintf(stderr,"fft2d_mpi_create_plan: could not allocate memory"
                      " for MPI request and status variables\n");
       return NULL;
    }
    mpi_data->r_req = &mpi_data->s_req[mpi_data->numprocs-1];
    mpi_data->r_status = &mpi_data->s_status[mpi_data->numprocs-1];
    for (k = 0; k < 2*(mpi_data->numprocs-1); k++){
       mpi_data->s_req[k] = MPI_REQUEST_NULL;
    }
    return mpi_data;
}

void fft2d_mpi_destroy_plan(fft2d_mpi_plan p)
{
    if (p) {
	if (p->p_fft_x) fftw_destroy_plan(p->p_fft_x);
	if (p->p_fft_y) fftw_destroy_plan(p->p_fft_y);
        if (p->work) free(p->work);
	if (p->mpi_data) destroy_mpi_data(p->mpi_data);
	free(p);
    }
}

void destroy_mpi_data(fft_mpi_data mpi_data){
   if (mpi_data){
     if (mpi_data->s_req) free(mpi_data->s_req);
     if (mpi_data->s_status) free(mpi_data->s_status);
     MPI_Type_free(&mpi_data->block);
     free(mpi_data);
   }
}

void matrix_transpose_mpi(fft_mpi_data mpi_data,
                          fftw_complex *local_data, fftw_complex *work){
int nx,ny,i,j,k,id_proc,is,js,msendtag,mrecvtag,ij,ji;
fftw_complex dummy;
int mpierr;
  nx = mpi_data->nx_proc * mpi_data->numprocs;
  ny = mpi_data->ny_proc * mpi_data->numprocs;

/* transpose the matrix:                    */
  /* square matrices are handled separately */
  if (nx == ny) {
     /* first transpose the block that is local to the process */
     js=mpi_data->myid*mpi_data->nx_proc;
     for (i = 0; i < mpi_data->nx_proc; i++) {
       for (j = i+1; j < mpi_data->nx_proc; j++) {
         ij=js+i*nx+j;
         ji=js+j*nx+i;
         dummy=local_data[ij];
         local_data[ij]=local_data[ji];
         local_data[ji]=dummy;
       }
     }

     /* now the remaining blocks: copy data from work back to local data,
        transpose blocks when copying back */
     for (k = 1; k < mpi_data->numprocs; k++) {
       id_proc = (mpi_data->myid + k) % mpi_data->numprocs;
       js=id_proc*mpi_data->nx_proc;
/*     mpierr=MPI_Wait(&mpi_data->s_req[k-1],&mpi_data->s_status[k-1]);
       if (mpierr == MPI_ERR_IN_STATUS){
          fprintf(stderr,"id=%i: error in MPI_Wait s_req\n",mpi_data->myid);
       }
       mpierr=MPI_Wait(&mpi_data->r_req[k-1],&mpi_data->r_status[k-1]);
       if (mpierr == MPI_ERR_IN_STATUS){
          fprintf(stderr,"id=%i: error in MPI_Wait r_req\n",mpi_data->myid);
       }*/
       for (i = 0; i < mpi_data->nx_proc; i++) {
         for (j = 0; j < mpi_data->nx_proc; j++) {
           local_data[js+j*nx+i]=work[js+i*nx+j];
         }
       }
     }
   } else {
     /* copy local block to work space */
     js=mpi_data->myid*mpi_data->ny_proc;
     for (i = 0; i < mpi_data->nx_proc; i++) {
       for (j = 0; j < mpi_data->ny_proc; j++) {
         work[js+i*ny+j]=local_data[js+i*ny+j];
       }
     }
     /* all sends must be completed before even the local block can be
        transposed */
/*   MPI_Waitall(mpi_data->numprocs-1,mpi_data->s_req,mpi_data->s_status);*/
     /* now transpose the block that is local to the process */
     is=mpi_data->myid*mpi_data->nx_proc;
     for (i = 0; i < mpi_data->nx_proc; i++) {
       for (j = 0; j < mpi_data->ny_proc; j++) {
         local_data[is+j*nx+i]=work[js+i*ny+j];
       }
     }

     /* now the remaining blocks: copy data from work back to local data,
        transpose blocks when copying back */
     for (k = 1; k < mpi_data->numprocs; k++) {
       id_proc = (mpi_data->myid + k) % mpi_data->numprocs;
       is=id_proc*mpi_data->nx_proc;
       js=id_proc*mpi_data->ny_proc;
/*     MPI_Wait(&mpi_data->r_req[k-1],&mpi_data->r_status[k-1]);*/
       for (i = 0; i < mpi_data->nx_proc; i++) {
         for (j = 0; j < mpi_data->ny_proc; j++) {
           local_data[is+j*nx+i]=work[js+i*ny+j];
         }
       }
     }
   }
}     

void fft2d_mpi(fft2d_mpi_plan p,fftw_complex *local_data,fftw_complex *work){

int i,j,nx,ny,k,is;
int nx_proc,ny_proc,myid,numprocs,id_proc,msendtag,mrecvtag,mpierr;
MPI_Comm comm;
MPI_Datatype block;
fftw_complex *sendbuf;
fftw_complex *recvbuf;

int ij;
   if (work == NULL) {
      fprintf(stderr,"fft2d_mpi requires work array\n");
      return;
   }

/* First, transform second dimension, which is local to this process: */
   nx_proc = p->mpi_data->nx_proc;
   nx = p->p_fft_x->n;
   ny = p->p_fft_y->n;
   ny_proc = p->mpi_data->ny_proc;
   myid = p->mpi_data->myid;
   numprocs = p->mpi_data->numprocs;
   comm = p->mpi_data->comm;
   block = p->mpi_data->block;
   sendbuf = &p->work[ny];
   recvbuf = &p->work[2*ny];
   for (k = 1; k < numprocs; k++){
      id_proc = (myid + k) % numprocs;
      is = id_proc*ny_proc;
      msendtag = myid*numprocs + id_proc;
      mrecvtag = id_proc*numprocs + myid;
      MPI_Recv_init(&recvbuf[is],1,block,id_proc,mrecvtag,comm,
                    &p->mpi_data->r_req[k-1]);
      MPI_Send_init(&sendbuf[is],1,block,id_proc,msendtag,
                    comm,&p->mpi_data->s_req[k-1]);
   }
   for (i = 0; i < nx_proc; i++){
      fftw(p->p_fft_y, 1, &local_data[i*ny], 1, 1, p->work, 1, 0);
      MPI_Waitall(2*(numprocs-1), p->mpi_data->s_req, p->mpi_data->s_status);
      memcpy(sendbuf, &local_data[i*ny], ny*sizeof(fftw_complex));
      for (k = 1; k < numprocs; k++){
         id_proc = (myid + k) % numprocs;
         is = id_proc*ny_proc;
         if (i > 0) memcpy(&work[(i-1)*ny+is], &recvbuf[is], 
                           ny_proc*sizeof(fftw_complex));
      }
      MPI_Startall(2*(numprocs-1), p->mpi_data->s_req);
   }
   MPI_Waitall(2*(numprocs-1), p->mpi_data->s_req, p->mpi_data->s_status);
   for (k = 1; k < numprocs; k++){
      id_proc = (myid + k) % numprocs;
      is = id_proc*ny_proc;
      memcpy(&work[(ny_proc-1)*ny+is], &recvbuf[is], 
             ny_proc*sizeof(fftw_complex));
      MPI_Request_free(&p->mpi_data->s_req[k-1]);
      MPI_Request_free(&p->mpi_data->r_req[k-1]);
   }

/* Second, transpose the first dimension with the second dimension
   to bring the x dimension local to this process: */

   matrix_transpose_mpi(p->mpi_data, local_data, work);

/* Third, transform the x dimension, which is now local and contiguous: */
   fftw(p->p_fft_x, ny_proc, local_data, 1, nx, work, 1, 0);
}
