diff --git a/src/pysignalr/client.py b/src/pysignalr/client.py index af986b1..a911c99 100644 --- a/src/pysignalr/client.py +++ b/src/pysignalr/client.py @@ -1,6 +1,7 @@ from __future__ import annotations import uuid +import ssl from collections import defaultdict from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Awaitable, Callable @@ -100,11 +101,13 @@ def __init__( connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT, max_size: int | None = DEFAULT_MAX_SIZE, access_token_factory: Callable[[], str] | None = None, + ssl: ssl.SSLContext | None = None, ) -> None: self._url = url self._protocol = protocol or JSONProtocol() self._headers = headers or {} self._access_token_factory = access_token_factory + self._ssl = ssl self._message_handlers: defaultdict[str, list[MessageCallback | None]] = defaultdict(list) self._stream_handlers: dict[ @@ -121,6 +124,7 @@ def __init__( connection_timeout=connection_timeout, max_size=max_size, access_token_factory=access_token_factory, + ssl=ssl, ) self._error_callback: CompletionMessageCallback | None = None diff --git a/src/pysignalr/transport/websocket.py b/src/pysignalr/transport/websocket.py index 2428ae4..e5e4443 100644 --- a/src/pysignalr/transport/websocket.py +++ b/src/pysignalr/transport/websocket.py @@ -2,6 +2,7 @@ import asyncio import logging +import ssl from contextlib import suppress from http import HTTPStatus from typing import Awaitable, Callable @@ -54,6 +55,7 @@ def __init__( connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT, max_size: int | None = DEFAULT_MAX_SIZE, access_token_factory: Callable[[], str] | None = None, + ssl: ssl.SSLContext | None = None, ): """ Initializes the WebSocket transport with the provided parameters. @@ -79,6 +81,7 @@ def __init__( self._connection_timeout = connection_timeout self._max_size = max_size self._access_token_factory = access_token_factory + self._ssl = ssl self._state = ConnectionState.disconnected self._connected = asyncio.Event() @@ -144,14 +147,27 @@ async def _loop(self) -> None: except ServerConnectionError as e: raise NegotiationTimeout from e - connection_loop = connect( - self._url, - extra_headers=self._headers, - ping_interval=self._ping_interval, - open_timeout=self._connection_timeout, - max_size=self._max_size, - logger=_logger, - ) + # Since websockets interprets the presence of the ssl option as something different than providing None, + # the call needs to be made with or without ssl option to work properly + if self._ssl is None: + connection_loop = connect( + self._url, + extra_headers=self._headers, + ping_interval=self._ping_interval, + open_timeout=self._connection_timeout, + max_size=self._max_size, + logger=_logger, + ) + else: + connection_loop = connect( + self._url, + extra_headers=self._headers, + ping_interval=self._ping_interval, + open_timeout=self._connection_timeout, + max_size=self._max_size, + logger=_logger, + ssl=self._ssl, + ) async for conn in connection_loop: try: