diff --git a/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/McpHttpUtility.cs b/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/McpHttpUtility.cs index 0bc5219..0a2a84b 100644 --- a/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/McpHttpUtility.cs +++ b/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/McpHttpUtility.cs @@ -5,6 +5,9 @@ using Microsoft.AspNetCore.Http.Features; using Microsoft.Extensions.Primitives; using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; +using ModelContextProtocol.Protocol; using static Microsoft.Azure.Functions.Extensions.Mcp.McpConstants; namespace Microsoft.Azure.Functions.Extensions.Mcp.Http; @@ -39,4 +42,82 @@ internal static void SetSseContext(HttpContext context) context.Features.GetRequiredFeature().DisableBuffering(); } + + internal static async ValueTask ProcessJsonRpcPayloadAsync(HttpRequest request, JsonSerializerOptions options, CancellationToken cancellationToken, bool unwrapOnly = false) + { + // Process the incoming request body as JSON. Support both raw JSON-RPC messages and + // wrapped payloads with the shape: { "isFunctionsMcpResult": true, "content": } + // + // When unwrapOnly is false: If the wrapper is present, deserialize the inner "content" as the JsonRpcMessage. + // Otherwise, deserialize the root object directly as a JsonRpcMessage and return it. + // + // When unwrapOnly is true: If the wrapper is present, replace the request.Body stream with a memory stream + // containing only the inner content. If no wrapper is present, leave the original body intact. + + // If the body is empty, return null. + if (request.ContentLength == null || request.ContentLength == 0) + { + return null; + } + + // Read the request body into a JsonDocument for inspection. + request.EnableBuffering(); + try + { + using var doc = await JsonDocument.ParseAsync(request.Body, cancellationToken: cancellationToken); + var root = doc.RootElement; + + JsonElement messageElement = root; + bool isWrapped = false; + + if (root.ValueKind == JsonValueKind.Object && + root.TryGetProperty("isFunctionsMcpResult", out var marker) && + marker.ValueKind == JsonValueKind.True && + root.TryGetProperty("content", out var content)) + { + messageElement = content; + isWrapped = true; + } + + if (unwrapOnly) + { + if (isWrapped) + { + var inner = messageElement.GetRawText(); + var bytes = System.Text.Encoding.UTF8.GetBytes(inner); + request.Body = new MemoryStream(bytes); + request.ContentLength = bytes.Length; + } + else + { + // Reset position so downstream readers can consume the original body. + request.Body.Seek(0, SeekOrigin.Begin); + } + return null; + } + else + { + var raw = messageElement.GetRawText(); + return JsonSerializer.Deserialize(raw, options); + } + } + finally + { + if (!unwrapOnly) + { + // Reset the request body so it can be read later by other components if necessary. + request.Body.Seek(0, SeekOrigin.Begin); + } + } + } + + internal static async ValueTask ExtractJsonRpcMessageSseAsync(HttpRequest request, JsonSerializerOptions options, CancellationToken cancellationToken) + { + return await ProcessJsonRpcPayloadAsync(request, options, cancellationToken, unwrapOnly: false); + } + + internal static async ValueTask ExtractJsonRpcMessageHttpStreamableAsync(HttpRequest request, CancellationToken cancellationToken) + { + await ProcessJsonRpcPayloadAsync(request, JsonSerializerOptions.Default, cancellationToken, unwrapOnly: true); + } } diff --git a/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/SseRequestHandler.cs b/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/SseRequestHandler.cs index 882e648..37546c3 100644 --- a/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/SseRequestHandler.cs +++ b/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/SseRequestHandler.cs @@ -105,7 +105,7 @@ public async Task HandleMessageRequestAsync(HttpContext context, McpOptions mcpO return; } - var message = await context.Request.ReadFromJsonAsync(McpJsonSerializerOptions.DefaultOptions, context.RequestAborted); + var message = await McpHttpUtility.ExtractJsonRpcMessageSseAsync(context.Request, McpJsonSerializerOptions.DefaultOptions, context.RequestAborted); if (message is null) { diff --git a/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/StreamableHttpRequestHandler.cs b/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/StreamableHttpRequestHandler.cs index f7745ef..6d53c4e 100644 --- a/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/StreamableHttpRequestHandler.cs +++ b/src/Microsoft.Azure.Functions.Extensions.Mcp/Http/StreamableHttpRequestHandler.cs @@ -45,6 +45,10 @@ private async Task HandlePostRequestAsync(HttpContext context) return; } + // If the worker wrapped the JSON-RPC message in the { isFunctionsMcpResult: true, content: ... } + // envelope, unwrap it so the underlying transport receives the raw JSON-RPC payload it expects. + await McpHttpUtility.ExtractJsonRpcMessageHttpStreamableAsync(context.Request, context.RequestAborted); + var session = await GetOrCreateSessionAsync(context, mcpOptions.Value); if (session is null) diff --git a/test/Extensions.Mcp.Tests/McpHttpUtilityTests.cs b/test/Extensions.Mcp.Tests/McpHttpUtilityTests.cs index 739d97d..6fb7926 100644 --- a/test/Extensions.Mcp.Tests/McpHttpUtilityTests.cs +++ b/test/Extensions.Mcp.Tests/McpHttpUtilityTests.cs @@ -3,9 +3,12 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.Azure.Functions.Extensions.Mcp.Http; +using Microsoft.Azure.Functions.Extensions.Mcp.Serialization; using Microsoft.Extensions.Primitives; +using ModelContextProtocol.Protocol; using Moq; -using Microsoft.Azure.Functions.Extensions.Mcp.Http; +using System.Text; namespace Microsoft.Azure.Functions.Extensions.Mcp.Tests; @@ -65,4 +68,134 @@ public void SetSseContext_SetsCorrectHeaders() responseFeatureMock.Verify(feature => feature.DisableBuffering(), Times.Once); } + + [Fact] + public async Task ExtractJsonRpcMessageAsync_UnwrapsWrapper() + { + var wrapper = "{ \"isFunctionsMcpResult\": true, \"content\": { \"jsonrpc\": \"2.0\", \"method\": \"test\" } }"; + + var context = new DefaultHttpContext(); + var bytes = Encoding.UTF8.GetBytes(wrapper); + context.Request.Body = new MemoryStream(bytes); + context.Request.ContentLength = bytes.Length; + + var message = await McpHttpUtility.ExtractJsonRpcMessageSseAsync(context.Request, McpJsonSerializerOptions.DefaultOptions, default); + + Assert.NotNull(message); + Assert.IsType(message); + var notification = (JsonRpcNotification)message!; + Assert.Equal("test", notification.Method); + } + + [Fact] + public async Task ExtractJsonRpcMessageAsync_ParsesRawJsonRpc() + { + var raw = "{ \"jsonrpc\": \"2.0\", \"method\": \"raw\" }"; + + var context = new DefaultHttpContext(); + var bytes = Encoding.UTF8.GetBytes(raw); + context.Request.Body = new MemoryStream(bytes); + context.Request.ContentLength = bytes.Length; + + var message = await McpHttpUtility.ExtractJsonRpcMessageSseAsync(context.Request, McpJsonSerializerOptions.DefaultOptions, default); + + Assert.NotNull(message); + Assert.IsType(message); + var notification = (JsonRpcNotification)message!; + Assert.Equal("raw", notification.Method); + } + + [Fact] + public async Task UnwrapFunctionsMcpPayloadAsync_ReplacesBodyWithInnerContent() + { + var wrapper = "{ \"isFunctionsMcpResult\": true, \"content\": { \"jsonrpc\": \"2.0\", \"method\": \"inner\" } }"; + + var context = new DefaultHttpContext(); + var bytes = Encoding.UTF8.GetBytes(wrapper); + context.Request.Body = new MemoryStream(bytes); + context.Request.ContentLength = bytes.Length; + + await McpHttpUtility.ExtractJsonRpcMessageHttpStreamableAsync(context.Request, default); + + using var sr = new StreamReader(context.Request.Body, Encoding.UTF8); + context.Request.Body.Position = 0; + var bodyText = await sr.ReadToEndAsync(); + + Assert.Contains("\"method\": \"inner\"", bodyText); + Assert.DoesNotContain("isFunctionsMcpResult", bodyText); + } + + [Fact] + public async Task ProcessJsonRpcPayloadAsync_ExtractMode_UnwrapsWrapper() + { + var wrapper = "{ \"isFunctionsMcpResult\": true, \"content\": { \"jsonrpc\": \"2.0\", \"method\": \"test\" } }"; + + var context = new DefaultHttpContext(); + var bytes = Encoding.UTF8.GetBytes(wrapper); + context.Request.Body = new MemoryStream(bytes); + context.Request.ContentLength = bytes.Length; + + var message = await McpHttpUtility.ProcessJsonRpcPayloadAsync(context.Request, McpJsonSerializerOptions.DefaultOptions, default, unwrapOnly: false); + + Assert.NotNull(message); + Assert.IsType(message); + var notification = (JsonRpcNotification)message!; + Assert.Equal("test", notification.Method); + } + + [Fact] + public async Task ProcessJsonRpcPayloadAsync_UnwrapMode_ReplacesBodyWithInnerContent() + { + var wrapper = "{ \"isFunctionsMcpResult\": true, \"content\": { \"jsonrpc\": \"2.0\", \"method\": \"inner\" } }"; + + var context = new DefaultHttpContext(); + var bytes = Encoding.UTF8.GetBytes(wrapper); + context.Request.Body = new MemoryStream(bytes); + context.Request.ContentLength = bytes.Length; + + var result = await McpHttpUtility.ProcessJsonRpcPayloadAsync(context.Request, McpJsonSerializerOptions.DefaultOptions, default, unwrapOnly: true); + + Assert.Null(result); // unwrapOnly mode returns null + + using var sr = new StreamReader(context.Request.Body, Encoding.UTF8); + context.Request.Body.Position = 0; + var bodyText = await sr.ReadToEndAsync(); + + Assert.Contains("\"method\": \"inner\"", bodyText); + Assert.DoesNotContain("isFunctionsMcpResult", bodyText); + } + + [Fact] + public async Task ProcessJsonRpcPayloadAsync_EmptyBody_ReturnsNull() + { + var context = new DefaultHttpContext(); + context.Request.Body = new MemoryStream(); + context.Request.ContentLength = 0; + + var result = await McpHttpUtility.ProcessJsonRpcPayloadAsync(context.Request, McpJsonSerializerOptions.DefaultOptions, default, unwrapOnly: false); + + Assert.Null(result); + } + + [Fact] + public async Task ProcessJsonRpcPayloadAsync_UnwrapMode_NoWrapper_LeavesBodyIntact() + { + var raw = "{ \"jsonrpc\": \"2.0\", \"method\": \"raw\" }"; + + var context = new DefaultHttpContext(); + var bytes = Encoding.UTF8.GetBytes(raw); + context.Request.Body = new MemoryStream(bytes); + context.Request.ContentLength = bytes.Length; + + var result = await McpHttpUtility.ProcessJsonRpcPayloadAsync(context.Request, McpJsonSerializerOptions.DefaultOptions, default, unwrapOnly: true); + + Assert.Null(result); // unwrapOnly mode returns null + + using var sr = new StreamReader(context.Request.Body, Encoding.UTF8); + context.Request.Body.Position = 0; + var bodyText = await sr.ReadToEndAsync(); + + Assert.Equal(raw, bodyText); + Assert.DoesNotContain("isFunctionsMcpResult", bodyText); + } }