Skip to content

Commit

Permalink
Add websocket implementation for presence API
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Sindo committed Jan 16, 2025
1 parent e0b874f commit 8e3670d
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,9 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
}
}
}

Task<T> IServiceMessageWriter.InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken)
{
return _appConnection.InvokeAsync<T>(serviceMessage, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ internal interface IServiceMessageWriter
Task WriteAsync(ServiceMessage serviceMessage);

Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default);

Task<T> InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new();
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ public Task WriteAsync(ServiceMessage serviceMessage)

public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default)
{
if (serviceMessage is CheckConnectionExistenceWithAckMessage
|| serviceMessage is JoinGroupWithAckMessage
if (serviceMessage is CheckConnectionExistenceWithAckMessage
|| serviceMessage is JoinGroupWithAckMessage
|| serviceMessage is LeaveGroupWithAckMessage)
{
return WriteSingleResultAckableMessage(serviceMessage, cancellationToken);
Expand All @@ -67,6 +67,58 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
}
}

public async Task<T> InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new()
{
var results = new List<T>();
await WriteMultiEndpointMessageAsync(serviceMessage, async connection =>
results.Add(await connection.InvokeAsync<T>(serviceMessage.Clone(), cancellationToken)));
switch (serviceMessage)
{
case GroupMemberQueryMessage memberQueryMessage:
if (results is not IEnumerable<GroupMemberQueryResponse> memberQueryResults || typeof(T) != typeof(GroupMemberQueryResponse))
{
throw new InvalidOperationException($"The response of {nameof(GroupMemberQueryMessage)} should be of type {nameof(GroupMemberQueryResponse)}.");
}
return (T)(object)AggregateGroupMemberQueryResult(memberQueryResults, memberQueryMessage);
default:
throw new NotSupportedException($"{serviceMessage.GetType().Name} is not supported.");
}
static GroupMemberQueryResponse AggregateGroupMemberQueryResult(IEnumerable<GroupMemberQueryResponse> results, GroupMemberQueryMessage message)
{
var totalMembers = results.SelectMany(r => r.Members).ToArray();
var totalCount = totalMembers.Length;
if (totalCount <= message.Max)
{
// Quick path: Return all the members
var continuationToken = default(string);
// Select the maximum connection ID.
foreach (var member in totalMembers)
{
if (string.Compare(member.ConnectionId, continuationToken, StringComparison.InvariantCulture) > 0)
{
continuationToken = member.ConnectionId;
}
}
return new GroupMemberQueryResponse
{
Members = totalMembers,
ContinuationToken = continuationToken
};
}
else
{
// Slow path: get the minimal N members
// Priority Queue is not available in .NET Standard 2.0, so we sort the array instead
Array.Sort(totalMembers, (m1, m2) => string.Compare(m1.ConnectionId, m2.ConnectionId, StringComparison.InvariantCulture));
return new GroupMemberQueryResponse
{
Members = new ArraySegment<GroupMember>(totalMembers, 0, message.Max),
ContinuationToken = totalMembers[message.Max - 1].ConnectionId
};
}
}
}

/// <summary>
/// For user or group related operations, different endpoints might return different results
/// Strategy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
return CreateMessageWriter(serviceMessage).WriteAckableMessageAsync(serviceMessage, cancellationToken);
}

public Task<T> InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new()
{
return CreateMessageWriter(serviceMessage).InvokeAsync<T>(serviceMessage, cancellationToken);
}

public Task StartGetServersPing()
{
return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StartGetServersPing()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public virtual Task HandlePingAsync(PingMessage pingMessage)

public void HandleAck(AckMessage ackMessage)
{
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status);
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status, ackMessage.Payload);
}

public virtual Task WriteAsync(ServiceMessage serviceMessage)
Expand Down Expand Up @@ -249,6 +249,29 @@ public async Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage,
return AckHandler.HandleAckStatus(ackableMessage, status);
}

/// <summary>
/// <see cref="WriteAckableMessageAsync(ServiceMessage, CancellationToken)"/> only checks <see cref="AckMessage.Status"/> as the response,
/// while this method checks <see cref="AckMessage.Payload"/> and deserialize it to <typeparamref name="T"/>.
/// </summary>
public async Task<T> InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new()
{
if (serviceMessage is not IAckableMessage ackableMessage)
{
throw new ArgumentException($"{nameof(serviceMessage)} is not {nameof(IAckableMessage)}");
}

var task = _ackHandler.CreateSingleAck<T>(out var id, null, cancellationToken);
ackableMessage.AckId = id;

// Sending regular messages completes as soon as the data leaves the outbound pipe,
// whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout).
// Therefore sending them over different connections creates a possibility for processing them out of original order.
// By sending both message types over the same connection we ensure that they are sent (and processed) in their original order.
await WriteMessageAsync(serviceMessage);

return await task;
}

public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token)
{
_terminated = true;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public Task<AckStatus> CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul
return info.Task;
}

public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : IMessagePackSerializable, new()
public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : notnull, new()
{
id = NextId();
if (_disposed)
Expand Down Expand Up @@ -210,7 +210,7 @@ public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = nul
_tcs.TrySetResult(status);
}

