/*
 * _mssql module - low level Python module for communicating with MS SQL servers
 *
 * Initial Developer:
 *      Joon-cheol Park <jooncheol@gmail.com>, http://www.exman.pe.kr
 *
 * Active Developer:
 *      Andrzej Kukula <akukula@gmail.com>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 * MA  02110-1301  USA
 ****************************************************************************
 * CREDITS:
 * List ref count patch 2004.04.09 by Hans Roh <hans@lufex.com>
 * Significant contributions by Mark Pettit (thanks)
 * Multithreading patch by John-Peter Lee (thanks)
 ***************************************************************************/

#include <Python.h>
#include <structmember.h>
#include <datetime.h>

#ifdef MS_WINDOWS
  #define DBNTWIN32     // must identify operating system environment
  #define NOCRYPT       // must be defined under Visual C++ .NET 2003
  #include <windows.h>
  #include <lmerr.h>
  #include <sqlfront.h>
  #include <sqldb.h>    // DB-LIB header file (should always be included)
#else
  #define MSDBLIB       // we need FreeTDS to provide MSSQL API,
                        // not Sybase API. See README.freetds for details
  #include <sqlfront.h>
  #include <sqldb.h>

  #define SQLNUMERIC	SYBNUMERIC
  #define SQLDECIMAL	SYBDECIMAL
  #define SQLBIT	SYBBIT
  #define SQLINT1	SYBINT1
  #define SQLINT2	SYBINT2
  #define SQLINT4	SYBINT4
  #define SQLINTN	SYBINTN
  #define SQLFLT4	SYBREAL
  #define SQLFLT8	SYBFLT8
  #define SQLFLTN	SYBFLTN
  #define SQLDATETIME	SYBDATETIME
  #define SQLDATETIM4	SYBDATETIME4
  #define SQLDATETIMN	SYBDATETIMN
  #define SQLMONEY	SYBMONEY
  #define SQLMONEY4	SYBMONEY4
  #define SQLMONEYN	SYBMONEYN
  #define SQLBINARY	SYBBINARY
  #define SQLVARBINARY	SYBVARBINARY
  #define SQLIMAGE	SYBIMAGE
  #define SQLVARCHAR	SYBVARCHAR
  #define SQLCHAR	SYBCHAR
  #define SQLTEXT	SYBTEXT

  #define BYTE		unsigned char
  typedef unsigned char *LPBYTE;
#endif

#define TYPE_STRING	1
#define TYPE_BINARY	2
#define TYPE_NUMBER	3
#define TYPE_DATETIME	4
#define TYPE_DECIMAL	5

#include <stdio.h>
#include <string.h>	// include for string functions
#include <stdlib.h>	// include for malloc and free

#define MSSQL_SEVERITY(self) ((self != NULL) ? self->mssql_severity : _mssql_severity)
#define MSSQL_ERROR_STR(self) ((self != NULL) ? self->mssql_error_str : _mssql_error_str)

#define check_cancel_and_raise(rtc)			\
 do {										\
    if (rtc == FAIL) {						\
      Py_BEGIN_ALLOW_THREADS				\
      dbcancel(self->dbproc);				\
      Py_END_ALLOW_THREADS					\
	  maybe_raise(self);					\
	  return NULL;							\
	} else if (*MSSQL_ERROR_STR(self)) {	\
      if (maybe_raise(self))				\
        return NULL;						\
    }										\
 } while (0);

#define maybe_pyerr_occurred()				\
 do {										\
   if (PyErr_Occurred())					\
     return NULL;							\
 } while (0);

static PyObject *_mssql_error;
static PyObject *decmod;  // "decimal" module handle
static PyObject *_mssql_module;

// Connection object
typedef struct {
  PyObject_HEAD
  DBPROCESS *dbproc;	// PDBPROCESS dbproc;
  int connected;
  char *mssql_error_str;	// the error message buffer
  /* we'll be calculating max message severity returned by multiple calls
     to the handlers in a row, and if that severity is higher than
     minimum required, we'll raise exception. */
  int mssql_severity;
} _mssql_ConnectionObj;

// Prototypes for internal functions.
PyObject *GetRow(DBPROCESS *dbproc, int);
PyObject *GetHeaders(DBPROCESS *dbproc);
void clr_err(_mssql_ConnectionObj *self);
int maybe_raise(_mssql_ConnectionObj *);
int rmv_lcl(char *, char *, size_t);

static PyTypeObject _mssql_ConnectionObj_Type;

#define MSSQLDB_MSGSIZE 1024
#define PYMSSQL_MSGSIZE (MSSQLDB_MSGSIZE*8)
static char _mssql_error_str[PYMSSQL_MSGSIZE]={0,};
/* we'll be calculating max message severity returned by multiple calls
   to the handlers in a row, and if that severity is higher than
   minimum required, we'll raise exception. */
static int _mssql_severity = 0;

/* to be able to route error messages to the right connection instance,
   we must maintain a list of allocated connection objects */
struct conn_obj_list_node {
  struct conn_obj_list_node *next;
  _mssql_ConnectionObj *obj;
};

static struct conn_obj_list_node *conn_obj_list = (struct conn_obj_list_node *) NULL;

/* _mssql.Connection class methods *******************************************/

static char _mssql_fetch_array_doc[] =
"fetch_array() -- get the whole result set.\n\n\
This method returns the whole result set from your queries:\n\
column names and types, then number of rows and rows themselves.\n\
";

