diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/AuthUtility.cs b/src/Microsoft.Azure.SignalR.Common/Auth/AuthUtility.cs index 078c5ad50..fc229c2f7 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/AuthUtility.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/AuthUtility.cs @@ -38,6 +38,12 @@ public static string GenerateJwtBearer( { KeyId = signingKey.Id }; + + if (signingKey is AadAccessKey) + { + // disable cache when using AadAccessKey + securityKey.CryptoProviderFactory.CacheSignatureProviders = false; + } credentials = new SigningCredentials(securityKey, GetSecurityAlgorithm(algorithm)); } diff --git a/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs index 31281166a..e07c3bfc1 100644 --- a/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs +++ b/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs @@ -84,11 +84,13 @@ private static IServiceCollection AddSignalRServiceCore(this IServiceCollection .Where(service => service.ServiceType != typeof(IServiceConnectionContainer)) .Where(service => service.ServiceType != typeof(IHostedService)); services.Add(tempServices); - services.AddSingleton(sp => + // Remove the JsonHubProtocol and add new one. + // On .NET Standard 2.0, registering multiple hub protocols with the same name is forbidden. + services.Replace(ServiceDescriptor.Singleton(sp => { var objectSerializer = sp.GetRequiredService>().Value.ObjectSerializer; return objectSerializer != null ? new JsonObjectSerializerHubProtocol(objectSerializer) : new JsonHubProtocol(); - }); + })); //add dependencies for persistent mode only services .AddSingleton() diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs index e0ed9fc4e..71d101577 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs @@ -2,11 +2,13 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Collections; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Linq; using System.Security.Claims; using System.Text; - +using Azure.Identity; using Microsoft.IdentityModel.Tokens; using Xunit; @@ -30,14 +32,16 @@ public void TestAccessTokenTooLongThrowsException() Assert.Equal("AccessToken must not be longer than 4K.", exception.Message); } - [Fact] - public void TestGenerateJwtBearerCaching() + [Theory] + [ClassData(typeof(CachingTestData))] + internal void TestGenerateJwtBearerCaching(AccessKey accessKey, bool shouldCache) { var count = 0; while (count < 1000) { - var accessKey = new AccessKey("http://localhost:443", SigningKey); - AuthUtility.GenerateJwtBearer(audience: Audience, expires: DateTime.UtcNow.Add(DefaultLifetime), signingKey: accessKey); + AuthUtility.GenerateJwtBearer(audience: Audience, + expires: DateTime.UtcNow.Add(DefaultLifetime), + signingKey: accessKey); count++; }; @@ -48,14 +52,35 @@ public void TestGenerateJwtBearerCaching() var value = cache.GetType().GetField("_signingSignatureProviders", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance).GetValue(cache); var signingProviders = value as ConcurrentDictionary; - + // Validate same signing key cache once. - Assert.Single(signingProviders); + if (shouldCache) + { + Assert.Single(signingProviders); + } + else + { + Assert.Empty(signingProviders); + } + signingProviders.Clear(); } - private Claim[] GenerateClaims(int count) + private static Claim[] GenerateClaims(int count) { return Enumerable.Range(0, count).Select(s => new Claim($"ClaimSubject{s}", $"ClaimValue{s}")).ToArray(); } + + public class CachingTestData : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return new object[] { new AccessKey("http://localhost:443", SigningKey), true }; + var key = new AadAccessKey(new Uri("http://localhost"), new DefaultAzureCredential()); + key.UpdateAccessKey("foo", SigningKey); + yield return new object[] { key, false }; + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } } }