/*
 * Copyright (c) 1997,1998 Massachusetts Institute of Technology
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 */

/*
 * test_main.c: driver for test programs (linked with fftw_test.c/rfftw_test.c)
 */

/* $Id: test_main.c,v 1.13 1998/09/28 21:09:50 stevenj Exp $ */
#include <fftw-int.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <time.h>

#include "test_main.h"

#ifdef HAVE_GETOPT
#    ifdef HAVE_GETOPT_H
#        include <getopt.h>
#    elif defined(HAVE_UNISTD_H)
#        include <unistd.h>
#    endif
#endif

double mydrand(void)
{
     double d = rand();
     return  (d / (double)RAND_MAX) - .5;
}

/* return random 0 or non-zero */
int coinflip(void)
{
     return (rand() & 8192);  /* higher-order bits are often more random */
}

/*******************
 * global variables
 *******************/
int verbose;
int wisdom_flag, measure_flag;
int speed_flag = FFTW_MEASURE;
int dimensions, dimensions_specified = 0;
int chk_mem_leak;
int paranoid;
int howmany_fields = 1;
int max_iterations = 0; /* maximum number of iterations to perform
			   in "infinite" tests--default (0) means no limit */

/*******************
 * procedures
 *******************/

/* smart time printer */
char *smart_sprint_time(double x)
{
     static char buf[128];

     if (x < 1.0E-6)
	  sprintf(buf, "%f ns", x * 1.0E9);
     else if (x < 1.0E-3)
	  sprintf(buf, "%f us", x * 1.0E6);
     else if (x < 1.0)
	  sprintf(buf, "%f ms", x * 1.0E3);
     else
	  sprintf(buf, "%f s", x);

     return buf;
}

/* greet the user */
/* jokes stolen from http://whereis.mit.edu/bin/map */
void please_wait(void)
{
     int i;
     const char *s[] = {
	  "(while a large software vendor in Seattle takes over the world)",
	  "(and remember, this is faster than Java)",
	  "(and dream of faster computers)",
	  "(checking the gravitational constant in your locale)",
	  "(at least you are not on hold)",
	  "(while X11 grows by another kilobyte)",
	  "(while Windows NT reboots)",
	  "(correcting for the phase of the moon)",
     };
     int choices = sizeof(s) / sizeof(*s);

     i = rand() % choices;
     printf("Please wait %s.\n", s[i < 0 ? -i : i]);
}

void please_wait_forever(void)
{
     int i;
     const char *s[] = {
	  "(but it won't crash, either)",
	  "(at least in theory)",
	  "(please be patient)",
	  "(our next release will complete it more quickly)",
#if defined(__WIN32__) || defined(WIN32) || defined(_WINDOWS)
	  "(by the way, Linux executes infinite loops faster)",
#endif
     };
     int choices = sizeof(s) / sizeof(*s);

     if (!max_iterations) {
	  i = rand() % choices;
          printf("This test does not terminate %s.\n", s[i < 0 ? -i : i]);
     }
     else {
	  printf("This test will run for %d iterations.\n", max_iterations);
	  please_wait();
     }
}

/*************************************************
 * Speed tests
 *************************************************/

double mflops(double t, int N)
{
     return(5.0 * N * log((double) N) / (log(2.0) * t * 1.0e6));
}

void print_dims(int n)
{
     int i;

     printf("%d",n);
     for (i = 1; i < dimensions; ++i)
	  printf("x%d",n);
}
void test_speed(int n)
{
     int specific;

     please_wait();

     if (howmany_fields > 1)
	  WHEN_VERBOSE(1,printf("TIMING MULTIPLE-FIELD FFT: "
				"howmany=%d, stride=%d, dist=%d\n\n",
				howmany_fields,howmany_fields,1));

     for (specific = 0; specific <= 1; ++specific) {
	  WHEN_VERBOSE(1, 
		printf("SPEED TEST: n = %d, FFTW_FORWARD, out of place, %s\n",
		       n, SPECIFICP(specific)));
	  test_speed_aux(n, FFTW_FORWARD, 0, specific);
	  
	  WHEN_VERBOSE(1, 
            printf("SPEED TEST: n = %d, FFTW_FORWARD, in place, %s\n",
		       n, SPECIFICP(specific)));
	  test_speed_aux(n, FFTW_FORWARD, FFTW_IN_PLACE, specific);
	  
	  WHEN_VERBOSE(1, 
            printf("SPEED TEST: n = %d, FFTW_BACKWARD, out of place, %s\n",
		       n, SPECIFICP(specific)));
	  test_speed_aux(n, FFTW_BACKWARD, 0, specific);
	  
	  WHEN_VERBOSE(1, 
            printf("SPEED TEST: n = %d, FFTW_BACKWARD, in place, %s\n",
		       n, SPECIFICP(specific)));
	  test_speed_aux(n, FFTW_BACKWARD, FFTW_IN_PLACE, specific);
     }
}