// fetch_array
static PyObject *_mssql_fetch_array(_mssql_ConnectionObj *self, PyObject *args)
{
  PyObject *resultSet=NULL, *o=NULL;
  int rtc, rows, i;
  // resultSet[ result(  header(...), recored(...) ),
  //         result(...), ...]

  if (!(resultSet = PyList_New(0))) {
    PyErr_SetString(_mssql_error,"Could not create fetch tuple");
    return NULL;
  }

  if (!self->connected) {
    PyErr_SetString(_mssql_error,"Not connected to any MS SQL server");
    return NULL;
  }

  clr_err(self);

  // command executed correctly, get results information
  Py_BEGIN_ALLOW_THREADS
  rtc = dbresults(self->dbproc);
  Py_END_ALLOW_THREADS

  while (rtc != NO_MORE_RESULTS) {
    PyObject *result=NULL, *header=NULL, *record=NULL;

    check_cancel_and_raise(rtc);

    if (!(result = PyTuple_New(3))) {
	  Py_BEGIN_ALLOW_THREADS
      dbcancel(self->dbproc);
	  Py_END_ALLOW_THREADS

      PyErr_SetString(_mssql_error,"Could not create result tuple");
      return NULL;
    }

    header = GetHeaders(self->dbproc);       // print header data
    PyTuple_SET_ITEM(result,0,header);

    if (!(record = PyList_New(0))) {
	  Py_BEGIN_ALLOW_THREADS
      dbcancel(self->dbproc);
	  Py_END_ALLOW_THREADS

      PyErr_SetString(_mssql_error,"Could not create record tuple");
      return NULL;
    }

    clr_err(self);

    // loop on each row, until all read
	Py_BEGIN_ALLOW_THREADS
	rtc = dbnextrow(self->dbproc);
	Py_END_ALLOW_THREADS

    while (rtc != NO_MORE_ROWS) {
      // if fail, then clear connection completely, just in case.
      check_cancel_and_raise(rtc);

      clr_err(self);
      o = GetRow(self->dbproc, rtc); // pass constant REG_ROW or compute id
      PyList_Append(record, o);
      Py_DECREF(o);

	  // Get the next row
	  Py_BEGIN_ALLOW_THREADS
	  rtc = dbnextrow(self->dbproc);
	  Py_END_ALLOW_THREADS
    }

	Py_BEGIN_ALLOW_THREADS
    rows=DBCOUNT(self->dbproc);
	Py_END_ALLOW_THREADS

    i=0;
    // affected_rows
    PyTuple_SET_ITEM(result,1,PyInt_FromLong((long) rows));
    PyTuple_SET_ITEM(result,2,record);

    PyList_Append(resultSet,result);
    Py_DECREF(result);

	// Setup condition for next iteration
	Py_BEGIN_ALLOW_THREADS
	rtc = dbresults(self->dbproc);
	Py_END_ALLOW_THREADS
  }   // end while(dbresults())

  //if there were no results ensure we return proper data model
  if (PyList_Size(resultSet) == 0) {
    PyObject *result=NULL, *header=NULL, *record=NULL;
    if (!(result = PyTuple_New(3))) {
      PyErr_SetString(_mssql_error,"Could not create result tuple");
      return NULL;
    }
    if (!(record = PyList_New(0))) {
      PyErr_SetString(_mssql_error,"Could not create record list");
      return NULL;
    }
    if (!(header = PyTuple_New(0))) {
      PyErr_SetString(_mssql_error,"Could not create header tuple");
      return NULL;
    }
    PyTuple_SET_ITEM(result,0,header);
    PyTuple_SET_ITEM(result,1,PyInt_FromLong(0L));
    PyTuple_SET_ITEM(result,2,record);
    PyList_Append(resultSet,result);
    Py_DECREF(result);
  }

  return resultSet;
}

static char _mssql_query_doc[] =
"query(query_string) -- Send a query to the MS SQL Server.\n\n\
This method queries MS SQL Server and returns 1 if the query\n\
succeeded. An exception is thrown otherwise.\n\
";

static PyObject *_mssql_query(_mssql_ConnectionObj *self, PyObject *args)
{
  RETCODE rtc;
  char *Query = PyString_AsString(args);

  maybe_pyerr_occurred();

  if (!self->connected) {
    PyErr_SetString(_mssql_error,"Not connected to any MS SQL server");
    return NULL;
  }

  clr_err(self);

  // Execute the query
  Py_BEGIN_ALLOW_THREADS
  dbcmd(self->dbproc,Query);
  rtc = dbsqlexec(self->dbproc);
  Py_END_ALLOW_THREADS

  check_cancel_and_raise(rtc);

  return PyInt_FromLong(1L);
}

static char _mssql_select_db_doc[] =
"select_db(dbname) -- Select a database and make it the current database.\n\n\
This function selects given database as the current one.\n\
It returns 1 if the action succeeded. An exception is thrown otherwise.\n\
";

// select_db
static PyObject *_mssql_select_db(_mssql_ConnectionObj *self, PyObject *args)
{
  RETCODE rtc;
  char *dbname = PyString_AsString(args);

  maybe_pyerr_occurred();

  if (!self->connected) {
    PyErr_SetString(_mssql_error,"Not connected to any MS SQL server");
    return NULL;
  }

  clr_err(self);

  Py_BEGIN_ALLOW_THREADS
  rtc = dbuse(self->dbproc,dbname);
  Py_END_ALLOW_THREADS

  check_cancel_and_raise(rtc);

  return PyInt_FromLong(1L);
}

