Skip to content

Commit 2637486

Browse files
authored
Reconnect fix (#22)
1 parent 234e8ec commit 2637486

File tree

3 files changed

+35
-13
lines changed

3 files changed

+35
-13
lines changed

src/pysignalr/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
from websockets.exceptions import InvalidStatusCode
1313

1414

15-
class NegotiationTimeout(Exception):
15+
class NegotiationFailure(Exception):
1616
"""
17-
Exception raised when the connection URL generated during negotiation is no longer valid.
17+
Exception raised when the connection fails.
1818
"""
1919
pass
2020

21-
2221
async def __aiter__(
2322
self: websockets.legacy.client.Connect,
2423
) -> AsyncIterator[websockets.legacy.client.WebSocketClientProtocol]:
@@ -43,12 +42,9 @@ async def __aiter__(
4342
async with self as protocol:
4443
yield protocol
4544

46-
# Handle expired connection URLs by raising a NegotiationTimeout exception.
47-
except InvalidStatusCode as e:
48-
if e.status_code == HTTPStatus.NOT_FOUND:
49-
raise NegotiationTimeout from e
50-
except asyncio.TimeoutError as e:
51-
raise NegotiationTimeout from e
45+
# Handle expired connection URLs by raising a NegotiationFailure exception.
46+
except (InvalidStatusCode, asyncio.TimeoutError) as e:
47+
raise NegotiationFailure from e
5248

5349
except Exception:
5450
# Add a random initial delay between 0 and 5 seconds.

src/pysignalr/client.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
DEFAULT_CONNECTION_TIMEOUT,
2828
DEFAULT_MAX_SIZE,
2929
DEFAULT_PING_INTERVAL,
30+
DEFAULT_RETRY_SLEEP,
31+
DEFAULT_RETRY_MULTIPLIER,
32+
DEFAULT_RETRY_COUNT,
3033
WebsocketTransport,
3134
)
3235

@@ -35,7 +38,6 @@
3538
MessageCallback = Callable[[Message], Awaitable[None]]
3639
CompletionMessageCallback = Callable[[CompletionMessage], Awaitable[None]]
3740

38-
3941
class ClientStream:
4042
"""
4143
Client to server streaming implementation.
@@ -100,6 +102,9 @@ def __init__(
100102
ping_interval: int = DEFAULT_PING_INTERVAL,
101103
connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT,
102104
max_size: int | None = DEFAULT_MAX_SIZE,
105+
retry_sleep: float = DEFAULT_RETRY_SLEEP,
106+
retry_multiplier: float = DEFAULT_RETRY_MULTIPLIER,
107+
retry_count: int = DEFAULT_RETRY_COUNT,
103108
access_token_factory: Callable[[], str] | None = None,
104109
ssl: ssl.SSLContext | None = None,
105110
) -> None:
@@ -121,6 +126,9 @@ def __init__(
121126
callback=self._on_message,
122127
headers=self._headers,
123128
ping_interval=ping_interval,
129+
retry_sleep=retry_sleep,
130+
retry_multiplier=retry_multiplier,
131+
retry_count=retry_count,
124132
connection_timeout=connection_timeout,
125133
max_size=max_size,
126134
access_token_factory=access_token_factory,

src/pysignalr/transport/websocket.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from websockets.protocol import State
1515

1616
import pysignalr.exceptions as exceptions
17-
from pysignalr import NegotiationTimeout
17+
from pysignalr import NegotiationFailure
1818
from pysignalr.messages import CompletionMessage, Message, PingMessage
1919
from pysignalr.protocol.abstract import Protocol
2020
from pysignalr.transport.abstract import ConnectionState, Transport
@@ -24,6 +24,10 @@
2424
DEFAULT_PING_INTERVAL = 10
2525
DEFAULT_CONNECTION_TIMEOUT = 10
2626

27+
DEFAULT_RETRY_SLEEP = 1
28+
DEFAULT_RETRY_MULTIPLIER = 1.1
29+
DEFAULT_RETRY_COUNT = 10
30+
2731
_logger = logging.getLogger('pysignalr.transport')
2832

2933

@@ -53,6 +57,9 @@ def __init__(
5357
skip_negotiation: bool = False,
5458
ping_interval: int = DEFAULT_PING_INTERVAL,
5559
connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT,
60+
retry_sleep: float = DEFAULT_RETRY_SLEEP,
61+
retry_multiplier: float = DEFAULT_RETRY_MULTIPLIER,
62+
retry_count: int = DEFAULT_RETRY_COUNT,
5663
max_size: int | None = DEFAULT_MAX_SIZE,
5764
access_token_factory: Callable[[], str] | None = None,
5865
ssl: ssl.SSLContext | None = None,
@@ -81,6 +88,9 @@ def __init__(
8188
self._connection_timeout = connection_timeout
8289
self._max_size = max_size
8390
self._access_token_factory = access_token_factory
91+
self._retry_sleep = retry_sleep
92+
self._retry_multiplier = retry_multiplier
93+
self._retry_count = retry_count
8494
self._ssl = ssl
8595

8696
self._state = ConnectionState.disconnected
@@ -121,9 +131,17 @@ async def run(self) -> None:
121131
Runs the WebSocket transport, managing the connection lifecycle.
122132
"""
123133
while True:
124-
with suppress(NegotiationTimeout):
134+
try:
125135
await self._loop()
126-
await self._set_state(ConnectionState.disconnected)
136+
except NegotiationFailure as e:
137+
await self._set_state(ConnectionState.disconnected)
138+
self._retry_count -= 1
139+
if self._retry_count <= 0:
140+
raise e
141+
self._retry_sleep *= self._retry_multiplier
142+
await asyncio.sleep(self._retry_sleep)
143+
else:
144+
await self._set_state(ConnectionState.disconnected)
127145

128146
async def send(self, message: Message) -> None:
129147
"""

0 commit comments

Comments
 (0)