void test_speed_nd(int n)
{
     int specific;

     please_wait();

     if (howmany_fields > 1)
	  WHEN_VERBOSE(1,printf("TIMING MULTIPLE-FIELD FFT: "
				"howmany=%d, stride=%d, dist=%d\n\n",
				howmany_fields,howmany_fields,1));

     for (specific = 0; specific <= 1; ++specific) {
	  printf("SPEED TEST: ");
	  WHEN_VERBOSE(1, print_dims(n));
	  WHEN_VERBOSE(1, printf(", FFTW_FORWARD, in place, %s\n",
				 SPECIFICP(specific)));
	  test_speed_nd_aux(n, FFTW_FORWARD, FFTW_IN_PLACE, specific);
	  
	  WHEN_VERBOSE(1, printf("SPEED TEST: "));
	  print_dims(n);
	  WHEN_VERBOSE(1, printf(", FFTW_BACKWARD, in place, %s\n",
				 SPECIFICP(specific)));
	  test_speed_nd_aux(n, FFTW_BACKWARD, FFTW_IN_PLACE, specific);
     }
}

/*************************************************
 * correctness tests
 *************************************************/     

double compute_error_complex(fftw_complex *A, int astride, 
			     fftw_complex *B, int bstride, int n)
{
     /* compute the relative error */
     double error = 0.0;
     int i;

     for (i = 0; i < n; ++i) {
          double a;
          double mag;
          a = sqrt(SQR(c_re(A[i*astride]) - c_re(B[i*bstride])) +
                   SQR(c_im(A[i*astride]) - c_im(B[i*bstride])));
          mag = 0.5 * (sqrt(SQR(c_re(A[i*astride])) 
			    + SQR(c_im(A[i*astride]))) +
                       sqrt(SQR(c_re(B[i*bstride])) 
			    + SQR(c_im(B[i*bstride])))) + TOLERANCE;

          a /= mag;
          if (a > error)
               error = a;

#              ifdef HAVE_ISNAN
	       CHECK(!isnan(a), "NaN in answer");
#              endif
     }
     return error;
}

/* test forever */
void test_all(void)
{
     int n;

     please_wait_forever();
     for (n = 1; !max_iterations || n <= max_iterations; ++n) {
	  test_correctness(n);
	  if (!(wisdom_flag & FFTW_USE_WISDOM) && chk_mem_leak)
	       fftw_check_memory_leaks();
     }
}

#define MAX_FACTOR 13

int rand_small_factors(int N)
{
     int f, n = 1;

     f = rand() % MAX_FACTOR + 1;

     while (n * f <= N) {
	  n *= f;
	  f = rand() % MAX_FACTOR + 1;
     }
     
     return n;
}

#define MAX_N 16384

void random_dims(int rank, int *n)
{
     int maxsize, dim;
     double maxsize_d;

     /* workaround to weird gcc warning */
     maxsize_d = pow((double) (rank == 1 ? MAX_N / 4 : MAX_N), 
		     1.0 / (double) rank);
     maxsize = (int) maxsize_d;

     if (maxsize < 1)
	  maxsize = 1;

     for (dim = 0; dim < rank; ++dim)
	  n[dim] = rand_small_factors(maxsize);
}

void test_random(void)
{
     static int counter = 0;
     int n;

     if ((++counter) % 16 == 0)
	  n = rand() % (MAX_N / 16) + 1;
     else
	  random_dims(1,&n);

     test_correctness(n);
}