private sealed class SinglePayloadAck<T> : SingleAckInfo<IMessagePackSerializable> where T : IMessagePackSerializable, new()
private sealed class SinglePayloadAck<T> : SingleAckInfo<T> where T : notnull, new()
{
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Management
{
Expand All @@ -14,5 +17,7 @@ public abstract class GroupManager : IGroupManager
public abstract Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default);

public abstract Task RemoveFromAllGroupsAsync(string connectionId, CancellationToken cancellationToken = default);

public virtual IAsyncEnumerable<GroupMemberQueryResponse> ListConnectionsInGroupAsync(string groupName, int? max, CancellationToken token) => throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Management
{
internal class GroupManagerAdapter : GroupManager
{
private readonly IHubLifetimeManager _lifetimeManager;
private readonly IServiceHubLifetimeManager _lifetimeManager;

public GroupManagerAdapter(IHubLifetimeManager lifetimeManager)
public GroupManagerAdapter(IServiceHubLifetimeManager lifetimeManager)
{
_lifetimeManager = lifetimeManager;
}
Expand All @@ -21,5 +22,7 @@ public GroupManagerAdapter(IHubLifetimeManager lifetimeManager)
public override Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) => _lifetimeManager.RemoveFromGroupAsync(connectionId, groupName, cancellationToken);

public override Task RemoveFromAllGroupsAsync(string connectionId, CancellationToken cancellationToken = default) => _lifetimeManager.RemoveFromAllGroupsAsync(connectionId, cancellationToken);

public override IAsyncEnumerable<GroupMemberQueryResponse> ListConnectionsInGroupAsync(string groupName, int? top, CancellationToken token) => _lifetimeManager.ListConnectionsInGroup(groupName, top, token);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Management
{
Expand All @@ -15,5 +17,7 @@ internal interface IServiceHubLifetimeManager : IHubLifetimeManager, IUserGroupH
Task<bool> UserExistsAsync(string userId, CancellationToken cancellationToken);

Task<bool> GroupExistsAsync(string groupName, CancellationToken cancellationToken);

IAsyncEnumerable<GroupMemberQueryResponse> ListConnectionsInGroup(string groupName, int? top = null, CancellationToken token = default);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Primitives;
using static Microsoft.Azure.SignalR.Constants;

Expand Down Expand Up @@ -310,6 +311,11 @@ await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedRespons
return exists;
}

public IAsyncEnumerable<GroupMemberQueryResponse> ListConnectionsInGroup(string groupName, int? top = null, CancellationToken token = default)
{
throw new NotImplementedException();
}

public Task DisposeAsync() => Task.CompletedTask;

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#nullable enable
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
Expand Down Expand Up @@ -204,6 +206,53 @@ public Task<bool> GroupExistsAsync(string groupName, CancellationToken cancellat
return WriteAckableMessageAsync(message, cancellationToken);
}

public async IAsyncEnumerable<GroupMemberQueryResponse> ListConnectionsInGroup(string groupName, int? top = null, [EnumeratorCancellation] CancellationToken token = default)
{
if (string.IsNullOrWhiteSpace(groupName))
{
throw new ArgumentException($"'{nameof(groupName)}' cannot be null or whitespace.", nameof(groupName));
}

var message = new GroupMemberQueryMessage()
{
GroupName = groupName,
};
if (top != null)
{
message.Max = top.Value;
}
AppendMessageTracingId(message);

var currentCount = 0;
while (true)
{
var response = await ServiceConnectionContainer.InvokeAsync<GroupMemberQueryResponse>(message, token);
currentCount += response.Members.Count;
if (top != null)
{
top = top - currentCount;
if (top <= 0)
{
response.ContinuationToken = null;
}
}
yield return response;
if (response.ContinuationToken == null)
{
yield break;
}
message = new GroupMemberQueryMessage()
{
GroupName = groupName,
ContinuationToken = response.ContinuationToken,
};
if (top != null)
{
message.Max = top.Value - currentCount;
}
}
}

protected override T AppendMessageTracingId<T>(T message)
{
if (_serviceManagerOptions.Value.EnableMessageTracing)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,9 @@ public void Dispose()
{
StopAsync().GetAwaiter().GetResult();
}

Task<T> IServiceMessageWriter.InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken)
{
return _serviceConnection.InvokeAsync<T>(serviceMessage, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,9 @@ public Task CloseClientConnections(CancellationToken token)
{
throw new NotImplementedException();
}

Task<T> IServiceMessageWriter.InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public Task StopAsync()
public Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) => Task.CompletedTask;

public Task CloseClientConnections(CancellationToken token) => Task.CompletedTask;

Task<T> IServiceMessageWriter.InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
}
10 changes: 9 additions & 1 deletion test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System;
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
Expand Down Expand Up @@ -141,5 +144,10 @@ public Task CloseClientConnections(CancellationToken token)
{
throw new NotImplementedException();
}

Task<T> IServiceMessageWriter.InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
}
}

0 comments on commit 8e3670d

Please sign in to comment.