Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ client = PolymarketUS(
> **Note**: WebSocket connections are async-only due to their event-driven nature.
> Use `asyncio.run()` when working with the sync client, or use `AsyncPolymarketUS` directly.

> **Reconnection**: connections automatically reconnect with exponential backoff
> on unexpected drops, re-sign the auth handshake, and replay every active
> subscription. A `reconnect` event fires after a successful reconnect. Reconnect
> stops on fatal auth failures (401/403/429). Disable with `auto_reconnect=False`.
> Note that `order`, `position`, and `trade` streams do not replay history on
> reconnect; resubscribe to `SUBSCRIPTION_TYPE_ORDER_SNAPSHOT` if you need current
> open orders, while market data and account balance snapshots are sent automatically.

```python
import asyncio
import os
Expand Down Expand Up @@ -297,6 +305,7 @@ WebSocket methods (`connect()`, `subscribe()`, `close()`) are async and must be
- `account_balance_snapshot` - Initial balance
- `account_balance_update` - Balance changes
- `heartbeat` - Connection keepalive
- `reconnect` - Reconnected and resubscribed after a drop
- `error` - Error events
- `close` - Connection closed

Expand All @@ -305,6 +314,7 @@ WebSocket methods (`connect()`, `subscribe()`, `close()`) are async and must be
- `market_data_lite` - Lightweight price data
- `trade` - Trade notifications
- `heartbeat` - Connection keepalive
- `reconnect` - Reconnected and resubscribed after a drop
- `error` - Error events
- `close` - Connection closed

Expand Down
183 changes: 156 additions & 27 deletions polymarket_us/websocket/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Base WebSocket class."""
"""Base WebSocket class with automatic reconnect and resubscribe."""

from __future__ import annotations

import asyncio
import contextlib
import json
import random
from collections.abc import Callable
from typing import Any

Expand All @@ -16,9 +17,49 @@

from .types import MarketSubscriptionType, PrivateSubscriptionType

# WebSocket upgrade failures with these statuses are fatal (bad credentials or
# rate limiting) and must not trigger reconnect attempts.
_FATAL_AUTH_STATUSES = frozenset({401, 403, 429})

_RECONNECT_INITIAL_SECONDS = 0.5
_RECONNECT_MAX_SECONDS = 30.0


def _reconnect_delay(attempt: int) -> float:
"""Exponential reconnect backoff with equal jitter (attempt is 0-indexed)."""
capped = min(_RECONNECT_INITIAL_SECONDS * (2**attempt), _RECONNECT_MAX_SECONDS)
return capped / 2 + random.random() * (capped / 2)


def _upgrade_status(exc: Exception) -> int | None:
"""Extract the HTTP status from a failed WebSocket upgrade, if available.

Handles both the modern (``exc.response.status_code``) and legacy
(``exc.status_code``) ``websockets`` exception shapes.
"""
response = getattr(exc, "response", None)
if response is not None:
status = getattr(response, "status_code", None)
if isinstance(status, int):
return status
status = getattr(exc, "status_code", None)
return status if isinstance(status, int) else None


class _Subscription:
"""A subscription the client should replay after a reconnect."""

def __init__(
self,
subscription_type: PrivateSubscriptionType | MarketSubscriptionType,
market_slugs: list[str] | None,
) -> None:
self.subscription_type = subscription_type
self.market_slugs = market_slugs


class BaseWebSocket:
"""Base WebSocket class with event emitter pattern."""
"""Base WebSocket class with an event emitter and resilient connection."""

def __init__(
self,
Expand All @@ -27,6 +68,8 @@ def __init__(
secret_key: str,
base_url: str = "wss://api.polymarket.us",
path: str,
auto_reconnect: bool = True,
reconnect_max_attempts: int | None = None,
) -> None:
"""Initialize WebSocket.

Expand All @@ -35,40 +78,117 @@ def __init__(
secret_key: Base64-encoded Ed25519 secret key
base_url: WebSocket base URL
path: WebSocket endpoint path
auto_reconnect: Reconnect and replay subscriptions on unexpected drops
reconnect_max_attempts: Max reconnect attempts per drop (None = unlimited)
"""
self.key_id = key_id
self.secret_key = secret_key
self.base_url = base_url
self.path = path
self.auto_reconnect = auto_reconnect
self.reconnect_max_attempts = reconnect_max_attempts
self._ws: ClientConnection | None = None
self._listeners: dict[str, list[Callable[..., Any]]] = {}
self._once_listeners: dict[str, list[Callable[..., Any]]] = {}
self._message_task: asyncio.Task[None] | None = None
self._run_task: asyncio.Task[None] | None = None
self._subscriptions: dict[str, _Subscription] = {}
self._closed = False

async def connect(self) -> None:
"""Establish WebSocket connection."""
"""Establish the WebSocket connection and start processing messages."""
self._closed = False
await self._open_socket()
self._emit("open")
self._run_task = asyncio.create_task(self._run())

async def _open_socket(self) -> None:
"""Open a socket with a freshly signed auth handshake."""
url = f"{self.base_url}{self.path}"
# Re-sign on every (re)connect: the timestamp must be within the skew window.
headers = create_auth_headers(self.key_id, self.secret_key, "GET", self.path)

self._ws = await websockets.connect(url, additional_headers=headers)
self._emit("open")

# Start message handler
self._message_task = asyncio.create_task(self._message_loop())

async def _message_loop(self) -> None:
"""Process incoming messages."""
if not self._ws:
return
try:
async for message in self._ws:
if isinstance(message, bytes):
message = message.decode("utf-8")
self._handle_message(message)
except websockets.ConnectionClosed:
self._emit("close")
except Exception as e:
self._emit("error", PolymarketUSError(str(e)))
async def _run(self) -> None:
"""Read messages, reconnecting and resubscribing on unexpected drops."""
while True:
try:
if self._ws is None:
break
async for message in self._ws:
if isinstance(message, bytes):
message = message.decode("utf-8")
self._handle_message(message)
except websockets.ConnectionClosed:
pass
except Exception as e:
self._emit("error", PolymarketUSError(str(e)))

# The loop may have exited on a still-open socket (e.g. a handler
# error rather than a drop). Close it before reconnecting so the old
# connection isn't leaked when _open_socket overwrites self._ws.
with contextlib.suppress(Exception):
if self._ws is not None:
await self._ws.close(1000, "OK")

if self._closed or not self.auto_reconnect:
if not self._closed:
self._emit("close")
return

if not await self._reconnect():
if not self._closed:
self._emit("close")
return
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.

async def _reconnect(self) -> bool:
"""Reconnect with backoff and replay subscriptions. Returns success."""
attempt = 0
while not self._closed and (
self.reconnect_max_attempts is None or attempt < self.reconnect_max_attempts
):
await asyncio.sleep(_reconnect_delay(attempt))
if self._closed:
return False
try:
await self._open_socket()
except Exception as e:
status = _upgrade_status(e)
if status in _FATAL_AUTH_STATUSES:
self._emit("error", PolymarketUSError(f"WebSocket auth failed ({status})"))
return False
attempt += 1
continue
# The user may have called close() while the upgrade was in flight.
if self._closed:
return False
# If the fresh connection drops mid-replay, treat it as another
# failed attempt rather than letting the exception kill the task.
try:
await self._resubscribe()
except Exception:
# Close the just-opened socket before retrying so it isn't
# orphaned when the next attempt overwrites self._ws.
with contextlib.suppress(Exception):
if self._ws:
await self._ws.close(1000, "OK")
attempt += 1
continue
Comment thread
cursor[bot] marked this conversation as resolved.
self._emit("reconnect")
return True
return False

async def _resubscribe(self) -> None:
"""Replay all active subscriptions after a reconnect."""
for request_id, sub in list(self._subscriptions.items()):
request: dict[str, Any] = {
"subscribe": {
"requestId": request_id,
"subscriptionType": sub.subscription_type,
}
}
if sub.market_slugs:
request["subscribe"]["marketSlugs"] = sub.market_slugs
await self.send(request)

def _handle_message(self, data: str) -> None:
"""Handle incoming message (override in subclasses)."""
Expand All @@ -92,6 +212,9 @@ async def subscribe(
) -> None:
"""Subscribe to a data stream.

The subscription is recorded so it can be replayed automatically after a
reconnect.

Args:
request_id: Unique request ID
subscription_type: Type of subscription
Expand All @@ -106,24 +229,30 @@ async def subscribe(
if market_slugs:
request["subscribe"]["marketSlugs"] = market_slugs
await self.send(request)
self._subscriptions[request_id] = _Subscription(subscription_type, market_slugs)

async def unsubscribe(self, request_id: str) -> None:
"""Unsubscribe from a data stream.

Args:
request_id: Request ID of the subscription to cancel
"""
self._subscriptions.pop(request_id, None)
await self.send({"unsubscribe": {"requestId": request_id}})

async def close(self) -> None:
"""Close the WebSocket connection."""
if self._message_task:
self._message_task.cancel()
"""Close the WebSocket connection and stop reconnecting."""
self._closed = True
# Cancel first so an in-flight reconnect (sleeping or mid-handshake) is
# interrupted rather than left to open a socket close() never sees.
if self._run_task:
self._run_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._message_task
await self._run_task
self._run_task = None
if self._ws:
await self._ws.close(1000, "OK")
self._ws = None
self._ws = None
Comment thread
cursor[bot] marked this conversation as resolved.

@property
def is_connected(self) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion polymarket_us/websocket/markets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Markets WebSocket."""

import json
from typing import Any

from polymarket_us.errors import PolymarketUSError, WebSocketError

Expand All @@ -11,7 +12,7 @@
class MarketsWebSocket(BaseWebSocket):
"""WebSocket for market data (order book, trades)."""

def __init__(self, **kwargs: str) -> None:
def __init__(self, **kwargs: Any) -> None:
"""Initialize markets WebSocket."""
super().__init__(path="/v1/ws/markets", **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion polymarket_us/websocket/private.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Private WebSocket."""

import json
from typing import Any

from polymarket_us.errors import PolymarketUSError, WebSocketError

Expand All @@ -11,7 +12,7 @@
class PrivateWebSocket(BaseWebSocket):
"""WebSocket for private data (orders, positions, balances)."""

def __init__(self, **kwargs: str) -> None:
def __init__(self, **kwargs: Any) -> None:
"""Initialize private WebSocket."""
super().__init__(path="/v1/ws/private", **kwargs)

Expand Down
7 changes: 5 additions & 2 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,15 @@ def test_signature_is_base64(self) -> None:

def test_handles_64_byte_key(self) -> None:
"""Should handle 64-byte keys (uses first 32 bytes)."""
import base64

# 64-byte key (seed + public key), base64 encoded
secret_key_64 = "nWGxne/9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A=" * 2
seed = base64.b64decode("nWGxne/9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A=")
secret_key_64 = base64.b64encode(seed + seed).decode()
# Should not raise
headers = create_auth_headers(
key_id="test",
secret_key=secret_key_64[:88], # 64 bytes in base64
secret_key=secret_key_64,
method="GET",
path="/v1/test",
)
Expand Down
Loading
Loading