Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WaitForCompletionOrCreateCheckStatusResponseAsync to Microsoft.Azure.Functions.Worker.DurableTaskClientExtensions #2875

Merged
merged 16 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 112 additions & 2 deletions src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -18,6 +19,70 @@ namespace Microsoft.Azure.Functions.Worker;
/// </summary>
public static class DurableTaskClientExtensions
{
/// <summary>
/// Waits for the completion of the specified orchestration instance with a retry interval, controlled by the cancellation token.
/// If the orchestration does not complete within the required time, returns an HTTP response containing the <see cref="HttpManagementPayload"/> class to manage instances.
/// </summary>
/// <param name="client">The <see cref="DurableTaskClient"/>.</param>
/// <param name="request">The HTTP request that this response is for.</param>
/// <param name="instanceId">The ID of the orchestration instance to check.</param>
/// <param name="retryInterval">The timeout between checks for output from the durable function. The default value is 1 second.</param>
/// <param name="returnInternalServerErrorOnFailure">Optional parameter that configures the http response code returned. Defaults to <c>false</c>.</param>
/// <param name="getInputsAndOutputs">Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to <c>false</c>.</param>
/// <param name="cancellation">A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload.</param>
/// <returns></returns>
public static async Task<HttpResponseData> WaitForCompletionOrCreateCheckStatusResponseAsync(
this DurableTaskClient client,
HttpRequestData request,
string instanceId,
TimeSpan? retryInterval = null,
bool returnInternalServerErrorOnFailure = false,
bool getInputsAndOutputs = false,
CancellationToken cancellation = default
)
{
TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1);
try
{
while (true)
{
var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs);
if (status != null)
{
if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
#pragma warning disable CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
#pragma warning restore CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
{
var response = request.CreateResponse(
(status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK);
await response.WriteAsJsonAsync(new
{
Name = status.Name,
InstanceId = status.InstanceId,
CreatedAt = status.CreatedAt,
LastUpdatedAt = status.LastUpdatedAt,
RuntimeStatus = status.RuntimeStatus.ToString(), // Convert enum to string
SerializedInput = status.SerializedInput,
SerializedOutput = status.SerializedOutput,
SerializedCustomStatus = status.SerializedCustomStatus
}, statusCode: response.StatusCode);

return response;
}
}
await Task.Delay(retryIntervalLocal, cancellation);
}
}
// If the task is canceled, call CreateCheckStatusResponseAsync to return a response containing instance management URLs.
catch (OperationCanceledException)
{
return await CreateCheckStatusResponseAsync(client, request, instanceId);
}
}

/// <summary>
/// Creates an HTTP response that is useful for checking the status of the specified instance.
/// </summary>
Expand Down Expand Up @@ -170,13 +235,13 @@ static string BuildUrl(string url, params string?[] queryValues)
// The base URL could be null if:
// 1. The DurableTaskClient isn't a FunctionsDurableTaskClient (which would have the baseUrl from bindings)
// 2. There's no valid HttpRequestData provided
string? baseUrl = ((request != null) ? request.Url.GetLeftPart(UriPartial.Authority) : GetBaseUrl(client));
string? baseUrl = ((request != null) ? GetBaseUrlFromRequest(request) : GetBaseUrl(client));

if (baseUrl == null)
{
throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload.");
}

bool isFromRequest = request != null;

string formattedInstanceId = Uri.EscapeDataString(instanceId);
Expand Down Expand Up @@ -214,6 +279,51 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response)
?? throw new InvalidOperationException("A serializer is not configured for the worker.");
}

private static string? GetBaseUrlFromRequest(HttpRequestData request)
nytian marked this conversation as resolved.
Show resolved Hide resolved
{
// Default to the scheme from the request URL
string proto = request.Url.Scheme;
string host = request.Url.Authority;

// Check for "Forwarded" header
if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders))
{
var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';')
.Select(pair => pair.Split('='))
.Where(pair => pair.Length == 2)
.ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim());

if (forwardedDict != null)
{
if (forwardedDict.TryGetValue("proto", out var forwardedProto))
{
proto = forwardedProto;
}
if (forwardedDict.TryGetValue("host", out var forwardedHost))
{
host = forwardedHost;
// Return if either proto or host (or both) were found in "Forwarded" header
return $"{proto}://{forwardedHost}";
}
}
}
// Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present
if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos))
{
proto = protos.FirstOrDefault() ?? proto;
}
if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts))
{
// Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found
host = hosts.FirstOrDefault() ?? host;
return $"{proto}://{host}";
}

// Construct and return the base URL from default fallback values
return $"{proto}://{host}";
}


private static string? GetQueryParams(DurableTaskClient client)
{
return client is FunctionsDurableTaskClient functions ? functions.QueryString : null;
Expand Down
Loading
Loading