diff --git a/getstream/video/rtc/connection_manager.py b/getstream/video/rtc/connection_manager.py index e781d917..25f59a8f 100644 --- a/getstream/video/rtc/connection_manager.py +++ b/getstream/video/rtc/connection_manager.py @@ -19,11 +19,13 @@ from getstream.video.rtc.connection_utils import ( ConnectionState, SfuConnectionError, + SfuJoinError, ConnectionOptions, connect_websocket, join_call, watch_call, ) +from getstream.video.rtc.coordinator.backoff import exp_backoff from getstream.video.rtc.track_util import ( fix_sdp_msid_semantic, fix_sdp_rtcp_fb, @@ -55,6 +57,7 @@ def __init__( user_id: Optional[str] = None, create: bool = True, subscription_config: Optional[SubscriptionConfig] = None, + max_join_retries: int = 3, **kwargs: Any, ): super().__init__() @@ -68,6 +71,9 @@ def __init__( self.session_id: str = str(uuid.uuid4()) self.join_response: Optional[JoinCallResponse] = None self.local_sfu: bool = False # Local SFU flag for development + if max_join_retries < 0: + raise ValueError("max_join_retries must be >= 0") + self._max_join_retries: int = max_join_retries # Private attributes self._connection_state: ConnectionState = ConnectionState.IDLE @@ -282,6 +288,7 @@ async def _connect_internal( ws_url: Optional[str] = None, token: Optional[str] = None, session_id: Optional[str] = None, + migrating_from_list: Optional[list] = None, ) -> None: """ Internal connection method that handles the core connection logic. @@ -318,12 +325,15 @@ async def _connect_internal( if not (ws_url or token): if self.user_id is None: raise ValueError("user_id is required for joining a call") + last_failed = migrating_from_list[-1] if migrating_from_list else None join_response = await join_call( self.call, self.user_id, "auto", self.create, self.local_sfu, + migrating_from=last_failed, + migrating_from_list=migrating_from_list, **self.kwargs, ) ws_url = join_response.data.credentials.server.ws_endpoint @@ -395,6 +405,8 @@ async def _connect_internal( logger.exception(f"No join response from WebSocket: {sfu_event}") logger.debug(f"WebSocket connected successfully to {ws_url}") + except SfuJoinError: + raise except Exception as e: logger.exception(f"Failed to connect WebSocket to {ws_url}: {e}") raise SfuConnectionError(f"WebSocket connection failed: {e}") from e @@ -427,7 +439,8 @@ async def connect(self): Connect to SFU. This method automatically handles retry logic for transient errors - like "server is full" and network issues. + like "server is full" by requesting a different SFU from the + coordinator. """ logger.info("Connecting to SFU") # Fire-and-forget the coordinator WS connection so we don't block here @@ -445,7 +458,54 @@ def _on_coordinator_task_done(task: asyncio.Task): logger.exception("Coordinator WS task failed") self._coordinator_task.add_done_callback(_on_coordinator_task_done) - await self._connect_internal() + + await self._connect_with_sfu_reassignment() + + async def _connect_with_sfu_reassignment(self) -> None: + """Try connecting to SFU, reassigning to a different one on failure.""" + failed_sfus: list[str] = [] + + # First attempt without delay + attempt = 0 + try: + await self._connect_internal() + return + except SfuJoinError as e: + self._handle_join_failure(e, attempt, failed_sfus) + if self._max_join_retries == 0: + raise + + # Retries with exponential backoff, requesting a different SFU + async for delay in exp_backoff(max_retries=self._max_join_retries, base=0.5): + attempt += 1 + logger.info(f"Retrying in {delay}s with different SFU...") + await asyncio.sleep(delay) + try: + await self._connect_internal( + migrating_from_list=failed_sfus if failed_sfus else None, + ) + return + except SfuJoinError as e: + self._handle_join_failure(e, attempt, failed_sfus) + if attempt >= self._max_join_retries: + raise + + def _handle_join_failure( + self, error: SfuJoinError, attempt: int, failed_sfus: list[str] + ) -> None: + """Track a failed SFU and clean up partial connection state.""" + if self.join_response and self.join_response.credentials: + edge = self.join_response.credentials.server.edge_name + if edge and edge not in failed_sfus: + failed_sfus.append(edge) + logger.warning( + f"SFU join failed (attempt {attempt + 1}/{1 + self._max_join_retries}, " + f"code={error.error_code}). Failed SFUs: {failed_sfus}" + ) + if self._ws_client: + self._ws_client.close() + self._ws_client = None + self.connection_state = ConnectionState.IDLE async def wait(self): """ diff --git a/getstream/video/rtc/connection_utils.py b/getstream/video/rtc/connection_utils.py index bdc8d02f..75036a14 100644 --- a/getstream/video/rtc/connection_utils.py +++ b/getstream/video/rtc/connection_utils.py @@ -58,20 +58,6 @@ "connect_websocket", ] -# Private constants - internal use only -_RETRYABLE_ERROR_PATTERNS = [ - "server is full", - "server overloaded", - "capacity exceeded", - "try again later", - "service unavailable", - "connection timeout", - "network error", - "temporary failure", - "connection refused", - "connection reset", -] - # Public classes and exceptions class ConnectionState(Enum): @@ -94,6 +80,22 @@ class SfuConnectionError(Exception): pass +class SfuJoinError(SfuConnectionError): + """Raised when SFU join fails with a retryable error code.""" + + def __init__(self, message: str, error_code: int = 0, should_retry: bool = False): + super().__init__(message) + self.error_code = error_code + self.should_retry = should_retry + + +_RETRYABLE_SFU_ERROR_CODES = { + 700, # ERROR_CODE_SFU_FULL + 600, # ERROR_CODE_SFU_SHUTTING_DOWN + 301, # ERROR_CODE_CALL_PARTICIPANT_LIMIT_REACHED +} + + @dataclass class ConnectionOptions: """Options for the connection process.""" @@ -175,6 +177,8 @@ async def join_call_coordinator_request( notify: Optional[bool] = None, video: Optional[bool] = None, location: Optional[str] = None, + migrating_from: Optional[str] = None, + migrating_from_list: Optional[list] = None, ) -> StreamResponse[JoinCallResponse]: """Make a request to join a call via the coordinator. @@ -208,6 +212,10 @@ async def join_call_coordinator_request( video=video, data=data, ) + if migrating_from: + json_body["migrating_from"] = migrating_from + if migrating_from_list: + json_body["migrating_from_list"] = migrating_from_list # Make the POST request to join the call return await client.post( @@ -423,6 +431,8 @@ async def connect_websocket( """ logger.info(f"Connecting to WebSocket at {ws_url}") + ws_client = None + success = False try: # Create JoinRequest for WebSocket connection join_request = await create_join_request(token, session_id) @@ -448,34 +458,24 @@ async def connect_websocket( sfu_event = await ws_client.connect() logger.debug("WebSocket connection established") + success = True return ws_client, sfu_event + except SignalingError as e: + if ( + e.error + and hasattr(e.error, "code") + and e.error.code in _RETRYABLE_SFU_ERROR_CODES + ): + raise SfuJoinError( + str(e), + error_code=e.error.code, + should_retry=True, + ) from e + raise except Exception as e: logger.error(f"Failed to connect WebSocket to {ws_url}: {e}") raise SignalingError(f"WebSocket connection failed: {e}") - - -# Private functions -def _is_retryable(retry_state: Any) -> bool: - """Check if an error should be retried. - - Args: - retry_state: The retry state object from tenacity - - Returns: - True if the error should be retried, False otherwise - """ - # Extract the actual exception from the retry state - if hasattr(retry_state, "outcome") and retry_state.outcome.failed: - error = retry_state.outcome.exception() - else: - return False - - # Import here to avoid circular imports - from getstream.video.rtc.signaling import SignalingError - - if not isinstance(error, (SignalingError, SfuConnectionError)): - return False - - error_message = str(error).lower() - return any(pattern in error_message for pattern in _RETRYABLE_ERROR_PATTERNS) + finally: + if ws_client and not success: + ws_client.close() diff --git a/getstream/video/rtc/signaling.py b/getstream/video/rtc/signaling.py index 675dc7aa..38ecabcf 100644 --- a/getstream/video/rtc/signaling.py +++ b/getstream/video/rtc/signaling.py @@ -26,7 +26,9 @@ class SignalingError(Exception): """Exception raised for errors in the signaling process.""" - pass + def __init__(self, message: str, error=None): + super().__init__(message) + self.error = error class WebSocketClient(StreamAsyncIOEventEmitter): @@ -111,8 +113,10 @@ async def connect(self): # Check if the first message is an error if self.first_message and self.first_message.HasField("error"): - error_msg = self.first_message.error.error.message - raise SignalingError(f"Connection failed: {error_msg}") + sfu_error = self.first_message.error.error + raise SignalingError( + f"Connection failed: {sfu_error.message}", error=sfu_error + ) # Check if we got join_response if self.first_message and self.first_message.HasField("join_response"): diff --git a/scripts/test_sfu_connect.py b/scripts/test_sfu_connect.py new file mode 100644 index 00000000..cff1f690 --- /dev/null +++ b/scripts/test_sfu_connect.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +""" +Utility script for testing SFU connection and retry behavior. + +Connects to a call as a given user and logs each step of the connection +process — useful for verifying SFU assignment, retry on transient errors +(e.g. SFU_FULL), and reassignment via the coordinator. + +Environment variables +--------------------- +STREAM_API_KEY — Stream API key (required) +STREAM_API_SECRET — Stream API secret (required) +STREAM_BASE_URL — Coordinator URL (default: Stream cloud). + Set to http://127.0.0.1:3030 for a local coordinator. +USER_ID — User ID to join as (default: "test-user"). +CALL_TYPE — Call type (default: "default"). +CALL_ID — Call ID. If not set, a random UUID is generated. + +Usage +----- + # Connect via cloud coordinator + STREAM_API_KEY=... STREAM_API_SECRET=... \\ + uv run --extra webrtc python scripts/test_sfu_connect.py + + # Connect via local coordinator + STREAM_BASE_URL=http://127.0.0.1:3030 \\ + uv run --extra webrtc python scripts/test_sfu_connect.py +""" + +import asyncio +import logging +import os +import uuid + +from dotenv import load_dotenv + +from getstream import AsyncStream +from getstream.models import CallRequest +from getstream.video.rtc import ConnectionManager + +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +async def run(): + base_url = os.getenv("STREAM_BASE_URL") + user_id = os.getenv("USER_ID", "test-user") + call_type = os.getenv("CALL_TYPE", "default") + call_id = os.getenv("CALL_ID", str(uuid.uuid4())) + + logger.info("Configuration:") + logger.info(f" Coordinator: {base_url or 'cloud (default)'}") + logger.info(f" User: {user_id}") + logger.info(f" Call: {call_type}:{call_id}") + + client_kwargs = {} + if base_url: + client_kwargs["base_url"] = base_url + + client = AsyncStream(timeout=10.0, **client_kwargs) + + call = client.video.call(call_type, call_id) + logger.info("Creating call...") + await call.get_or_create(data=CallRequest(created_by_id=user_id)) + logger.info("Call created") + + cm = ConnectionManager( + call=call, + user_id=user_id, + create=False, + ) + + logger.info("Connecting to SFU...") + + async with cm: + join = cm.join_response + if join and join.credentials: + logger.info(f"Connected to SFU: {join.credentials.server.edge_name}") + logger.info(f" WS endpoint: {join.credentials.server.ws_endpoint}") + logger.info(f" Session ID: {cm.session_id}") + + logger.info("Holding connection for 3s...") + await asyncio.sleep(3) + + logger.info("Leaving call") + + logger.info("Done") + + +if __name__ == "__main__": + asyncio.run(run()) diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py new file mode 100644 index 00000000..91e8bbfe --- /dev/null +++ b/tests/test_connection_manager.py @@ -0,0 +1,172 @@ +import contextlib + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from getstream.video.rtc.connection_manager import ConnectionManager +from getstream.video.rtc.connection_utils import SfuJoinError, SfuConnectionError +from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2 + + +@contextlib.contextmanager +def patched_dependencies(): + """Patch heavy ConnectionManager dependencies for unit testing.""" + with ( + patch("getstream.video.rtc.connection_manager.PeerConnectionManager"), + patch("getstream.video.rtc.connection_manager.NetworkMonitor"), + patch("getstream.video.rtc.connection_manager.ReconnectionManager"), + patch("getstream.video.rtc.connection_manager.RecordingManager"), + patch("getstream.video.rtc.connection_manager.SubscriptionManager"), + patch("getstream.video.rtc.connection_manager.ParticipantsState"), + patch("getstream.video.rtc.connection_manager.Tracer"), + patch( + "getstream.video.rtc.connection_manager.asyncio.sleep", + new_callable=AsyncMock, + ), + ): + yield + + +@pytest.fixture +def connection_manager(request): + """Create a ConnectionManager with mocked heavy dependencies. + + Accepts max_join_retries via indirect parametrize, defaults to 3. + """ + max_join_retries = getattr(request, "param", 3) + with patched_dependencies(): + mock_call = MagicMock() + mock_call.call_type = "default" + mock_call.id = "test_call" + cm = ConnectionManager( + call=mock_call, user_id="user1", max_join_retries=max_join_retries + ) + cm._connect_coordinator_ws = AsyncMock() + yield cm + + +class TestConnectRetry: + """Tests for connect() retry logic when SFU is full.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize("connection_manager", [2], indirect=True) + async def test_retries_on_sfu_join_error_and_passes_failed_sfus( + self, connection_manager + ): + """When SFU is full, connect() should retry with migrating_from_list.""" + cm = connection_manager + call_count = 0 + received_migrating_from_list = [] + + async def mock_connect_internal(migrating_from_list=None, **kwargs): + nonlocal call_count + call_count += 1 + received_migrating_from_list.append( + list(migrating_from_list) if migrating_from_list else None + ) + + if call_count <= 2: + mock_join_response = MagicMock() + mock_join_response.credentials.server.edge_name = ( + f"sfu-node-{call_count}" + ) + cm.join_response = mock_join_response + raise SfuJoinError( + "server is full", + error_code=models_pb2.ERROR_CODE_SFU_FULL, + should_retry=True, + ) + cm.running = True + + cm._connect_internal = mock_connect_internal + + await cm.connect() + + assert call_count == 3 + assert received_migrating_from_list[0] is None + assert received_migrating_from_list[1] == ["sfu-node-1"] + assert received_migrating_from_list[2] == ["sfu-node-1", "sfu-node-2"] + + @pytest.mark.asyncio + @pytest.mark.parametrize("connection_manager", [1], indirect=True) + async def test_raises_after_all_retries_exhausted(self, connection_manager): + """When all retries are exhausted, connect() should raise SfuJoinError.""" + cm = connection_manager + call_count = 0 + + async def always_fail(migrating_from_list=None, **kwargs): + nonlocal call_count + call_count += 1 + mock_join_response = MagicMock() + mock_join_response.credentials.server.edge_name = "sfu-node-1" + cm.join_response = mock_join_response + raise SfuJoinError( + "server is full", + error_code=models_pb2.ERROR_CODE_SFU_FULL, + should_retry=True, + ) + + cm._connect_internal = always_fail + + with pytest.raises(SfuJoinError): + await cm.connect() + + assert call_count == 2 # 1 initial + 1 retry + + @pytest.mark.asyncio + async def test_non_retryable_error_propagates_immediately(self, connection_manager): + """Non-retryable errors should not trigger retry.""" + cm = connection_manager + call_count = 0 + + async def fail_with_generic_error(migrating_from_list=None, **kwargs): + nonlocal call_count + call_count += 1 + raise SfuConnectionError("something went wrong") + + cm._connect_internal = fail_with_generic_error + + with pytest.raises(SfuConnectionError): + await cm.connect() + + assert call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize("connection_manager", [1], indirect=True) + async def test_cleans_up_ws_client_between_retries(self, connection_manager): + """Partial WS state should be cleaned up before retry.""" + cm = connection_manager + call_count = 0 + + first_ws_client = MagicMock() + + async def mock_connect_internal(migrating_from_list=None, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + cm._ws_client = first_ws_client + mock_join_response = MagicMock() + mock_join_response.credentials.server.edge_name = "sfu-node-1" + cm.join_response = mock_join_response + raise SfuJoinError( + "server is full", + error_code=models_pb2.ERROR_CODE_SFU_FULL, + should_retry=True, + ) + cm.running = True + + cm._connect_internal = mock_connect_internal + + await cm.connect() + + assert call_count == 2 + first_ws_client.close.assert_called_once() + assert cm._ws_client is None + + def test_rejects_negative_max_join_retries(self): + """max_join_retries must be >= 0.""" + with ( + patched_dependencies(), + pytest.raises(ValueError, match="max_join_retries must be >= 0"), + ): + ConnectionManager(call=MagicMock(), user_id="user1", max_join_retries=-1) diff --git a/tests/test_connection_utils.py b/tests/test_connection_utils.py new file mode 100644 index 00000000..ff04a0ca --- /dev/null +++ b/tests/test_connection_utils.py @@ -0,0 +1,131 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from getstream.video.rtc.connection_utils import ( + connect_websocket, + ConnectionOptions, + SfuConnectionError, + SfuJoinError, + join_call_coordinator_request, +) +from getstream.video.rtc.signaling import SignalingError +from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2 + + +@pytest.fixture +def mock_ws_client(): + """Patch WebSocketClient and yield the mock instance.""" + with patch("getstream.video.rtc.connection_utils.WebSocketClient") as mock_ws_cls: + mock_ws = AsyncMock() + mock_ws_cls.return_value = mock_ws + yield mock_ws + + +@pytest.fixture +def coordinator_request(): + """Set up a mock coordinator client that captures the request body.""" + mock_call = AsyncMock() + mock_call.call_type = "default" + mock_call.id = "test_call" + mock_call.client.stream.api_key = "key" + mock_call.client.stream.api_secret = "secret" + mock_call.client.stream.base_url = "https://test.url" + + captured_body = {} + + with patch("getstream.video.rtc.connection_utils.user_client") as mock_user_client: + mock_client = AsyncMock() + + async def capture_post(*args, **kwargs): + captured_body.update(kwargs.get("json", {})) + return AsyncMock() + + mock_client.post = capture_post + mock_user_client.return_value = mock_client + yield mock_call, captured_body + + +class TestConnectWebsocket: + @pytest.mark.asyncio + async def test_raises_sfu_join_error_on_sfu_full(self, mock_ws_client): + """connect_websocket should raise SfuJoinError when SFU is full.""" + sfu_error = models_pb2.Error( + code=models_pb2.ERROR_CODE_SFU_FULL, + message="server is full", + should_retry=True, + ) + mock_ws_client.connect = AsyncMock( + side_effect=SignalingError( + "Connection failed: server is full", error=sfu_error + ) + ) + + with pytest.raises(SfuJoinError) as exc_info: + await connect_websocket( + token="test_token", + ws_url="wss://test.url", + session_id="test_session", + options=ConnectionOptions(), + ) + + assert exc_info.value.error_code == models_pb2.ERROR_CODE_SFU_FULL + assert exc_info.value.should_retry is True + assert isinstance(exc_info.value, SfuConnectionError) + + @pytest.mark.asyncio + async def test_non_retryable_error_propagates_as_signaling_error( + self, mock_ws_client + ): + """Non-retryable SignalingError should not become SfuJoinError.""" + sfu_error = models_pb2.Error( + code=models_pb2.ERROR_CODE_PERMISSION_DENIED, + message="permission denied", + should_retry=False, + ) + mock_ws_client.connect = AsyncMock( + side_effect=SignalingError( + "Connection failed: permission denied", error=sfu_error + ) + ) + + with pytest.raises(SignalingError) as exc_info: + await connect_websocket( + token="test_token", + ws_url="wss://test.url", + session_id="test_session", + options=ConnectionOptions(), + ) + + assert not isinstance(exc_info.value, SfuJoinError) + + +class TestJoinCallCoordinatorRequest: + @pytest.mark.asyncio + async def test_includes_migrating_from_in_body(self, coordinator_request): + """migrating_from and migrating_from_list should be included in the request body.""" + mock_call, captured_body = coordinator_request + + await join_call_coordinator_request( + call=mock_call, + user_id="user1", + location="auto", + migrating_from="sfu-london-1", + migrating_from_list=["sfu-london-1", "sfu-paris-2"], + ) + + assert captured_body["migrating_from"] == "sfu-london-1" + assert captured_body["migrating_from_list"] == ["sfu-london-1", "sfu-paris-2"] + + @pytest.mark.asyncio + async def test_omits_migrating_from_when_not_provided(self, coordinator_request): + """migrating_from should not appear in body when not provided.""" + mock_call, captured_body = coordinator_request + + await join_call_coordinator_request( + call=mock_call, + user_id="user1", + location="auto", + ) + + assert "migrating_from" not in captured_body + assert "migrating_from_list" not in captured_body diff --git a/tests/test_signaling.py b/tests/test_signaling.py index 724e0f50..2241f9f0 100644 --- a/tests/test_signaling.py +++ b/tests/test_signaling.py @@ -5,6 +5,7 @@ from getstream.video.rtc.signaling import WebSocketClient, SignalingError from getstream.video.rtc.pb.stream.video.sfu.event import events_pb2 +from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2 class TestWebSocketClient: @@ -129,6 +130,38 @@ async def test_connect_error(self, join_request, mock_websocket): # Clean up client.close() + @pytest.mark.asyncio + async def test_connect_error_preserves_error_code( + self, join_request, mock_websocket + ): + """Test that SignalingError preserves the SFU error code.""" + client = WebSocketClient( + "wss://test.url", join_request, asyncio.get_running_loop() + ) + + # Prepare an SFU FULL error response + error_response = events_pb2.SfuEvent() + error_response.error.error.code = models_pb2.ERROR_CODE_SFU_FULL + error_response.error.error.message = "server is full" + error_response_bytes = error_response.SerializeToString() + + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + + on_open_callback = mock_websocket.call_args[1]["on_open"] + on_open_callback(mock_websocket.return_value) + + on_message_callback = mock_websocket.call_args[1]["on_message"] + on_message_callback(mock_websocket.return_value, error_response_bytes) + + with pytest.raises(SignalingError) as exc_info: + await connect_task + + assert exc_info.value.error is not None + assert exc_info.value.error.code == models_pb2.ERROR_CODE_SFU_FULL + + client.close() + @pytest.mark.asyncio async def test_websocket_error_during_connect(self, join_request, mock_websocket): """Test WebSocket error during connection.""" diff --git a/uv.lock b/uv.lock index 08c32e71..3bc04b64 100644 --- a/uv.lock +++ b/uv.lock @@ -955,7 +955,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiohttp", marker = "extra == 'webrtc'", specifier = ">=3.13.2,<4" }, - { name = "aiortc", marker = "extra == 'webrtc'", specifier = ">=1.14.0,<2" }, + { name = "aiortc", marker = "extra == 'webrtc'", specifier = ">=1.14.0,<1.15.0" }, { name = "av", marker = "extra == 'webrtc'", specifier = ">=14.2.0,<17" }, { name = "dataclasses-json", specifier = ">=0.6.0,<0.7" }, { name = "httpx", specifier = ">=0.28.1" },