diff --git a/contributing/samples/live_agent_api_server_example/live_agent_example.py b/contributing/samples/live_agent_api_server_example/live_agent_example.py index da06be086a..66302a9e12 100644 --- a/contributing/samples/live_agent_api_server_example/live_agent_example.py +++ b/contributing/samples/live_agent_api_server_example/live_agent_example.py @@ -168,6 +168,7 @@ def init_pyaudio_playback(): try: pya_interface_instance.terminate() except: + # TODO: be more specific about exception type pass pya_interface_instance = None pya_output_stream_instance = None diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index afedb7387a..933b21fb15 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import re from contextlib import asynccontextmanager import importlib import json @@ -757,16 +758,25 @@ async def internal_lifespan(app: FastAPI): # Run the FastAPI server. app = FastAPI(lifespan=internal_lifespan) + # Store parsed allow_origins for WebSocket Origin validation + # (CORS middleware doesn't apply to WebSocket upgrades) + _ws_allowed_origins: tuple[list[str], Optional[re.Pattern[str]], bool] = ( + [], + None, + False, + ) if allow_origins: - literal_origins, combined_regex = _parse_cors_origins(allow_origins) + literal_origins, combined_regex_str = _parse_cors_origins(allow_origins) + compiled_regex = re.compile(combined_regex_str) if combined_regex_str else None app.add_middleware( CORSMiddleware, allow_origins=literal_origins, - allow_origin_regex=combined_regex, + allow_origin_regex=combined_regex_str, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) + _ws_allowed_origins = (literal_origins, compiled_regex, True) @app.get("/health") async def health() -> dict[str, str]: @@ -1802,6 +1812,24 @@ async def run_agent_live( enable_affective_dialog: bool | None = Query(default=None), enable_session_resumption: bool | None = Query(default=None), ) -> None: + # Validate Origin header to prevent cross-origin WebSocket hijacking. + # WebSocket connections are not protected by CORS, so we must validate + # the Origin ourselves. See https://github.com/google/adk-python/issues/4947 + origin = websocket.headers.get("origin") + literal_origins, compiled_regex, origins_configured = _ws_allowed_origins + if origins_configured: + # CORS origins were configured: allow only listed origins + allowed = origin in literal_origins or ( + compiled_regex and origin and compiled_regex.match(origin) + ) + elif origin: + # No CORS config: only allow same-origin requests + allowed = False + else: + allowed = True # No Origin header (non-browser client) + if not allowed: + await websocket.close(code=1008, reason="Origin not allowed") + return await websocket.accept() session = await self.session_service.get_session( diff --git a/tests/unittests/cli/test_adk_web_server_run_live.py b/tests/unittests/cli/test_adk_web_server_run_live.py index 1c3c42593c..9a0da8159e 100644 --- a/tests/unittests/cli/test_adk_web_server_run_live.py +++ b/tests/unittests/cli/test_adk_web_server_run_live.py @@ -203,3 +203,90 @@ async def _get_runner_async(_self, _app_name: str): run_config.session_resumption.transparent is expected_session_resumption_transparent ) + + +def _make_app(allow_origins=None): + """Helper to create a test FastAPI app with optional allow_origins.""" + session_service = InMemorySessionService() + asyncio.run( + session_service.create_session( + app_name="test_app", + user_id="user", + session_id="session", + state={}, + ) + ) + + runner = _CapturingRunner() + adk_web_server = AdkWebServer( + agent_loader=_DummyAgentLoader(), + session_service=session_service, + memory_service=types.SimpleNamespace(), + artifact_service=types.SimpleNamespace(), + credential_service=types.SimpleNamespace(), + eval_sets_manager=types.SimpleNamespace(), + eval_set_results_manager=types.SimpleNamespace(), + agents_dir=".", + ) + + async def _get_runner_async(_self, _app_name: str): + return runner + + adk_web_server.get_runner_async = _get_runner_async.__get__(adk_web_server) # pytype: disable=attribute-error + + fast_api_app = adk_web_server.get_fast_api_app( + setup_observer=lambda _observer, _server: None, + tear_down_observer=lambda _observer, _server: None, + allow_origins=allow_origins, + ) + return TestClient(fast_api_app) + + +def test_websocket_rejects_cross_origin_without_config(): + """WebSocket without allow_origins config rejects cross-origin requests. + + Regression test for https://github.com/google/adk-python/issues/4947 + """ + client = _make_app(allow_origins=None) + url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT" + + # Simulate a cross-origin request by manually providing an Origin header + with pytest.raises(Exception) as exc_info: + client.websocket_connect(url, headers={"origin": "http://evil.com"}) + + # Connection should be rejected (1008 or connection error) + assert "1008" in str(exc_info.value) or "WebSocket" in str(type(exc_info.value).__name__) + + +def test_websocket_accepts_same_origin_without_config(): + """WebSocket without allow_origins accepts requests without Origin header (non-browser clients).""" + client = _make_app(allow_origins=None) + url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT" + + # No Origin header = non-browser client = allowed + with client.websocket_connect(url) as ws: + _ = ws.receive_text() + + +def test_websocket_accepts_configured_origin(): + """WebSocket accepts when origin matches the configured allow_origins list.""" + client = _make_app(allow_origins=["http://localhost:8000"]) + url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT" + + with client.websocket_connect( + url, headers={"origin": "http://localhost:8000"} + ) as ws: + _ = ws.receive_text() + + +def test_websocket_rejects_unlisted_origin(): + """WebSocket rejects when origin is not in the configured allow_origins list.""" + client = _make_app(allow_origins=["http://localhost:8000"]) + url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT" + + with pytest.raises(Exception) as exc_info: + client.websocket_connect( + url, headers={"origin": "http://evil.com"} + ) + + assert "1008" in str(exc_info.value) or "WebSocket" in str(type(exc_info.value).__name__)