/*************************************************
 * multi-dimensional correctness tests
 *************************************************/     

void testnd_correctness_both(int rank, int *n,
			     int alt_api, int specific, int force_buf)
{
     int dim;

     WHEN_VERBOSE(1,
		  printf("Testing %snd correctness for size = %d",
			 fftw_prefix, n[0]);
		  for (dim = 1; dim < rank; ++dim)
		  printf("x%d",n[dim]);
		  printf("...");
		  fflush(stdout));

     if (alt_api)
	  WHEN_VERBOSE(1,printf("alt. api..."));
     if (specific)
	  WHEN_VERBOSE(1,printf("specific..."));
     if (force_buf)
	  WHEN_VERBOSE(1,printf("force buf..."));

     testnd_correctness(rank,n,FFTW_FORWARD,  alt_api,specific,force_buf);
     testnd_correctness(rank,n,FFTW_BACKWARD, alt_api,specific,force_buf);

     WHEN_VERBOSE(1,printf("OK\n"));
}

void testnd_random(int rank)
{
     int *n;

     n = (int *) fftw_malloc(sizeof(int) * rank);
     random_dims(rank,n);
     testnd_correctness_both(rank,n,coinflip(),coinflip(),coinflip());
     fftw_free(n);
}

/* loop forever */
void test_all_random(void)
{
     int counter;
     please_wait_forever();
     for (counter=0; !max_iterations || counter < max_iterations; ++counter) {
	  if (dimensions_specified)
	       testnd_random(dimensions);
	  else if ((counter) % 2 == 0)
	       test_random();
	  else
	       testnd_random(rand() % MAX_RANK + 1);
     }
}

int pow2sqrt(int n)
/* return greatest power of two <= sqrt(n) */
{
     int s = 1;

     while (s*s*4 <= n)
	  s *= 2;
     return s;
}

void testnd_correctness_big(int rank, int totalsize)
{
     int *n, dim;

     if (rank == 0)
	  return;

     n = (int *) fftw_malloc(sizeof(int) * rank);
     for (dim = 0; dim < rank; ++dim)
          n[dim] = 2;

     if (totalsize < (1 << rank))
	  totalsize = 1 << rank;

     n[0] = totalsize / (1 << (rank - 1));
     testnd_correctness_both(rank,n,coinflip(),coinflip(),0);

     if (rank > 1) {
	  n[0] = 2;
	  n[rank - 1] = totalsize / (1 << (rank - 1));
	  testnd_correctness_both(rank,n,coinflip(),coinflip(),0);

	  n[0] = totalsize / (1 << (rank - 2));
	  n[0] = pow2sqrt(n[0]);
	  n[rank - 1] = n[0];
	  testnd_correctness_both(rank,n,coinflip(),coinflip(),0);
     }

     if (rank > 2) {
	  n[0] = n[rank - 1] = 2;
	  n[rank/2] = totalsize / (1 << (rank - 1));
	  testnd_correctness_both(rank,n,coinflip(),coinflip(),0);
     }

     fftw_free(n);
}

void testnd_correctness_square(int rank, int size)
{
     int *n, dim;
     int alt_api, specific, force_buf;

     n = (int *) fftw_malloc(sizeof(int) * rank);
     for (dim = 0; dim < rank; ++dim)
	  n[dim] = size;

     for (alt_api = 0; alt_api <= 1; ++alt_api)
	  for (specific = 0; specific <= 1; ++specific)
	       for (force_buf = 0; force_buf <= 1; ++force_buf)
		    testnd_correctness_both(rank,n,alt_api,specific,force_buf);

     fftw_free(n);
}

/* test forever */
void testnd_all(int rank)
{
     int n;

     please_wait_forever();
     for (n = 1; !max_iterations || n <= max_iterations; ++n)
	  testnd_correctness_square(rank,n);
}

fftw_direction random_dir(void)
{
     if (coinflip())
	  return FFTW_FORWARD;
     else
	  return FFTW_BACKWARD;
}

/*************************************************
 * timer tests
 *************************************************/     

static double hack_sum;
static int hack_sum_i;