static char _mssql_close_doc[] =
"close() -- close connection to an MS SQL Server.\n\n\
This function tries to close the connection and free all memory\n\
it used. If it fails, an exception is thrown. You can call\n\
stdmsg() to get reason for failing.\n\
";

// close
static PyObject *_mssql_close(_mssql_ConnectionObj *self, PyObject *args)
{
#ifdef MS_WINDOWS
  RETCODE rtc;
#endif
  struct conn_obj_list_node *p, *n;

  maybe_pyerr_occurred();

  if (self->connected) {
    clr_err(self);

	// this doesn't need to release GIL - essentially just a call to free()
    dbfreebuf(self->dbproc);
#ifdef MS_WINDOWS
	Py_BEGIN_ALLOW_THREADS
    rtc = dbclose(self->dbproc);
	Py_END_ALLOW_THREADS

    if (rtc == FAIL) {
      PyErr_SetString(_mssql_error,"Failed to close and free the DB connection");
      return NULL;
    }
#else
	Py_BEGIN_ALLOW_THREADS
    dbclose(self->dbproc);
	Py_END_ALLOW_THREADS
#endif
    self->connected = 0;
  } else {
    PyErr_SetString(_mssql_error,"Not connected to any MS SQL server");
    return NULL;
  }

  //+++ find and remove the node from the list
  n = conn_obj_list;
  p = NULL;

  while (n != NULL) {
    if (n->obj == self) { // found
      free (n->obj->mssql_error_str);

	  if (p != NULL)
	    p->next = n->next;
	  else
	    conn_obj_list = n->next;

	  break;
	}
	
	p = n;
	n = n->next;
  }

  Py_INCREF(Py_None);
  return Py_None;
}


static char _mssql_errmsg_doc[] =
"errmsg() -- display last error message.\n\n\
This method returns the error message that occurred during the\n\
latest SQL operation. ``None'' is returned if there was no error.\n\
It's both module method (this way you can find the reason for\n\
not being able to log in) and _mssql.Connection object\n\
instance method (errors occurring after connecting to MS SQL.)\n\
If you do care about all error messages, you should check it\n\
after each SQL operation; some of them may not raise an exception\n\
if the severity is too low.\n\
";

static PyObject *_mssql_errmsg(_mssql_ConnectionObj *self, PyObject *args)
{
  if (*MSSQL_ERROR_STR(self))
    return PyString_FromString(MSSQL_ERROR_STR(self));
  else {
    Py_INCREF(Py_None);
    return Py_None;
  }
}

static char _mssql_stdmsg_doc[] =
"stdmsg() -- display MS SQL Server error message.\n\n\
This method returns the error message from MS SQL Server.\n\
You can safely ignore many of them, especially with severity < 5.\n\
If your query fails, this method can be a way to get info about\n\
what went wrong.\n\n\
It's both _mssql module method and _mssql.Connection object\n\
instance method.\n\n\
This method is deprecated. You can use errmsg() to gather all\n\
error messages.\n\
";

static PyObject *_mssql_stdmsg(_mssql_ConnectionObj *self, PyObject *args)
{
  if (*MSSQL_ERROR_STR(self))
    return PyString_FromString(MSSQL_ERROR_STR(self));
  else {
    Py_INCREF(Py_None);
    return Py_None;
  }
}


/* _mssql module methods *****************************************************/

//#define ERRDEBUG 1

int err_handler(DBPROCESS *dbproc, int severity, int dberr, int oserr,
                char *dberrstr, char *oserrstr)
{
  struct conn_obj_list_node *p, *n;
  char *mssql_error_str = _mssql_error_str;
  int *mssql_severity = &_mssql_severity;

#ifdef ERRDEBUG
  fprintf(stderr, "\n*** err_handler(dbproc=%p, severity=%d, dberr=%d, \
oserr=%d, dberrstr='%s', oserrstr='%s'); DBDEAD(dbproc)=%d\n", (void *)dbproc,
severity, dberr, oserr, dberrstr, oserrstr, DBDEAD(dbproc));
  fprintf(stderr, "*** current max severity = %d\n\n", _mssql_severity);
#endif

  // try to find out which connection this handler belongs to.
  // do it by scanning the list
  n = conn_obj_list;
  p = NULL;

  while (n != NULL) {
    if (n->obj->dbproc == dbproc) { // found
      mssql_error_str = n->obj->mssql_error_str;
      mssql_severity = &n->obj->mssql_severity;
	  break;
	}
	
	p = n;
	n = n->next;
  }

  // calculate the maximum severity of all messages in a row
  if (*mssql_severity < severity)  *mssql_severity = severity;

  // but get all of them regardless of severity
  snprintf(mssql_error_str + strlen(mssql_error_str),
           PYMSSQL_MSGSIZE - strlen(mssql_error_str),
           "DB-Lib error message %d, severity %d:\n%s\n",
		   dberr, severity, dberrstr);

  if ((oserr != DBNOERR) && (oserr != 0)) {
    /* get a textual representation of the error code */

#ifdef MS_WINDOWS
    HMODULE hModule = NULL; // default to system source
    LPSTR msg;
    DWORD buflen;
    DWORD fmtflags = FORMAT_MESSAGE_ALLOCATE_BUFFER |
                     FORMAT_MESSAGE_IGNORE_INSERTS |
                     FORMAT_MESSAGE_FROM_SYSTEM;

    if (oserr > NERR_BASE && oserr <= MAX_NERR) {
	  // this can last a long time...
	  Py_BEGIN_ALLOW_THREADS
      hModule = LoadLibraryEx(TEXT("netmsg.dll"), NULL,
                  LOAD_LIBRARY_AS_DATAFILE);
	  Py_END_ALLOW_THREADS

      if (hModule != NULL)  fmtflags |= FORMAT_MESSAGE_FROM_HMODULE;
    }

    if (buflen = FormatMessageA(fmtflags,
          hModule,	// module to get message from (NULL == system)
          oserr, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // default language
          (LPSTR) &msg, 0, NULL)) {
#else
#define EXCOMM 9
    char *msg = strerror(oserr);
#endif

      snprintf(mssql_error_str + strlen(mssql_error_str),
               PYMSSQL_MSGSIZE - strlen(mssql_error_str),
               "%s error during %s\n",
               (severity == EXCOMM) ? "Net-Lib" : "Operating system",
               oserrstr);
      snprintf(mssql_error_str + strlen(mssql_error_str),
               PYMSSQL_MSGSIZE - strlen(mssql_error_str),
               "Error %d - %s", oserr, msg);

#ifdef MS_WINDOWS
      LocalFree(msg);

    // unload netmsg.dll
    if (hModule != NULL) FreeLibrary(hModule);
    }
#endif
  }

  return INT_CANCEL;  /* sigh FreeTDS sets DBDEAD on incorrect login! */
}

