#include "mex.h"
#include <math.h>

/*
Puts an lower triangular or symmetric matrix into vech form: e.g.
     1  
     2 5 
     3 6 8
     4 7 9 10
*/

double *vech(double *A, double *B, int n, bool UPPER)
{
  double *ij, *k,  *ijstart, *ijend;
  int  j;

  k = B;
  if (UPPER)
  {
    if (A==B) mexErrMsgTxt("Cannot overwrite with this function");
    ijstart=A;
    ijend=A+n*n;
    for (j=0;j<n;j++)
    {
      for (ij=ijstart+j; ij<ijend; k++, ij+=n) *k=*ij; 
      ijstart += n;
    }
  }
  else 
  { 
    ij = A;
    ijend=ij+n;
    for (j=1;j<=n;j++)
    {
      for (; ij<ijend; ij++) *k++ = *ij;
      ijend=ij+n;
      ij += j;
    }
  }
  return(B);
}


void mexFunction(
   int nlhs, mxArray *plhs[],
   int nrhs, const mxArray *prhs[])
{
  int n, n1;
  bool type;
   
  if (nrhs<1) mexErrMsgTxt("No input argument passed.");
  n=mxGetM(prhs[0]);
  if (mxGetN(prhs[0])!=n) mexErrMsgTxt("Matrix must be square");

  n1=n*(n+1)/2;

  if (nrhs>1)
    if (*mxGetPr(prhs[1])==1) type=true; else type=false;
  else
    type=false;

  if (mxIsComplex(prhs[0]))
  {
    plhs[0]=mxCreateDoubleMatrix(n1,1,mxCOMPLEX);
    vech(mxGetPr(prhs[0]),mxGetPr(plhs[0]),n,type);
    vech(mxGetPi(prhs[0]),mxGetPi(plhs[0]),n,type);
  }
  else
  {
    plhs[0]=mxCreateDoubleMatrix(n1,1,mxREAL);
    vech(mxGetPr(prhs[0]),mxGetPr(plhs[0]),n,type);
  }
}