From baeee52b17dcdc0221aea06a1dca167f3ea46327 Mon Sep 17 00:00:00 2001 From: laodouya Date: Thu, 4 May 2023 01:54:55 +0800 Subject: [PATCH 01/10] Add db api 2.0: connections --- dbapi/__init__.py | 84 ++++++++++ dbapi/connections.py | 208 ++++++++++++++++++++++++ dbapi/constants/FIELD_TYPE.py | 32 ++++ dbapi/constants/__init__.py | 0 dbapi/converters.py | 292 ++++++++++++++++++++++++++++++++++ dbapi/err.py | 61 +++++++ dbapi/times.py | 21 +++ 7 files changed, 698 insertions(+) create mode 100644 dbapi/__init__.py create mode 100644 dbapi/connections.py create mode 100644 dbapi/constants/FIELD_TYPE.py create mode 100644 dbapi/constants/__init__.py create mode 100644 dbapi/converters.py create mode 100644 dbapi/err.py create mode 100644 dbapi/times.py diff --git a/dbapi/__init__.py b/dbapi/__init__.py new file mode 100644 index 00000000000..4c15da98bd3 --- /dev/null +++ b/dbapi/__init__.py @@ -0,0 +1,84 @@ +from .converters import escape_dict, escape_sequence, escape_string +from .constants import FIELD_TYPE +from .err import ( + Warning, Error, InterfaceError, DataError, + DatabaseError, OperationalError, IntegrityError, InternalError, + NotSupportedError, ProgrammingError) +from . import connections as _orig_conn + +VERSION = (0, 1, 0, None) +if VERSION[3] is not None: + VERSION_STRING = "%d.%d.%d_%s" % VERSION +else: + VERSION_STRING = "%d.%d.%d" % VERSION[:3] + +threadsafety = 1 +apilevel = "2.0" +paramstyle = "format" + + +class DBAPISet(frozenset): + + def __ne__(self, other): + if isinstance(other, set): + return frozenset.__ne__(self, other) + else: + return other not in self + + def __eq__(self, other): + if isinstance(other, frozenset): + return frozenset.__eq__(self, other) + else: + return other in self + + def __hash__(self): + return frozenset.__hash__(self) + + +# TODO it's in pep249 find out meaning and usage of this +# https://www.python.org/dev/peps/pep-0249/#string +STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, + FIELD_TYPE.VAR_STRING]) +BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, + FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB]) +NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT, + FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG, + FIELD_TYPE.TINY, FIELD_TYPE.YEAR]) +DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE]) +TIME = DBAPISet([FIELD_TYPE.TIME]) +TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME]) +DATETIME = TIMESTAMP +ROWID = DBAPISet() + + +def Binary(x): + """Return x as a binary type.""" + return bytes(x) + + +def Connect(*args, **kwargs): + """ + Connect to the database; see connections.Connection.__init__() for + more information. + """ + from .connections import Connection + return Connection(*args, **kwargs) + + +if _orig_conn.Connection.__init__.__doc__ is not None: + Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ +del _orig_conn + + +def get_client_info(): # for MySQLdb compatibility + version = VERSION + if VERSION[3] is None: + version = VERSION[:3] + return '.'.join(map(str, version)) + + +connect = Connection = Connect + +NULL = "NULL" + +__version__ = get_client_info() diff --git a/dbapi/connections.py b/dbapi/connections.py new file mode 100644 index 00000000000..93d366c0b0e --- /dev/null +++ b/dbapi/connections.py @@ -0,0 +1,208 @@ +import chdb +import json +from . import err +from .cursors import Cursor +from . import converters + +DEBUG = False +VERBOSE = False + + +class Connection(object): + """ + Representation of a connection with chdb. + + The proper way to get an instance of this class is to call + connect(). + + Accepts several arguments: + + :param cursorclass: Custom cursor class to use. + + See `Connection `_ in the + specification. + """ + + _closed = False + + def __init__(self, cursorclass=Cursor): + + self._resp = None + + # 1. pre-process params in init + self.encoding = 'utf8' + + self.cursorclass = cursorclass + + self._result = None + self._affected_rows = 0 + + self.connect() + + def connect(self): + self._closed = False + self._execute_command("select 1;") + self._read_query_result() + + def close(self): + """ + Send the quit message and close the socket. + + See `Connection.close() `_ + in the specification. + + :raise Error: If the connection is already closed. + """ + if self._closed: + raise err.Error("Already closed") + self._closed = True + + @property + def open(self): + """Return True if the connection is open""" + return not self._closed + + def commit(self): + """ + Commit changes to stable storage. + + See `Connection.commit() `_ + in the specification. + """ + return + + def rollback(self): + """ + Roll back the current transaction. + + See `Connection.rollback() `_ + in the specification. + """ + return + + def cursor(self, cursor=None): + """ + Create a new cursor to execute queries with. + + :param cursor: The type of cursor to create; current only :py:class:`Cursor` + None means use Cursor. + """ + if cursor: + return cursor(self) + return self.cursorclass(self) + + # The following methods are INTERNAL USE ONLY (called from Cursor) + def query(self, sql): + if isinstance(sql, str): + sql = sql.encode(self.encoding, 'surrogateescape') + self._execute_command(sql) + self._affected_rows = self._read_query_result() + return self._affected_rows + + def _execute_command(self, sql): + """ + :raise InterfaceError: If the connection is closed. + :raise ValueError: If no username was specified. + """ + if self._closed: + raise err.InterfaceError("Connection closed") + + if isinstance(sql, str): + sql = sql.encode(self.encoding) + + if isinstance(sql, bytearray): + sql = bytes(sql) + + # drop last command return + if self._resp is not None: + self._resp = None + + if DEBUG: + print("DEBUG: query:", sql) + try: + self._resp = chdb.query(sql, output_format="JSON").data() + except Exception as error: + raise err.InterfaceError("query err: %s" % error) + + def escape(self, obj, mapping=None): + """Escape whatever value you pass to it. + + Non-standard, for internal use; do not use this in your applications. + """ + if isinstance(obj, str): + return "'" + self.escape_string(obj) + "'" + if isinstance(obj, (bytes, bytearray)): + ret = self._quote_bytes(obj) + return ret + return converters.escape_item(obj, mapping=mapping) + + def escape_string(self, s): + return converters.escape_string(s) + + def _quote_bytes(self, s): + return converters.escape_bytes(s) + + def _read_query_result(self): + self._result = None + result = CHDBResult(self) + result.read() + self._result = result + return result.affected_rows + + def __enter__(self): + """Context manager that returns a Cursor""" + return self.cursor() + + def __exit__(self, exc, value, traceback): + """On successful exit, commit. On exception, rollback""" + if exc: + self.rollback() + else: + self.commit() + + @property + def resp(self): + return self._resp + + +class CHDBResult(object): + def __init__(self, connection): + """ + :type connection: Connection + """ + self.connection = connection + self.affected_rows = 0 + self.insert_id = None + self.warning_count = 0 + self.message = None + self.field_count = 0 + self.description = None + self.rows = None + self.has_next = None + + def read(self): + try: + data = json.loads(self.connection.resp) + except Exception as error: + raise err.InterfaceError("Unsupported response format:" % error) + + try: + self.field_count = len(data["meta"]) + description = [] + for i in range(data["meta"]): + fields = [] + fields.append(data['meta'][i]["name"]) + fields.append(data['meta'][i]["type"]) + description.append(tuple(fields)) + self.description = tuple(description) + + rows = [] + for line in data["data"]: + row = [] + for i in range(self.field_count): + column_data = converters.convert_column_data(self.description[i][1], line[i]) + row.append(column_data) + rows.append(tuple(row)) + self.rows = tuple(rows) + except Exception as error: + raise err.InterfaceError("Read return data err:" % error) diff --git a/dbapi/constants/FIELD_TYPE.py b/dbapi/constants/FIELD_TYPE.py new file mode 100644 index 00000000000..2bc7713424a --- /dev/null +++ b/dbapi/constants/FIELD_TYPE.py @@ -0,0 +1,32 @@ +DECIMAL = 0 +TINY = 1 +SHORT = 2 +LONG = 3 +FLOAT = 4 +DOUBLE = 5 +NULL = 6 +TIMESTAMP = 7 +LONGLONG = 8 +INT24 = 9 +DATE = 10 +TIME = 11 +DATETIME = 12 +YEAR = 13 +NEWDATE = 14 +VARCHAR = 15 +BIT = 16 +JSON = 245 +NEWDECIMAL = 246 +ENUM = 247 +SET = 248 +TINY_BLOB = 249 +MEDIUM_BLOB = 250 +LONG_BLOB = 251 +BLOB = 252 +VAR_STRING = 253 +STRING = 254 +GEOMETRY = 255 + +CHAR = TINY +INTERVAL = ENUM + diff --git a/dbapi/constants/__init__.py b/dbapi/constants/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbapi/converters.py b/dbapi/converters.py new file mode 100644 index 00000000000..17a210f219b --- /dev/null +++ b/dbapi/converters.py @@ -0,0 +1,292 @@ +import datetime +from decimal import Decimal +from .err import DataError +import re +import time +import arrow + + +def escape_item(val, mapping=None): + if mapping is None: + mapping = encoders + encoder = mapping.get(type(val)) + + # Fallback to default when no encoder found + if not encoder: + try: + encoder = mapping[str] + except KeyError: + raise TypeError("no default type converter defined") + + val = encoder(val, mapping) + return val + + +def escape_dict(val, mapping=None): + n = {} + for k, v in val.items(): + quoted = escape_item(v, mapping) + n[k] = quoted + return n + + +def escape_sequence(val, mapping=None): + n = [] + for item in val: + quoted = escape_item(item, mapping) + n.append(quoted) + return "(" + ",".join(n) + ")" + + +def escape_set(val, mapping=None): + return ','.join([escape_item(x, mapping) for x in val]) + + +def escape_bool(value, mapping=None): + return str(int(value)) + + +def escape_object(value, mapping=None): + return str(value) + + +def escape_int(value, mapping=None): + return str(value) + + +def escape_float(value, mapping=None): + return '%.15g' % value + + +_escape_table = [chr(x) for x in range(128)] +_escape_table[ord("'")] = u"''" + + +def _escape_unicode(value, mapping=None): + """escapes *value* with adding single quote. + + Value should be unicode + """ + return value.translate(_escape_table) + + +escape_string = _escape_unicode + +# On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow. +# (fixed in Python 3.6, http://bugs.python.org/issue24870) +# Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff. +# We can escape special chars and surrogateescape at once. +_escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)] + + +def escape_bytes(value, mapping=None): + return "'%s'" % value.decode('latin1').translate(_escape_bytes_table) + + +def escape_unicode(value, mapping=None): + return u"'%s'" % _escape_unicode(value) + + +def escape_str(value, mapping=None): + return "'%s'" % escape_string(str(value), mapping) + + +def escape_None(value, mapping=None): + return 'NULL' + + +def escape_timedelta(obj, mapping=None): + seconds = int(obj.seconds) % 60 + minutes = int(obj.seconds // 60) % 60 + hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24 + if obj.microseconds: + fmt = "'{0:02d}:{1:02d}:{2:02d}.{3:06d}'" + else: + fmt = "'{0:02d}:{1:02d}:{2:02d}'" + return fmt.format(hours, minutes, seconds, obj.microseconds) + + +def escape_time(obj, mapping=None): + return "'{}'".format(obj.isoformat(timespec='microseconds')) + + +def escape_datetime(obj, mapping=None): + return "'{}'".format(obj.isoformat(sep=' ', timespec='microseconds')) + # if obj.microsecond: + # fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'" + # else: + # fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'" + # return fmt.format(obj) + + +def escape_date(obj, mapping=None): + return "'{}'".format(obj.isoformat()) + + +def escape_struct_time(obj, mapping=None): + return escape_datetime(datetime.datetime(*obj[:6])) + + +def _convert_second_fraction(s): + if not s: + return 0 + # Pad zeros to ensure the fraction length in microseconds + s = s.ljust(6, '0') + return int(s[:6]) + + +def convert_datetime(obj): + """Returns a DATETIME or TIMESTAMP column value as a datetime object: + + >>> datetime_or_None('2007-02-25 23:06:20') + datetime.datetime(2007, 2, 25, 23, 6, 20) + >>> datetime_or_None('2007-02-25T23:06:20') + datetime.datetime(2007, 2, 25, 23, 6, 20) + + Illegal values are raise DataError + + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + + try: + return arrow.get(obj).datetime + except Exception as err: + raise DataError("Not valid datetime struct: %s" % err) + + +TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") + + +def convert_timedelta(obj): + """Returns a TIME column as a timedelta object: + + >>> timedelta_or_None('25:06:17') + datetime.timedelta(1, 3977) + >>> timedelta_or_None('-25:06:17') + datetime.timedelta(-2, 83177) + + Illegal values are returned as None: + + >>> timedelta_or_None('random crap') is None + True + + Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but + can accept values as (+|-)DD HH:MM:SS. The latter format will not + be parsed correctly by this function. + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + + m = TIMEDELTA_RE.match(obj) + if not m: + return obj + + try: + groups = list(m.groups()) + groups[-1] = _convert_second_fraction(groups[-1]) + negate = -1 if groups[0] else 1 + hours, minutes, seconds, microseconds = groups[1:] + + tdelta = datetime.timedelta( + hours=int(hours), + minutes=int(minutes), + seconds=int(seconds), + microseconds=int(microseconds) + ) * negate + return tdelta + except ValueError as err: + raise DataError("Not valid time or timedelta struct: %s" % err) + + +def convert_time(obj): + """Returns a TIME column as a time object: + + >>> time_or_None('15:06:17') + datetime.time(15, 6, 17) + + Illegal values are returned DataError: + + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + + try: + return arrow.get("1970-01-01T" + obj).time() + except Exception: + return convert_timedelta(obj) + + +def convert_date(obj): + """Returns a DATE column as a date object: + + >>> date_or_None('2007-02-26') + datetime.date(2007, 2, 26) + + Illegal values are returned as None: + + >>> date_or_None('2007-02-31') is None + True + >>> date_or_None('0000-00-00') is None + True + + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + try: + return arrow.get(obj).date() + except Exception as err: + raise DataError("Not valid date struct: %s" % err) + + +def convert_set(s): + if isinstance(s, (bytes, bytearray)): + return set(s.split(b",")) + return set(s.split(",")) + + +def convert_characters(connection, data): + if connection.use_unicode: + data = data.decode("utf8") + return data + + +def convert_column_data(column_type, column_data): + data = column_data + + # Null + if data is None: + return data + + if not isinstance(column_type, str): + return data + + column_type = column_type.lower().strip() + if column_type == 'time': + data = convert_time(column_data) + elif column_type == 'date': + data = convert_date(column_data) + elif column_type == 'datetime': + data = convert_datetime(column_data) + + return data + + +encoders = { + bool: escape_bool, + int: escape_int, + float: escape_float, + str: escape_unicode, + tuple: escape_sequence, + list: escape_sequence, + set: escape_sequence, + frozenset: escape_sequence, + dict: escape_dict, + type(None): escape_None, + datetime.date: escape_date, + datetime.datetime: escape_datetime, + datetime.timedelta: escape_timedelta, + datetime.time: escape_time, + time.struct_time: escape_struct_time, + Decimal: escape_object, +} diff --git a/dbapi/err.py b/dbapi/err.py new file mode 100644 index 00000000000..df97a15e108 --- /dev/null +++ b/dbapi/err.py @@ -0,0 +1,61 @@ +class StandardError(Exception): + """Exception related to operation with chdb.""" + + +class Warning(StandardError): + """Exception raised for important warnings like data truncations + while inserting, etc.""" + + +class Error(StandardError): + """Exception that is the base class of all other error exceptions + (not Warning).""" + + +class InterfaceError(Error): + """Exception raised for errors that are related to the database + interface rather than the database itself.""" + + +class DatabaseError(Error): + """Exception raised for errors that are related to the + database.""" + + +class DataError(DatabaseError): + """Exception raised for errors that are due to problems with the + processed data like division by zero, numeric value out of range, + etc.""" + + +class OperationalError(DatabaseError): + """Exception raised for errors that are related to the database's + operation and not necessarily under the control of the programmer, + e.g. an unexpected disconnect occurs, the data source name is not + found, a transaction could not be processed, a memory allocation + error occurred during processing, etc.""" + + +class IntegrityError(DatabaseError): + """Exception raised when the relational integrity of the database + is affected, e.g. a foreign key check fails, duplicate key, + etc.""" + + +class InternalError(DatabaseError): + """Exception raised when the database encounters an internal + error, e.g. the cursor is not valid anymore, the transaction is + out of sync, etc.""" + + +class ProgrammingError(DatabaseError): + """Exception raised for programming errors, e.g. table not found + or already exists, syntax error in the SQL statement, wrong number + of parameters specified, etc.""" + + +class NotSupportedError(DatabaseError): + """Exception raised in case a method or database API was used + which is not supported by the database, e.g. requesting a + .rollback() on a connection that does not support transaction or + has transactions turned off.""" diff --git a/dbapi/times.py b/dbapi/times.py new file mode 100644 index 00000000000..9afa599677a --- /dev/null +++ b/dbapi/times.py @@ -0,0 +1,21 @@ +from time import localtime +from datetime import date, datetime, time, timedelta + + +Date = date +Time = time +TimeDelta = timedelta +Timestamp = datetime + + +def DateFromTicks(ticks): + return date(*localtime(ticks)[:3]) + + +def TimeFromTicks(ticks): + return time(*localtime(ticks)[3:6]) + + +def TimestampFromTicks(ticks): + return datetime(*localtime(ticks)[:6]) + From 4762357b126508beb2fa89596b1ef45fb10e72f4 Mon Sep 17 00:00:00 2001 From: laodouya Date: Mon, 8 May 2023 11:42:57 +0800 Subject: [PATCH 02/10] Add cursors for db api --- dbapi/cursors.py | 301 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 dbapi/cursors.py diff --git a/dbapi/cursors.py b/dbapi/cursors.py new file mode 100644 index 00000000000..d5467b922f9 --- /dev/null +++ b/dbapi/cursors.py @@ -0,0 +1,301 @@ +from . import err +from functools import partial +import re + +# Regular expression for :meth:`Cursor.executemany`. +# executemany only supports simple bulk insert. +# You can use it to load large dataset. +RE_INSERT_VALUES = re.compile( + r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" + + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", + re.IGNORECASE | re.DOTALL) + + +class Cursor(object): + """ + This is the object you use to interact with the database. + + Do not create an instance of a Cursor yourself. Call + connections.Connection.cursor(). + + See `Cursor `_ in + the specification. + """ + + #: Max statement size which :meth:`executemany` generates. + #: + #: Default value is 1024000. + max_stmt_length = 1024000 + + def __init__(self, connection): + self.connection = connection + self.description = None + self.rowcount = -1 + self.rownumber = 0 + self.arraysize = 1 + self.lastrowid = None + self._result = None + self._rows = None + self._executed = None + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + del exc_info + self.close() + + def __iter__(self): + return iter(self.fetchone, None) + + def callproc(self, procname, args=()): + """Execute stored procedure procname with args + + procname -- string, name of procedure to execute on server + + args -- Sequence of parameters to use with procedure + + Returns the original args. + + Compatibility warning: PEP-249 specifies that any modified + parameters must be returned. This is currently impossible + as they are only available by storing them in a server + variable and then retrieved by a query. Since stored + procedures return zero or more result sets, there is no + reliable way to get at OUT or INOUT parameters via callproc. + The server variables are named @_procname_n, where procname + is the parameter above and n is the position of the parameter + (from zero). Once all result sets generated by the procedure + have been fetched, you can issue a SELECT @_procname_0, ... + query using .execute() to get any OUT or INOUT values. + + Compatibility warning: The act of calling a stored procedure + itself creates an empty result set. This appears after any + result sets generated by the procedure. This is non-standard + behavior with respect to the DB-API. Be sure to use nextset() + to advance through all result sets; otherwise you may get + disconnected. + """ + + return args + + def close(self): + """ + Closing a cursor just exhausts all remaining data. + """ + conn = self.connection + if conn is None: + return + try: + while self.nextset(): + pass + finally: + self.connection = None + + def _get_db(self): + if not self.connection: + raise err.ProgrammingError("Cursor closed") + return self.connection + + def _escape_args(self, args, conn): + if isinstance(args, (tuple, list)): + return tuple(conn.escape(arg) for arg in args) + elif isinstance(args, dict): + return {key: conn.escape(val) for (key, val) in args.items()} + else: + # If it's not a dictionary let's try escaping it anyway. + # Worst case it will throw a Value error + return conn.escape(args) + + def mogrify(self, query, args=None): + """ + Returns the exact string that is sent to the database by calling the + execute() method. + + This method follows the extension to the DB API 2.0 followed by Psycopg. + """ + conn = self._get_db() + + if args is not None: + query = query % self._escape_args(args, conn) + + return query + + def _clear_result(self): + self.rownumber = 0 + self._result = None + + self.rowcount = 0 + self.description = None + self.lastrowid = None + self._rows = None + + def _do_get_result(self): + conn = self._get_db() + + self._result = result = conn._result + + self.rowcount = result.affected_rows + self.description = result.description + self.lastrowid = result.insert_id + self._rows = result.rows + + def _query(self, q): + conn = self._get_db() + self._last_executed = q + self._clear_result() + conn.query(q) + self._do_get_result() + return self.rowcount + + def execute(self, query, args=None): + """Execute a query + + :param str query: Query to execute. + + :param args: parameters used with query. (optional) + :type args: tuple, list or dict + + :return: Number of affected rows + :rtype: int + + If args is a list or tuple, %s can be used as a placeholder in the query. + If args is a dict, %(name)s can be used as a placeholder in the query. + """ + while self.nextset(): + pass + + query = self.mogrify(query, args) + + result = self._query(query) + self._executed = query + return result + + def executemany(self, query, args): + # type: (str, list) -> int + """Run several data against one query + + :param query: query to execute on server + :param args: Sequence of sequences or mappings. It is used as parameter. + :return: Number of rows affected, if any. + + This method improves performance on multiple-row INSERT and + REPLACE. Otherwise, it is equivalent to looping over args with + execute(). + """ + if not args: + return 0 + + m = RE_INSERT_VALUES.match(query) + if m: + q_prefix = m.group(1) % () + q_values = m.group(2).rstrip() + q_postfix = m.group(3) or '' + assert q_values[0] == '(' and q_values[-1] == ')' + return self._do_execute_many(q_prefix, q_values, q_postfix, args, + self.max_stmt_length, + self._get_db().encoding) + + self.rowcount = sum(self.execute(query, arg) for arg in args) + return self.rowcount + + def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): + conn = self._get_db() + escape = self._escape_args + if isinstance(prefix, str): + prefix = prefix.encode(encoding) + if isinstance(postfix, str): + postfix = postfix.encode(encoding) + sql = str(prefix) + args = iter(args) + v = values % escape(next(args), conn) + if isinstance(v, str): + v = v.encode(encoding, 'surrogateescape') + sql += v + rows = 0 + for arg in args: + v = values % escape(arg, conn) + if isinstance(v, str): + v = v.encode(encoding, 'surrogateescape') + if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: + rows += self.execute(sql + postfix) + sql = str(prefix) + else: + sql += ',' + sql += v + rows += self.execute(sql + postfix) + self.rowcount = rows + return rows + + def _check_executed(self): + if not self._executed: + raise err.ProgrammingError("execute() first") + + def fetchone(self): + """Fetch the next row""" + self._check_executed() + if self._rows is None or self.rownumber >= len(self._rows): + return None + result = self._rows[self.rownumber] + self.rownumber += 1 + return result + + def fetchmany(self, size=None): + """Fetch several rows""" + self._check_executed() + if self._rows is None: + return () + end = self.rownumber + (size or self.arraysize) + result = self._rows[self.rownumber:end] + self.rownumber = min(end, len(self._rows)) + return result + + def fetchall(self): + """Fetch all the rows""" + self._check_executed() + if self._rows is None: + return () + if self.rownumber: + result = self._rows[self.rownumber:] + else: + result = self._rows + self.rownumber = len(self._rows) + return result + + def nextset(self): + """Get the next query set""" + # Not support for now + return None + + def setinputsizes(self, *args): + """Does nothing, required by DB API.""" + + def setoutputsizes(self, *args): + """Does nothing, required by DB API.""" + + +class DictCursor(Cursor): + """A cursor which returns results as a dictionary""" + # You can override this to use OrderedDict or other dict-like types. + dict_type = dict + + def _do_get_result(self): + super(self)._do_get_result() + fields = [] + if self.description: + for f in self._result.fields: + name = f.name + if name in fields: + name = f.table_name + '.' + name + fields.append(name) + self._fields = fields + + if fields and self._rows: + self._rows = [self._conv_row(r) for r in self._rows] + + def _conv_row(self, row): + if row is None: + return None + return self.dict_type(zip(self._fields, row)) + From 7c7bc606a790338a6ab9fa6d2ae28cec671d392d Mon Sep 17 00:00:00 2001 From: laodouya Date: Mon, 8 May 2023 21:00:22 +0800 Subject: [PATCH 03/10] Fix db driver column decode --- dbapi/connections.py | 8 +++----- dbapi/cursors.py | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/dbapi/connections.py b/dbapi/connections.py index 93d366c0b0e..8315862a63b 100644 --- a/dbapi/connections.py +++ b/dbapi/connections.py @@ -189,10 +189,8 @@ def read(self): try: self.field_count = len(data["meta"]) description = [] - for i in range(data["meta"]): - fields = [] - fields.append(data['meta'][i]["name"]) - fields.append(data['meta'][i]["type"]) + for meta in data["meta"]: + fields = [meta["name"], meta["type"]] description.append(tuple(fields)) self.description = tuple(description) @@ -200,7 +198,7 @@ def read(self): for line in data["data"]: row = [] for i in range(self.field_count): - column_data = converters.convert_column_data(self.description[i][1], line[i]) + column_data = converters.convert_column_data(self.description[i][1], line[self.description[i][0]]) row.append(column_data) rows.append(tuple(row)) self.rows = tuple(rows) diff --git a/dbapi/cursors.py b/dbapi/cursors.py index d5467b922f9..72819f00656 100644 --- a/dbapi/cursors.py +++ b/dbapi/cursors.py @@ -281,13 +281,11 @@ class DictCursor(Cursor): dict_type = dict def _do_get_result(self): - super(self)._do_get_result() + super()._do_get_result() fields = [] if self.description: - for f in self._result.fields: - name = f.name - if name in fields: - name = f.table_name + '.' + name + for f in self.description: + name = f[0] fields.append(name) self._fields = fields From cf51fa4a027464bbb8172092c672061c931bab4f Mon Sep 17 00:00:00 2001 From: laodouya Date: Thu, 11 May 2023 21:08:08 +0800 Subject: [PATCH 04/10] Add arm64 for darwin clang --- dbapi/connections.py | 2 +- setup.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/dbapi/connections.py b/dbapi/connections.py index 8315862a63b..5fca50dc209 100644 --- a/dbapi/connections.py +++ b/dbapi/connections.py @@ -1,4 +1,3 @@ -import chdb import json from . import err from .cursors import Cursor @@ -120,6 +119,7 @@ def _execute_command(self, sql): if DEBUG: print("DEBUG: query:", sql) try: + import chdb self._resp = chdb.query(sql, output_format="JSON").data() except Exception as error: raise err.InterfaceError("query err: %s" % error) diff --git a/setup.py b/setup.py index 5a0e57e66e0..7b1b9b3f3f7 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os +import platform import sys import re import subprocess @@ -72,7 +73,7 @@ def fix_version_init(version): f.seek(0) f.write(init_content) f.truncate() - + # As of Python 3.6, CCompiler has a `has_flag` method. # cf http://bugs.python.org/issue26689 @@ -117,12 +118,15 @@ def build_extensions(self): print("CC: " + os.environ.get('CC')) print("CXX: " + os.environ.get('CXX')) if sys.platform == 'darwin': - if os.system('which /usr/local/opt/llvm/bin/clang++ > /dev/null') == 0: - os.environ['CC'] = '/usr/local/opt/llvm/bin/clang' - os.environ['CXX'] = '/usr/local/opt/llvm/bin/clang++' - elif os.system('which /usr/local/opt/llvm@15/bin/clang++ > /dev/null') == 0: - os.environ['CC'] = '/usr/local/opt/llvm@15/bin/clang' - os.environ['CXX'] = '/usr/local/opt/llvm@15/bin/clang++' + brew_prefix = '/usr/local/opt' + if platform.machine() == 'arm64': + brew_prefix = '/opt/homebrew/opt' + if os.system('which '+brew_prefix+'/llvm/bin/clang++ > /dev/null') == 0: + os.environ['CC'] = brew_prefix + '/llvm/bin/clang' + os.environ['CXX'] = brew_prefix + '/llvm/bin/clang++' + elif os.system('which '+brew_prefix+'/llvm@15/bin/clang++ > /dev/null') == 0: + os.environ['CC'] = brew_prefix + '/llvm@15/bin/clang' + os.environ['CXX'] = brew_prefix + '/llvm@15/bin/clang++' else: raise RuntimeError("Must use brew clang++") elif sys.platform == 'linux': From 8d881550ab90fb8453897aa7bbfe7ee90e1a7805 Mon Sep 17 00:00:00 2001 From: laodouya Date: Fri, 12 May 2023 16:06:36 +0800 Subject: [PATCH 05/10] Move dbapi location --- {dbapi => chdb/dbapi}/__init__.py | 0 {dbapi => chdb/dbapi}/connections.py | 0 {dbapi => chdb/dbapi}/constants/FIELD_TYPE.py | 0 {dbapi => chdb/dbapi}/constants/__init__.py | 0 {dbapi => chdb/dbapi}/converters.py | 0 {dbapi => chdb/dbapi}/cursors.py | 1 - {dbapi => chdb/dbapi}/err.py | 0 {dbapi => chdb/dbapi}/times.py | 0 8 files changed, 1 deletion(-) rename {dbapi => chdb/dbapi}/__init__.py (100%) rename {dbapi => chdb/dbapi}/connections.py (100%) rename {dbapi => chdb/dbapi}/constants/FIELD_TYPE.py (100%) rename {dbapi => chdb/dbapi}/constants/__init__.py (100%) rename {dbapi => chdb/dbapi}/converters.py (100%) rename {dbapi => chdb/dbapi}/cursors.py (99%) rename {dbapi => chdb/dbapi}/err.py (100%) rename {dbapi => chdb/dbapi}/times.py (100%) diff --git a/dbapi/__init__.py b/chdb/dbapi/__init__.py similarity index 100% rename from dbapi/__init__.py rename to chdb/dbapi/__init__.py diff --git a/dbapi/connections.py b/chdb/dbapi/connections.py similarity index 100% rename from dbapi/connections.py rename to chdb/dbapi/connections.py diff --git a/dbapi/constants/FIELD_TYPE.py b/chdb/dbapi/constants/FIELD_TYPE.py similarity index 100% rename from dbapi/constants/FIELD_TYPE.py rename to chdb/dbapi/constants/FIELD_TYPE.py diff --git a/dbapi/constants/__init__.py b/chdb/dbapi/constants/__init__.py similarity index 100% rename from dbapi/constants/__init__.py rename to chdb/dbapi/constants/__init__.py diff --git a/dbapi/converters.py b/chdb/dbapi/converters.py similarity index 100% rename from dbapi/converters.py rename to chdb/dbapi/converters.py diff --git a/dbapi/cursors.py b/chdb/dbapi/cursors.py similarity index 99% rename from dbapi/cursors.py rename to chdb/dbapi/cursors.py index 72819f00656..9fa762b30bc 100644 --- a/dbapi/cursors.py +++ b/chdb/dbapi/cursors.py @@ -1,5 +1,4 @@ from . import err -from functools import partial import re # Regular expression for :meth:`Cursor.executemany`. diff --git a/dbapi/err.py b/chdb/dbapi/err.py similarity index 100% rename from dbapi/err.py rename to chdb/dbapi/err.py diff --git a/dbapi/times.py b/chdb/dbapi/times.py similarity index 100% rename from dbapi/times.py rename to chdb/dbapi/times.py From 200819593a6d1fdfa0f6daa2890d9896b699c9da Mon Sep 17 00:00:00 2001 From: laodouya Date: Fri, 12 May 2023 16:27:37 +0800 Subject: [PATCH 06/10] Use brew --prefix for brew install location --- setup.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 7b1b9b3f3f7..20cfb9d0e09 100644 --- a/setup.py +++ b/setup.py @@ -118,9 +118,11 @@ def build_extensions(self): print("CC: " + os.environ.get('CC')) print("CXX: " + os.environ.get('CXX')) if sys.platform == 'darwin': - brew_prefix = '/usr/local/opt' - if platform.machine() == 'arm64': - brew_prefix = '/opt/homebrew/opt' + try: + import subprocess + brew_prefix = subprocess.check_output('brew --prefix', shell=True).decode("utf-8").strip("\n") + except Exception: + raise RuntimeError("Must install brew") if os.system('which '+brew_prefix+'/llvm/bin/clang++ > /dev/null') == 0: os.environ['CC'] = brew_prefix + '/llvm/bin/clang' os.environ['CXX'] = brew_prefix + '/llvm/bin/clang++' From a4e50fb0b10ba1273aa86bdc1ae7c873c1b2e493 Mon Sep 17 00:00:00 2001 From: laodouya Date: Fri, 12 May 2023 16:33:39 +0800 Subject: [PATCH 07/10] Add example for dbapi --- examples/dbapi.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 examples/dbapi.py diff --git a/examples/dbapi.py b/examples/dbapi.py new file mode 100644 index 00000000000..82baa6f6f37 --- /dev/null +++ b/examples/dbapi.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +from chdb import dbapi +from chdb.dbapi.cursors import DictCursor + +print("chdb driver version: {0}".format(dbapi.get_client_info())) + +conn1 = dbapi.connect() +cur1 = conn1.cursor() +cur1.execute('select version()') +print("description: ", cur1.description) +print("data: ", cur1.fetchone()) +cur1.close() +conn1.close() + +conn2 = dbapi.connect(cursorclass=DictCursor) +cur2 = conn2.cursor() +cur2.execute(''' +SELECT + town, + district, + count() AS c, + round(avg(price)) AS price +FROM url('https://datasets-documentation.s3.eu-west-3.amazonaws.com/house_parquet/house_0.parquet') +GROUP BY + town, + district +LIMIT 10 +''') +print("description", cur2.description) +for row in cur2: + print(row) + +cur2.close() +conn2.close() From b7dc44df01ae530829cb1b5e062c82eae35eb560 Mon Sep 17 00:00:00 2001 From: laodouya Date: Fri, 12 May 2023 16:42:31 +0800 Subject: [PATCH 08/10] Fix brew_prefix for llvm --- setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 20cfb9d0e09..cf736810489 100644 --- a/setup.py +++ b/setup.py @@ -123,12 +123,12 @@ def build_extensions(self): brew_prefix = subprocess.check_output('brew --prefix', shell=True).decode("utf-8").strip("\n") except Exception: raise RuntimeError("Must install brew") - if os.system('which '+brew_prefix+'/llvm/bin/clang++ > /dev/null') == 0: - os.environ['CC'] = brew_prefix + '/llvm/bin/clang' - os.environ['CXX'] = brew_prefix + '/llvm/bin/clang++' - elif os.system('which '+brew_prefix+'/llvm@15/bin/clang++ > /dev/null') == 0: - os.environ['CC'] = brew_prefix + '/llvm@15/bin/clang' - os.environ['CXX'] = brew_prefix + '/llvm@15/bin/clang++' + if os.system('which '+brew_prefix+'/opt/llvm/bin/clang++ > /dev/null') == 0: + os.environ['CC'] = brew_prefix + '/opt/llvm/bin/clang' + os.environ['CXX'] = brew_prefix + '/opt/llvm/bin/clang++' + elif os.system('which '+brew_prefix+'/opt/llvm@15/bin/clang++ > /dev/null') == 0: + os.environ['CC'] = brew_prefix + '/opt/llvm@15/bin/clang' + os.environ['CXX'] = brew_prefix + '/opt/llvm@15/bin/clang++' else: raise RuntimeError("Must use brew clang++") elif sys.platform == 'linux': From 56a6fa04d3869abe26290a1181cd41aad2aa29f6 Mon Sep 17 00:00:00 2001 From: laodouya Date: Fri, 12 May 2023 16:47:04 +0800 Subject: [PATCH 09/10] Remove unused import in setup.py --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index cf736810489..0b950c1f988 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ import os -import platform import sys import re import subprocess From 75353c66a55a5c0017d82e07b8707357966a3922 Mon Sep 17 00:00:00 2001 From: laodouya Date: Fri, 12 May 2023 17:05:59 +0800 Subject: [PATCH 10/10] Fix setup.py --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 0b950c1f988..147c012c64f 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,6 @@ def build_extensions(self): print("CXX: " + os.environ.get('CXX')) if sys.platform == 'darwin': try: - import subprocess brew_prefix = subprocess.check_output('brew --prefix', shell=True).decode("utf-8").strip("\n") except Exception: raise RuntimeError("Must install brew")