/* gosh! different prototypes! */
#ifdef MS_WINDOWS
#define LINE_T DBUSMALLINT
#else
#define LINE_T int
#endif
int msg_handler(DBPROCESS *dbproc, DBINT msgno, int msgstate, int severity,
                char *msgtext, char *srvname, char *procname, LINE_T line)
{
  struct conn_obj_list_node *p, *n;
  char *mssql_error_str = _mssql_error_str;
  int *mssql_severity = &_mssql_severity;

#ifdef ERRDEBUG
  fprintf(stderr, "\n+++ msg_handler(dbproc=%p, msgno=%d, msgstate=%d, \
severity=%d, msgtext='%s', srvname='%s', procname='%s', line=%d\n", (void *)dbproc,
msgno, msgstate, severity, msgtext, srvname, procname, line);
  fprintf(stderr, "+++ current max severity = %d\n\n", _mssql_severity);
#endif

  // try to find out which connection this handler belongs to.
  // do it by scanning the list
  n = conn_obj_list;
  p = NULL;

  while (n != NULL) {
    if (n->obj->dbproc == dbproc) { // found
      mssql_error_str = n->obj->mssql_error_str;
      mssql_severity = &n->obj->mssql_severity;
	  break;
	}
	
	p = n;
	n = n->next;
  }

  // calculate the maximum severity of all messages in a row
  if (*mssql_severity < severity)  *mssql_severity = severity;

  // but get all of them regardless of severity
  if ((procname != NULL) && *procname)
    snprintf(mssql_error_str + strlen(mssql_error_str),
           PYMSSQL_MSGSIZE - strlen(mssql_error_str),
           "SQL Server message %ld, severity %d, state %d, procedure %s, line %d:\n%s\n",
           (long)msgno, severity, msgstate, procname, line, msgtext);
  else
    snprintf(mssql_error_str + strlen(mssql_error_str),
           PYMSSQL_MSGSIZE - strlen(mssql_error_str),
           "SQL Server message %ld, severity %d, state %d, line %d:\n%s\n",
           (long)msgno, severity, msgstate, line, msgtext);

  return 0;
}


static char _mssql_connect_doc[] =
"connect(server,user,password) -- connect to an MS SQL Server.\n\
This method returns an instance of class _mssql.Connection.\n\n\
server   - is an instance of MS SQL server; it can be host name,\n\
           'host:port' syntax or server identifier from freetds.conf.\n\
           Under Windows use the syntax 'hostname\\instancename'\n\
           to connect to a named instance. Be aware that this\n\
		   difference makes the code non-portable!\n\
user     - user name to login as.\n\
password - password to authenticate with.\n\
";

