diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs index 6d5923fae..84275f473 100644 --- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs +++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.Linq; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -60,7 +61,8 @@ public static async Task WaitForCompletionOrCreateCheckStatusR status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated || status.RuntimeStatus == OrchestrationRuntimeStatus.Failed) { - var response = request.CreateResponse(HttpStatusCode.OK); + var response = request.CreateResponse( + (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)? HttpStatusCode.InternalServerError: HttpStatusCode.OK); await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId) { CreatedAt = status.CreatedAt, @@ -69,12 +71,7 @@ await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.In SerializedInput = status.SerializedInput, SerializedOutput = status.SerializedOutput, SerializedCustomStatus = status.SerializedCustomStatus, - }); - - if (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) - { - response.StatusCode = HttpStatusCode.InternalServerError; - } + }, statusCode: response.StatusCode); return response; } @@ -245,7 +242,7 @@ static string BuildUrl(string url, params string?[] queryValues) // The base URL could be null if: // 1. The DurableTaskClient isn't a FunctionsDurableTaskClient (which would have the baseUrl from bindings) // 2. There's no valid HttpRequestData provided - string? baseUrl = ((request != null) ? request.Url.GetLeftPart(UriPartial.Authority) : GetBaseUrl(client)); + string? baseUrl = ((request != null) ? GetBaseUrlFromRequest(request) : GetBaseUrl(client)); if (baseUrl == null) { @@ -289,6 +286,50 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response) ?? throw new InvalidOperationException("A serializer is not configured for the worker."); } + private static string? GetBaseUrlFromRequest(HttpRequestData request) + { + // Default to the scheme from the request URL + string proto = request.Url.Scheme; + string baseUrl; + + // Check for "Forwarded" header + if (request.Headers.TryGetValues("Forwarded", out var forwarded)) + { + var forwardedDict = (forwarded.FirstOrDefault() ?? "").Split(';') + .Select(pair => pair.Split('=')) + .Where(pair => pair.Length == 2) // Ensure valid key-value pairs + .ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim()); + + if (forwardedDict.TryGetValue("proto", out var forwardedProto)) + { + proto = forwardedProto; + } + + if (forwardedDict.TryGetValue("host", out var forwardedHost)) + { + baseUrl = $"{proto}://{forwardedHost}"; + return baseUrl; + } + } + + // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers + if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos)) + { + proto = protos.First(); + } + + if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts)) + { + baseUrl = $"{proto}://{hosts.First()}"; + return baseUrl; + } + + // Fallback to using the request's URL if no forwarding headers are found + baseUrl = $"{proto}://{request.Url.Authority}"; + return baseUrl; + } + + private static string? GetQueryParams(DurableTaskClient client) { return client is FunctionsDurableTaskClient functions ? functions.QueryString : null; diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs index d17a635db..a52e1a713 100644 --- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs +++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs @@ -92,6 +92,8 @@ public void CreateHttpManagementPayload_WithHttpRequestData() // Create mock HttpRequestData object. var mockFunctionContext = new Mock(); var mockHttpRequestData = new Mock(mockFunctionContext.Object); + var headers = new HttpHeadersCollection(); + mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers); mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri(requestUrl)); HttpManagementPayload payload = client.CreateHttpManagementPayload(instanceId, mockHttpRequestData.Object); @@ -269,7 +271,12 @@ private HttpRequestData MockHttpRequestAndResponseData() // Set up the URL property. mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("http://localhost:7075/orchestrators/E1_HelloSequence")); + + var headers = new HttpHeadersCollection(); + // Setup the Headers property to return the empty headers + mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers); + var mockHttpResponseData = new Mock(mockFunctionContext.Object) { DefaultValue = DefaultValue.Mock