Skip to content

Commit

Permalink
add back forword request handling and update test accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
nytian committed Nov 1, 2024
1 parent d01341d commit ffdd114
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 8 deletions.
57 changes: 49 additions & 8 deletions src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -60,7 +61,8 @@ public static async Task<HttpResponseData> WaitForCompletionOrCreateCheckStatusR
status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
{
var response = request.CreateResponse(HttpStatusCode.OK);
var response = request.CreateResponse(
(status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)? HttpStatusCode.InternalServerError: HttpStatusCode.OK);
await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId)
{
CreatedAt = status.CreatedAt,
Expand All @@ -69,12 +71,7 @@ await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.In
SerializedInput = status.SerializedInput,
SerializedOutput = status.SerializedOutput,
SerializedCustomStatus = status.SerializedCustomStatus,
});

if (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)
{
response.StatusCode = HttpStatusCode.InternalServerError;
}
}, statusCode: response.StatusCode);

return response;
}
Expand Down Expand Up @@ -245,7 +242,7 @@ static string BuildUrl(string url, params string?[] queryValues)
// The base URL could be null if:
// 1. The DurableTaskClient isn't a FunctionsDurableTaskClient (which would have the baseUrl from bindings)
// 2. There's no valid HttpRequestData provided
string? baseUrl = ((request != null) ? request.Url.GetLeftPart(UriPartial.Authority) : GetBaseUrl(client));
string? baseUrl = ((request != null) ? GetBaseUrlFromRequest(request) : GetBaseUrl(client));

if (baseUrl == null)
{
Expand Down Expand Up @@ -289,6 +286,50 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response)
?? throw new InvalidOperationException("A serializer is not configured for the worker.");
}

private static string? GetBaseUrlFromRequest(HttpRequestData request)
{
// Default to the scheme from the request URL
string proto = request.Url.Scheme;
string baseUrl;

// Check for "Forwarded" header
if (request.Headers.TryGetValues("Forwarded", out var forwarded))
{
var forwardedDict = (forwarded.FirstOrDefault() ?? "").Split(';')
.Select(pair => pair.Split('='))
.Where(pair => pair.Length == 2) // Ensure valid key-value pairs
.ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim());

if (forwardedDict.TryGetValue("proto", out var forwardedProto))
{
proto = forwardedProto;
}

if (forwardedDict.TryGetValue("host", out var forwardedHost))
{
baseUrl = $"{proto}://{forwardedHost}";
return baseUrl;
}
}

// Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers
if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos))
{
proto = protos.First();
}

if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts))
{
baseUrl = $"{proto}://{hosts.First()}";
return baseUrl;
}

// Fallback to using the request's URL if no forwarding headers are found
baseUrl = $"{proto}://{request.Url.Authority}";
return baseUrl;
}


private static string? GetQueryParams(DurableTaskClient client)
{
return client is FunctionsDurableTaskClient functions ? functions.QueryString : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ public void CreateHttpManagementPayload_WithHttpRequestData()
// Create mock HttpRequestData object.
var mockFunctionContext = new Mock<FunctionContext>();
var mockHttpRequestData = new Mock<HttpRequestData>(mockFunctionContext.Object);
var headers = new HttpHeadersCollection();
mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers);
mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri(requestUrl));

HttpManagementPayload payload = client.CreateHttpManagementPayload(instanceId, mockHttpRequestData.Object);
Expand Down Expand Up @@ -269,7 +271,12 @@ private HttpRequestData MockHttpRequestAndResponseData()

// Set up the URL property.
mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("http://localhost:7075/orchestrators/E1_HelloSequence"));

var headers = new HttpHeadersCollection();

// Setup the Headers property to return the empty headers
mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers);

var mockHttpResponseData = new Mock<HttpResponseData>(mockFunctionContext.Object)
{
DefaultValue = DefaultValue.Mock
Expand Down

0 comments on commit ffdd114

Please sign in to comment.