static PyObject *_mssql_connect(_mssql_ConnectionObj *self, PyObject *args)
{
  _mssql_ConnectionObj *dbprochandle;
  LOGINREC  *login;         // allocate a DB-LIB login structure
  char *server, *user, *passwd;
  RETCODE rtc;
  struct conn_obj_list_node *n;

#ifdef MS_WINDOWS
  char szAttribute[1024];
  char *p = szAttribute;
#endif

  if(!PyArg_ParseTuple(args, "sss:connect", &server, &user, &passwd)) {
    // PyArg_ParseTupleAndKeywords() raises exception itself
    return NULL;
  }

  clr_err(NULL);

  login = dblogin();		// get login record from DB-LIB

  // these don't need to release GIL
  DBSETLUSER(login, user);
  DBSETLPWD(login, passwd);
  DBSETLAPP(login, "pymssql");
#ifdef MS_WINDOWS
  DBSETLVERSION(login, DBVER60);
#else
  DBSETLHOST(login, server);
#endif

  dbprochandle = PyObject_NEW(_mssql_ConnectionObj, &_mssql_ConnectionObj_Type);
  if (dbprochandle == NULL) {
    PyErr_SetString(_mssql_error,"Could not create _mssql.Connection object");
    return NULL;
  }

  //+++ create list node and allocate message buffer
  n = malloc(sizeof(struct conn_obj_list_node));
  dbprochandle->mssql_error_str = malloc(PYMSSQL_MSGSIZE);

  if ((n == NULL) || (dbprochandle->mssql_error_str == NULL)) {
    PyErr_SetString(_mssql_error,"Out of memory");
    return NULL;
  }

  Py_BEGIN_ALLOW_THREADS
  dbprochandle->dbproc = dbopen(login,server);
  Py_END_ALLOW_THREADS

  if (dbprochandle->dbproc == NULL) {
    if (*_mssql_error_str)
      PyErr_SetString(_mssql_error, _mssql_error_str);
    else
      PyErr_SetString(_mssql_error,"Could not connect to MS SQL Server");

    return NULL;
  }

  //dbsetopt(dbprochandle->dbproc, DBBUFFER, "2", -1);
  //dbsetopt(dbprochandle->dbproc, DBTEXTLIMIT, "8096", -1);

  // these don't need to release GIL
#ifdef MS_WINDOWS
  dbfreelogin(login);	// Frees a login record.
#else
  dbloginfree(login);	// Frees a login record.
#endif
  dbprochandle->connected = 1;

  // set initial connection properties to some reasonable values
  Py_BEGIN_ALLOW_THREADS
  dbcmd(dbprochandle->dbproc,
    "SET ARITHABORT ON;"
    "SET CONCAT_NULL_YIELDS_NULL ON;"
    "SET ANSI_NULLS ON;"
    "SET ANSI_NULL_DFLT_ON ON;"
    "SET ANSI_PADDING ON;"
    "SET ANSI_WARNINGS ON;"
    "SET ANSI_NULL_DFLT_ON ON;"
    "SET CURSOR_CLOSE_ON_COMMIT ON;"
    "SET QUOTED_IDENTIFIER ON"
  );

  rtc = dbsqlexec(dbprochandle->dbproc);
  Py_END_ALLOW_THREADS

  if (rtc == FAIL) {
    PyObject *o = _mssql_close(dbprochandle, NULL);
    Py_XDECREF(o);
    PyErr_SetString(_mssql_error,"Could not set connection properties");
    return NULL;
  }

  Py_BEGIN_ALLOW_THREADS
  dbcancel(dbprochandle->dbproc);
  Py_END_ALLOW_THREADS

  //+++ prepend this connection to the list
  n->next = conn_obj_list;
  n->obj = dbprochandle;
  conn_obj_list = n;

  return (PyObject *)dbprochandle;
}


static char _mssql_set_login_timeout_doc[] =
"set_login_timeout(seconds) -- set the login timeout. The default\n\
    for SQL Server is 60 seconds. BEWARE that it works on Windows only;\n\
	FreeTDS silently ignores login timeouts.\n\
";

static PyObject *_mssql_set_login_timeout(_mssql_ConnectionObj *self, PyObject *args)
{
  long tmout = PyInt_AsLong(args);
  RETCODE rtc;

  maybe_pyerr_occurred();
  clr_err(self);
  rtc = dbsetlogintime(tmout);
  check_cancel_and_raise(rtc);

  Py_INCREF(Py_None);
  return Py_None;
}


static char _mssql_set_query_timeout_doc[] =
"set_query_timeout(seconds) -- set the amount of time to wait for results\n\
from the server. The default is 0 which means infinite wait.\n\
";

static PyObject *_mssql_set_query_timeout(_mssql_ConnectionObj *self, PyObject *args)
{
  long tmout = PyInt_AsLong(args);
  RETCODE rtc;

  maybe_pyerr_occurred();
  clr_err(self);
  rtc = dbsettime(tmout);
  check_cancel_and_raise(rtc);

  Py_INCREF(Py_None);
  return Py_None;
}



static void _mssql_ConnectionObj_dealloc(_mssql_ConnectionObj *self)
{
  if(self->connected) {
    PyObject *o =  _mssql_close(self, NULL);
    Py_XDECREF(o);
  }

  PyObject_Free((char *) self);
}


static PyObject * _mssql_ConnectionObj_repr(_mssql_ConnectionObj *self)
{
  char buf[100];
  if(self->connected)
    sprintf(buf,"<Open mssql connection at %lx>",(long)self);
  else
    sprintf(buf,"<Closed mssql connection at %lx>",(long)self);
  return PyString_FromString(buf);
}


static PyMethodDef _mssql_ConnectionObj_methods[] = {
  {"select_db",  (PyCFunction) _mssql_select_db,   METH_O,      _mssql_select_db_doc},
  {"query",      (PyCFunction) _mssql_query,       METH_O,      _mssql_query_doc},
  {"fetch_array",(PyCFunction) _mssql_fetch_array, METH_NOARGS, _mssql_fetch_array_doc},
  {"close",      (PyCFunction) _mssql_close,       METH_NOARGS, _mssql_close_doc},
  {"errmsg",     (PyCFunction) _mssql_errmsg,      METH_NOARGS, _mssql_errmsg_doc},
  {"stdmsg",     (PyCFunction) _mssql_stdmsg,      METH_NOARGS, _mssql_stdmsg_doc},
  {"set_login_timeout", (PyCFunction) _mssql_set_login_timeout, METH_O, _mssql_set_login_timeout_doc},
  {"set_query_timeout", (PyCFunction) _mssql_set_query_timeout, METH_O, _mssql_set_query_timeout_doc},
  {NULL,         NULL}
};


