From 8e3670d156d750ee2f371ba03d9f20248f090f9e Mon Sep 17 00:00:00 2001 From: Zitong Yang Date: Thu, 16 Jan 2025 13:59:43 +0800 Subject: [PATCH] Add websocket implementation for presence API --- .../ServiceConnectionManager.cs | 5 ++ .../Interfaces/IServiceMessageWriter.cs | 2 + .../MultiEndpointMessageWriter.cs | 56 ++++++++++++++++++- ...MultiEndpointServiceConnectionContainer.cs | 5 ++ .../ServiceConnectionContainerBase.cs | 25 ++++++++- .../Utilities/AckHandler.cs | 4 +- .../HubContext/GroupManager.cs | 5 ++ .../HubContext/GroupManagerAdapter.cs | 9 ++- .../IServiceHubLifetimeManager.cs | 4 ++ .../RestHubLifetimeManager.cs | 6 ++ .../WebsocketsHubLifetimeManager.cs | 49 ++++++++++++++++ .../ServiceConnectionManager.cs | 5 ++ .../TestServiceConnectionContainer.cs | 5 ++ .../TestServiceConnectionManager.cs | 5 ++ .../ServiceHubDispatcherTests.cs | 10 +++- 15 files changed, 186 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs index 3f09eec14..ec6a32d8c 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs @@ -168,4 +168,9 @@ private IEnumerable GetConnections() } } } + + Task IServiceMessageWriter.InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken) + { + return _appConnection.InvokeAsync(serviceMessage, cancellationToken); + } } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageWriter.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageWriter.cs index ba0699d0e..2bde6fb3d 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageWriter.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageWriter.cs @@ -12,4 +12,6 @@ internal interface IServiceMessageWriter Task WriteAsync(ServiceMessage serviceMessage); Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default); + + Task InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new(); } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs index e9cf62fb5..458decfd0 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs @@ -55,8 +55,8 @@ public Task WriteAsync(ServiceMessage serviceMessage) public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) { - if (serviceMessage is CheckConnectionExistenceWithAckMessage - || serviceMessage is JoinGroupWithAckMessage + if (serviceMessage is CheckConnectionExistenceWithAckMessage + || serviceMessage is JoinGroupWithAckMessage || serviceMessage is LeaveGroupWithAckMessage) { return WriteSingleResultAckableMessage(serviceMessage, cancellationToken); @@ -67,6 +67,58 @@ public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel } } + public async Task InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new() + { + var results = new List(); + await WriteMultiEndpointMessageAsync(serviceMessage, async connection => + results.Add(await connection.InvokeAsync(serviceMessage.Clone(), cancellationToken))); + switch (serviceMessage) + { + case GroupMemberQueryMessage memberQueryMessage: + if (results is not IEnumerable memberQueryResults || typeof(T) != typeof(GroupMemberQueryResponse)) + { + throw new InvalidOperationException($"The response of {nameof(GroupMemberQueryMessage)} should be of type {nameof(GroupMemberQueryResponse)}."); + } + return (T)(object)AggregateGroupMemberQueryResult(memberQueryResults, memberQueryMessage); + default: + throw new NotSupportedException($"{serviceMessage.GetType().Name} is not supported."); + } + static GroupMemberQueryResponse AggregateGroupMemberQueryResult(IEnumerable results, GroupMemberQueryMessage message) + { + var totalMembers = results.SelectMany(r => r.Members).ToArray(); + var totalCount = totalMembers.Length; + if (totalCount <= message.Max) + { + // Quick path: Return all the members + var continuationToken = default(string); + // Select the maximum connection ID. + foreach (var member in totalMembers) + { + if (string.Compare(member.ConnectionId, continuationToken, StringComparison.InvariantCulture) > 0) + { + continuationToken = member.ConnectionId; + } + } + return new GroupMemberQueryResponse + { + Members = totalMembers, + ContinuationToken = continuationToken + }; + } + else + { + // Slow path: get the minimal N members + // Priority Queue is not available in .NET Standard 2.0, so we sort the array instead + Array.Sort(totalMembers, (m1, m2) => string.Compare(m1.ConnectionId, m2.ConnectionId, StringComparison.InvariantCulture)); + return new GroupMemberQueryResponse + { + Members = new ArraySegment(totalMembers, 0, message.Max), + ContinuationToken = totalMembers[message.Max - 1].ConnectionId + }; + } + } + } + /// /// For user or group related operations, different endpoints might return different results /// Strategy: diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs index 9c80086e2..6b787e3c2 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs @@ -150,6 +150,11 @@ public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel return CreateMessageWriter(serviceMessage).WriteAckableMessageAsync(serviceMessage, cancellationToken); } + public Task InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new() + { + return CreateMessageWriter(serviceMessage).InvokeAsync(serviceMessage, cancellationToken); + } + public Task StartGetServersPing() { return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StartGetServersPing())); diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index bbe5857da..469df4633 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -221,7 +221,7 @@ public virtual Task HandlePingAsync(PingMessage pingMessage) public void HandleAck(AckMessage ackMessage) { - _ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status); + _ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status, ackMessage.Payload); } public virtual Task WriteAsync(ServiceMessage serviceMessage) @@ -249,6 +249,29 @@ public async Task WriteAckableMessageAsync(ServiceMessage serviceMessage, return AckHandler.HandleAckStatus(ackableMessage, status); } + /// + /// only checks as the response, + /// while this method checks and deserialize it to . + /// + public async Task InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new() + { + if (serviceMessage is not IAckableMessage ackableMessage) + { + throw new ArgumentException($"{nameof(serviceMessage)} is not {nameof(IAckableMessage)}"); + } + + var task = _ackHandler.CreateSingleAck(out var id, null, cancellationToken); + ackableMessage.AckId = id; + + // Sending regular messages completes as soon as the data leaves the outbound pipe, + // whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout). + // Therefore sending them over different connections creates a possibility for processing them out of original order. + // By sending both message types over the same connection we ensure that they are sent (and processed) in their original order. + await WriteMessageAsync(serviceMessage); + + return await task; + } + public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) { _terminated = true; diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs index 86a8293d7..865adb373 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs @@ -50,7 +50,7 @@ public Task CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul return info.Task; } - public Task CreateSingleAck(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : IMessagePackSerializable, new() + public Task CreateSingleAck(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : notnull, new() { id = NextId(); if (_disposed) @@ -210,7 +210,7 @@ public override bool Ack(AckStatus status, ReadOnlySequence? payload = nul _tcs.TrySetResult(status); } - private sealed class SinglePayloadAck : SingleAckInfo where T : IMessagePackSerializable, new() + private sealed class SinglePayloadAck : SingleAckInfo where T : notnull, new() { public SinglePayloadAck(TimeSpan timeout) : base(timeout) { } public override bool Ack(AckStatus status, ReadOnlySequence? payload = null) diff --git a/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManager.cs b/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManager.cs index bd2d86492..5584b562e 100644 --- a/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManager.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Protocol; namespace Microsoft.Azure.SignalR.Management { @@ -14,5 +17,7 @@ public abstract class GroupManager : IGroupManager public abstract Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default); public abstract Task RemoveFromAllGroupsAsync(string connectionId, CancellationToken cancellationToken = default); + + public virtual IAsyncEnumerable ListConnectionsInGroupAsync(string groupName, int? max, CancellationToken token) => throw new NotImplementedException(); } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManagerAdapter.cs b/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManagerAdapter.cs index 12520adea..860f54fd3 100644 --- a/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManagerAdapter.cs +++ b/src/Microsoft.Azure.SignalR.Management/HubContext/GroupManagerAdapter.cs @@ -1,17 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Protocol; namespace Microsoft.Azure.SignalR.Management { internal class GroupManagerAdapter : GroupManager { - private readonly IHubLifetimeManager _lifetimeManager; + private readonly IServiceHubLifetimeManager _lifetimeManager; - public GroupManagerAdapter(IHubLifetimeManager lifetimeManager) + public GroupManagerAdapter(IServiceHubLifetimeManager lifetimeManager) { _lifetimeManager = lifetimeManager; } @@ -21,5 +22,7 @@ public GroupManagerAdapter(IHubLifetimeManager lifetimeManager) public override Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) => _lifetimeManager.RemoveFromGroupAsync(connectionId, groupName, cancellationToken); public override Task RemoveFromAllGroupsAsync(string connectionId, CancellationToken cancellationToken = default) => _lifetimeManager.RemoveFromAllGroupsAsync(connectionId, cancellationToken); + + public override IAsyncEnumerable ListConnectionsInGroupAsync(string groupName, int? top, CancellationToken token) => _lifetimeManager.ListConnectionsInGroup(groupName, top, token); } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Management/IServiceHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/IServiceHubLifetimeManager.cs index f5f7b608d..205090062 100644 --- a/src/Microsoft.Azure.SignalR.Management/IServiceHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/IServiceHubLifetimeManager.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Azure.SignalR.Protocol; namespace Microsoft.Azure.SignalR.Management { @@ -15,5 +17,7 @@ internal interface IServiceHubLifetimeManager : IHubLifetimeManager, IUserGroupH Task UserExistsAsync(string userId, CancellationToken cancellationToken); Task GroupExistsAsync(string groupName, CancellationToken cancellationToken); + + IAsyncEnumerable ListConnectionsInGroup(string groupName, int? top = null, CancellationToken token = default); } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs index cf2b968d5..538a8dabb 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Primitives; using static Microsoft.Azure.SignalR.Constants; @@ -310,6 +311,11 @@ await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedRespons return exists; } + public IAsyncEnumerable ListConnectionsInGroup(string groupName, int? top = null, CancellationToken token = default) + { + throw new NotImplementedException(); + } + public Task DisposeAsync() => Task.CompletedTask; private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) => diff --git a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs index ee5983391..0911d6690 100644 --- a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs @@ -2,6 +2,8 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. #nullable enable using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; @@ -204,6 +206,53 @@ public Task GroupExistsAsync(string groupName, CancellationToken cancellat return WriteAckableMessageAsync(message, cancellationToken); } + public async IAsyncEnumerable ListConnectionsInGroup(string groupName, int? top = null, [EnumeratorCancellation] CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(groupName)) + { + throw new ArgumentException($"'{nameof(groupName)}' cannot be null or whitespace.", nameof(groupName)); + } + + var message = new GroupMemberQueryMessage() + { + GroupName = groupName, + }; + if (top != null) + { + message.Max = top.Value; + } + AppendMessageTracingId(message); + + var currentCount = 0; + while (true) + { + var response = await ServiceConnectionContainer.InvokeAsync(message, token); + currentCount += response.Members.Count; + if (top != null) + { + top = top - currentCount; + if (top <= 0) + { + response.ContinuationToken = null; + } + } + yield return response; + if (response.ContinuationToken == null) + { + yield break; + } + message = new GroupMemberQueryMessage() + { + GroupName = groupName, + ContinuationToken = response.ContinuationToken, + }; + if (top != null) + { + message.Max = top.Value - currentCount; + } + } + } + protected override T AppendMessageTracingId(T message) { if (_serviceManagerOptions.Value.EnableMessageTracing) diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs index 498dcfd3c..6a6193a70 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs @@ -63,4 +63,9 @@ public void Dispose() { StopAsync().GetAwaiter().GetResult(); } + + Task IServiceMessageWriter.InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken) + { + return _serviceConnection.InvokeAsync(serviceMessage, cancellationToken); + } } diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs index 4bf8b6f1f..3b1681103 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs @@ -109,4 +109,9 @@ public Task CloseClientConnections(CancellationToken token) { throw new NotImplementedException(); } + + Task IServiceMessageWriter.InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } } \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs index a2ff32333..0013c9541 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs @@ -62,4 +62,9 @@ public Task StopAsync() public Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) => Task.CompletedTask; public Task CloseClientConnections(CancellationToken token) => Task.CompletedTask; + + Task IServiceMessageWriter.InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } } diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs index 3e4ed57e1..6f68a03da 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; @@ -141,5 +144,10 @@ public Task CloseClientConnections(CancellationToken token) { throw new NotImplementedException(); } + + Task IServiceMessageWriter.InvokeAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } } }