From 08ce745f98807482849f4174fe67eeb8ffdb5cd9 Mon Sep 17 00:00:00 2001 From: Mathias Fussenegger Date: Mon, 9 Sep 2024 10:27:59 +0200 Subject: [PATCH] fixup! clean up --- cr8/clients.py | 47 +++++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/cr8/clients.py b/cr8/clients.py index 452f246..3b3c4b8 100644 --- a/cr8/clients.py +++ b/cr8/clients.py @@ -7,7 +7,7 @@ import time from urllib.parse import urlparse, parse_qs, urlunparse from datetime import datetime, date -from typing import List, Union, Iterable +from typing import List, Union, Iterable, Dict from decimal import Decimal from cr8.aio import asyncio # import via aio for uvloop setup @@ -321,42 +321,45 @@ def __init__(self, hosts, conn_pool_limit=25, session_settings=None): self.urls = itertools.cycle(list(map(_append_sql, hosts))) self.conn_pool_limit = conn_pool_limit self.is_cratedb = True - self._pool = [] + self._pools = {} self.session_settings = session_settings or {} - @property - async def _session(self): - if not self._pool: - self._connector_params = { + async def _session(self, url): + pool = self._pools.get(url) + if not pool: + pool = asyncio.Queue(maxsize=self.conn_pool_limit) + self._pools[url] = pool + _connector_params = { 'limit': 1, 'verify_ssl': _verify_ssl_from_first(self.hosts) } for n in range(0, self.conn_pool_limit): - tcp_connector = aiohttp.TCPConnector(**self._connector_params) + tcp_connector = aiohttp.TCPConnector(**_connector_params) session = aiohttp.ClientSession(connector=tcp_connector) for setting, value in self.session_settings.items(): payload = {'stmt': f'set {setting}={value}'} await _exec( session, - next(self.urls), + url, dumps(payload, cls=CrateJsonEncoder) ) - self._pool.append(session) + pool.put_nowait(session) - return self._pool.pop() + return await pool.get() async def execute(self, stmt, args=None): payload = {'stmt': _plain_or_callable(stmt)} if args: payload['args'] = _plain_or_callable(args) - session = await self._session + url = next(self.urls) + session = await self._session(url) result = await _exec( session, - next(self.urls), + url, dumps(payload, cls=CrateJsonEncoder) ) - self._pool.append(session) + await self._pools[url].put(session) return result async def execute_many(self, stmt, bulk_args): @@ -364,15 +367,16 @@ async def execute_many(self, stmt, bulk_args): stmt=_plain_or_callable(stmt), bulk_args=_plain_or_callable(bulk_args) ), cls=CrateJsonEncoder) - session = await self._session - result = await _exec(session, next(self.urls), data) - self._pool.append(session) + url = next(self.urls) + session = await self._session(url) + result = await _exec(session, url, data) + await self._pools[url].put(session) return result async def get_server_version(self): - session = await self._session urlparts = urlparse(self.hosts[0]) url = urlunparse((urlparts.scheme, urlparts.netloc, '/', '', '', '')) + session = await self._session(url) async with session.get(url) as resp: r = await resp.json() version = r['version'] @@ -381,12 +385,15 @@ async def get_server_version(self): 'number': version['number'], 'date': _date_or_none(version['build_timestamp'][:10]) } - self._pool.append(session) + await self._pools[url].put(session) return result async def _close(self): - for session in self._pool: - await session.close() + for url, pool in self._pools.items(): + while not pool.empty(): + session = await pool.get() + await session.close() + self._pools = {} def close(self): asyncio.get_event_loop().run_until_complete(self._close())