diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index afedb7387a..6b775ed38f 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -692,6 +692,7 @@ def get_fast_api_app( register_processors: Callable[[TracerProvider], None] = lambda o: None, otel_to_cloud: bool = False, with_ui: bool = False, + ws_allowed_origins: Optional[set[str]] = None, ): """Creates a FastAPI app for the ADK web server. @@ -1802,6 +1803,13 @@ async def run_agent_live( enable_affective_dialog: bool | None = Query(default=None), enable_session_resumption: bool | None = Query(default=None), ) -> None: + # Validate Origin header to prevent cross-origin WebSocket hijacking. + if ws_allowed_origins is not None: + origin = websocket.headers.get("origin") + if origin and origin not in ws_allowed_origins: + await websocket.close(code=1008, reason="Origin not allowed") + return + await websocket.accept() session = await self.session_service.get_session( diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 8f78c15f9b..8a4f3d2e4d 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -259,10 +259,26 @@ def tear_down_observer(observer: Observer, _: AdkWebServer): web_assets_dir=ANGULAR_DIST_PATH, ) + # Build allowed origins for WebSocket CSRF protection. + ws_allowed_origins: Optional[set[str]] = {f"http://{host}:{port}"} + if host in ("0.0.0.0", "127.0.0.1", "localhost", "::1", "::"): + ws_allowed_origins.update({ + f"http://localhost:{port}", + f"http://127.0.0.1:{port}", + }) + if allow_origins: + for origin in allow_origins: + if origin == "*": + ws_allowed_origins = None + break + if not origin.startswith("regex:"): + ws_allowed_origins.add(origin) + app = adk_web_server.get_fast_api_app( lifespan=lifespan, allow_origins=allow_origins, otel_to_cloud=otel_to_cloud, + ws_allowed_origins=ws_allowed_origins, **extra_fast_api_args, ) diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 0ea28e6683..070f1f445e 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -25,6 +25,7 @@ from unittest.mock import patch from fastapi.testclient import TestClient +from starlette.websockets import WebSocketDisconnect from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App @@ -1658,6 +1659,54 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() +@pytest.fixture +def csrf_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """TestClient with WebSocket origin checking enabled (no allow_origins='*').""" + return _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + allow_origins=None, + ) + + +def test_ws_rejects_cross_origin(csrf_test_client, create_test_session): + """WebSocket connections with a foreign Origin must be rejected.""" + info = create_test_session + with pytest.raises(WebSocketDisconnect) as exc_info: + with csrf_test_client.websocket_connect( + f"/run_live?app_name={info['app_name']}&user_id={info['user_id']}&session_id={info['session_id']}", + headers={"Origin": "http://evil.com"}, + ): + pass + assert exc_info.value.code == 1008 + + +def test_ws_allows_same_origin(csrf_test_client, create_test_session): + """WebSocket connections from the server's own origin must not be rejected.""" + info = create_test_session + try: + with csrf_test_client.websocket_connect( + f"/run_live?app_name={info['app_name']}&user_id={info['user_id']}&session_id={info['session_id']}", + headers={"Origin": "http://127.0.0.1:8000"}, + ): + pass + except WebSocketDisconnect as e: + # Must not be rejected with 1008 (origin check). + # Other close codes (e.g. 1011 from dummy runner) are fine. + assert e.code != 1008 + + def test_agent_run_resume_without_message_success( test_app, create_test_session ):