Skip to content

Commit

Permalink
Implement persence API in websocket transport
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Sindo committed Jan 20, 2025
1 parent 6e4f3a7 commit d4806ce
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,9 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
}
}
}

public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top)
{
throw new NotSupportedException();
}
}
15 changes: 15 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Interfaces/IPresenceManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR;

/// <summary>
/// Manager for presence operations.
/// </summary>
internal interface IPresenceManager
{
IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace Microsoft.Azure.SignalR;

internal interface IServiceConnectionContainer : IServiceConnectionManager, IDisposable
internal interface IServiceConnectionContainer : IServiceConnectionManager, IPresenceManager, IDisposable
{
ServiceConnectionStatus Status { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.Azure.SignalR;
/// <summary>
/// A service connection container which sends message to multiple service endpoints.
/// </summary>
internal class MultiEndpointMessageWriter : IServiceMessageWriter
internal class MultiEndpointMessageWriter : IServiceMessageWriter, IPresenceManager
{
private readonly ILogger _logger;

Expand Down Expand Up @@ -55,8 +55,8 @@ public Task WriteAsync(ServiceMessage serviceMessage)

public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default)
{
if (serviceMessage is CheckConnectionExistenceWithAckMessage
|| serviceMessage is JoinGroupWithAckMessage
if (serviceMessage is CheckConnectionExistenceWithAckMessage
|| serviceMessage is JoinGroupWithAckMessage
|| serviceMessage is LeaveGroupWithAckMessage)
{
return WriteSingleResultAckableMessage(serviceMessage, cancellationToken);
Expand Down Expand Up @@ -172,6 +172,32 @@ private async Task WriteSingleEndpointMessageAsync(HubServiceEndpoint endpoint,
}
}

public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top)
{
if (TargetEndpoints.Length == 0)
{
Log.NoEndpointRouted(_logger, nameof(GroupMemberQueryMessage));
yield break;
}
foreach (var endpoint in TargetEndpoints)
{
IAsyncEnumerable<GroupMember> enumerable;
try
{
enumerable = endpoint.ConnectionContainer.ListConnectionsInGroupAsync(groupName, top);
}
catch (ServiceConnectionNotActiveException)
{
Log.FailedWritingMessageToEndpoint(_logger, nameof(GroupMemberQueryMessage), null, endpoint.ToString());
continue;
}
await foreach (var member in enumerable)
{
yield return member;
}
}
}

internal static class Log
{
public const string FailedWritingMessageToEndpointTemplate = "{0} message {1} is not sent to endpoint {2} because all connections to this endpoint are offline.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
return CreateMessageWriter(serviceMessage).WriteAckableMessageAsync(serviceMessage, cancellationToken);
}


public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top)
{
var targetEndpoints = _routerEndpoints.needRouter ? _router.GetEndpointsForGroup(groupName, _routerEndpoints.endpoints) : _routerEndpoints.endpoints;
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _loggerFactory);
return messageWriter.ListConnectionsInGroupAsync(groupName, top);
}

public Task StartGetServersPing()
{
return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StartGetServersPing()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public virtual Task HandlePingAsync(PingMessage pingMessage)

public void HandleAck(AckMessage ackMessage)
{
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status);
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status, ackMessage.Payload);
}

public virtual Task WriteAsync(ServiceMessage serviceMessage)
Expand Down Expand Up @@ -249,6 +249,53 @@ public async Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage,
return AckHandler.HandleAckStatus(ackableMessage, status);
}

public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top)
{
var currentCount = 0;
do
{
if (top != null)
{
top -= currentCount;
}
var message = new GroupMemberQueryMessage() { GroupName = groupName, Top = top };
var response = await InvokeAsync<GroupMemberQueryResponse>(message);
foreach (var member in response.Members)
{
yield return member;
currentCount++;
if (top != null && currentCount >= top || response.ContinuationToken == null)
{
yield break;
}
}
message.ContinuationToken = response.ContinuationToken;
} while (true);
}

/// <summary>
/// <see cref="WriteAckableMessageAsync(ServiceMessage, CancellationToken)"/> only checks <see cref="AckMessage.Status"/> as the response,
/// while this method checks <see cref="AckMessage.Payload"/> and deserialize it to <typeparamref name="T"/>.
/// </summary>
private async Task<T> InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new()
{
if (serviceMessage is not IAckableMessage ackableMessage)
{
throw new ArgumentException($"{nameof(serviceMessage)} is not {nameof(IAckableMessage)}");
}

var task = _ackHandler.CreateSingleAck<T>(out var id, null, cancellationToken);
ackableMessage.AckId = id;

// Sending regular messages completes as soon as the data leaves the outbound pipe,
// whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout).
// Therefore sending them over different connections creates a possibility for processing them out of original order.
// By sending both message types over the same connection we ensure that they are sent (and processed) in their original order.
await WriteMessageAsync(serviceMessage);

return await task;
}

public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token)
{
_terminated = true;
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ public Task<AckStatus> CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul
return info.Task;
}

public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : IMessagePackSerializable, new()
public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : notnull, new()
{
id = NextId();
if (_disposed)
{
return Task.FromResult(new T());
}
var info = (IAckInfo<IMessagePackSerializable>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
var info = (SinglePayloadAck<T>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
cancellationToken.Register(info.Cancel);
return info.Task.ContinueWith(task => (T)task.Result);
return info.Task.ContinueWith(task => task.Result);
}

public static bool HandleAckStatus(IAckableMessage message, AckStatus status)
Expand Down Expand Up @@ -210,7 +210,7 @@ public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = nul
_tcs.TrySetResult(status);
}

private sealed class SinglePayloadAck<T> : SingleAckInfo<IMessagePackSerializable> where T : IMessagePackSerializable, new()
private sealed class SinglePayloadAck<T> : SingleAckInfo<T> where T : notnull, new()
{
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,9 @@ public Task CloseClientConnections(CancellationToken token)
{
throw new NotImplementedException();
}

public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top)
{
throw new NotImplementedException();
}
}

0 comments on commit d4806ce

Please sign in to comment.