void negative_time(void) 
{
     fprintf(stderr,
	     "* PROBLEM: I measured a negative time interval.\n"
	     "* Please make sure you defined the timer correctly\n"
	     "* or contact fftw@theory.lcs.mit.edu for help.\n");
}

/*
 * paranoid test to see if time is monotonic.  If not, you are
 * really in trouble
 */
void test_timer_paranoid(void)
{
     fftw_time start_t, end_t;
     double sec;
     int i;

     start_t = fftw_get_time();

     /* waste some time */
     for (i = 0; i < 10000; ++i)
	  hack_sum_i = i;

     end_t = fftw_get_time();
     sec = fftw_time_to_sec(fftw_time_diff(end_t,start_t));
     if (sec < 0.0) 
	  negative_time();
}

void test_timer(void)
{
     double times[32], acc, min_time = 10000.00;
     unsigned long iters, iter;
     fftw_time begin, end, start;
     double t, tmax, tmin;
     int last = 0, i, repeat;

     please_wait();
     test_timer_paranoid();

     start = fftw_get_time();

     for (i = 0; i < 32; i++) {
          double sum = 0.0, x = 1.0;
          double sum1 = 0.0, x1 = 1.0;

          iters = 1 << i;
	  tmin = 1.0E10;
	  tmax = -1.0E10;

	  for (repeat = 0; repeat < FFTW_TIME_REPEAT; ++repeat) {
	       begin = fftw_get_time();
	       for (iter = 0; iter < iters; ++iter) {
		    /* some random calculations for timing... */
		    sum += x; x = .5*x + 0.2*x1; sum1 += x+x1; 
		    x1 = .4*x1 + 0.1*x;
		    sum += x; x = .5*x + 0.2*x1; sum1 += x+x1; 
		    x1 = .4*x1 + 0.1*x;
		    sum += x; x = .5*x + 0.2*x1; sum1 += x+x1; 
		    x1 = .4*x1 + 0.1*x;
		    sum += x; x = .5*x + 0.2*x1; sum1 += x+x1; 
		    x1 = .4*x1 + 0.1*x;
	       }
	       end = fftw_get_time();

	       hack_sum = sum;
	       t = fftw_time_to_sec(fftw_time_diff(end, begin));
	       if (t < tmin)
		    tmin = t;
	       if (t > tmax)
		    tmax = t;

	       /* do not run for too long */
	       t = fftw_time_to_sec(fftw_time_diff(end, start));
	       if (t > FFTW_TIME_LIMIT)
		    break;
	  }

	  if (tmin < 0.0) 
	       negative_time();
	       
          times[i] = tmin;

          WHEN_VERBOSE(2, 
		       printf("Number of iterations = 2^%d = %lu, time = %g, "
			      "time/iter = %g\n",
			      i, iters, times[i], 
			      times[i] / iters));
          WHEN_VERBOSE(2,
		       printf("   (out of %d tries, tmin = %g, tmax = %g)\n",
			      FFTW_TIME_REPEAT,tmin,tmax));

	  last = i;
	  if (times[i] > 10.0) 
	       break;
     }

     /*
      * at this point, `last' is the last valid element in the
      * `times' array.
      */

     for (i = 0; i <= last; ++i)
	  if (times[i] > 0.0 && times[i] < min_time)
	       min_time = times[i];

     WHEN_VERBOSE(1, printf("\nMinimum resolvable time interval = %g seconds.\n\n",
	    min_time));
     
     for (acc = 0.1; acc > 0.0005; acc *= 0.1) {
          double t_final;
	  t_final = times[last] / (1 << last);
	  
          for (i = last; i >= 0; --i) {
               double t_cur, error;
               iters = 1 << i;
               t_cur = times[i] / iters;
               error = (t_cur - t_final) / t_final;
               if (error < 0.0)
		    error = -error;
               if (error > acc)
		    break;
	  }

	  ++i;
	  
          WHEN_VERBOSE(1,
		 printf("Minimum time for %g%% consistency = %g seconds.\n",
                 acc * 100.0, times[i]));
     }
     WHEN_VERBOSE(1, 
		  printf("\nMinimum time used in FFTW timing (FFTW_TIME_MIN)"
			 " = %g seconds.\n", FFTW_TIME_MIN));
}

