Skip to content

Commit

Permalink
Add option to supply external ssl context
Browse files Browse the repository at this point in the history
  • Loading branch information
Ola Lidholm committed Nov 8, 2024
1 parent b207c6e commit ba061c3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
4 changes: 4 additions & 0 deletions src/pysignalr/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import uuid
import ssl
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Awaitable, Callable
Expand Down Expand Up @@ -100,11 +101,13 @@ def __init__(
connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT,
max_size: int | None = DEFAULT_MAX_SIZE,
access_token_factory: Callable[[], str] | None = None,
ssl: ssl.SSLContext | None = None,
) -> None:
self._url = url
self._protocol = protocol or JSONProtocol()
self._headers = headers or {}
self._access_token_factory = access_token_factory
self._ssl = ssl

self._message_handlers: defaultdict[str, list[MessageCallback | None]] = defaultdict(list)
self._stream_handlers: dict[
Expand All @@ -121,6 +124,7 @@ def __init__(
connection_timeout=connection_timeout,
max_size=max_size,
access_token_factory=access_token_factory,
ssl=ssl,
)
self._error_callback: CompletionMessageCallback | None = None

Expand Down
32 changes: 24 additions & 8 deletions src/pysignalr/transport/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import ssl
from contextlib import suppress
from http import HTTPStatus
from typing import Awaitable, Callable
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT,
max_size: int | None = DEFAULT_MAX_SIZE,
access_token_factory: Callable[[], str] | None = None,
ssl: ssl.SSLContext | None = None,
):
"""
Initializes the WebSocket transport with the provided parameters.
Expand All @@ -79,6 +81,7 @@ def __init__(
self._connection_timeout = connection_timeout
self._max_size = max_size
self._access_token_factory = access_token_factory
self._ssl = ssl

self._state = ConnectionState.disconnected
self._connected = asyncio.Event()
Expand Down Expand Up @@ -144,14 +147,27 @@ async def _loop(self) -> None:
except ServerConnectionError as e:
raise NegotiationTimeout from e

connection_loop = connect(
self._url,
extra_headers=self._headers,
ping_interval=self._ping_interval,
open_timeout=self._connection_timeout,
max_size=self._max_size,
logger=_logger,
)
# Since websockets interprets the presence of the ssl option as something different than providing None,
# the call needs to be made with or without ssl option to work properly
if self._ssl is None:
connection_loop = connect(
self._url,
extra_headers=self._headers,
ping_interval=self._ping_interval,
open_timeout=self._connection_timeout,
max_size=self._max_size,
logger=_logger,
)
else:
connection_loop = connect(
self._url,
extra_headers=self._headers,
ping_interval=self._ping_interval,
open_timeout=self._connection_timeout,
max_size=self._max_size,
logger=_logger,
ssl=self._ssl,
)

async for conn in connection_loop:
try:
Expand Down

0 comments on commit ba061c3

Please sign in to comment.