static PyMemberDef _mssql_ConnectionObj_members[] = {
  {"connected", T_INT, offsetof(_mssql_ConnectionObj, connected), READONLY,
   "True if the connection is open"},
  {NULL}
};


static char _mssql_ConnectionObj_Type_doc[] = {
"This object represents an MS SQL database connection. You can\n\
make queries and obtain results through a database connection.\n\
"};

static PyTypeObject _mssql_ConnectionObj_Type = {
  PyObject_HEAD_INIT(NULL)
  0,				/* ob_size           */
  "_mssql.Connection",		/* tp_name           */
  sizeof(_mssql_ConnectionObj),	/* tp_basicsize      */
  0,				/* tp_itemsize       */
  (destructor)_mssql_ConnectionObj_dealloc,	/* tp_dealloc     */
  0,				/* tp_print          */
  0,				/* tp_getattr        */
  0,				/* tp_setattr        */
  0,				/* tp_compare        */
  (reprfunc)_mssql_ConnectionObj_repr,		/* tp_repr        */
  0,				/* tp_as_number      */
  0,				/* tp_as_sequence    */
  0,				/* tp_as_mapping     */
  0,				/* tp_hash           */
  0,				/* tp_call           */
  0,				/* tp_str            */
  0,				/* tp_getattro       */
  0,				/* tp_setattro       */
  0,				/* tp_as_buffer      */
  Py_TPFLAGS_DEFAULT,		/* tp_flags          */
  _mssql_ConnectionObj_Type_doc,/* tp_doc            */
  0,				/* tp_traverse       */
  0,				/* tp_clear          */
  0,				/* tp_richcompare    */
  0,				/* tp_weaklistoffset */
  0,				/* tp_iter           */
  0,				/* tp_iternext       */
  _mssql_ConnectionObj_methods,	/* tp_methods        */
  _mssql_ConnectionObj_members,	/* tp_members        */
  0,				/* tp_getset         */
  0,				/* tp_base           */
  0,				/* tp_dict           */
  0,				/* tp_descr_get      */
  0,				/* tp_descr_set      */
  0,				/* tp_dictoffset     */
  0,				/* tp_init           */
  NULL,				/* tp_alloc          */
  NULL,				/* tp_new            */
  NULL,				/* tp_free Low-level free-memory routine */
  0,				/* tp_bases          */
  0,				/* tp_mro method resolution order */
  0,				/* tp_defined        */
};


static PyMethodDef _mssql_methods[] = {
  {"connect", (PyCFunction) _mssql_connect, METH_VARARGS, _mssql_connect_doc},
  {"errmsg",  (PyCFunction) _mssql_errmsg,  METH_NOARGS,  _mssql_errmsg_doc},
  {"stdmsg",  (PyCFunction) _mssql_stdmsg,  METH_NOARGS,  _mssql_stdmsg_doc},
  {"set_login_timeout", (PyCFunction) _mssql_set_login_timeout, METH_O, _mssql_set_login_timeout_doc},
  {"set_query_timeout", (PyCFunction) _mssql_set_query_timeout, METH_O, _mssql_set_query_timeout_doc},
  {NULL,      NULL}
};


