diff --git a/eng/ci/code-mirror.yml b/eng/ci/code-mirror.yml index 0a2196b95..40f2d2a0f 100644 --- a/eng/ci/code-mirror.yml +++ b/eng/ci/code-mirror.yml @@ -5,6 +5,7 @@ trigger: # Keep this set limited as appropriate (don't mirror individual user branches). - main - dev + - v3.x resources: repositories: diff --git a/eng/ci/official-build.yml b/eng/ci/official-build.yml index e7a871026..858ca3518 100644 --- a/eng/ci/official-build.yml +++ b/eng/ci/official-build.yml @@ -7,6 +7,7 @@ trigger: include: - main - dev + - v3.x # CI only, does not trigger on PRs. pr: none diff --git a/release_notes.md b/release_notes.md index fa7883271..812793b79 100644 --- a/release_notes.md +++ b/release_notes.md @@ -5,10 +5,12 @@ ### New Features - Fail fast if extendedSessionsEnabled set to 'true' for the worker type that doesn't support extended sessions (https://github.com/Azure/azure-functions-durable-extension/pull/2732). +- Added an `IFunctionsWorkerApplicationBuilder.ConfigureDurableExtension()` extension method for cases where auto-registration does not work (no source gen running). (#2950) ### Bug Fixes - Fix custom connection name not working when using IDurableClientFactory.CreateClient() - contributed by [@hctan](https://github.com/hctan) +- Made durable extension for isolated worker configuration idempotent, allowing multiple calls safely. (#2950) ### Breaking Changes diff --git a/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs b/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs index 2920a0f37..ea5350463 100644 --- a/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs +++ b/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs @@ -35,6 +35,16 @@ internal class AzureStorageDurabilityProvider : DurabilityProvider private readonly JObject storageOptionsJson; private readonly ILogger logger; + private readonly object initLock = new object(); + +#if !FUNCTIONS_V1 + private DurableTaskScaleMonitor singletonScaleMonitor; +#endif + +#if FUNCTIONS_V3_OR_GREATER + private DurableTaskTargetScaler singletonTargetScaler; +#endif + public AzureStorageDurabilityProvider( AzureStorageOrchestrationService service, IStorageAccountProvider storageAccountProvider, @@ -226,12 +236,11 @@ internal static OrchestrationInstanceStatusQueryCondition ConvertWebjobsDurableC #if !FUNCTIONS_V1 internal DurableTaskMetricsProvider GetMetricsProvider( - string functionName, string hubName, CloudStorageAccount storageAccount, ILogger logger) { - return new DurableTaskMetricsProvider(functionName, hubName, logger, performanceMonitor: null, storageAccount); + return new DurableTaskMetricsProvider(hubName, logger, performanceMonitor: null, storageAccount); } /// @@ -242,16 +251,22 @@ public override bool TryGetScaleMonitor( string connectionName, out IScaleMonitor scaleMonitor) { - CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount(); - DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(functionName, hubName, storageAccount, this.logger); - scaleMonitor = new DurableTaskScaleMonitor( - functionId, - functionName, - hubName, - storageAccount, - this.logger, - metricsProvider); - return true; + lock (this.initLock) + { + if (this.singletonScaleMonitor == null) + { + CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount(); + DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(hubName, storageAccount, this.logger); + this.singletonScaleMonitor = new DurableTaskScaleMonitor( + hubName, + storageAccount, + this.logger, + metricsProvider); + } + + scaleMonitor = this.singletonScaleMonitor; + return true; + } } #endif @@ -263,11 +278,23 @@ public override bool TryGetTargetScaler( string connectionName, out ITargetScaler targetScaler) { - // This is only called by the ScaleController, it doesn't run in the Functions Host process. - CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount(); - DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(functionName, hubName, storageAccount, this.logger); - targetScaler = new DurableTaskTargetScaler(functionId, metricsProvider, this, this.logger); - return true; + lock (this.initLock) + { + if (this.singletonTargetScaler == null) + { + // This is only called by the ScaleController, it doesn't run in the Functions Host process. + CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount(); + DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(hubName, storageAccount, this.logger); + + // Scalers in Durable Functions are shared for all functions in the same task hub. + // So instead of using a function ID, we use the task hub name as the basis for the descriptor ID. + string id = $"DurableTask-AzureStorage:{hubName ?? "default"}"; + this.singletonTargetScaler = new DurableTaskTargetScaler(id, metricsProvider, this, this.logger); + } + + targetScaler = this.singletonTargetScaler; + return true; + } } #endif } diff --git a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs index e3565d169..999821abc 100644 --- a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs +++ b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs @@ -13,16 +13,18 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask { internal class DurableTaskMetricsProvider { - private readonly string functionName; private readonly string hubName; private readonly ILogger logger; private readonly CloudStorageAccount storageAccount; private DisconnectedPerformanceMonitor performanceMonitor; - public DurableTaskMetricsProvider(string functionName, string hubName, ILogger logger, DisconnectedPerformanceMonitor performanceMonitor, CloudStorageAccount storageAccount) + public DurableTaskMetricsProvider( + string hubName, + ILogger logger, + DisconnectedPerformanceMonitor performanceMonitor, + CloudStorageAccount storageAccount) { - this.functionName = functionName; this.hubName = hubName; this.logger = logger; this.performanceMonitor = performanceMonitor; @@ -42,7 +44,7 @@ public virtual async Task GetMetricsAsync() } catch (StorageException e) { - this.logger.LogWarning("{details}. Function: {functionName}. HubName: {hubName}.", e.ToString(), this.functionName, this.hubName); + this.logger.LogWarning("{details}. HubName: {hubName}.", e.ToString(), this.hubName); } if (heartbeat != null) diff --git a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs index 4c05b3df0..c24762e06 100644 --- a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs +++ b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs @@ -16,8 +16,6 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask { internal sealed class DurableTaskScaleMonitor : IScaleMonitor { - private readonly string functionId; - private readonly string functionName; private readonly string hubName; private readonly CloudStorageAccount storageAccount; private readonly ScaleMonitorDescriptor scaleMonitorDescriptor; @@ -27,31 +25,27 @@ internal sealed class DurableTaskScaleMonitor : IScaleMonitor GetScaleResultAsync(TargetScalerContext co // and the ScaleController is injecting it's own custom ILogger implementation that forwards logs to Kusto. var metricsLog = $"Metrics: workItemQueueLength={workItemQueueLength}. controlQueueLengths={serializedControlQueueLengths}. " + $"maxConcurrentOrchestrators={this.MaxConcurrentOrchestrators}. maxConcurrentActivities={this.MaxConcurrentActivities}"; - var scaleControllerLog = $"Target worker count for '{this.functionId}' is '{numWorkersToRequest}'. " + + var scaleControllerLog = $"Target worker count for '{this.scaler}' is '{numWorkersToRequest}'. " + metricsLog; // target worker count should never be negative @@ -85,7 +89,7 @@ public async Task GetScaleResultAsync(TargetScalerContext co // We want to augment the exception with metrics information for investigation purposes var metricsLog = $"Metrics: workItemQueueLength={metrics?.WorkItemQueueLength}. controlQueueLengths={metrics?.ControlQueueLengths}. " + $"maxConcurrentOrchestrators={this.MaxConcurrentOrchestrators}. maxConcurrentActivities={this.MaxConcurrentActivities}"; - var errorLog = $"Error: target worker count for '{this.functionId}' resulted in exception. " + metricsLog; + var errorLog = $"Error: target worker count for '{this.scaler}' resulted in exception. " + metricsLog; throw new Exception(errorLog, ex); } } diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs index bbd6222a8..251ebb2d7 100644 --- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs +++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; +using System.Linq; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -18,6 +19,70 @@ namespace Microsoft.Azure.Functions.Worker; /// public static class DurableTaskClientExtensions { + /// + /// Waits for the completion of the specified orchestration instance with a retry interval, controlled by the cancellation token. + /// If the orchestration does not complete within the required time, returns an HTTP response containing the class to manage instances. + /// + /// The . + /// The HTTP request that this response is for. + /// The ID of the orchestration instance to check. + /// The timeout between checks for output from the durable function. The default value is 1 second. + /// Optional parameter that configures the http response code returned. Defaults to false. + /// Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to false. + /// A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload. + /// + public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync( + this DurableTaskClient client, + HttpRequestData request, + string instanceId, + TimeSpan? retryInterval = null, + bool returnInternalServerErrorOnFailure = false, + bool getInputsAndOutputs = false, + CancellationToken cancellation = default + ) + { + TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1); + try + { + while (true) + { + var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs); + if (status != null) + { + if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed || +#pragma warning disable CS0618 // Type or member is obsolete + status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled || +#pragma warning restore CS0618 // Type or member is obsolete + status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated || + status.RuntimeStatus == OrchestrationRuntimeStatus.Failed) + { + var response = request.CreateResponse( + (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK); + await response.WriteAsJsonAsync(new + { + Name = status.Name, + InstanceId = status.InstanceId, + CreatedAt = status.CreatedAt, + LastUpdatedAt = status.LastUpdatedAt, + RuntimeStatus = status.RuntimeStatus.ToString(), // Convert enum to string + SerializedInput = status.SerializedInput, + SerializedOutput = status.SerializedOutput, + SerializedCustomStatus = status.SerializedCustomStatus + }, statusCode: response.StatusCode); + + return response; + } + } + await Task.Delay(retryIntervalLocal, cancellation); + } + } + // If the task is canceled, call CreateCheckStatusResponseAsync to return a response containing instance management URLs. + catch (OperationCanceledException) + { + return await CreateCheckStatusResponseAsync(client, request, instanceId); + } + } + /// /// Creates an HTTP response that is useful for checking the status of the specified instance. /// @@ -170,13 +235,13 @@ static string BuildUrl(string url, params string?[] queryValues) // The base URL could be null if: // 1. The DurableTaskClient isn't a FunctionsDurableTaskClient (which would have the baseUrl from bindings) // 2. There's no valid HttpRequestData provided - string? baseUrl = ((request != null) ? request.Url.GetLeftPart(UriPartial.Authority) : GetBaseUrl(client)); + string? baseUrl = ((request != null) ? GetBaseUrlFromRequest(request) : GetBaseUrl(client)); if (baseUrl == null) { throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload."); } - + bool isFromRequest = request != null; string formattedInstanceId = Uri.EscapeDataString(instanceId); @@ -214,6 +279,51 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response) ?? throw new InvalidOperationException("A serializer is not configured for the worker."); } + private static string? GetBaseUrlFromRequest(HttpRequestData request) + { + // Default to the scheme from the request URL + string proto = request.Url.Scheme; + string host = request.Url.Authority; + + // Check for "Forwarded" header + if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders)) + { + var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';') + .Select(pair => pair.Split('=')) + .Where(pair => pair.Length == 2) + .ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim()); + + if (forwardedDict != null) + { + if (forwardedDict.TryGetValue("proto", out var forwardedProto)) + { + proto = forwardedProto; + } + if (forwardedDict.TryGetValue("host", out var forwardedHost)) + { + host = forwardedHost; + // Return if either proto or host (or both) were found in "Forwarded" header + return $"{proto}://{forwardedHost}"; + } + } + } + // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present + if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos)) + { + proto = protos.FirstOrDefault() ?? proto; + } + if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts)) + { + // Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found + host = hosts.FirstOrDefault() ?? host; + return $"{proto}://{host}"; + } + + // Construct and return the base URL from default fallback values + return $"{proto}://{host}"; + } + + private static string? GetQueryParams(DurableTaskClient client) { return client is FunctionsDurableTaskClient functions ? functions.QueryString : null; diff --git a/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs b/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs index 626acd6bf..af7bab017 100644 --- a/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs +++ b/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs @@ -1,20 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. -using System; -using Azure.Core.Serialization; using Microsoft.Azure.Functions.Worker.Core; using Microsoft.Azure.Functions.Worker.Extensions.DurableTask; -using Microsoft.DurableTask; -using Microsoft.DurableTask.Client; -using Microsoft.DurableTask.Converters; -using Microsoft.DurableTask.Worker; -using Microsoft.DurableTask.Worker.Shims; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyInjection.Extensions; -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; [assembly: WorkerExtensionStartup(typeof(DurableTaskExtensionStartup))] @@ -28,49 +16,6 @@ public sealed class DurableTaskExtensionStartup : WorkerExtensionStartup /// public override void Configure(IFunctionsWorkerApplicationBuilder applicationBuilder) { - applicationBuilder.Services.AddSingleton(); - applicationBuilder.Services.AddOptions() - .Configure(options => options.EnableEntitySupport = true) - .PostConfigure((opt, sp) => - { - if (GetConverter(sp) is DataConverter converter) - { - opt.DataConverter = converter; - } - }); - - applicationBuilder.Services.AddOptions() - .Configure(options => options.EnableEntitySupport = true) - .PostConfigure((opt, sp) => - { - if (GetConverter(sp) is DataConverter converter) - { - opt.DataConverter = converter; - } - }); - - applicationBuilder.Services.TryAddSingleton(sp => - { - DurableTaskWorkerOptions options = sp.GetRequiredService>().Value; - ILoggerFactory factory = sp.GetRequiredService(); - return new DurableTaskShimFactory(options, factory); // For GrpcOrchestrationRunner - }); - - applicationBuilder.Services.Configure(o => - { - o.InputConverters.Register(); - }); - - applicationBuilder.UseMiddleware(); - } - - private static DataConverter? GetConverter(IServiceProvider services) - { - // We intentionally do not consider a DataConverter in the DI provider, or if one was already set. This is to - // ensure serialization is consistent with the rest of Azure Functions. This is particularly important because - // TaskActivity bindings use ObjectSerializer directly for the time being. Due to this, allowing DataConverter - // to be set separately from ObjectSerializer would give an inconsistent serialization solution. - WorkerOptions? worker = services.GetRequiredService>()?.Value; - return worker?.Serializer is not null ? new ObjectConverterShim(worker.Serializer) : null; + applicationBuilder.ConfigureDurableExtension(); } } diff --git a/src/Worker.Extensions.DurableTask/FunctionsWorkerApplicationBuilderExtensions.cs b/src/Worker.Extensions.DurableTask/FunctionsWorkerApplicationBuilderExtensions.cs new file mode 100644 index 000000000..642446dd4 --- /dev/null +++ b/src/Worker.Extensions.DurableTask/FunctionsWorkerApplicationBuilderExtensions.cs @@ -0,0 +1,125 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Linq; +using Azure.Core.Serialization; +using Microsoft.Azure.Functions.Worker.Core; +using Microsoft.Azure.Functions.Worker.Extensions.DurableTask; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Converters; +using Microsoft.DurableTask.Worker; +using Microsoft.DurableTask.Worker.Shims; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Microsoft.Azure.Functions.Worker; + +/// +/// Extensions for . +/// +public static class FunctionsWorkerApplicationBuilderExtensions +{ + /// + /// Configures the Durable Functions extension for the worker. + /// + /// The builder to configure. + /// The for call chaining. + public static IFunctionsWorkerApplicationBuilder ConfigureDurableExtension(this IFunctionsWorkerApplicationBuilder builder) + { + if (builder is null) + { + throw new ArgumentNullException(nameof(builder)); + } + + builder.Services.TryAddSingleton(); + builder.Services.TryAddEnumerable( + ServiceDescriptor.Singleton, ConfigureClientOptions>()); + builder.Services.TryAddEnumerable( + ServiceDescriptor.Singleton, PostConfigureClientOptions>()); + builder.Services.TryAddEnumerable( + ServiceDescriptor.Singleton, ConfigureWorkerOptions>()); + builder.Services.TryAddEnumerable( + ServiceDescriptor.Singleton, PostConfigureWorkerOptions>()); + + builder.Services.TryAddSingleton(sp => + { + DurableTaskWorkerOptions options = sp.GetRequiredService>().Value; + ILoggerFactory factory = sp.GetRequiredService(); + return new DurableTaskShimFactory(options, factory); // For GrpcOrchestrationRunner + }); + + builder.Services.TryAddEnumerable( + ServiceDescriptor.Singleton, ConfigureInputConverter>()); + if (!builder.Services.Any(d => d.ServiceType == typeof(DurableTaskFunctionsMiddleware))) + { + builder.UseMiddleware(); + } + + return builder; + } + + private class ConfigureInputConverter : IConfigureOptions + { + public void Configure(WorkerOptions options) + { + options.InputConverters.Register(); + } + } + + private class ConfigureClientOptions : IConfigureOptions + { + public void Configure(DurableTaskClientOptions options) + { + options.EnableEntitySupport = true; + } + } + + private class PostConfigureClientOptions : IPostConfigureOptions + { + readonly IOptionsMonitor workerOptions; + + public PostConfigureClientOptions(IOptionsMonitor workerOptions) + { + this.workerOptions = workerOptions; + } + + public void PostConfigure(string name, DurableTaskClientOptions options) + { + if (this.workerOptions.Get(name).Serializer is { } serializer) + { + options.DataConverter = new ObjectConverterShim(serializer); + } + } + } + + private class ConfigureWorkerOptions : IConfigureOptions + { + public void Configure(DurableTaskWorkerOptions options) + { + options.EnableEntitySupport = true; + } + } + + private class PostConfigureWorkerOptions : IPostConfigureOptions + { + readonly IOptionsMonitor workerOptions; + + public PostConfigureWorkerOptions(IOptionsMonitor workerOptions) + { + this.workerOptions = workerOptions; + } + + public void PostConfigure(string name, DurableTaskWorkerOptions options) + { + if (this.workerOptions.Get(name).Serializer is { } serializer) + { + options.DataConverter = new ObjectConverterShim(serializer); + } + } + } +} diff --git a/test/FunctionsV2/DurableTaskListenerTests.cs b/test/FunctionsV2/DurableTaskListenerTests.cs index 3be412bd9..996a336cd 100644 --- a/test/FunctionsV2/DurableTaskListenerTests.cs +++ b/test/FunctionsV2/DurableTaskListenerTests.cs @@ -2,13 +2,10 @@ // Licensed under the MIT License. See LICENSE in the project root for license information. using System; -using System.Linq; -using Microsoft.Azure.WebJobs.Host.Executors; using Microsoft.Azure.WebJobs.Host.Scale; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; -using Moq; using Xunit; namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.Tests @@ -40,9 +37,9 @@ public void GetMonitor_ReturnsExpectedValue() IScaleMonitor scaleMonitor = this.listener.GetMonitor(); Assert.Equal(typeof(DurableTaskScaleMonitor), scaleMonitor.GetType()); - Assert.Equal($"{this.functionId}-DurableTaskTrigger-DurableTaskHub".ToLower(), scaleMonitor.Descriptor.Id); + Assert.Equal($"DurableTaskTrigger-DurableTaskHub".ToLower(), scaleMonitor.Descriptor.Id); - var scaleMonitor2 = this.listener.GetMonitor(); + IScaleMonitor scaleMonitor2 = this.listener.GetMonitor(); Assert.Same(scaleMonitor, scaleMonitor2); } diff --git a/test/FunctionsV2/DurableTaskScaleMonitorTests.cs b/test/FunctionsV2/DurableTaskScaleMonitorTests.cs index 57566da29..330567dda 100644 --- a/test/FunctionsV2/DurableTaskScaleMonitorTests.cs +++ b/test/FunctionsV2/DurableTaskScaleMonitorTests.cs @@ -20,8 +20,6 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.Tests { public class DurableTaskScaleMonitorTests { - private readonly string functionId = "DurableTaskTriggerFunctionId"; - private readonly FunctionName functionName = new FunctionName("DurableTaskTriggerFunctionName"); private readonly string hubName = "DurableTaskTriggerHubName"; private readonly CloudStorageAccount storageAccount = CloudStorageAccount.Parse(TestHelpers.GetStorageConnectionString()); private readonly ITestOutputHelper output; @@ -41,15 +39,12 @@ public DurableTaskScaleMonitorTests(ITestOutputHelper output) this.traceHelper = new EndToEndTraceHelper(logger, false); this.performanceMonitor = new Mock(MockBehavior.Strict, this.storageAccount, this.hubName, (int?)null); var metricsProvider = new DurableTaskMetricsProvider( - this.functionName.Name, this.hubName, logger, this.performanceMonitor.Object, this.storageAccount); this.scaleMonitor = new DurableTaskScaleMonitor( - this.functionId, - this.functionName.Name, this.hubName, this.storageAccount, logger, @@ -61,7 +56,7 @@ public DurableTaskScaleMonitorTests(ITestOutputHelper output) [Trait("Category", PlatformSpecificHelpers.TestCategory)] public void ScaleMonitorDescriptor_ReturnsExpectedValue() { - Assert.Equal($"{this.functionId}-DurableTaskTrigger-{this.hubName}".ToLower(), this.scaleMonitor.Descriptor.Id); + Assert.Equal($"DurableTaskTrigger-{this.hubName}".ToLower(), this.scaleMonitor.Descriptor.Id); } [Fact] diff --git a/test/FunctionsV2/DurableTaskTargetScalerTests.cs b/test/FunctionsV2/DurableTaskTargetScalerTests.cs index cc2272e16..0fcae7ec6 100644 --- a/test/FunctionsV2/DurableTaskTargetScalerTests.cs +++ b/test/FunctionsV2/DurableTaskTargetScalerTests.cs @@ -47,7 +47,6 @@ public DurableTaskTargetScalerTests(ITestOutputHelper output) CloudStorageAccount nullCloudStorageAccountMock = null; this.metricsProviderMock = new Mock( MockBehavior.Strict, - "FunctionName", "HubName", logger, nullPerformanceMonitorMock, diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs index 6f975d2c5..1623f4559 100644 --- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs +++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs @@ -1,6 +1,10 @@ +using System.Net; +using Azure.Core.Serialization; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.DurableTask.Client; +using Microsoft.Extensions.Options; using Moq; +using Newtonsoft.Json; namespace Microsoft.Azure.Functions.Worker.Tests { @@ -9,7 +13,7 @@ namespace Microsoft.Azure.Functions.Worker.Tests /// public class FunctionsDurableTaskClientTests { - private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null) + private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null, OrchestrationMetadata? orchestrationMetadata = null) { // construct mock client @@ -21,6 +25,12 @@ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? bas durableClientMock.Setup(x => x.TerminateInstanceAsync( It.IsAny(), It.IsAny(), It.IsAny())).Returns(completedTask); + if (orchestrationMetadata != null) + { + durableClientMock.Setup(x => x.GetInstancesAsync(orchestrationMetadata.InstanceId, It.IsAny(), It.IsAny())) + .ReturnsAsync(orchestrationMetadata); + } + DurableTaskClient durableClient = durableClientMock.Object; FunctionsDurableTaskClient client = new FunctionsDurableTaskClient(durableClient, queryString: null, httpBaseUrl: baseUrl); return client; @@ -82,6 +92,8 @@ public void CreateHttpManagementPayload_WithHttpRequestData() // Create mock HttpRequestData object. var mockFunctionContext = new Mock(); var mockHttpRequestData = new Mock(mockFunctionContext.Object); + var headers = new HttpHeadersCollection(); + mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers); mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri(requestUrl)); HttpManagementPayload payload = client.CreateHttpManagementPayload(instanceId, mockHttpRequestData.Object); @@ -89,6 +101,153 @@ public void CreateHttpManagementPayload_WithHttpRequestData() AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId); } + /// + /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns the expected response when the orchestration is completed. + /// The expected response should include OrchestrationMetadata in the body with an HttpStatusCode.OK. + /// + [Fact] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenCompleted() + { + string instanceId = "test-instance-id-completed"; + var expectedResult = new OrchestrationMetadata("TestCompleted", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Completed, + SerializedCustomStatus = "TestCustomStatus", + SerializedInput = "TestInput", + SerializedOutput = "TestOutput" + }; + + var client = this.GetTestFunctionsDurableTaskClient( orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId); + + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body); + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(orchestratorMetadata); + AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata); + } + + /// + /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestrator didn't finish within + /// the timeout period. The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted. + /// + [Fact] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning() + { + string instanceId = "test-instance-id-running"; + var expectedResult = new OrchestrationMetadata("TestRunning", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Running, + }; + + var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + HttpResponseData response; + using (CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(10))) + { + response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, cancellation: cts.Token); + }; + + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + HttpManagementPayload? payload; + using (var reader = new StreamReader(response.Body)) + { + payload = JsonConvert.DeserializeObject(await reader.ReadToEndAsync()); + } + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(payload); + AssertHttpManagementPayload(payload, "https://localhost:7075/runtime/webhooks/durabletask", instanceId); + } + + /// + /// Tests the `WaitForCompletionOrCreateCheckStatusResponseAsync` method to ensure it returns the correct HTTP status code + /// based on the `returnInternalServerErrorOnFailure` parameter when the orchestration has failed. + /// + [Theory] + [InlineData(true, HttpStatusCode.InternalServerError)] + [InlineData(false, HttpStatusCode.OK)] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFailed(bool returnInternalServerErrorOnFailure, HttpStatusCode expected) + { + string instanceId = "test-instance-id-failed"; + var expectedResult = new OrchestrationMetadata("TestFailed", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Failed, + SerializedOutput = "Microsoft.DurableTask.TaskFailedException: Task 'SayHello' (#0) failed with an unhandled exception: Exception while executing function: Functions.SayHello", + SerializedInput = null + }; + + var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, returnInternalServerErrorOnFailure: returnInternalServerErrorOnFailure); + + Assert.NotNull(response); + Assert.Equal(expected, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body); + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(orchestratorMetadata); + AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata); + } + + /// + /// Tests the `GetBaseUrlFromRequest` can return the right base URL from the HttpRequestData with different forwarding or proxies. + /// This test covers the following scenarios: + /// - Using the "Forwarded" header + /// - Using "X-Forwarded-Proto" and "X-Forwarded-Host" headers + /// - Using only "X-Forwarded-Host" with default protocol + /// - no headers + /// + [Theory] + [InlineData("Forwarded", "proto=https;host=forwarded.example.com","","", "https://forwarded.example.com/runtime/webhooks/durabletask")] + [InlineData("X-Forwarded-Proto", "https", "X-Forwarded-Host", "xforwarded.example.com", "https://xforwarded.example.com/runtime/webhooks/durabletask")] + [InlineData("", "", "X-Forwarded-Host", "test.net", "https://test.net/runtime/webhooks/durabletask")] + [InlineData("", "", "", "", "https://localhost:7075/runtime/webhooks/durabletask")] // Default base URL for empty headers + public void TestHttpRequestDataForwardingHandling(string header1, string? value1, string header2, string value2, string expectedBaseUrl) + { + var headers = new HttpHeadersCollection(); + if (!string.IsNullOrEmpty(header1)) + { + headers.Add(header1, value1); + } + if (!string.IsNullOrEmpty(header2)) + { + headers.Add(header2, value2); + } + + var request = this.MockHttpRequestAndResponseData(headers); + var client = this.GetTestFunctionsDurableTaskClient(); + + var payload = client.CreateHttpManagementPayload("testInstanceId", request); + AssertHttpManagementPayload(payload, expectedBaseUrl, "testInstanceId"); + } + + + private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId) { Assert.Equal(instanceId, payload.Id); @@ -99,5 +258,79 @@ private static void AssertHttpManagementPayload(HttpManagementPayload payload, s Assert.Equal($"{BaseUrl}/instances/{instanceId}/suspend?reason={{{{text}}}}", payload.SuspendPostUri); Assert.Equal($"{BaseUrl}/instances/{instanceId}/resume?reason={{{{text}}}}", payload.ResumePostUri); } + + private static void AssertOrhcestrationMetadata(OrchestrationMetadata expectedResult, dynamic actualResult) + { + Assert.Equal(expectedResult.Name, actualResult.GetProperty("Name").GetString()); + Assert.Equal(expectedResult.InstanceId, actualResult.GetProperty("InstanceId").GetString()); + Assert.Equal(expectedResult.CreatedAt, actualResult.GetProperty("CreatedAt").GetDateTime()); + Assert.Equal(expectedResult.LastUpdatedAt, actualResult.GetProperty("LastUpdatedAt").GetDateTime()); + Assert.Equal(expectedResult.RuntimeStatus.ToString(), actualResult.GetProperty("RuntimeStatus").GetString()); + Assert.Equal(expectedResult.SerializedInput, actualResult.GetProperty("SerializedInput").GetString()); + Assert.Equal(expectedResult.SerializedOutput, actualResult.GetProperty("SerializedOutput").GetString()); + Assert.Equal(expectedResult.SerializedCustomStatus, actualResult.GetProperty("SerializedCustomStatus").GetString()); + } + + // Mocks the required HttpRequestData and HttpResponseData for testing purposes. + // This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body. + // The headers of HttpRequestData can be provided as an optional parameter, otherwise an empty HttpHeadersCollection is used. + private HttpRequestData MockHttpRequestAndResponseData(HttpHeadersCollection? headers = null) + { + var mockObjectSerializer = new Mock(); + + // Setup the SerializeAsync method + mockObjectSerializer.Setup(s => s.SerializeAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(async (stream, value, type, token) => + { + await System.Text.Json.JsonSerializer.SerializeAsync(stream, value, type, cancellationToken: token); + }); + + var workerOptions = new WorkerOptions + { + Serializer = mockObjectSerializer.Object + }; + var mockOptions = new Mock>(); + mockOptions.Setup(o => o.Value).Returns(workerOptions); + + // Mock the service provider + var mockServiceProvider = new Mock(); + + // Set up the service provider to return the mock IOptions + mockServiceProvider.Setup(sp => sp.GetService(typeof(IOptions))) + .Returns(mockOptions.Object); + + // Set up the service provider to return the mock ObjectSerializer + mockServiceProvider.Setup(sp => sp.GetService(typeof(ObjectSerializer))) + .Returns(mockObjectSerializer.Object); + + // Create a mock FunctionContext and assign the service provider + var mockFunctionContext = new Mock(); + mockFunctionContext.SetupGet(c => c.InstanceServices).Returns(mockServiceProvider.Object); + var mockHttpRequestData = new Mock(mockFunctionContext.Object); + + // Set up the URL property. + mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("https://localhost:7075/orchestrators/E1_HelloSequence")); + + // If headers are provided, use them, otherwise create a new empty HttpHeadersCollection + headers ??= new HttpHeadersCollection(); + + // Setup the Headers property to return the empty headers + mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers); + + var mockHttpResponseData = new Mock(mockFunctionContext.Object) + { + DefaultValue = DefaultValue.Mock + }; + + // Enable setting StatusCode and Body as mutable properties + mockHttpResponseData.SetupProperty(r => r.StatusCode, HttpStatusCode.OK); + mockHttpResponseData.SetupProperty(r => r.Body, new MemoryStream()); + + // Setup CreateResponse to return the configured HttpResponseData mock + mockHttpRequestData.Setup(r => r.CreateResponse()) + .Returns(mockHttpResponseData.Object); + + return mockHttpRequestData.Object; + } } }