diff --git a/langserve/client.py b/langserve/client.py index b8c62aa2..df8c0b19 100644 --- a/langserve/client.py +++ b/langserve/client.py @@ -6,6 +6,7 @@ from typing import ( Any, AsyncIterator, + Dict, Iterator, List, Optional, @@ -15,6 +16,7 @@ from urllib.parse import urljoin import httpx +from httpx._types import AuthTypes, CertTypes, CookieTypes, HeaderTypes, VerifyTypes from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.load.dump import dumpd from langchain.schema.runnable import Runnable @@ -110,16 +112,48 @@ def __init__( url: str, *, timeout: Optional[float] = None, + auth: Optional[AuthTypes] = None, + headers: Optional[HeaderTypes] = None, + cookies: Optional[CookieTypes] = None, + verify: VerifyTypes = True, + cert: Optional[CertTypes] = None, + client_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the client. Args: url: The url of the server timeout: The timeout for requests + auth: Authentication class for requests + headers: Headers to send with requests + cookies: Cookies to send with requests + verify: Whether to verify SSL certificates + cert: SSL certificate to use for requests + client_kwargs: If provided will be unpacked as kwargs to both the sync + and async httpx clients """ + _client_kwargs = client_kwargs or {} self.url = url - self.sync_client = httpx.Client(base_url=url, timeout=timeout) - self.async_client = httpx.AsyncClient(base_url=url, timeout=timeout) + self.sync_client = httpx.Client( + base_url=url, + timeout=timeout, + auth=auth, + headers=headers, + cookies=cookies, + verify=verify, + cert=cert, + **_client_kwargs, + ) + self.async_client = httpx.AsyncClient( + base_url=url, + timeout=timeout, + auth=auth, + headers=headers, + cookies=cookies, + verify=verify, + cert=cert, + **_client_kwargs, + ) # Register cleanup handler once RemoteRunnable is garbage collected weakref.finalize(self, _close_clients, self.sync_client, self.async_client)