/*************************************************
 * help
 *************************************************/     

void usage(void)
{
     printf("Usage:  %s_test [options]\n", fftw_prefix);
     printf("  -r        : test correctness for random sizes "
	    "(does not terminate)\n");
     printf("  -d <n>    : -s/-c/-a/-p/-r/-b applies to"
	    " n-dim. transforms (default n=1)\n");
     printf("  -s <n>    : test speed for size n\n");
     printf("  -c <n>    : test correctness for size n\n");
     printf("  -a        : test correctness for all sizes "
	    "(does not terminate)\n");
     printf("  -b        : test very large transforms\n");
     printf("  -p        : test planner\n");
     printf("  -m        : use FFTW_MEASURE in correctness tests\n");
     printf("  -e        : use FFTW_ESTIMATE in speed tests\n");
     printf("  -w <file> : use wisdom & read/write it from/to file\n");
     printf("  -t        : test timer resolution\n");
     printf("  -v        : verbose output for subsequent options\n");
     printf("  -P        : enable paranoid tests\n");
     printf("  -h        : this help\n");
#ifndef HAVE_GETOPT
     printf("(When run with no arguments, an interactive mode is used.)\n");
#endif
}

char wfname[128];

#define TEST_BIG_N (1<<19)

void handle_option(char opt, char *optarg)
{
     FILE *wf;
     int n;

     switch (opt) {
	 case 'd':
	      n = atoi(optarg);
	      CHECK(n > 0, "-d requires a positive integer argument");
	      dimensions = n;
	      dimensions_specified = 1;
	      break;
	      
	 case 's':
	      n = atoi(optarg);
	      CHECK(n > 0, "-s requires a positive integer argument");
	      if (dimensions == 1 && !dimensions_specified)
		   test_speed(n);
	      else
		   test_speed_nd(n);
	      break;
	      
	 case 'c':
	      n = atoi(optarg);
	      CHECK(n > 0, "-c requires a positive integer argument");
	      if (dimensions == 1 && !dimensions_specified)
		   test_correctness(n);
	      else
		   testnd_correctness_square(dimensions,n);
	      break;
	      
	 case 'b':
	      if (!dimensions_specified) {
		   int rank;
		   test_correctness(TEST_BIG_N);
		   for (rank = 1; rank <= 5; ++rank)
			testnd_correctness_big(rank, TEST_BIG_N);
	      }
	      else
		   testnd_correctness_big(dimensions, TEST_BIG_N);
	      break;
	      
	 case 'p':
	      test_planner();
	      break;

	 case 'P':
	      paranoid = 1;
	      break;

	 case 'r':
	      test_all_random();
	      break;
	      
	 case 'a':
	      if (dimensions == 1 && !dimensions_specified)
		   test_all();
	      else
		   testnd_all(dimensions);
	      break;
	      
	 case 't':
	      test_timer();
	      break;
	      
	 case 'f':
	      n = atoi(optarg);
	      CHECK(n > 0, "-f requires a positive integer argument");
	      howmany_fields = n;
	      break;
	      
	 case 'm':
	      measure_flag = FFTW_MEASURE;
	      break;

	 case 'e':
	      speed_flag = FFTW_ESTIMATE;
	      break;

	 case 'w':
	      wisdom_flag = FFTW_USE_WISDOM;
	      strcpy(wfname,optarg);
	      wf = fopen(wfname,"r");
	      if (wf == 0) {
		   printf("Couldn't open wisdom file \"%s\".\n",wfname);
		   printf("This file will be created upon completion.\n");
	      }
	      else {
		   CHECK(FFTW_SUCCESS == fftw_import_wisdom_from_file(wf),
			 "invalid wisdom file format");
		   fclose(wf);
	      }
	      break;
	      
	 case 'v':
	      verbose++;
	      break;

	 case 'x':
	      n = atoi(optarg);
              CHECK(n > 0, "-x requires a positive integer argument");
	      max_iterations = n;
	      break;
	      
	 case 'h':
	 default:
	      usage();
     }
     
     /* every test must free all the used FFTW memory */
     if (!(wisdom_flag & FFTW_USE_WISDOM) && chk_mem_leak)
	  fftw_check_memory_leaks();
}

