@@ -216,18 +216,25 @@ def _verify_ssl_from_first(hosts):
216
216
217
217
218
218
class AsyncpgClient :
219
- def __init__ (self , hosts , pool_size = 25 ):
219
+ def __init__ (self , hosts , pool_size = 25 , session_settings = None ):
220
220
self .dsn = _to_dsn (hosts )
221
221
self .pool_size = pool_size
222
222
self ._pool = None
223
223
self .is_cratedb = True
224
+ self .session_settings = session_settings or {}
224
225
225
226
async def _get_pool (self ):
227
+
228
+ async def set_session_settings (conn ):
229
+ for setting , value in self .session_settings .items ():
230
+ await conn .execute (f'set { setting } ={ value } ' )
231
+
226
232
if not self ._pool :
227
233
self ._pool = await asyncpg .create_pool (
228
234
self .dsn ,
229
235
min_size = self .pool_size ,
230
- max_size = self .pool_size
236
+ max_size = self .pool_size ,
237
+ setup = set_session_settings
231
238
)
232
239
return self ._pool
233
240
@@ -308,7 +315,7 @@ def _append_sql(host):
308
315
309
316
310
317
class HttpClient :
311
- def __init__ (self , hosts , conn_pool_limit = 25 ):
318
+ def __init__ (self , hosts , conn_pool_limit = 25 , session_settings = None ):
312
319
self .hosts = hosts
313
320
self .urls = itertools .cycle (list (map (_append_sql , hosts )))
314
321
self ._connector_params = {
@@ -317,13 +324,21 @@ def __init__(self, hosts, conn_pool_limit=25):
317
324
}
318
325
self .__session = None
319
326
self .is_cratedb = True
327
+ self .session_settings = session_settings or {}
320
328
321
329
@property
322
330
async def _session (self ):
323
331
session = self .__session
324
332
if session is None :
325
333
conn = aiohttp .TCPConnector (** self ._connector_params )
326
334
self .__session = session = aiohttp .ClientSession (connector = conn )
335
+ for setting , value in self .session_settings .items ():
336
+ payload = {'stmt' : f'set { setting } ={ value } ' }
337
+ await _exec (
338
+ session ,
339
+ next (self .urls ),
340
+ dumps (payload , cls = CrateJsonEncoder )
341
+ )
327
342
return session
328
343
329
344
async def execute (self , stmt , args = None ):
@@ -372,10 +387,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
372
387
self .close ()
373
388
374
389
375
- def client (hosts , concurrency = 25 ):
390
+ def client (hosts , session_settings = None , concurrency = 25 ):
376
391
hosts = hosts or 'localhost:4200'
377
392
if hosts .startswith ('asyncpg://' ):
378
393
if not asyncpg :
379
394
raise ValueError ('Cannot use "asyncpg" scheme if asyncpg is not available' )
380
- return AsyncpgClient (hosts , pool_size = concurrency )
381
- return HttpClient (_to_http_hosts (hosts ), conn_pool_limit = concurrency )
395
+ return AsyncpgClient (hosts , pool_size = concurrency , session_settings = session_settings )
396
+ return HttpClient (_to_http_hosts (hosts ), conn_pool_limit = concurrency , session_settings = session_settings )
0 commit comments