From fa0f42bcf42093890f1925224b1e24af350117dd Mon Sep 17 00:00:00 2001 From: Zhenghui Yan Date: Thu, 16 Jan 2025 15:22:36 +0800 Subject: [PATCH] add n-way merge for async enumerable --- .../Utilities/NWayMergeAsyncEnumerable.cs | 63 +++++++++++ .../Utils/NWayMergeAsyncEnumerableTest.cs | 101 ++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 src/Microsoft.Azure.SignalR.Common/Utilities/NWayMergeAsyncEnumerable.cs create mode 100644 test/Microsoft.Azure.SignalR.Common.Tests/Utils/NWayMergeAsyncEnumerableTest.cs diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/NWayMergeAsyncEnumerable.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/NWayMergeAsyncEnumerable.cs new file mode 100644 index 000000000..45a541bb6 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/NWayMergeAsyncEnumerable.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Microsoft.Azure.SignalR.Common; + +public class NWayMergeAsyncEnumerable : IAsyncEnumerable +{ + private readonly IComparer _comparer; + private readonly IAsyncEnumerable[] _sources; + + public NWayMergeAsyncEnumerable(params IAsyncEnumerable[] sources) + : this(null, sources) + { + } + + public NWayMergeAsyncEnumerable(IComparer comparer, params IAsyncEnumerable[] sources) + { + _comparer = comparer ?? Comparer.Default; + _sources = sources ?? throw new ArgumentNullException(nameof(sources)); + foreach (var source in _sources) + { + if (source == null) + { + throw new ArgumentException("Item cannot be null.", nameof(sources)); + } + } + } + + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + var sources = Array.ConvertAll(_sources, source => source.GetAsyncEnumerator(cancellationToken)); + var hasMore = new bool[sources.Length]; + for (int i = 0; i < sources.Length; i++) + { + hasMore[i] = await sources[i].MoveNextAsync(); + } + while (Array.IndexOf(hasMore, true) != -1) + { + for (int i = 0; i < sources.Length; i++) + { + if (!hasMore[i]) + { + continue; + } + var current = i; + for (int j = 0; j < sources.Length; j++) + { + if (j != current && hasMore[j] && _comparer.Compare(sources[j].Current, sources[current].Current) < 0) + { + current = j; + } + } + yield return sources[current].Current; + hasMore[current] = await sources[current].MoveNextAsync(); + break; + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Utils/NWayMergeAsyncEnumerableTest.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Utils/NWayMergeAsyncEnumerableTest.cs new file mode 100644 index 000000000..27725d972 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Utils/NWayMergeAsyncEnumerableTest.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; + +using Xunit; + +namespace Microsoft.Azure.SignalR.Common.Tests; + +public class NWayMergeAsyncEnumerableTest +{ + [Fact] + public async Task TestOneWayMergeAsyncEnumerable() + { + List source = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + var sources = new IAsyncEnumerable[] + { + ToAsyncEnumerable(source), + }; + var multiWayMergeAsyncEnumerable = new NWayMergeAsyncEnumerable(sources); + var result = await ToListAsync(multiWayMergeAsyncEnumerable); + Assert.Equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], result); + } + + [Fact] + public async Task TestEmptyOneWayMergeAsyncEnumerable() + { + List source = []; + var sources = new IAsyncEnumerable[] + { + ToAsyncEnumerable(source), + }; + var multiWayMergeAsyncEnumerable = new NWayMergeAsyncEnumerable(sources); + var result = await ToListAsync(multiWayMergeAsyncEnumerable); + Assert.Equal([], result); + } + + [Fact] + public async Task TestThreeWayMergeAsyncEnumerable() + { + List source1 = [1, 3, 5, 7, 9]; + List source2 = [2, 4, 6, 8, 10]; + List source3 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + var sources = new IAsyncEnumerable[] + { + ToAsyncEnumerable(source1), + ToAsyncEnumerable(source2), + ToAsyncEnumerable(source3), + }; + var multiWayMergeAsyncEnumerable = new NWayMergeAsyncEnumerable(sources); + var result = await ToListAsync(multiWayMergeAsyncEnumerable); + Assert.Equal([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10], result); + } + + [Fact] + public async Task TestEmptyThreeWayMergeAsyncEnumerable() + { + List source1 = []; + List source2 = []; + List source3 = []; + var sources = new IAsyncEnumerable[] + { + ToAsyncEnumerable(source1), + ToAsyncEnumerable(source2), + ToAsyncEnumerable(source3), + }; + var multiWayMergeAsyncEnumerable = new NWayMergeAsyncEnumerable(sources); + var result = await ToListAsync(multiWayMergeAsyncEnumerable); + Assert.Equal([], result); + } + + [Fact] + public void TestNullCases() + { + IAsyncEnumerable[] args = null; + Assert.Throws(() => new NWayMergeAsyncEnumerable(args)); + IAsyncEnumerable arg = null; + Assert.Throws(() => new NWayMergeAsyncEnumerable(arg)); + } + + private static async IAsyncEnumerable ToAsyncEnumerable(IEnumerable source) + { + foreach (var item in source) + { + await Task.Delay(1); + yield return item; + } + } + + private static async Task> ToListAsync(IAsyncEnumerable source) + { + var list = new List(); + await foreach (var item in source) + { + list.Add(item); + } + return list; + } +}