short askuser(const char *s)
{
     char line[200] = "", c;
     int i, count = 0;

     do {
          if (count++ > 0)
               printf("Invalid response.  Please enter \"y\" or \"n\".\n");
          printf("%s (y/n) ",s);
          while (line[0] == 0 || line[0] == '\n') /* skip blank lines */
               fgets(line,200,stdin);
          for (i = 0; line[i] && (line[i] == ' ' || line[i] == '\t'); ++i)
               ;
          c = line[i];
     } while (c != 'n' && c != 'N' && c != 'y' && c != 'Y');

     return(c == 'y' || c == 'Y');
}

int main(int argc, char *argv[])
{
     verbose = 1;
     wisdom_flag = 0;
     measure_flag = FFTW_ESTIMATE;
     dimensions = 1;
     chk_mem_leak = 1;
     paranoid = 0;

#ifdef DETERMINISTIC
     srand(1123);
#else
     srand((unsigned int) time(NULL));
#endif

     /* To parse the command line, we use getopt, but this
	does not seem to be in the ANSI standard (it is only
	available on UNIX, apparently). */
#ifndef HAVE_GETOPT
     if (argc > 1)
	  printf("Sorry, command-line arguments are not available on\n"
		 "this system.  Run fftw_test with no arguments to\n"
		 "use it in interactive mode.\n");

     if (argc <= 1) {
	  int n = 0;
	  char s[128] = "";

	  usage();

	  printf("\n");

	  if (askuser("Perform random correctness tests (non-terminating)?"))
	       handle_option('r',"");

	  if (askuser("Verbose output?"))
	       handle_option('v',"");
	  if (askuser("Paranoid test?"))
	       handle_option('P',"");

	  if (askuser("Test multi-dimensional transforms?")) {
	       printf("  Enter dimensions: ");
	       scanf("%d",&n);
	       sprintf(s,"%d",n);
	       handle_option('d',s);
	  }

	  if (askuser("Use/test wisdom?")) {
	       printf("  Enter wisdom file name to use: ");
	       fgets(s,128,stdin);
	       handle_option('w',s);
	  }
	  if (askuser("Test correctness?")) {
	       if (askuser("  -- for all sizes?"))
		    handle_option('a',"");
	       else {
		    printf("  Enter n: ");
		    scanf("%d",&n);
		    sprintf(s,"%d",n);
		    handle_option('c',s);
	       }
	  }
	  if (askuser("Test speed?")) {
	       printf("  Enter n: ");
	       scanf("%d",&n);
	       sprintf(s,"%d",n);
	       handle_option('s',s);
	  }
	  if (askuser("Test planner?"))
	       handle_option('p',"");
	  if (askuser("Test timer?"))
	       handle_option('t',"");
     }

#else /* read command-line args using getopt facility */
     {
	  extern char *optarg;
	  extern int optind;
	  int c;

	  if (argc <= 1) 
	       usage();
	  while ((c = getopt(argc, argv, "s:c:w:d:f:bpPartvmehx:")) != -1)
	       handle_option(c, optarg);
	  if (argc != optind)
	       usage();
     }
#endif

     if (wisdom_flag & FFTW_USE_WISDOM) {
	  char *ws;
	  FILE *wf;

	  ws = fftw_export_wisdom_to_string();
	  CHECK(ws != 0,"error exporting wisdom to string");
	  printf("\nAccumulated wisdom:\n     %s\n",ws);
	  fftw_forget_wisdom();
	  CHECK(FFTW_SUCCESS == fftw_import_wisdom_from_string(ws),
		"unexpected error reading in wisdom from string");
	  fftw_free(ws);

	  wf = fopen(wfname,"w");
	  CHECK(wf != 0,"error creating wisdom file");
	  fftw_export_wisdom_to_file(wf);
	  fclose(wf);
     }

     /* make sure to dispose of wisdom before checking for memory leaks */
     fftw_forget_wisdom();

     fftw_check_memory_leaks();
     fftw_print_max_memory_usage();

     return 0;
}
