Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement presence API for IServiceConnectionContainer #2125

Merged
merged 6 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand Down Expand Up @@ -168,4 +168,9 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
}
}
}

public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
{
throw new NotImplementedException();
}
}
17 changes: 17 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Interfaces/IPresenceManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading;

using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR;

/// <summary>
/// Manager for presence operations.
/// </summary>
internal interface IPresenceManager
{
IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace Microsoft.Azure.SignalR;

internal interface IServiceConnectionContainer : IServiceConnectionManager, IDisposable
internal interface IServiceConnectionContainer : IServiceConnectionManager, IPresenceManager, IDisposable
{
ServiceConnectionStatus Status { get; }

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand All @@ -16,7 +18,7 @@ namespace Microsoft.Azure.SignalR;
/// <summary>
/// A service connection container which sends message to multiple service endpoints.
/// </summary>
internal class MultiEndpointMessageWriter : IServiceMessageWriter
internal class MultiEndpointMessageWriter : IServiceMessageWriter, IPresenceManager
{
private readonly ILogger _logger;

Expand Down Expand Up @@ -55,8 +57,8 @@ public Task WriteAsync(ServiceMessage serviceMessage)

public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default)
{
if (serviceMessage is CheckConnectionExistenceWithAckMessage
|| serviceMessage is JoinGroupWithAckMessage
if (serviceMessage is CheckConnectionExistenceWithAckMessage
|| serviceMessage is JoinGroupWithAckMessage
|| serviceMessage is LeaveGroupWithAckMessage)
{
return WriteSingleResultAckableMessage(serviceMessage, cancellationToken);
Expand Down Expand Up @@ -172,6 +174,44 @@ private async Task WriteSingleEndpointMessageAsync(HubServiceEndpoint endpoint,
}
}

public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, [EnumeratorCancellation] CancellationToken token = default)
{
if (TargetEndpoints.Length == 0)
{
Log.NoEndpointRouted(_logger, nameof(GroupMemberQueryMessage));
yield break;
}
if (top <= 0)
{
throw new ArgumentOutOfRangeException(nameof(top), "Top must be greater than 0.");
}
foreach (var endpoint in TargetEndpoints)
{
IAsyncEnumerable<GroupMember> enumerable;
try
{
enumerable = endpoint.ConnectionContainer.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
}
catch (ServiceConnectionNotActiveException)
{
Log.FailedWritingMessageToEndpoint(_logger, nameof(GroupMemberQueryMessage), null, endpoint.ToString());
continue;
}
await foreach (var member in enumerable)
{
yield return member;
if (top.HasValue)
{
top--;
if (top == 0)
{
yield break;
}
}
}
}
}

internal static class Log
{
public const string FailedWritingMessageToEndpointTemplate = "{0} message {1} is not sent to endpoint {2} because all connections to this endpoint are offline.";
Expand Down Expand Up @@ -211,4 +251,4 @@ public static void FailedWritingMessageToEndpoint(ILogger logger, string message
_failedWritingMessageToEndpoint(logger, messageType, tracingId, endpoint, null);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -154,6 +155,14 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
return CreateMessageWriter(serviceMessage).WriteAckableMessageAsync(serviceMessage, cancellationToken);
}


public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
{
var targetEndpoints = _routerEndpoints.needRouter ? _router.GetEndpointsForGroup(groupName, _routerEndpoints.endpoints) : _routerEndpoints.endpoints;
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _loggerFactory);
return messageWriter.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
}

public Task StartGetServersPing()
{
return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StartGetServersPing()));
Expand Down Expand Up @@ -499,4 +508,4 @@ public static void FailedRemovingConnectionForEndpoint(ILogger logger, string en
_failedRemovingConnectionForEndpoint(logger, endpoint, ex);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand All @@ -8,6 +8,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -221,7 +222,7 @@ public virtual Task HandlePingAsync(PingMessage pingMessage)

public void HandleAck(AckMessage ackMessage)
{
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status);
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status, ackMessage.Payload);
}

public virtual Task WriteAsync(ServiceMessage serviceMessage)
Expand Down Expand Up @@ -249,6 +250,60 @@ public async Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage,
return AckHandler.HandleAckStatus(ackableMessage, status);
}

public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, [EnumeratorCancellation] CancellationToken token = default)
{
if (string.IsNullOrWhiteSpace(groupName))
{
throw new ArgumentException($"'{nameof(groupName)}' cannot be null or whitespace.", nameof(groupName));
}
if (top != null && top <= 0)
{
throw new ArgumentException($"'{nameof(top)}' must be greater than 0.", nameof(top));
}
var message = new GroupMemberQueryMessage() { GroupName = groupName, Top = top, TracingId = tracingId };
do
{
var response = await InvokeAsync<GroupMemberQueryResponse>(message, token);
foreach (var member in response.Members)
{
yield return member;
}
if (response.ContinuationToken == null)
{
yield break;
}
if (message.Top != null)
{
message.Top -= response.Members.Count;
}
message.ContinuationToken = response.ContinuationToken;
} while (true);
}

/// <summary>
/// <see cref="WriteAckableMessageAsync(ServiceMessage, CancellationToken)"/> only checks <see cref="AckMessage.Status"/> as the response,
/// while this method checks <see cref="AckMessage.Payload"/> and deserialize it to <typeparamref name="T"/>.
/// </summary>
/// Made "interval virtual" for testing
internal virtual async Task<T> InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new()
{
if (serviceMessage is not IAckableMessage ackableMessage)
{
throw new ArgumentException($"{nameof(serviceMessage)} is not {nameof(IAckableMessage)}");
}

var task = _ackHandler.CreateSingleAck<T>(out var id, null, cancellationToken);
ackableMessage.AckId = id;

// Sending regular messages completes as soon as the data leaves the outbound pipe,
// whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout).
// Therefore sending them over different connections creates a possibility for processing them out of original order.
// By sending both message types over the same connection we ensure that they are sent (and processed) in their original order.
await WriteMessageAsync(serviceMessage);

return await task;
}

