diff --git a/eng/ci/code-mirror.yml b/eng/ci/code-mirror.yml
index 0a2196b95..40f2d2a0f 100644
--- a/eng/ci/code-mirror.yml
+++ b/eng/ci/code-mirror.yml
@@ -5,6 +5,7 @@ trigger:
# Keep this set limited as appropriate (don't mirror individual user branches).
- main
- dev
+ - v3.x
resources:
repositories:
diff --git a/eng/ci/official-build.yml b/eng/ci/official-build.yml
index e7a871026..858ca3518 100644
--- a/eng/ci/official-build.yml
+++ b/eng/ci/official-build.yml
@@ -7,6 +7,7 @@ trigger:
include:
- main
- dev
+ - v3.x
# CI only, does not trigger on PRs.
pr: none
diff --git a/release_notes.md b/release_notes.md
index fa7883271..812793b79 100644
--- a/release_notes.md
+++ b/release_notes.md
@@ -5,10 +5,12 @@
### New Features
- Fail fast if extendedSessionsEnabled set to 'true' for the worker type that doesn't support extended sessions (https://github.com/Azure/azure-functions-durable-extension/pull/2732).
+- Added an `IFunctionsWorkerApplicationBuilder.ConfigureDurableExtension()` extension method for cases where auto-registration does not work (no source gen running). (#2950)
### Bug Fixes
- Fix custom connection name not working when using IDurableClientFactory.CreateClient() - contributed by [@hctan](https://github.com/hctan)
+- Made durable extension for isolated worker configuration idempotent, allowing multiple calls safely. (#2950)
### Breaking Changes
diff --git a/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs b/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs
index 2920a0f37..ea5350463 100644
--- a/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs
+++ b/src/WebJobs.Extensions.DurableTask/AzureStorageDurabilityProvider.cs
@@ -35,6 +35,16 @@ internal class AzureStorageDurabilityProvider : DurabilityProvider
private readonly JObject storageOptionsJson;
private readonly ILogger logger;
+ private readonly object initLock = new object();
+
+#if !FUNCTIONS_V1
+ private DurableTaskScaleMonitor singletonScaleMonitor;
+#endif
+
+#if FUNCTIONS_V3_OR_GREATER
+ private DurableTaskTargetScaler singletonTargetScaler;
+#endif
+
public AzureStorageDurabilityProvider(
AzureStorageOrchestrationService service,
IStorageAccountProvider storageAccountProvider,
@@ -226,12 +236,11 @@ internal static OrchestrationInstanceStatusQueryCondition ConvertWebjobsDurableC
#if !FUNCTIONS_V1
internal DurableTaskMetricsProvider GetMetricsProvider(
- string functionName,
string hubName,
CloudStorageAccount storageAccount,
ILogger logger)
{
- return new DurableTaskMetricsProvider(functionName, hubName, logger, performanceMonitor: null, storageAccount);
+ return new DurableTaskMetricsProvider(hubName, logger, performanceMonitor: null, storageAccount);
}
///
@@ -242,16 +251,22 @@ public override bool TryGetScaleMonitor(
string connectionName,
out IScaleMonitor scaleMonitor)
{
- CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount();
- DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(functionName, hubName, storageAccount, this.logger);
- scaleMonitor = new DurableTaskScaleMonitor(
- functionId,
- functionName,
- hubName,
- storageAccount,
- this.logger,
- metricsProvider);
- return true;
+ lock (this.initLock)
+ {
+ if (this.singletonScaleMonitor == null)
+ {
+ CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount();
+ DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(hubName, storageAccount, this.logger);
+ this.singletonScaleMonitor = new DurableTaskScaleMonitor(
+ hubName,
+ storageAccount,
+ this.logger,
+ metricsProvider);
+ }
+
+ scaleMonitor = this.singletonScaleMonitor;
+ return true;
+ }
}
#endif
@@ -263,11 +278,23 @@ public override bool TryGetTargetScaler(
string connectionName,
out ITargetScaler targetScaler)
{
- // This is only called by the ScaleController, it doesn't run in the Functions Host process.
- CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount();
- DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(functionName, hubName, storageAccount, this.logger);
- targetScaler = new DurableTaskTargetScaler(functionId, metricsProvider, this, this.logger);
- return true;
+ lock (this.initLock)
+ {
+ if (this.singletonTargetScaler == null)
+ {
+ // This is only called by the ScaleController, it doesn't run in the Functions Host process.
+ CloudStorageAccount storageAccount = this.storageAccountProvider.GetStorageAccountDetails(connectionName).ToCloudStorageAccount();
+ DurableTaskMetricsProvider metricsProvider = this.GetMetricsProvider(hubName, storageAccount, this.logger);
+
+ // Scalers in Durable Functions are shared for all functions in the same task hub.
+ // So instead of using a function ID, we use the task hub name as the basis for the descriptor ID.
+ string id = $"DurableTask-AzureStorage:{hubName ?? "default"}";
+ this.singletonTargetScaler = new DurableTaskTargetScaler(id, metricsProvider, this, this.logger);
+ }
+
+ targetScaler = this.singletonTargetScaler;
+ return true;
+ }
}
#endif
}
diff --git a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs
index e3565d169..999821abc 100644
--- a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs
+++ b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskMetricsProvider.cs
@@ -13,16 +13,18 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask
{
internal class DurableTaskMetricsProvider
{
- private readonly string functionName;
private readonly string hubName;
private readonly ILogger logger;
private readonly CloudStorageAccount storageAccount;
private DisconnectedPerformanceMonitor performanceMonitor;
- public DurableTaskMetricsProvider(string functionName, string hubName, ILogger logger, DisconnectedPerformanceMonitor performanceMonitor, CloudStorageAccount storageAccount)
+ public DurableTaskMetricsProvider(
+ string hubName,
+ ILogger logger,
+ DisconnectedPerformanceMonitor performanceMonitor,
+ CloudStorageAccount storageAccount)
{
- this.functionName = functionName;
this.hubName = hubName;
this.logger = logger;
this.performanceMonitor = performanceMonitor;
@@ -42,7 +44,7 @@ public virtual async Task GetMetricsAsync()
}
catch (StorageException e)
{
- this.logger.LogWarning("{details}. Function: {functionName}. HubName: {hubName}.", e.ToString(), this.functionName, this.hubName);
+ this.logger.LogWarning("{details}. HubName: {hubName}.", e.ToString(), this.hubName);
}
if (heartbeat != null)
diff --git a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs
index 4c05b3df0..c24762e06 100644
--- a/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs
+++ b/src/WebJobs.Extensions.DurableTask/Listener/DurableTaskScaleMonitor.cs
@@ -16,8 +16,6 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask
{
internal sealed class DurableTaskScaleMonitor : IScaleMonitor
{
- private readonly string functionId;
- private readonly string functionName;
private readonly string hubName;
private readonly CloudStorageAccount storageAccount;
private readonly ScaleMonitorDescriptor scaleMonitorDescriptor;
@@ -27,31 +25,27 @@ internal sealed class DurableTaskScaleMonitor : IScaleMonitor GetScaleResultAsync(TargetScalerContext co
// and the ScaleController is injecting it's own custom ILogger implementation that forwards logs to Kusto.
var metricsLog = $"Metrics: workItemQueueLength={workItemQueueLength}. controlQueueLengths={serializedControlQueueLengths}. " +
$"maxConcurrentOrchestrators={this.MaxConcurrentOrchestrators}. maxConcurrentActivities={this.MaxConcurrentActivities}";
- var scaleControllerLog = $"Target worker count for '{this.functionId}' is '{numWorkersToRequest}'. " +
+ var scaleControllerLog = $"Target worker count for '{this.scaler}' is '{numWorkersToRequest}'. " +
metricsLog;
// target worker count should never be negative
@@ -85,7 +89,7 @@ public async Task GetScaleResultAsync(TargetScalerContext co
// We want to augment the exception with metrics information for investigation purposes
var metricsLog = $"Metrics: workItemQueueLength={metrics?.WorkItemQueueLength}. controlQueueLengths={metrics?.ControlQueueLengths}. " +
$"maxConcurrentOrchestrators={this.MaxConcurrentOrchestrators}. maxConcurrentActivities={this.MaxConcurrentActivities}";
- var errorLog = $"Error: target worker count for '{this.functionId}' resulted in exception. " + metricsLog;
+ var errorLog = $"Error: target worker count for '{this.scaler}' resulted in exception. " + metricsLog;
throw new Exception(errorLog, ex);
}
}
diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
index bbd6222a8..251ebb2d7 100644
--- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
+++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
@@ -2,6 +2,7 @@
// Licensed under the MIT License. See License.txt in the project root for license information.
using System;
+using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
@@ -18,6 +19,70 @@ namespace Microsoft.Azure.Functions.Worker;
///
public static class DurableTaskClientExtensions
{
+ ///
+ /// Waits for the completion of the specified orchestration instance with a retry interval, controlled by the cancellation token.
+ /// If the orchestration does not complete within the required time, returns an HTTP response containing the class to manage instances.
+ ///
+ /// The .
+ /// The HTTP request that this response is for.
+ /// The ID of the orchestration instance to check.
+ /// The timeout between checks for output from the durable function. The default value is 1 second.
+ /// Optional parameter that configures the http response code returned. Defaults to false.
+ /// Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to false.
+ /// A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload.
+ ///
+ public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync(
+ this DurableTaskClient client,
+ HttpRequestData request,
+ string instanceId,
+ TimeSpan? retryInterval = null,
+ bool returnInternalServerErrorOnFailure = false,
+ bool getInputsAndOutputs = false,
+ CancellationToken cancellation = default
+ )
+ {
+ TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1);
+ try
+ {
+ while (true)
+ {
+ var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs);
+ if (status != null)
+ {
+ if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
+#pragma warning disable CS0618 // Type or member is obsolete
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
+#pragma warning restore CS0618 // Type or member is obsolete
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
+ {
+ var response = request.CreateResponse(
+ (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK);
+ await response.WriteAsJsonAsync(new
+ {
+ Name = status.Name,
+ InstanceId = status.InstanceId,
+ CreatedAt = status.CreatedAt,
+ LastUpdatedAt = status.LastUpdatedAt,
+ RuntimeStatus = status.RuntimeStatus.ToString(), // Convert enum to string
+ SerializedInput = status.SerializedInput,
+ SerializedOutput = status.SerializedOutput,
+ SerializedCustomStatus = status.SerializedCustomStatus
+ }, statusCode: response.StatusCode);
+
+ return response;
+ }
+ }
+ await Task.Delay(retryIntervalLocal, cancellation);
+ }
+ }
+ // If the task is canceled, call CreateCheckStatusResponseAsync to return a response containing instance management URLs.
+ catch (OperationCanceledException)
+ {
+ return await CreateCheckStatusResponseAsync(client, request, instanceId);
+ }
+ }
+
///
/// Creates an HTTP response that is useful for checking the status of the specified instance.
///
@@ -170,13 +235,13 @@ static string BuildUrl(string url, params string?[] queryValues)
// The base URL could be null if:
// 1. The DurableTaskClient isn't a FunctionsDurableTaskClient (which would have the baseUrl from bindings)
// 2. There's no valid HttpRequestData provided
- string? baseUrl = ((request != null) ? request.Url.GetLeftPart(UriPartial.Authority) : GetBaseUrl(client));
+ string? baseUrl = ((request != null) ? GetBaseUrlFromRequest(request) : GetBaseUrl(client));
if (baseUrl == null)
{
throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload.");
}
-
+
bool isFromRequest = request != null;
string formattedInstanceId = Uri.EscapeDataString(instanceId);
@@ -214,6 +279,51 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response)
?? throw new InvalidOperationException("A serializer is not configured for the worker.");
}
+ private static string? GetBaseUrlFromRequest(HttpRequestData request)
+ {
+ // Default to the scheme from the request URL
+ string proto = request.Url.Scheme;
+ string host = request.Url.Authority;
+
+ // Check for "Forwarded" header
+ if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders))
+ {
+ var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';')
+ .Select(pair => pair.Split('='))
+ .Where(pair => pair.Length == 2)
+ .ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim());
+
+ if (forwardedDict != null)
+ {
+ if (forwardedDict.TryGetValue("proto", out var forwardedProto))
+ {
+ proto = forwardedProto;
+ }
+ if (forwardedDict.TryGetValue("host", out var forwardedHost))
+ {
+ host = forwardedHost;
+ // Return if either proto or host (or both) were found in "Forwarded" header
+ return $"{proto}://{forwardedHost}";
+ }
+ }
+ }
+ // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present
+ if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos))
+ {
+ proto = protos.FirstOrDefault() ?? proto;
+ }
+ if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts))
+ {
+ // Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found
+ host = hosts.FirstOrDefault() ?? host;
+ return $"{proto}://{host}";
+ }
+
+ // Construct and return the base URL from default fallback values
+ return $"{proto}://{host}";
+ }
+
+
private static string? GetQueryParams(DurableTaskClient client)
{
return client is FunctionsDurableTaskClient functions ? functions.QueryString : null;
diff --git a/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs b/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs
index 626acd6bf..af7bab017 100644
--- a/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs
+++ b/src/Worker.Extensions.DurableTask/DurableTaskExtensionStartup.cs
@@ -1,20 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.
-using System;
-using Azure.Core.Serialization;
using Microsoft.Azure.Functions.Worker.Core;
using Microsoft.Azure.Functions.Worker.Extensions.DurableTask;
-using Microsoft.DurableTask;
-using Microsoft.DurableTask.Client;
-using Microsoft.DurableTask.Converters;
-using Microsoft.DurableTask.Worker;
-using Microsoft.DurableTask.Worker.Shims;
-using Microsoft.Extensions.DependencyInjection;
-using Microsoft.Extensions.DependencyInjection.Extensions;
-using Microsoft.Extensions.Hosting;
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Options;
[assembly: WorkerExtensionStartup(typeof(DurableTaskExtensionStartup))]
@@ -28,49 +16,6 @@ public sealed class DurableTaskExtensionStartup : WorkerExtensionStartup
///
public override void Configure(IFunctionsWorkerApplicationBuilder applicationBuilder)
{
- applicationBuilder.Services.AddSingleton();
- applicationBuilder.Services.AddOptions()
- .Configure(options => options.EnableEntitySupport = true)
- .PostConfigure((opt, sp) =>
- {
- if (GetConverter(sp) is DataConverter converter)
- {
- opt.DataConverter = converter;
- }
- });
-
- applicationBuilder.Services.AddOptions()
- .Configure(options => options.EnableEntitySupport = true)
- .PostConfigure((opt, sp) =>
- {
- if (GetConverter(sp) is DataConverter converter)
- {
- opt.DataConverter = converter;
- }
- });
-
- applicationBuilder.Services.TryAddSingleton(sp =>
- {
- DurableTaskWorkerOptions options = sp.GetRequiredService>().Value;
- ILoggerFactory factory = sp.GetRequiredService();
- return new DurableTaskShimFactory(options, factory); // For GrpcOrchestrationRunner
- });
-
- applicationBuilder.Services.Configure(o =>
- {
- o.InputConverters.Register();
- });
-
- applicationBuilder.UseMiddleware();
- }
-
- private static DataConverter? GetConverter(IServiceProvider services)
- {
- // We intentionally do not consider a DataConverter in the DI provider, or if one was already set. This is to
- // ensure serialization is consistent with the rest of Azure Functions. This is particularly important because
- // TaskActivity bindings use ObjectSerializer directly for the time being. Due to this, allowing DataConverter
- // to be set separately from ObjectSerializer would give an inconsistent serialization solution.
- WorkerOptions? worker = services.GetRequiredService>()?.Value;
- return worker?.Serializer is not null ? new ObjectConverterShim(worker.Serializer) : null;
+ applicationBuilder.ConfigureDurableExtension();
}
}
diff --git a/src/Worker.Extensions.DurableTask/FunctionsWorkerApplicationBuilderExtensions.cs b/src/Worker.Extensions.DurableTask/FunctionsWorkerApplicationBuilderExtensions.cs
new file mode 100644
index 000000000..642446dd4
--- /dev/null
+++ b/src/Worker.Extensions.DurableTask/FunctionsWorkerApplicationBuilderExtensions.cs
@@ -0,0 +1,125 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the MIT License. See License.txt in the project root for license information.
+
+using System;
+using System.Linq;
+using Azure.Core.Serialization;
+using Microsoft.Azure.Functions.Worker.Core;
+using Microsoft.Azure.Functions.Worker.Extensions.DurableTask;
+using Microsoft.DurableTask;
+using Microsoft.DurableTask.Client;
+using Microsoft.DurableTask.Converters;
+using Microsoft.DurableTask.Worker;
+using Microsoft.DurableTask.Worker.Shims;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.DependencyInjection.Extensions;
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.Azure.Functions.Worker;
+
+///
+/// Extensions for .
+///
+public static class FunctionsWorkerApplicationBuilderExtensions
+{
+ ///
+ /// Configures the Durable Functions extension for the worker.
+ ///
+ /// The builder to configure.
+ /// The for call chaining.
+ public static IFunctionsWorkerApplicationBuilder ConfigureDurableExtension(this IFunctionsWorkerApplicationBuilder builder)
+ {
+ if (builder is null)
+ {
+ throw new ArgumentNullException(nameof(builder));
+ }
+
+ builder.Services.TryAddSingleton();
+ builder.Services.TryAddEnumerable(
+ ServiceDescriptor.Singleton, ConfigureClientOptions>());
+ builder.Services.TryAddEnumerable(
+ ServiceDescriptor.Singleton, PostConfigureClientOptions>());
+ builder.Services.TryAddEnumerable(
+ ServiceDescriptor.Singleton, ConfigureWorkerOptions>());
+ builder.Services.TryAddEnumerable(
+ ServiceDescriptor.Singleton, PostConfigureWorkerOptions>());
+
+ builder.Services.TryAddSingleton(sp =>
+ {
+ DurableTaskWorkerOptions options = sp.GetRequiredService>().Value;
+ ILoggerFactory factory = sp.GetRequiredService();
+ return new DurableTaskShimFactory(options, factory); // For GrpcOrchestrationRunner
+ });
+
+ builder.Services.TryAddEnumerable(
+ ServiceDescriptor.Singleton, ConfigureInputConverter>());
+ if (!builder.Services.Any(d => d.ServiceType == typeof(DurableTaskFunctionsMiddleware)))
+ {
+ builder.UseMiddleware();
+ }
+
+ return builder;
+ }
+
+ private class ConfigureInputConverter : IConfigureOptions
+ {
+ public void Configure(WorkerOptions options)
+ {
+ options.InputConverters.Register();
+ }
+ }
+
+ private class ConfigureClientOptions : IConfigureOptions
+ {
+ public void Configure(DurableTaskClientOptions options)
+ {
+ options.EnableEntitySupport = true;
+ }
+ }
+
+ private class PostConfigureClientOptions : IPostConfigureOptions
+ {
+ readonly IOptionsMonitor workerOptions;
+
+ public PostConfigureClientOptions(IOptionsMonitor workerOptions)
+ {
+ this.workerOptions = workerOptions;
+ }
+
+ public void PostConfigure(string name, DurableTaskClientOptions options)
+ {
+ if (this.workerOptions.Get(name).Serializer is { } serializer)
+ {
+ options.DataConverter = new ObjectConverterShim(serializer);
+ }
+ }
+ }
+
+ private class ConfigureWorkerOptions : IConfigureOptions
+ {
+ public void Configure(DurableTaskWorkerOptions options)
+ {
+ options.EnableEntitySupport = true;
+ }
+ }
+
+ private class PostConfigureWorkerOptions : IPostConfigureOptions
+ {
+ readonly IOptionsMonitor workerOptions;
+
+ public PostConfigureWorkerOptions(IOptionsMonitor workerOptions)
+ {
+ this.workerOptions = workerOptions;
+ }
+
+ public void PostConfigure(string name, DurableTaskWorkerOptions options)
+ {
+ if (this.workerOptions.Get(name).Serializer is { } serializer)
+ {
+ options.DataConverter = new ObjectConverterShim(serializer);
+ }
+ }
+ }
+}
diff --git a/test/FunctionsV2/DurableTaskListenerTests.cs b/test/FunctionsV2/DurableTaskListenerTests.cs
index 3be412bd9..996a336cd 100644
--- a/test/FunctionsV2/DurableTaskListenerTests.cs
+++ b/test/FunctionsV2/DurableTaskListenerTests.cs
@@ -2,13 +2,10 @@
// Licensed under the MIT License. See LICENSE in the project root for license information.
using System;
-using System.Linq;
-using Microsoft.Azure.WebJobs.Host.Executors;
using Microsoft.Azure.WebJobs.Host.Scale;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
-using Moq;
using Xunit;
namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.Tests
@@ -40,9 +37,9 @@ public void GetMonitor_ReturnsExpectedValue()
IScaleMonitor scaleMonitor = this.listener.GetMonitor();
Assert.Equal(typeof(DurableTaskScaleMonitor), scaleMonitor.GetType());
- Assert.Equal($"{this.functionId}-DurableTaskTrigger-DurableTaskHub".ToLower(), scaleMonitor.Descriptor.Id);
+ Assert.Equal($"DurableTaskTrigger-DurableTaskHub".ToLower(), scaleMonitor.Descriptor.Id);
- var scaleMonitor2 = this.listener.GetMonitor();
+ IScaleMonitor scaleMonitor2 = this.listener.GetMonitor();
Assert.Same(scaleMonitor, scaleMonitor2);
}
diff --git a/test/FunctionsV2/DurableTaskScaleMonitorTests.cs b/test/FunctionsV2/DurableTaskScaleMonitorTests.cs
index 57566da29..330567dda 100644
--- a/test/FunctionsV2/DurableTaskScaleMonitorTests.cs
+++ b/test/FunctionsV2/DurableTaskScaleMonitorTests.cs
@@ -20,8 +20,6 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.Tests
{
public class DurableTaskScaleMonitorTests
{
- private readonly string functionId = "DurableTaskTriggerFunctionId";
- private readonly FunctionName functionName = new FunctionName("DurableTaskTriggerFunctionName");
private readonly string hubName = "DurableTaskTriggerHubName";
private readonly CloudStorageAccount storageAccount = CloudStorageAccount.Parse(TestHelpers.GetStorageConnectionString());
private readonly ITestOutputHelper output;
@@ -41,15 +39,12 @@ public DurableTaskScaleMonitorTests(ITestOutputHelper output)
this.traceHelper = new EndToEndTraceHelper(logger, false);
this.performanceMonitor = new Mock(MockBehavior.Strict, this.storageAccount, this.hubName, (int?)null);
var metricsProvider = new DurableTaskMetricsProvider(
- this.functionName.Name,
this.hubName,
logger,
this.performanceMonitor.Object,
this.storageAccount);
this.scaleMonitor = new DurableTaskScaleMonitor(
- this.functionId,
- this.functionName.Name,
this.hubName,
this.storageAccount,
logger,
@@ -61,7 +56,7 @@ public DurableTaskScaleMonitorTests(ITestOutputHelper output)
[Trait("Category", PlatformSpecificHelpers.TestCategory)]
public void ScaleMonitorDescriptor_ReturnsExpectedValue()
{
- Assert.Equal($"{this.functionId}-DurableTaskTrigger-{this.hubName}".ToLower(), this.scaleMonitor.Descriptor.Id);
+ Assert.Equal($"DurableTaskTrigger-{this.hubName}".ToLower(), this.scaleMonitor.Descriptor.Id);
}
[Fact]
diff --git a/test/FunctionsV2/DurableTaskTargetScalerTests.cs b/test/FunctionsV2/DurableTaskTargetScalerTests.cs
index cc2272e16..0fcae7ec6 100644
--- a/test/FunctionsV2/DurableTaskTargetScalerTests.cs
+++ b/test/FunctionsV2/DurableTaskTargetScalerTests.cs
@@ -47,7 +47,6 @@ public DurableTaskTargetScalerTests(ITestOutputHelper output)
CloudStorageAccount nullCloudStorageAccountMock = null;
this.metricsProviderMock = new Mock(
MockBehavior.Strict,
- "FunctionName",
"HubName",
logger,
nullPerformanceMonitorMock,
diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
index 6f975d2c5..1623f4559 100644
--- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
+++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
@@ -1,6 +1,10 @@
+using System.Net;
+using Azure.Core.Serialization;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.DurableTask.Client;
+using Microsoft.Extensions.Options;
using Moq;
+using Newtonsoft.Json;
namespace Microsoft.Azure.Functions.Worker.Tests
{
@@ -9,7 +13,7 @@ namespace Microsoft.Azure.Functions.Worker.Tests
///
public class FunctionsDurableTaskClientTests
{
- private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null)
+ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null, OrchestrationMetadata? orchestrationMetadata = null)
{
// construct mock client
@@ -21,6 +25,12 @@ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? bas
durableClientMock.Setup(x => x.TerminateInstanceAsync(
It.IsAny(), It.IsAny(), It.IsAny())).Returns(completedTask);
+ if (orchestrationMetadata != null)
+ {
+ durableClientMock.Setup(x => x.GetInstancesAsync(orchestrationMetadata.InstanceId, It.IsAny(), It.IsAny()))
+ .ReturnsAsync(orchestrationMetadata);
+ }
+
DurableTaskClient durableClient = durableClientMock.Object;
FunctionsDurableTaskClient client = new FunctionsDurableTaskClient(durableClient, queryString: null, httpBaseUrl: baseUrl);
return client;
@@ -82,6 +92,8 @@ public void CreateHttpManagementPayload_WithHttpRequestData()
// Create mock HttpRequestData object.
var mockFunctionContext = new Mock();
var mockHttpRequestData = new Mock(mockFunctionContext.Object);
+ var headers = new HttpHeadersCollection();
+ mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers);
mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri(requestUrl));
HttpManagementPayload payload = client.CreateHttpManagementPayload(instanceId, mockHttpRequestData.Object);
@@ -89,6 +101,153 @@ public void CreateHttpManagementPayload_WithHttpRequestData()
AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId);
}
+ ///
+ /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns the expected response when the orchestration is completed.
+ /// The expected response should include OrchestrationMetadata in the body with an HttpStatusCode.OK.
+ ///
+ [Fact]
+ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenCompleted()
+ {
+ string instanceId = "test-instance-id-completed";
+ var expectedResult = new OrchestrationMetadata("TestCompleted", instanceId)
+ {
+ CreatedAt = DateTime.UtcNow,
+ LastUpdatedAt = DateTime.UtcNow,
+ RuntimeStatus = OrchestrationRuntimeStatus.Completed,
+ SerializedCustomStatus = "TestCustomStatus",
+ SerializedInput = "TestInput",
+ SerializedOutput = "TestOutput"
+ };
+
+ var client = this.GetTestFunctionsDurableTaskClient( orchestrationMetadata: expectedResult);
+
+ HttpRequestData request = this.MockHttpRequestAndResponseData();
+
+ HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId);
+
+ Assert.NotNull(response);
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+
+ // Reset stream position for reading
+ response.Body.Position = 0;
+ var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body);
+
+ // Assert the response content is not null and check the content is correct.
+ Assert.NotNull(orchestratorMetadata);
+ AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
+ }
+
+ ///
+ /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestrator didn't finish within
+ /// the timeout period. The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted.
+ ///
+ [Fact]
+ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning()
+ {
+ string instanceId = "test-instance-id-running";
+ var expectedResult = new OrchestrationMetadata("TestRunning", instanceId)
+ {
+ CreatedAt = DateTime.UtcNow,
+ LastUpdatedAt = DateTime.UtcNow,
+ RuntimeStatus = OrchestrationRuntimeStatus.Running,
+ };
+
+ var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);
+
+ HttpRequestData request = this.MockHttpRequestAndResponseData();
+ HttpResponseData response;
+ using (CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)))
+ {
+ response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, cancellation: cts.Token);
+ };
+
+ Assert.NotNull(response);
+ Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
+
+ // Reset stream position for reading
+ response.Body.Position = 0;
+ HttpManagementPayload? payload;
+ using (var reader = new StreamReader(response.Body))
+ {
+ payload = JsonConvert.DeserializeObject(await reader.ReadToEndAsync());
+ }
+
+ // Assert the response content is not null and check the content is correct.
+ Assert.NotNull(payload);
+ AssertHttpManagementPayload(payload, "https://localhost:7075/runtime/webhooks/durabletask", instanceId);
+ }
+
+ ///
+ /// Tests the `WaitForCompletionOrCreateCheckStatusResponseAsync` method to ensure it returns the correct HTTP status code
+ /// based on the `returnInternalServerErrorOnFailure` parameter when the orchestration has failed.
+ ///
+ [Theory]
+ [InlineData(true, HttpStatusCode.InternalServerError)]
+ [InlineData(false, HttpStatusCode.OK)]
+ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFailed(bool returnInternalServerErrorOnFailure, HttpStatusCode expected)
+ {
+ string instanceId = "test-instance-id-failed";
+ var expectedResult = new OrchestrationMetadata("TestFailed", instanceId)
+ {
+ CreatedAt = DateTime.UtcNow,
+ LastUpdatedAt = DateTime.UtcNow,
+ RuntimeStatus = OrchestrationRuntimeStatus.Failed,
+ SerializedOutput = "Microsoft.DurableTask.TaskFailedException: Task 'SayHello' (#0) failed with an unhandled exception: Exception while executing function: Functions.SayHello",
+ SerializedInput = null
+ };
+
+ var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);
+
+ HttpRequestData request = this.MockHttpRequestAndResponseData();
+
+ HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, returnInternalServerErrorOnFailure: returnInternalServerErrorOnFailure);
+
+ Assert.NotNull(response);
+ Assert.Equal(expected, response.StatusCode);
+
+ // Reset stream position for reading
+ response.Body.Position = 0;
+ var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body);
+
+ // Assert the response content is not null and check the content is correct.
+ Assert.NotNull(orchestratorMetadata);
+ AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
+ }
+
+ ///
+ /// Tests the `GetBaseUrlFromRequest` can return the right base URL from the HttpRequestData with different forwarding or proxies.
+ /// This test covers the following scenarios:
+ /// - Using the "Forwarded" header
+ /// - Using "X-Forwarded-Proto" and "X-Forwarded-Host" headers
+ /// - Using only "X-Forwarded-Host" with default protocol
+ /// - no headers
+ ///
+ [Theory]
+ [InlineData("Forwarded", "proto=https;host=forwarded.example.com","","", "https://forwarded.example.com/runtime/webhooks/durabletask")]
+ [InlineData("X-Forwarded-Proto", "https", "X-Forwarded-Host", "xforwarded.example.com", "https://xforwarded.example.com/runtime/webhooks/durabletask")]
+ [InlineData("", "", "X-Forwarded-Host", "test.net", "https://test.net/runtime/webhooks/durabletask")]
+ [InlineData("", "", "", "", "https://localhost:7075/runtime/webhooks/durabletask")] // Default base URL for empty headers
+ public void TestHttpRequestDataForwardingHandling(string header1, string? value1, string header2, string value2, string expectedBaseUrl)
+ {
+ var headers = new HttpHeadersCollection();
+ if (!string.IsNullOrEmpty(header1))
+ {
+ headers.Add(header1, value1);
+ }
+ if (!string.IsNullOrEmpty(header2))
+ {
+ headers.Add(header2, value2);
+ }
+
+ var request = this.MockHttpRequestAndResponseData(headers);
+ var client = this.GetTestFunctionsDurableTaskClient();
+
+ var payload = client.CreateHttpManagementPayload("testInstanceId", request);
+ AssertHttpManagementPayload(payload, expectedBaseUrl, "testInstanceId");
+ }
+
+
+
private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId)
{
Assert.Equal(instanceId, payload.Id);
@@ -99,5 +258,79 @@ private static void AssertHttpManagementPayload(HttpManagementPayload payload, s
Assert.Equal($"{BaseUrl}/instances/{instanceId}/suspend?reason={{{{text}}}}", payload.SuspendPostUri);
Assert.Equal($"{BaseUrl}/instances/{instanceId}/resume?reason={{{{text}}}}", payload.ResumePostUri);
}
+
+ private static void AssertOrhcestrationMetadata(OrchestrationMetadata expectedResult, dynamic actualResult)
+ {
+ Assert.Equal(expectedResult.Name, actualResult.GetProperty("Name").GetString());
+ Assert.Equal(expectedResult.InstanceId, actualResult.GetProperty("InstanceId").GetString());
+ Assert.Equal(expectedResult.CreatedAt, actualResult.GetProperty("CreatedAt").GetDateTime());
+ Assert.Equal(expectedResult.LastUpdatedAt, actualResult.GetProperty("LastUpdatedAt").GetDateTime());
+ Assert.Equal(expectedResult.RuntimeStatus.ToString(), actualResult.GetProperty("RuntimeStatus").GetString());
+ Assert.Equal(expectedResult.SerializedInput, actualResult.GetProperty("SerializedInput").GetString());
+ Assert.Equal(expectedResult.SerializedOutput, actualResult.GetProperty("SerializedOutput").GetString());
+ Assert.Equal(expectedResult.SerializedCustomStatus, actualResult.GetProperty("SerializedCustomStatus").GetString());
+ }
+
+ // Mocks the required HttpRequestData and HttpResponseData for testing purposes.
+ // This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body.
+ // The headers of HttpRequestData can be provided as an optional parameter, otherwise an empty HttpHeadersCollection is used.
+ private HttpRequestData MockHttpRequestAndResponseData(HttpHeadersCollection? headers = null)
+ {
+ var mockObjectSerializer = new Mock();
+
+ // Setup the SerializeAsync method
+ mockObjectSerializer.Setup(s => s.SerializeAsync(It.IsAny(), It.IsAny