From e07c8e699f7bb9b37931ccb006b38bee40c90647 Mon Sep 17 00:00:00 2001 From: James Kirk Date: Tue, 1 Apr 2025 20:56:49 -0400 Subject: [PATCH] feat: add bearer token pass-through for SSE MCP --- letta/functions/mcp_client/sse_client.py | 3 ++- letta/functions/mcp_client/types.py | 3 +++ letta/server/server.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/letta/functions/mcp_client/sse_client.py b/letta/functions/mcp_client/sse_client.py index d06f955ddd..26fbf7e585 100644 --- a/letta/functions/mcp_client/sse_client.py +++ b/letta/functions/mcp_client/sse_client.py @@ -16,7 +16,8 @@ class SSEMCPClient(BaseMCPClient): def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool: try: - sse_cm = sse_client(url=server_config.server_url) + sse_headers = {'Authorization': f'Bearer {server_config.token}'} if server_config.token else None + sse_cm = sse_client(url=server_config.server_url, headers=sse_headers) sse_transport = self.loop.run_until_complete(asyncio.wait_for(sse_cm.__aenter__(), timeout=timeout)) self.stdio, self.write = sse_transport self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None))) diff --git a/letta/functions/mcp_client/types.py b/letta/functions/mcp_client/types.py index 2d8b7af6d3..4f40f58aca 100644 --- a/letta/functions/mcp_client/types.py +++ b/letta/functions/mcp_client/types.py @@ -22,12 +22,15 @@ class BaseServerConfig(BaseModel): class SSEServerConfig(BaseServerConfig): type: MCPServerType = MCPServerType.SSE server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)") + token: Optional[str] = Field(None, description="The bearer token to use for authentication (if required)") def to_dict(self) -> dict: values = { "transport": "sse", "url": self.server_url, } + if self.token is not None: + values["token"] = self.token return values diff --git a/letta/server/server.py b/letta/server/server.py index 636a528993..1fa6098aef 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1314,6 +1314,7 @@ def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig] server_params = SSEServerConfig( server_name=server_name, server_url=server_params_raw["url"], + token=server_params_raw.get("token"), ) mcp_server_list[server_name] = server_params except Exception as e: