Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement IAsyncEnumerable on CosmosLinqQuery #4355

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,32 @@ public static QueryDefinition ToQueryDefinition<T>(this IQueryable<T> query)
throw new ArgumentException("ToQueryDefinition is only supported on Cosmos LINQ query operations", nameof(query));
}

/// <summary>
/// This extension method returns the query as an asynchronous enumerable.
/// </summary>
/// <typeparam name="T">the type of object to query.</typeparam>
/// <param name="query">the IQueryable{T} to be converted.</param>
/// <returns>An asynchronous enumerable to go through the items.</returns>
/// <example>
/// This example shows how to get the query as an asynchronous enumerable.
///
/// <code language="c#">
/// <![CDATA[
/// IOrderedQueryable<ToDoActivity> linqQueryable = this.Container.GetItemLinqQueryable<ToDoActivity>();
/// IAsyncEnumerable<ToDoActivity> asyncEnumerable = linqQueryable.Where(item => (item.taskNum < 100)).AsAsyncEnumerable();
/// ]]>
/// </code>
/// </example>
public static IAsyncEnumerable<T> AsAsyncEnumerable<T>(this IQueryable<T> query)
{
if (query is CosmosLinqQuery<T> asyncEnumerable)
{
return asyncEnumerable;
}

throw new ArgumentException("AsAsyncEnumerable is only supported on Cosmos LINQ query operations", nameof(query));
}

/// <summary>
/// This extension method gets the FeedIterator from LINQ IQueryable to execute query asynchronously.
/// This will create the fresh new FeedIterator when called.
Expand Down
31 changes: 29 additions & 2 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace Microsoft.Azure.Cosmos.Linq
/// This is the entry point for LINQ query creation/execution, it generate query provider, implements IOrderedQueryable.
/// </summary>
/// <seealso cref="CosmosLinqQueryProvider"/>
internal sealed class CosmosLinqQuery<T> : IDocumentQuery<T>, IOrderedQueryable<T>
internal sealed class CosmosLinqQuery<T> : IDocumentQuery<T>, IOrderedQueryable<T>, IAsyncEnumerable<T>
{
private readonly CosmosLinqQueryProvider queryProvider;
private readonly Guid correlatedActivityId;
Expand Down Expand Up @@ -109,7 +109,7 @@ public IEnumerator<T> GetEnumerator()
" use GetItemQueryIterator to execute asynchronously");
}

FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false, out ScalarOperationKind scalarOperationKind);
using FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
Expand All @@ -128,6 +128,33 @@ public IEnumerator<T> GetEnumerator()
}
}

/// <summary>
/// Retrieves an object that can iterate through the individual results of the query asynchronously.
/// </summary>
/// <remarks>
/// This triggers an asynchronous multi-page load.
/// </remarks>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>IEnumerator</returns>
public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
using FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = await localFeedIterator.ReadNextAsync(cancellationToken);

foreach (T item in response)
{
yield return item;
}
}
}

/// <summary>
/// Synchronous Multi-Page load
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,31 @@ public void LinqQueryToIteratorBlockTest(bool isStreamIterator)
}
}

[TestMethod]
public async Task LinqQueryToAsyncEnumerable()
{
ToDoActivity toDoActivity = ToDoActivity.CreateRandomToDoActivity();
toDoActivity.taskNum = 20;
toDoActivity.id = "minTaskNum";
await this.Container.CreateItemAsync(toDoActivity, new PartitionKey(toDoActivity.pk));
toDoActivity.taskNum = 100;
toDoActivity.id = "maxTaskNum";
await this.Container.CreateItemAsync(toDoActivity, new PartitionKey(toDoActivity.pk));

IAsyncEnumerable<ToDoActivity> query = this.Container.GetItemLinqQueryable<ToDoActivity>()
.OrderBy(p => p.cost)
.AsAsyncEnumerable();

int found = 0;
await foreach (ToDoActivity item in query)
{
Assert.IsNotNull(item);
++found;
}

Assert.AreEqual(2, found);
}

[TestMethod]
[DataRow(false)]
[DataRow(true)]
Expand Down