21
21
22
22
import logging
23
23
from datetime import datetime , date
24
+ from types import ModuleType
24
25
25
26
from sqlalchemy import types as sqltypes
26
27
from sqlalchemy .engine import default , reflection
@@ -205,6 +206,12 @@ def initialize(self, connection):
205
206
self .default_schema_name = \
206
207
self ._get_default_schema_name (connection )
207
208
209
+ def set_isolation_level (self , dbapi_connection , level ):
210
+ """
211
+ For CrateDB, this is implemented as a noop.
212
+ """
213
+ pass
214
+
208
215
def do_rollback (self , connection ):
209
216
# if any exception is raised by the dbapi, sqlalchemy by default
210
217
# attempts to do a rollback crate doesn't support rollbacks.
@@ -223,7 +230,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
223
230
use_ssl = asbool (kwargs .pop ("ssl" , False ))
224
231
if use_ssl :
225
232
servers = ["https://" + server for server in servers ]
226
- return self .dbapi .connect (servers = servers , ** kwargs )
233
+
234
+ is_module = isinstance (self .dbapi , ModuleType )
235
+ if is_module :
236
+ driver_name = self .dbapi .__name__
237
+ else :
238
+ driver_name = self .dbapi .__class__ .__name__
239
+ if driver_name == "crate.client" :
240
+ if "database" in kwargs :
241
+ del kwargs ["database" ]
242
+ return self .dbapi .connect (servers = servers , ** kwargs )
243
+ elif driver_name in ["psycopg" , "PsycopgAdaptDBAPI" , "AsyncAdapt_asyncpg_dbapi" ]:
244
+ return self .dbapi .connect (host = host , port = port , ** kwargs )
245
+ else :
246
+ raise ValueError (f"Unknown driver variant: { driver_name } " )
247
+
227
248
return self .dbapi .connect (** kwargs )
228
249
229
250
def _get_default_schema_name (self , connection ):
@@ -269,11 +290,11 @@ def get_schema_names(self, connection, **kw):
269
290
def get_table_names (self , connection , schema = None , ** kw ):
270
291
if schema is None :
271
292
schema = self ._get_effective_schema_name (connection )
272
- cursor = connection .exec_driver_sql (
293
+ cursor = connection .exec_driver_sql (self . _format_query (
273
294
"SELECT table_name FROM information_schema.tables "
274
295
"WHERE {0} = ? "
275
296
"AND table_type = 'BASE TABLE' "
276
- "ORDER BY table_name ASC, {0} ASC" .format (self .schema_column ),
297
+ "ORDER BY table_name ASC, {0} ASC" ) .format (self .schema_column ),
277
298
(schema or self .default_schema_name , )
278
299
)
279
300
return [row [0 ] for row in cursor .fetchall ()]
@@ -295,7 +316,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
295
316
"AND column_name !~ ?" \
296
317
.format (self .schema_column )
297
318
cursor = connection .exec_driver_sql (
298
- query ,
319
+ self . _format_query ( query ) ,
299
320
(table_name ,
300
321
schema or self .default_schema_name ,
301
322
r"(.*)\[\'(.*)\'\]" ) # regex to filter subscript
@@ -334,7 +355,7 @@ def result_fun(result):
334
355
return set (rows [0 ] if rows else [])
335
356
336
357
pk_result = engine .exec_driver_sql (
337
- query ,
358
+ self . _format_query ( query ) ,
338
359
(table_name , schema or self .default_schema_name )
339
360
)
340
361
pks = result_fun (pk_result )
@@ -375,6 +396,17 @@ def has_ilike_operator(self):
375
396
server_version_info = self .server_version_info
376
397
return server_version_info is not None and server_version_info >= (4 , 1 , 0 )
377
398
399
+ def _format_query (self , query ):
400
+ """
401
+ When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
402
+ the paramstyle is not `qmark`, but `pyformat`.
403
+
404
+ TODO: Review: Is it legit and sane? Are there alternatives?
405
+ """
406
+ if self .paramstyle == "pyformat" :
407
+ query = query .replace ("= ?" , "= %s" ).replace ("!~ ?" , "!~ %s" )
408
+ return query
409
+
378
410
379
411
class DateTrunc (functions .GenericFunction ):
380
412
name = "date_trunc"
0 commit comments