Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def get_fast_api_app(
register_processors: Callable[[TracerProvider], None] = lambda o: None,
otel_to_cloud: bool = False,
with_ui: bool = False,
ws_allowed_origins: Optional[set[str]] = None,
):
"""Creates a FastAPI app for the ADK web server.

Expand Down Expand Up @@ -1802,6 +1803,13 @@ async def run_agent_live(
enable_affective_dialog: bool | None = Query(default=None),
enable_session_resumption: bool | None = Query(default=None),
) -> None:
# Validate Origin header to prevent cross-origin WebSocket hijacking.
if ws_allowed_origins is not None:
origin = websocket.headers.get("origin")
if origin and origin not in ws_allowed_origins:
await websocket.close(code=1008, reason="Origin not allowed")
return

await websocket.accept()

session = await self.session_service.get_session(
Expand Down
16 changes: 16 additions & 0 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,26 @@ def tear_down_observer(observer: Observer, _: AdkWebServer):
web_assets_dir=ANGULAR_DIST_PATH,
)

# Build allowed origins for WebSocket CSRF protection.
ws_allowed_origins: Optional[set[str]] = {f"http://{host}:{port}"}
if host in ("0.0.0.0", "127.0.0.1", "localhost", "::1", "::"):
ws_allowed_origins.update({
f"http://localhost:{port}",
f"http://127.0.0.1:{port}",
})
if allow_origins:
for origin in allow_origins:
if origin == "*":
ws_allowed_origins = None
break
if not origin.startswith("regex:"):
ws_allowed_origins.add(origin)

app = adk_web_server.get_fast_api_app(
lifespan=lifespan,
allow_origins=allow_origins,
otel_to_cloud=otel_to_cloud,
ws_allowed_origins=ws_allowed_origins,
**extra_fast_api_args,
)

Expand Down
49 changes: 49 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from unittest.mock import patch

from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig
from google.adk.apps.app import App
Expand Down Expand Up @@ -1658,6 +1659,54 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path):
assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists()


@pytest.fixture
def csrf_test_client(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
):
"""TestClient with WebSocket origin checking enabled (no allow_origins='*')."""
return _create_test_client(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
allow_origins=None,
)


def test_ws_rejects_cross_origin(csrf_test_client, create_test_session):
"""WebSocket connections with a foreign Origin must be rejected."""
info = create_test_session
with pytest.raises(WebSocketDisconnect) as exc_info:
with csrf_test_client.websocket_connect(
f"/run_live?app_name={info['app_name']}&user_id={info['user_id']}&session_id={info['session_id']}",
headers={"Origin": "http://evil.com"},
):
pass
assert exc_info.value.code == 1008


def test_ws_allows_same_origin(csrf_test_client, create_test_session):
"""WebSocket connections from the server's own origin must not be rejected."""
info = create_test_session
try:
with csrf_test_client.websocket_connect(
f"/run_live?app_name={info['app_name']}&user_id={info['user_id']}&session_id={info['session_id']}",
headers={"Origin": "http://127.0.0.1:8000"},
):
pass
except WebSocketDisconnect as e:
# Must not be rejected with 1008 (origin check).
# Other close codes (e.g. 1011 from dummy runner) are fine.
assert e.code != 1008


def test_agent_run_resume_without_message_success(
test_app, create_test_session
):
Expand Down