Skip to content

Commit

Permalink
Ensure HTTP session is always closed
Browse files Browse the repository at this point in the history
If a session was still in use when the http client was closed the
session remained open.
  • Loading branch information
mfussenegger committed Sep 9, 2024
1 parent f5852ce commit 1a6f0c6
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions cr8/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import calendar
import types
import time
import contextlib
from urllib.parse import urlparse, parse_qs, urlunparse
from datetime import datetime, date
from typing import List, Union, Iterable, Dict, Optional, Any
Expand Down Expand Up @@ -327,6 +328,7 @@ def __init__(self,
self._pools: Dict[str, asyncio.Queue] = {}
self.session_settings = session_settings or {}

@contextlib.asynccontextmanager
async def _session(self, url):
pool = self._pools.get(url)
if not pool:
Expand All @@ -348,20 +350,23 @@ async def _session(self, url):
)
pool.put_nowait(session)

return await pool.get()
session = await pool.get()
try:
yield session
finally:
await pool.put(session)

async def execute(self, stmt, args=None):
payload = {'stmt': _plain_or_callable(stmt)}
if args:
payload['args'] = _plain_or_callable(args)
url = next(self.urls)
session = await self._session(url)
result = await _exec(
session,
url,
dumps(payload, cls=CrateJsonEncoder)
)
await self._pools[url].put(session)
async with self._session(url) as session:
result = await _exec(
session,
url,
dumps(payload, cls=CrateJsonEncoder)
)
return result

async def execute_many(self, stmt, bulk_args):
Expand All @@ -370,32 +375,32 @@ async def execute_many(self, stmt, bulk_args):
bulk_args=_plain_or_callable(bulk_args)
), cls=CrateJsonEncoder)
url = next(self.urls)
session = await self._session(url)
result = await _exec(session, url, data)
await self._pools[url].put(session)
async with self._session(url) as session:
result = await _exec(session, url, data)
return result

async def get_server_version(self):
urlparts = urlparse(self.hosts[0])
url = urlunparse((urlparts.scheme, urlparts.netloc, '/', '', '', ''))
session = await self._session(url)
async with session.get(url) as resp:
r = await resp.json()
version = r['version']
result = {
'hash': version['build_hash'],
'number': version['number'],
'date': _date_or_none(version['build_timestamp'][:10])
}
await self._pools[url].put(session)
return result
async with self._session(url) as session:
async with session.get(url) as resp:
r = await resp.json()
version = r['version']
result = {
'hash': version['build_hash'],
'number': version['number'],
'date': _date_or_none(version['build_timestamp'][:10])
}
return result

async def _close(self):
for url, pool in self._pools.items():
while not pool.empty():
pools = self._pools
self._pools = {}
for url, pool in pools.items():
for i in range(0, self.conn_pool_limit):
session = await pool.get()
await session.close()
self._pools = {}
pools.clear()

def close(self):
asyncio.get_event_loop().run_until_complete(self._close())
Expand Down

0 comments on commit 1a6f0c6

Please sign in to comment.