From 7944a53f694f0fbc2f88baebe674042dce91566f Mon Sep 17 00:00:00 2001 From: Natt Piyapramote Date: Mon, 3 Jul 2017 16:43:02 +0700 Subject: [PATCH] workaround issue #94 --- sql_server/pyodbc/base.py | 64 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/sql_server/pyodbc/base.py b/sql_server/pyodbc/base.py index d4d2692a..514869c7 100644 --- a/sql_server/pyodbc/base.py +++ b/sql_server/pyodbc/base.py @@ -479,6 +479,10 @@ def enable_constraint_checking(self): if not self.needs_rollback: self.check_constraints() +import datetime +from decimal import Decimal +from uuid import UUID + class CursorWrapper(object): """ @@ -493,6 +497,36 @@ def __init__(self, cursor, connection): self.last_sql = '' self.last_params = () + def _pytype_to_sqltype(self, typ, value): + if value is None: + return 'INT' + elif isinstance(value, str): + length = len(value) + if length == 0: + return 'NVARCHAR' + return 'NVARCHAR(%s)' % len(value) + elif typ == int: + if value < 0x7FFFFFFF and value > -0x7FFFFFFF: + return 'INT' + else: + return 'BIGINT' + elif typ == float: + return 'FLOAT' + elif typ == bool: + return 'BIT' + elif isinstance(value, Decimal): + return 'NUMERIC' + elif isinstance(value, datetime.date): + return 'DATE' + elif isinstance(value, datetime.time): + return 'TIME' + elif isinstance(value, datetime.datetime): + return 'TIMESTAMP' + elif isinstance(value, UUID): + return 'uniqueidentifier' + else: + raise NotImplementedError('not support type %s (%s)' % (type(value), repr(value))) + def close(self): if self.active: self.active = False @@ -536,8 +570,38 @@ def format_params(self, params): return tuple(fp) + def _fix_for_params(self, query, params, unify_by_values=False): + if params is None: + params = [] + query = query + elif unify_by_values and len(params) > 0: + # Same workaround from django/db/backends/oracle/base.py + + # Handle params as a dict with unified query parameters by their + # values. It can be used only in single query execute() because + # executemany() shares the formatted query with each of the params + # list. e.g. for input params = [0.75, 2, 0.75, 'sth', 0.75] + # params_dict = {0.75: ':arg0', 2: ':arg1', 'sth': ':arg2'} + # args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0'] + # params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'} + params = [(param, type(param)) for param in params] + params_dict = {param: '@arg%d' % i for i, param in enumerate(set(params))} + args = [params_dict[param] for param in params] + + variables = [] + params = [] + for key, value in params_dict.items(): + datatype = self._pytype_to_sqltype(key[1], key[0]) + variables.append("%s %s = %%s " % (value, datatype)) + params.append(key[0]) + query = ('DECLARE %s \n' % ','.join(variables)) + (query % tuple(args)) + params = tuple(params) + return query, params + def execute(self, sql, params=None): self.last_sql = sql + if 'GROUP BY' in sql: + sql, params = self._fix_for_params(sql, params, unify_by_values=True) sql = self.format_sql(sql, params) params = self.format_params(params) self.last_params = params