Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging

from pydantic import BaseModel, Field
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.responses import Response

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,7 +99,7 @@ def _validate_content_type(self, content_type: str | None) -> bool: # pragma: n

return True

async def validate_request(self, request: Request, is_post: bool = False) -> Response | None:
async def validate_request(self, request: HTTPConnection, is_post: bool = False) -> Response | None:
"""Validate request headers for DNS rebinding protection.

Returns None if validation passes, or an error Response if validation fails.
Expand Down
21 changes: 20 additions & 1 deletion src/mcp/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import deprecated

import mcp.types as types
from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)
Expand All @@ -19,16 +20,34 @@
" the MCP specification; use the streamable HTTP transport instead."
)
@asynccontextmanager
async def websocket_server(scope: Scope, receive: Receive, send: Send):
async def websocket_server(
scope: Scope,
receive: Receive,
send: Send,
security_settings: TransportSecuritySettings | None = None,
):
"""
WebSocket server transport for MCP. This is an ASGI application, suitable to be
used with a framework like Starlette and a server like Hypercorn.

Set `security_settings` to enable Host/Origin header validation before the
handshake is accepted (same settings type as the SSE and Streamable HTTP
transports). When validation fails this raises `ValueError` after rejecting
the handshake.

Deprecated: this transport will be removed in mcp 2.0. WebSocket was never
part of the MCP specification; use the streamable HTTP transport instead.
"""

websocket = WebSocket(scope, receive, send)

security = TransportSecurityMiddleware(security_settings)
error_response = await security.validate_request(websocket, is_post=False)
if error_response is not None:
# Reject the handshake; the ASGI server maps a pre-accept close to HTTP 403.
await websocket.close()
raise ValueError("Request validation failed")

await websocket.accept(subprotocol="mcp")

read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
Expand Down
172 changes: 172 additions & 0 deletions tests/server/test_websocket_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Tests for WebSocket server request validation."""

# pyright: reportDeprecated=false

import logging
import multiprocessing
import socket
import warnings

import pytest
import uvicorn
from starlette.applications import Starlette
from starlette.routing import WebSocketRoute
from starlette.types import Message, Scope
from starlette.websockets import WebSocket
from websockets.asyncio.client import connect
from websockets.exceptions import InvalidStatus
from websockets.typing import Subprotocol

from mcp.server import Server
from mcp.server.transport_security import TransportSecuritySettings
from mcp.server.websocket import websocket_server
from tests.test_helpers import wait_for_server

logger = logging.getLogger(__name__)
SERVER_NAME = "test_ws_security_server"

# This suite intentionally exercises the deprecated WebSocket transport.
pytestmark = pytest.mark.filterwarnings(
"ignore:The WebSocket (client|server) transport is deprecated:DeprecationWarning"
)


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover
"""Run a WebSocket MCP server with the given security settings."""
warnings.filterwarnings("ignore", category=DeprecationWarning)
server = Server(SERVER_NAME)

async def handle_ws(websocket: WebSocket) -> None:
try:
async with websocket_server(
websocket.scope, websocket.receive, websocket.send, security_settings=security_settings
) as streams:
await server.run(streams[0], streams[1], server.create_initialization_options())
except ValueError as exc:
logger.debug(f"WebSocket connection failed validation: {exc}")

app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)])
uvicorn.run(app, host="127.0.0.1", port=port, log_level="error")


def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None):
"""Start the server in a subprocess and wait until it accepts connections."""
process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings))
process.start()
wait_for_server(port)
return process


@pytest.mark.anyio
async def test_ws_security_default_settings(server_port: int) -> None:
"""With no security settings the WebSocket transport accepts any Origin (matches SSE/StreamableHTTP default)."""
process = start_server_process(server_port)
try:
async with connect(
f"ws://127.0.0.1:{server_port}/ws",
subprotocols=[Subprotocol("mcp")],
additional_headers={"Origin": "http://evil.com"},
) as ws:
assert ws.subprotocol == "mcp"
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_ws_security_invalid_origin_header(server_port: int) -> None:
"""An Origin not in allowed_origins is rejected before the handshake completes."""
settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"]
)
process = start_server_process(server_port, settings)
try:
with pytest.raises(InvalidStatus) as exc_info:
async with connect(
f"ws://127.0.0.1:{server_port}/ws",
subprotocols=[Subprotocol("mcp")],
additional_headers={"Origin": "http://evil.com"},
):
pytest.fail("handshake should have been rejected") # pragma: no cover
assert exc_info.value.response.status_code == 403
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_ws_security_invalid_host_header(server_port: int) -> None:
"""A Host not in allowed_hosts is rejected before the handshake completes."""
settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
process = start_server_process(server_port, settings)
try:
with pytest.raises(InvalidStatus) as exc_info:
async with connect(f"ws://127.0.0.1:{server_port}/ws", subprotocols=[Subprotocol("mcp")]):
pytest.fail("handshake should have been rejected") # pragma: no cover
assert exc_info.value.response.status_code == 403
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_ws_security_allowed_origin(server_port: int) -> None:
"""An Origin matching allowed_origins is accepted."""
settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"]
)
process = start_server_process(server_port, settings)
try:
async with connect(
f"ws://127.0.0.1:{server_port}/ws",
subprotocols=[Subprotocol("mcp")],
additional_headers={"Origin": "http://localhost:8080"},
) as ws:
assert ws.subprotocol == "mcp"
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_ws_security_disabled(server_port: int) -> None:
"""Explicitly disabling protection accepts any Origin."""
settings = TransportSecuritySettings(enable_dns_rebinding_protection=False)
process = start_server_process(server_port, settings)
try:
async with connect(
f"ws://127.0.0.1:{server_port}/ws",
subprotocols=[Subprotocol("mcp")],
additional_headers={"Origin": "http://evil.com"},
) as ws:
assert ws.subprotocol == "mcp"
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_ws_security_rejects_before_accept() -> None:
"""A failing validation closes the connection before the handshake is accepted."""
settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
sent: list[Message] = []

async def receive() -> Message:
raise NotImplementedError

async def send(message: Message) -> None:
sent.append(message)

scope: Scope = {"type": "websocket", "headers": [(b"host", b"evil.com")]}
with pytest.raises(ValueError, match="Request validation failed"):
async with websocket_server(scope, receive, send, security_settings=settings):
pytest.fail("should not yield streams") # pragma: no cover

assert [m["type"] for m in sent] == ["websocket.close"]
Loading