diff --git a/cr8/clients.py b/cr8/clients.py index b301d38..08430d3 100644 --- a/cr8/clients.py +++ b/cr8/clients.py @@ -5,6 +5,7 @@ import calendar import types import time +import contextlib from urllib.parse import urlparse, parse_qs, urlunparse from datetime import datetime, date from typing import List, Union, Iterable, Dict, Optional, Any @@ -327,6 +328,7 @@ def __init__(self, self._pools: Dict[str, asyncio.Queue] = {} self.session_settings = session_settings or {} + @contextlib.asynccontextmanager async def _session(self, url): pool = self._pools.get(url) if not pool: @@ -348,20 +350,23 @@ async def _session(self, url): ) pool.put_nowait(session) - return await pool.get() + session = await pool.get() + try: + yield session + finally: + await pool.put(session) async def execute(self, stmt, args=None): payload = {'stmt': _plain_or_callable(stmt)} if args: payload['args'] = _plain_or_callable(args) url = next(self.urls) - session = await self._session(url) - result = await _exec( - session, - url, - dumps(payload, cls=CrateJsonEncoder) - ) - await self._pools[url].put(session) + async with self._session(url) as session: + result = await _exec( + session, + url, + dumps(payload, cls=CrateJsonEncoder) + ) return result async def execute_many(self, stmt, bulk_args): @@ -370,32 +375,32 @@ async def execute_many(self, stmt, bulk_args): bulk_args=_plain_or_callable(bulk_args) ), cls=CrateJsonEncoder) url = next(self.urls) - session = await self._session(url) - result = await _exec(session, url, data) - await self._pools[url].put(session) + async with self._session(url) as session: + result = await _exec(session, url, data) return result async def get_server_version(self): urlparts = urlparse(self.hosts[0]) url = urlunparse((urlparts.scheme, urlparts.netloc, '/', '', '', '')) - session = await self._session(url) - async with session.get(url) as resp: - r = await resp.json() - version = r['version'] - result = { - 'hash': version['build_hash'], - 'number': version['number'], - 'date': _date_or_none(version['build_timestamp'][:10]) - } - await self._pools[url].put(session) - return result + async with self._session(url) as session: + async with session.get(url) as resp: + r = await resp.json() + version = r['version'] + result = { + 'hash': version['build_hash'], + 'number': version['number'], + 'date': _date_or_none(version['build_timestamp'][:10]) + } + return result async def _close(self): - for url, pool in self._pools.items(): - while not pool.empty(): + pools = self._pools + self._pools = {} + for url, pool in pools.items(): + for i in range(0, self.conn_pool_limit): session = await pool.get() await session.close() - self._pools = {} + pools.clear() def close(self): asyncio.get_event_loop().run_until_complete(self._close())