public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token)
{
_terminated = true;
Expand Down
40 changes: 20 additions & 20 deletions src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ public Task<AckStatus> CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul
return info.Task;
}

public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : IMessagePackSerializable, new()
{
id = NextId();
if (_disposed)
public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : notnull, new()
{
return Task.FromResult(new T());
id = NextId();
if (_disposed)
{
return Task.FromResult(new T());
}
var info = (SinglePayloadAck<T>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
cancellationToken.Register(info.Cancel);
return info.Task.ContinueWith(task => task.Result);
}
var info = (IAckInfo<IMessagePackSerializable>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
cancellationToken.Register(info.Cancel);
return info.Task.ContinueWith(task => (T)task.Result);
}

public static bool HandleAckStatus(IAckableMessage message, AckStatus status)
{
Expand Down Expand Up @@ -210,19 +210,19 @@ public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = nul
_tcs.TrySetResult(status);
}

private sealed class SinglePayloadAck<T> : SingleAckInfo<IMessagePackSerializable> where T : IMessagePackSerializable, new()
{
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
private sealed class SinglePayloadAck<T> : SingleAckInfo<T> where T : notnull, new()
{
if (status == AckStatus.Timeout)
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
{
return _tcs.TrySetException(new TimeoutException($"Waiting for a {typeof(T).Name} response timed out."));
}
if (payload == null)
{
return _tcs.TrySetException(new InvalidDataException($"The expected payload is null."));
}
if (status == AckStatus.Timeout)
{
return _tcs.TrySetException(new TimeoutException($"Waiting for a {typeof(T).Name} response timed out."));
}
if (payload == null)
{
return _tcs.TrySetException(new InvalidDataException($"The expected payload is null."));
}

try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Buffers;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Logging;
using Moq;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -101,4 +107,62 @@ public void TestStrongConnectionStatus()
Assert.True(endpoint1.Online);
}
}

[Fact]
public async Task TestInvokeAsync()
{
var endpoint1 = new TestHubServiceEndpoint();
var conn1 = new TestServiceConnection();
var scf = new TestServiceConnectionFactory(endpoint1 => conn1);
var container = new WeakServiceConnectionContainer(scf, 5, endpoint1, Mock.Of<ILogger>());
var queryMessage = new GroupMemberQueryMessage() { GroupName = "group" };
var invokeTask = container.InvokeAsync<GroupMemberQueryResponse>(queryMessage, default);

var expectedResponse = new GroupMemberQueryResponse()
{
ContinuationToken = "abc",
Members = [new() { ConnectionId = "1" }, new() { ConnectionId = "2" }]
};
var buffer = new ArrayBufferWriter<byte>();
new ServiceProtocol().WriteMessagePayload(expectedResponse, buffer);
AckHandler.Singleton.TriggerAck(queryMessage.AckId, AckStatus.Ok, new ReadOnlySequence<byte>(buffer.WrittenMemory));
var response = await invokeTask;
Assert.Equal(queryMessage, conn1.ReceivedMessages.Single());
Assert.Equal(expectedResponse.ContinuationToken, response.ContinuationToken);
Assert.True(expectedResponse.Members.SequenceEqual(response.Members));
}

[Fact]
public async Task TestListConnectionsInGroupAsync()
{
var conn = new TestServiceConnection();
var groupName = "groupName";
var top = 3;
var tracingId = (ulong)1;
var connectionContainerMock = new Mock<ServiceConnectionContainerBase>(
new TestServiceConnectionFactory(endpoint => conn),
5,
new TestHubServiceEndpoint(),
null,
Mock.Of<ILogger>(),
null);
connectionContainerMock.SetupSequence(c => c.InvokeAsync<GroupMemberQueryResponse>(
It.IsAny<ServiceMessage>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new GroupMemberQueryResponse() { ContinuationToken = "abc", Members = [new() { ConnectionId = "1" }, new() { ConnectionId = "2" }] })
.ReturnsAsync(new GroupMemberQueryResponse() { ContinuationToken = null, Members = [new() { ConnectionId = "3" }] });
var enumerator = connectionContainerMock.Object
.ListConnectionsInGroupAsync(groupName, top, tracingId)
.GetAsyncEnumerator();
Assert.True(await enumerator.MoveNextAsync());
Assert.Equal("1", enumerator.Current.ConnectionId);
connectionContainerMock.Verify(c => c.InvokeAsync<GroupMemberQueryResponse>(
It.Is<GroupMemberQueryMessage>(m => m.GroupName == groupName && m.Top == 3 && m.TracingId == tracingId), It.IsAny<CancellationToken>()), Times.Once);
connectionContainerMock.Invocations.Clear();
Assert.True(await enumerator.MoveNextAsync());
Assert.True(await enumerator.MoveNextAsync());
Assert.Equal("3", enumerator.Current.ConnectionId);
connectionContainerMock.Verify(c => c.InvokeAsync<GroupMemberQueryResponse>(
It.Is<GroupMemberQueryMessage>(m => m.GroupName == groupName && m.Top == 1 && m.TracingId == tracingId && m.ContinuationToken == "abc"), It.IsAny<CancellationToken>()), Times.Once);
Assert.False(await enumerator.MoveNextAsync());
}
}
Loading
Loading