PyMODINIT_FUNC
init_mssql(void)
{
#ifdef MS_WINDOWS
  LPCSTR rtc;
#else
  RETCODE rtc;
#endif

  /* if we initialize this at declaration, MSVC 7 issues the following warn:
   warning C4232: nonstandard extension used : 'tp_getattro': address of
   dllimport 'PyObject_GenericGetAttr' is not static, identity not guaranteed */
  _mssql_ConnectionObj_Type.tp_getattro = PyObject_GenericGetAttr;

  PyDateTime_IMPORT;  // import datetime

  if (PyType_Ready(&_mssql_ConnectionObj_Type) < 0)
    return;

  _mssql_module = Py_InitModule3("_mssql", _mssql_methods,
                "low level Python module for communicating with MS SQL servers");

  if (_mssql_module == NULL)  return;

  _mssql_error = PyErr_NewException("_mssql.error", NULL, NULL);
  if (PyModule_AddObject(_mssql_module, "error", _mssql_error) == -1) return;
  if (PyModule_AddIntConstant(_mssql_module, "STRING", TYPE_STRING) == -1) return;
  if (PyModule_AddIntConstant(_mssql_module, "BINARY", TYPE_BINARY) == -1) return;
  if (PyModule_AddIntConstant(_mssql_module, "NUMBER", TYPE_NUMBER) == -1) return;
  if (PyModule_AddIntConstant(_mssql_module, "DATETIME", TYPE_DATETIME) == -1) return;
  if (PyModule_AddIntConstant(_mssql_module, "DECIMAL", TYPE_DECIMAL) == -1) return;
  if (PyModule_AddObject(_mssql_module,
       "min_error_severity", PyInt_FromLong((long) 5)) == -1) return;

  Py_INCREF(&_mssql_ConnectionObj_Type);

  decmod = PyImport_ImportModule("decimal");
  if (decmod == NULL)  return;

  // DB-Lib initialization moved here to be able to call set_login_timeout()
  // and get sensible error handling

  // this can be lengthy process
  Py_BEGIN_ALLOW_THREADS
  rtc = dbinit();
  Py_END_ALLOW_THREADS

#ifdef MS_WINDOWS
  if (rtc == (char *)NULL) {
#else
  if (rtc == FAIL) {
#endif
    PyErr_SetString(_mssql_error,"Could not initialize the communication layer");
    return;
  }

  // these don't need to release GIL
#ifdef MS_WINDOWS
  dberrhandle((DBERRHANDLE_PROC)err_handler);
  dbmsghandle((DBMSGHANDLE_PROC)msg_handler);
#else
  dberrhandle(err_handler);
  dbmsghandle(msg_handler);
#endif
}

/* internal functions ********************************************************/

PyObject *GetHeaders(DBPROCESS *dbproc)
{
  int x, cols, coltype, apicoltype;
  PyObject *headerSet=NULL;
  char *colname;

  Py_BEGIN_ALLOW_THREADS
  cols = dbnumcols(dbproc);			// get number of columns
  Py_END_ALLOW_THREADS

  if (!(headerSet = PyTuple_New(cols))) {
    PyErr_SetString(_mssql_error,"Could not create column tuple");
    return NULL;
  }

  for(x = 1; x <= cols; x++) {			// loop on all columns
    PyObject *columnSet = NULL;
    if (!(columnSet = PyTuple_New(2))) {
      PyErr_SetString(_mssql_error,"Could not create tuple for column header details");
      return NULL;
    }

    Py_BEGIN_ALLOW_THREADS
    colname = (char *) dbcolname(dbproc,x);
    coltype = dbcoltype(dbproc,x);
    Py_END_ALLOW_THREADS

    switch (coltype) {
      case SQLBIT: case SQLINT1: case SQLINT2: case SQLINT4: case SQLINTN:
	  case SQLFLT4: case SQLFLT8: case SQLFLTN:
        apicoltype = TYPE_NUMBER;
        break;

      case SQLMONEY: case SQLMONEY4: case SQLMONEYN:
      case SQLNUMERIC: case SQLDECIMAL:
        apicoltype = TYPE_DECIMAL;
		break;

      case SQLDATETIME: case SQLDATETIM4: case SQLDATETIMN:
        apicoltype = TYPE_DATETIME;
        break;

      case SQLVARCHAR: case SQLCHAR: case SQLTEXT:
        apicoltype = TYPE_STRING;
        break;

      //case SQLVARBINARY: case SQLBINARY: case SQLIMAGE:
      default:
        apicoltype = TYPE_BINARY;
    }

    PyTuple_SET_ITEM(columnSet,0,Py_BuildValue("s",colname));
    PyTuple_SET_ITEM(columnSet,1,Py_BuildValue("i",apicoltype));
    PyTuple_SET_ITEM(headerSet,x-1,columnSet);
  }

  return headerSet;
}


#define GET_DATA(dbproc,rowinfo,x) ((rowinfo==REG_ROW)?(BYTE*)dbdata(dbproc,x):(BYTE*)dbadata(dbproc,rowinfo,x))
#define GET_TYPE(dbproc,rowinfo,x) ((rowinfo==REG_ROW)?dbcoltype(dbproc,x):dbalttype(dbproc,rowinfo,x))
#define GET_LEN(dbproc,rowinfo,x) ((rowinfo==REG_ROW)?dbdatlen(dbproc,x):dbadlen(dbproc,rowinfo,x))

#define NUMERIC_BUF_SZ 45

PyObject *GetRow(DBPROCESS *dbproc, int rowinfo)
{
  int x,cols,coltype,len;
  LPBYTE data;			// column data pointer
  long intdata;
  double ddata;
  PyObject *record;
  DBDATEREC di;
  DBDATETIME dt;
  PyObject *o;  // datetime or decimal value
  char buf[NUMERIC_BUF_SZ]; // buffer in which we store text rep of big nums
  DBCOL dbcol;
  BYTE prec = 0;
  PyObject *po;

  Py_BEGIN_ALLOW_THREADS
  cols = dbnumcols(dbproc);	// get number of columns
  Py_END_ALLOW_THREADS

  if (!(record = PyTuple_New(cols))) {
    PyErr_SetString(_mssql_error,"Could not create record tuple");
    return NULL;
  }

  for (x=1; x<=cols; x++) {	// do all columns
    Py_BEGIN_ALLOW_THREADS
    data = GET_DATA(dbproc,rowinfo,x);    // get pointer to column's data
    coltype = GET_TYPE(dbproc,rowinfo,x);
    Py_END_ALLOW_THREADS

    if (data == NULL) {		// if NULL, use None
      PyTuple_SET_ITEM(record,x-1,Py_None);
      Py_INCREF(Py_None);
    } else {			// else have data
      switch (coltype) {
        case SQLBIT:
          intdata = (int) *(DBBIT *) data;
          PyTuple_SET_ITEM(record,x-1,PyBool_FromLong((long)intdata));
          break;
        case SQLINT1:
          intdata = (int) *(DBTINYINT *) data;
          PyTuple_SET_ITEM(record,x-1,Py_BuildValue("i",intdata));
          break;
        case SQLINT2:
          intdata = (int) *(DBSMALLINT *) data;
          PyTuple_SET_ITEM(record,x-1,Py_BuildValue("i",intdata));
          break;
        case SQLINT4:
          intdata = (int) *(DBINT *) data;
          PyTuple_SET_ITEM(record,x-1,Py_BuildValue("i",intdata));
          break;
        case SQLMONEY: case SQLMONEY4:
        case SQLNUMERIC: case SQLDECIMAL:  //XXX TEST
          dbcol.SizeOfStruct = sizeof(dbcol);

          if (dbcolinfo(dbproc,(rowinfo==REG_ROW) ? CI_REGULAR : CI_ALTERNATE,
                x, (rowinfo==REG_ROW) ? 0 : rowinfo, &dbcol) == FAIL) {
              PyErr_SetString(_mssql_error,"Could not obtain column info");
              return NULL;
          }

          if (coltype == SQLMONEY || coltype == SQLMONEY4)  prec = 4;

          o = PyObject_CallMethod(decmod, "getcontext", NULL);
          po = PyInt_FromLong((long) prec);
          if (PyObject_SetAttrString(o, "prec", po) == -1) {
            PyErr_SetString(_mssql_error,"Could not set decimal precision");
            return NULL;
          }
          Py_DECREF(po);

          len = dbconvert(dbproc,coltype,data,-1, SQLCHAR,(LPBYTE)buf, NUMERIC_BUF_SZ);
          buf[len] = 0; // null terminate the string
          if (!rmv_lcl(buf,buf,NUMERIC_BUF_SZ)) {
            PyErr_SetString(_mssql_error,"Could not remove locale formatting");
            return NULL;
          }
          o = PyObject_CallMethod(decmod, "Decimal", "s", buf);
          if (o == NULL)  return NULL;
          PyTuple_SET_ITEM(record,x-1,o);
          break;
        case SQLFLT4:
          dbconvert(dbproc,coltype,data,-1, SQLFLT8,(LPBYTE)&ddata,-1);
          PyTuple_SET_ITEM(record,x-1,Py_BuildValue("d",ddata));
          break;
        case SQLFLT8:
          ddata = *(double *)data;
          PyTuple_SET_ITEM(record,x-1,Py_BuildValue("d",ddata));
          break;
        case SQLDATETIM4:
          dbconvert(dbproc,coltype,data,-1, SQLDATETIME,(LPBYTE)&dt,-1);
          data = (LPBYTE) &dt;  // smalldatetime converted to full datetime
          // fall through
        case SQLDATETIME:
          dbdatecrack(dbproc, &di, (DBDATETIME*)data);

          // see README.freetds for info about date problem with FreeTDS
          o = PyDateTime_FromDateAndTime(
                di.year,di.month,di.day,di.hour,
                di.minute,di.second,di.millisecond*1000);
          PyTuple_SET_ITEM(record,x-1,o);
          break;
        //case SQLBINARY: case SQLVARBINARY:
        //case SQLIMAGE: case SQLTEXT:
        default:	// return as is (binary string)
          PyTuple_SET_ITEM(record,x-1,
          Py_BuildValue("s#",data,GET_LEN(dbproc,rowinfo,x)));
      } // end switch
    } // end else
  } // end for

  return record;             // done
}

/* rmv_lcl() -- strip off all locale formatting

   buf is supplied to make this solution thread-safe; conversion will succeed
   when buf is the same size as s (or larger); s is the string rep of the
   number to strip; scientific formats are not supported;
   buf can be == s (it can fix numbers in-place.)
   return codes: 0 - conversion failed (buf too small or buf or s is null)
                 1 - conversion succeeded

   Idea by Mark Pettit.
 */

int rmv_lcl(char *s, char *buf, size_t buflen)
{
  char c, *lastsep = NULL, *p = s, *b = buf;
  size_t  l;

  if (b == (char *)NULL) return 0;

  if (s == (char *)NULL) {
    *b = 0;
    return 0;
  }

  /* find last separator and length of s */
  while ((c = *p)) {
    if ((c == '.') || (c == ','))   lastsep = p;
    ++p;
  }

  l = p - s; // strlen(s)
  if (buflen < l) return 0;

  /* copy the number skipping all but last separator and all other chars */
  p = s;
  while ((c = *p)) {
    if (((c >= '0') && (c <= '9')) || (c == '-') || (c == '+'))
      *b++ = c;
    else if (p == lastsep)
      *b++ = '.';
    ++p;
  }

  *b = 0;
  return 1;
}

/* clear error condition so we can start accumulating error messages again */

void clr_err(_mssql_ConnectionObj *self) {
  *MSSQL_ERROR_STR(self) = '\0';
  if (self != NULL)
    self->mssql_severity = 0;
  else
   _mssql_severity = 0;
}

/* check whether accumulated severity is equal to or higher than
   min_error_severity, and if so, sets exception and returns true;
   else returns false (no need to raise exception)
*/

int maybe_raise(_mssql_ConnectionObj *self) {
  PyObject *o;
  long lo;
  char *errptr;

  o = PyObject_GetAttr(_mssql_module, PyString_FromString("min_error_severity"));
  lo = PyInt_AS_LONG(o);
  Py_DECREF(o);

  if (MSSQL_SEVERITY(self) < lo)
    return 0;

  // severe enough to raise an error
  errptr = MSSQL_ERROR_STR(self);

  PyErr_SetString(_mssql_error, (*errptr ? errptr : "Unknown error"));

  // cancel the instruction
  Py_BEGIN_ALLOW_THREADS
  dbcancel(self->dbproc);
  Py_END_ALLOW_THREADS

  return 1;
}
