Skip to content

Commit

Permalink
fix: reuse session (#6196)
Browse files Browse the repository at this point in the history
Co-authored-by: Jina Dev Bot <[email protected]>
  • Loading branch information
JoanFM and jina-bot authored Sep 15, 2024
1 parent a62b85e commit fbdde03
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 102 deletions.
3 changes: 3 additions & 0 deletions jina/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def Client(
prefetch: Optional[int] = 1000,
protocol: Optional[Union[str, List[str]]] = 'GRPC',
proxy: Optional[bool] = False,
reuse_session: Optional[bool] = False,
suppress_root_logging: Optional[bool] = False,
tls: Optional[bool] = False,
traces_exporter_host: Optional[str] = None,
Expand Down Expand Up @@ -59,6 +60,7 @@ def Client(
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol between server and client.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it
:param suppress_root_logging: If set, then no root handlers will be suppressed from logging.
:param tls: If set, connect to gateway using tls encryption
:param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent.
Expand Down Expand Up @@ -113,6 +115,7 @@ def Client(args: Optional['argparse.Namespace'] = None, **kwargs) -> Union[
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol between server and client.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it
:param suppress_root_logging: If set, then no root handlers will be suppressed from logging.
:param tls: If set, connect to gateway using tls encryption
:param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent.
Expand Down
29 changes: 18 additions & 11 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class BaseClient(InstrumentationMixin, ABC):
"""

def __init__(
self,
args: Optional['argparse.Namespace'] = None,
**kwargs,
self,
args: Optional['argparse.Namespace'] = None,
**kwargs,
):
if args and isinstance(args, argparse.Namespace):
self.args = args
Expand Down Expand Up @@ -63,6 +63,12 @@ def __init__(
)
send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__)

async def close(self):
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
return self.teardown_instrumentation()

def teardown_instrumentation(self):
"""Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for
exporting metrics data is properly cleaned up.
Expand Down Expand Up @@ -118,7 +124,7 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None:
raise BadClientInput from ex

def _get_requests(
self, **kwargs
self, **kwargs
) -> Union[Iterator['Request'], AsyncIterator['Request']]:
"""
Get request in generator.
Expand Down Expand Up @@ -177,13 +183,14 @@ def inputs(self, bytes_gen: 'InputType') -> None:

@abc.abstractmethod
async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
**kwargs,
): ...
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
**kwargs,
):
...

@abc.abstractmethod
def _is_flow_ready(self, **kwargs) -> bool:
Expand Down
31 changes: 16 additions & 15 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class AioHttpClientlet(ABC):

def __init__(
self,
url: str,
logger: 'JinaLogger',
max_attempts: int = 1,
initial_backoff: float = 0.5,
Expand All @@ -59,7 +58,6 @@ def __init__(
) -> None:
"""HTTP Client to be used with the streamer
:param url: url to send http/websocket request to
:param logger: jina logger
:param max_attempts: Number of sending attempts, including the original request.
:param initial_backoff: The first retry will happen with a delay of random(0, initial_backoff)
Expand All @@ -68,7 +66,6 @@ def __init__(
:param tracer_provider: Optional tracer_provider that will be used to configure aiohttp tracing.
:param kwargs: kwargs which will be forwarded to the `aiohttp.Session` instance. Used to pass headers to requests
"""
self.url = url
self.logger = logger
self.msg_recv = 0
self.msg_sent = 0
Expand Down Expand Up @@ -131,7 +128,6 @@ async def start(self):
"""
with ImportExtensions(required=True):
import aiohttp

self.session = aiohttp.ClientSession(
**self._session_kwargs, trace_configs=self._trace_config
)
Expand All @@ -154,9 +150,10 @@ class HTTPClientlet(AioHttpClientlet):

UPDATE_EVENT_PREFIX = 14 # the update event has the following format: "event: update: {document_json}"

async def send_message(self, request: 'Request'):
async def send_message(self, url, request: 'Request'):
"""Sends a POST request to the server
:param url: the URL where to send the message
:param request: request as dict
:return: send post message
"""
Expand All @@ -166,23 +163,24 @@ async def send_message(self, request: 'Request'):
req_dict['target_executor'] = req_dict['header']['target_executor']
for attempt in range(1, self.max_attempts + 1):
try:
request_kwargs = {'url': self.url}
request_kwargs = {'url': url}
if not docarray_v2:
request_kwargs['json'] = req_dict
else:
from docarray.base_doc.io.json import orjson_dumps

request_kwargs['data'] = JinaJsonPayload(value=req_dict)

async with self.session.post(**request_kwargs) as response:
try:
r_str = await response.json()
except aiohttp.ContentTypeError:
r_str = await response.text()
r_status = response.status
handle_response_status(response.status, r_str, self.url)
return r_status, r_str
handle_response_status(r_status, r_str, url)
return r_status, r_str
except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err:
self.logger.debug(f'Got an error: {err} sending POST to {self.url} in attempt {attempt}/{self.max_attempts}')
self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}')
await retry.wait_or_raise_err(
attempt=attempt,
err=err,
Expand All @@ -193,19 +191,20 @@ async def send_message(self, request: 'Request'):
)
except Exception as exc:
self.logger.debug(
f'Got a non-retried error: {exc} sending POST to {self.url}')
f'Got a non-retried error: {exc} sending POST to {url}')
raise exc

async def send_streaming_message(self, doc: 'Document', on: str):
async def send_streaming_message(self, url, doc: 'Document', on: str):
"""Sends a GET SSE request to the server
:param url: the URL where to send the message
:param doc: Request Document
:param on: Request endpoint
:yields: responses
"""
req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict()
request_kwargs = {
'url': self.url,
'url': url,
'headers': {'Accept': 'text/event-stream'},
'json': req_dict,
}
Expand All @@ -219,13 +218,14 @@ async def send_streaming_message(self, doc: 'Document', on: str):
elif event.startswith(b'end'):
pass

async def send_dry_run(self, **kwargs):
async def send_dry_run(self, url, **kwargs):
"""Query the dry_run endpoint from Gateway
:param url: the URL where to send the message
:param kwargs: keyword arguments to make sure compatible API with other clients
:return: send get message
"""
return await self.session.get(
url=self.url, timeout=kwargs.get('timeout', None)
url=url, timeout=kwargs.get('timeout', None)
).__aenter__()

async def recv_message(self):
Expand Down Expand Up @@ -267,8 +267,9 @@ async def __anext__(self):
class WebsocketClientlet(AioHttpClientlet):
"""Websocket Client to be used with the streamer"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, url, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.url = url
self.websocket = None
self.response_iter = None

Expand Down
Loading

0 comments on commit fbdde03

Please sign in to comment.