Skip to content

Commit

Permalink
Add cancellation token parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Sindo committed Jan 22, 2025
1 parent e4b27f5 commit 9878558
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand Down Expand Up @@ -169,8 +169,8 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
}
}

public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null)
public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
{
throw new NotSupportedException();
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading;

using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR;
Expand All @@ -11,5 +13,5 @@ namespace Microsoft.Azure.SignalR;
/// </summary>
internal interface IPresenceManager
{
IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null);
IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default);
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -172,7 +174,7 @@ private async Task WriteSingleEndpointMessageAsync(HubServiceEndpoint endpoint,
}
}

public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null)
public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, [EnumeratorCancellation] CancellationToken token = default)
{
if (TargetEndpoints.Length == 0)
{
Expand All @@ -188,7 +190,7 @@ public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string gr
IAsyncEnumerable<GroupMember> enumerable;
try
{
enumerable = endpoint.ConnectionContainer.ListConnectionsInGroupAsync(groupName, top, tracingId);
enumerable = endpoint.ConnectionContainer.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
}
catch (ServiceConnectionNotActiveException)
{
Expand Down Expand Up @@ -249,4 +251,4 @@ public static void FailedWritingMessageToEndpoint(ILogger logger, string message
_failedWritingMessageToEndpoint(logger, messageType, tracingId, endpoint, null);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -151,11 +152,11 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
}


public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null)
public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
{
var targetEndpoints = _routerEndpoints.needRouter ? _router.GetEndpointsForGroup(groupName, _routerEndpoints.endpoints) : _routerEndpoints.endpoints;
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _loggerFactory);
return messageWriter.ListConnectionsInGroupAsync(groupName, top);
return messageWriter.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
}

public Task StartGetServersPing()
Expand Down Expand Up @@ -517,4 +518,4 @@ public static void FailedRemovingConnectionForEndpoint(ILogger logger, string en
_failedRemovingConnectionForEndpoint(logger, endpoint, ex);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand All @@ -8,6 +8,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -249,7 +250,7 @@ public async Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage,
return AckHandler.HandleAckStatus(ackableMessage, status);
}

public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null)
public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, [EnumeratorCancellation] CancellationToken token = default)
{
if (string.IsNullOrWhiteSpace(groupName))
{
Expand All @@ -262,7 +263,7 @@ public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string gr
var message = new GroupMemberQueryMessage() { GroupName = groupName, Top = top, TracingId = tracingId };
do
{
var response = await InvokeAsync<GroupMemberQueryResponse>(message);
var response = await InvokeAsync<GroupMemberQueryResponse>(message, token);
foreach (var member in response.Members)
{
yield return member;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
Expand Down Expand Up @@ -31,7 +31,7 @@ public async Task ListConnectionsInGroup(int? top, int resultCount, params int?[
);
var containerMock = new Mock<IServiceConnectionContainer>();
containerMocks.Add(containerMock);
containerMock.Setup(c => c.ListConnectionsInGroupAsync(It.IsAny<string>(), It.IsAny<int?>(), null))
containerMock.Setup(c => c.ListConnectionsInGroupAsync(It.IsAny<string>(), It.IsAny<int?>(), null, default))
.Returns(resultFromConnectioContainer);
endpoint.ConnectionContainer = containerMock.Object;
targetEndpoints.Add(endpoint);
Expand All @@ -45,7 +45,7 @@ public async Task ListConnectionsInGroup(int? top, int resultCount, params int?[
Assert.Equal(resultCount, resultMembers.Count);
for (var i = 0; i < expectedTopsInInvocations.Length; i++)
{
containerMocks[i].Verify(c => c.ListConnectionsInGroupAsync("group", expectedTopsInInvocations[i], null), Times.Once());
containerMocks[i].Verify(c => c.ListConnectionsInGroupAsync("group", expectedTopsInInvocations[i], null, default), Times.Once());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Tests.Common;
Expand Down Expand Up @@ -114,4 +115,9 @@ public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupNam
{
throw new NotImplementedException();
}

public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
{
throw new NotImplementedException();
}
}

0 comments on commit 9878558

Please sign in to comment.