diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse.cs index 001cfd9469..6749e89395 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse.cs @@ -74,7 +74,6 @@ public AgentRunResponse(ChatResponse response) this.RawRepresentation = response; this.ResponseId = response.ResponseId; this.Usage = response.Usage; - this.ContinuationToken = response.ContinuationToken; } /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseUpdate.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseUpdate.cs index ccf3deae54..6556fb1184 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseUpdate.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseUpdate.cs @@ -75,7 +75,6 @@ public AgentRunResponseUpdate(ChatResponseUpdate chatResponseUpdate) this.RawRepresentation = chatResponseUpdate; this.ResponseId = chatResponseUpdate.ResponseId; this.Role = chatResponseUpdate.Role; - this.ContinuationToken = chatResponseUpdate.ContinuationToken; } /// Gets or sets the name of the author of the response update. diff --git a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs index fe3f73b28b..7a0e629238 100644 --- a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Text.Encodings.Web; using System.Text.Json; using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -68,6 +70,8 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(ChatClientAgentThread.ThreadState))] [JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))] [JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))] + [JsonSerializable(typeof(IReadOnlyCollection))] + [JsonSerializable(typeof(IReadOnlyCollection))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index a5a34d24a9..5952ba1d83 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -162,7 +162,10 @@ static Task GetResponseAsync(IChatClient chatClient, List RunStreamingAsync (ChatClientAgentThread safeThread, ChatOptions? chatOptions, List inputMessagesForChatClient, IList? aiContextProviderMessages) = await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false); - ValidateStreamResumptionAllowed(chatOptions?.ContinuationToken, safeThread); + var continuationToken = ParseContinuationToken(options?.ContinuationToken); + + ValidateStreamResumptionAllowed(continuationToken, safeThread); var chatClient = this.ChatClient; @@ -214,7 +219,7 @@ public override async IAsyncEnumerable RunStreamingAsync this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType); - List responseUpdates = []; + List responseUpdates = GetResponseUpdates(continuationToken); IAsyncEnumerator responseUpdatesEnumerator; @@ -225,7 +230,7 @@ public override async IAsyncEnumerable RunStreamingAsync } catch (Exception ex) { - await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -239,7 +244,7 @@ public override async IAsyncEnumerable RunStreamingAsync } catch (Exception ex) { - await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -251,7 +256,12 @@ public override async IAsyncEnumerable RunStreamingAsync update.AuthorName ??= this.Name; responseUpdates.Add(update); - yield return new(update) { AgentId = this.Id }; + + yield return new(update) + { + AgentId = this.Id, + ContinuationToken = WrapContinuationToken(update.ContinuationToken, GetInputMessages(inputMessages, continuationToken), responseUpdates) + }; } try @@ -260,7 +270,7 @@ public override async IAsyncEnumerable RunStreamingAsync } catch (Exception ex) { - await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } } @@ -272,10 +282,10 @@ public override async IAsyncEnumerable RunStreamingAsync this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. - await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfNewMessagesAsync(safeThread, GetInputMessages(inputMessages, continuationToken).Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfSuccessAsync(safeThread, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -382,6 +392,8 @@ private async Task RunCoreAsync inputMessagesForChatClient, IList? aiContextProviderMessages) = await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false); + ValidatePollingAllowed(chatOptions?.ContinuationToken, safeThread); + var chatClient = this.ChatClient; chatClient = ApplyRunOptionsTransformations(options, chatClient); @@ -583,12 +595,16 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider static ChatOptions? ApplyBackgroundResponsesProperties(ChatOptions? chatOptions, AgentRunOptions? agentRunOptions) { - // If any of the background response properties are set in the run options, we should apply both to the chat options. - if (agentRunOptions?.AllowBackgroundResponses is not null || agentRunOptions?.ContinuationToken is not null) + if (agentRunOptions?.AllowBackgroundResponses is not null) { chatOptions ??= new ChatOptions(); chatOptions.AllowBackgroundResponses = agentRunOptions.AllowBackgroundResponses; - chatOptions.ContinuationToken = agentRunOptions.ContinuationToken; + } + + if ((agentRunOptions?.ContinuationToken ?? chatOptions?.ContinuationToken) is { } continuationToken) + { + chatOptions ??= new ChatOptions(); + chatOptions.ContinuationToken = ParseContinuationToken(continuationToken)!.InnerToken; } return chatOptions; @@ -630,11 +646,6 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider throw new InvalidOperationException("Input messages are not allowed when continuing a background response using a continuation token."); } - if (chatOptions?.ContinuationToken is not null && typedThread.ConversationId is null && typedThread.MessageStore is null) - { - throw new InvalidOperationException("Continuation tokens are not allowed to be used for initial runs."); - } - List inputMessagesForChatClient = []; IList? aiContextProviderMessages = null; @@ -739,26 +750,80 @@ private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread t return Task.CompletedTask; } - private static void ValidateStreamResumptionAllowed(ResponseContinuationToken? continuationToken, ChatClientAgentThread safeThread) + private static void ValidateStreamResumptionAllowed(ChatClientAgentContinuationToken? continuationToken, ChatClientAgentThread safeThread) + { + if (continuationToken is null) + { + return; + } + + // If neither input messages nor response updates are present in the token, + // it means it's an initial run that cannot be resumed. + if (continuationToken.InputMessages is not { Count: > 0 } && continuationToken.ResponseUpdates is not { Count: > 0 }) + { + throw new InvalidOperationException("Continuation tokens are not allowed to be used for initial runs."); + } + } + + private static void ValidatePollingAllowed(ResponseContinuationToken? continuationToken, ChatClientAgentThread safeThread) { if (continuationToken is null) { return; } - // Streaming resumption is only supported with chat history managed by the agent service because, currently, there's no good solution - // to collect updates received in failed runs and pass them to the last successful run so it can store them to the message store. - if (safeThread.ConversationId is null) + // If neither conversation id nor message store are set on the thread, + // it means it's an initial run that cannot be polled. + if (safeThread.ConversationId is null && safeThread.MessageStore is null) + { + throw new InvalidOperationException("Continuation tokens are not allowed to be used for initial runs."); + } + } + + private static ChatClientAgentContinuationToken? ParseContinuationToken(ResponseContinuationToken? continuationToken) + { + return continuationToken is null + ? null + : ChatClientAgentContinuationToken.FromToken(continuationToken); + } + + private static ChatClientAgentContinuationToken? WrapContinuationToken(ResponseContinuationToken? continuationToken, IReadOnlyCollection? inputMessages = null, List? responseUpdates = null) + { + if (continuationToken is null) { - throw new NotSupportedException("Streaming resumption is only supported when chat history is stored and managed by the underlying AI service."); + return null; } - // Similarly, streaming resumption is not supported when a context provider is used because, currently, there's no good solution - // to collect updates received in failed runs and pass them to the last successful run so it can notify the context provider of the updates. - if (safeThread.AIContextProvider is not null) + return new(continuationToken) + { + // Save input messages to the continuation token so they can be added to the thread and + // provided to the context provider in the last successful streaming resumption run. + // That's necessary for scenarios where initial streaming run fails and streaming is resumed later. + InputMessages = inputMessages?.Count > 0 ? inputMessages : null, + + // Save all updates received so far to the continuation token so they can be provided to the + // message store and context provider in the last successful streaming resumption run. + // That's necessary for scenarios where a streaming run fails after some updates were received. + ResponseUpdates = responseUpdates?.Count > 0 ? responseUpdates : null + }; + } + + private static IReadOnlyCollection GetInputMessages(IReadOnlyCollection inputMessages, ChatClientAgentContinuationToken? token) + { + // First, use input messages if provided. + if (inputMessages.Count > 0) { - throw new NotSupportedException("Using context provider with streaming resumption is not supported."); + return inputMessages; } + + // Fallback to messages saved in the continuation token if available. + return token?.InputMessages ?? []; + } + + private static List GetResponseUpdates(ChatClientAgentContinuationToken? token) + { + // Restore any previously received updates from the continuation token. + return token?.ResponseUpdates?.ToList() ?? []; } private string GetLoggingAgentName() => this.Name ?? "UnnamedAgent"; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentContinuationToken.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentContinuationToken.cs new file mode 100644 index 0000000000..0057bdf7da --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentContinuationToken.cs @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Represents a continuation token for ChatClientAgent operations. +/// +internal class ChatClientAgentContinuationToken : ResponseContinuationToken +{ + private const string TokenTypeName = "chatClientAgentContinuationToken"; + private const string TypeDiscriminator = "type"; + + /// + /// Initializes a new instance of the class. + /// + /// A continuation token provided by the underlying . + [JsonConstructor] + internal ChatClientAgentContinuationToken(ResponseContinuationToken innerToken) + { + this.InnerToken = innerToken; + } + + public override ReadOnlyMemory ToBytes() + { + using MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream); + + writer.WriteStartObject(); + + // This property should be the first one written to identify the type during deserialization. + writer.WriteString(TypeDiscriminator, TokenTypeName); + + writer.WriteString("innerToken", JsonSerializer.Serialize(this.InnerToken, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken)))); + + if (this.InputMessages?.Count > 0) + { + writer.WriteString("inputMessages", JsonSerializer.Serialize(this.InputMessages, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IReadOnlyCollection)))); + } + + if (this.ResponseUpdates?.Count > 0) + { + writer.WriteString("responseUpdates", JsonSerializer.Serialize(this.ResponseUpdates, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IReadOnlyCollection)))); + } + + writer.WriteEndObject(); + + writer.Flush(); + + return stream.ToArray(); + } + + /// + /// Create a new instance of from the provided . + /// + /// The token to create the from. + /// A equivalent of the provided . + internal static ChatClientAgentContinuationToken FromToken(ResponseContinuationToken token) + { + if (token is ChatClientAgentContinuationToken chatClientContinuationToken) + { + return chatClientContinuationToken; + } + + ReadOnlyMemory data = token.ToBytes(); + + if (data.Length == 0) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it does not contain any data."); + } + + Utf8JsonReader reader = new(data.Span); + + // Move to the start object token. + _ = reader.Read(); + + // Validate that the token is of this type. + ValidateTokenType(reader, token); + + ResponseContinuationToken? innerToken = null; + IReadOnlyCollection? inputMessages = null; + IReadOnlyCollection? responseUpdates = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + continue; + } + switch (reader.GetString()) + { + case "innerToken": + _ = reader.Read(); + var innerTokenJson = reader.GetString() ?? throw new ArgumentException("No content for innerToken property.", nameof(token)); + innerToken = (ResponseContinuationToken?)JsonSerializer.Deserialize(innerTokenJson, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + break; + case "inputMessages": + _ = reader.Read(); + var innerMessagesJson = reader.GetString() ?? throw new ArgumentException("No content for inputMessages property.", nameof(token)); + inputMessages = (IReadOnlyCollection?)JsonSerializer.Deserialize(innerMessagesJson, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IReadOnlyCollection))); + break; + case "responseUpdates": + _ = reader.Read(); + var responseUpdatesJson = reader.GetString() ?? throw new ArgumentException("No content for responseUpdates property.", nameof(token)); + responseUpdates = (IReadOnlyCollection?)JsonSerializer.Deserialize(responseUpdatesJson, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IReadOnlyCollection))); + break; + default: + break; + } + } + + if (innerToken is null) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it does not contain an inner token."); + } + + return new ChatClientAgentContinuationToken(innerToken) + { + InputMessages = inputMessages, + ResponseUpdates = responseUpdates + }; + } + + private static void ValidateTokenType(Utf8JsonReader reader, ResponseContinuationToken token) + { + try + { + // Move to the first property. + _ = reader.Read(); + + // If the first property name is not "type", or its value does not match this token type name, then we know its not this token type. + if (reader.GetString() != TypeDiscriminator || !reader.Read() || reader.GetString() != TokenTypeName) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it is not of the correct type."); + } + } + catch (JsonException ex) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it could not be parsed.", ex); + } + } + + /// + /// Gets a continuation token provided by the underlying . + /// + internal ResponseContinuationToken InnerToken { get; } + + /// + /// Gets or sets the input messages used for streaming run. + /// + internal IReadOnlyCollection? InputMessages { get; set; } + + /// + /// Gets or sets the response updates received so far. + /// + internal IReadOnlyCollection? ResponseUpdates { get; set; } +} diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs index 913be969c6..9a535cd645 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs @@ -154,7 +154,10 @@ async Task> GetResponseAsync(IChatClient chatClient, List CreateResponse(ChatResponse chatResponse) { - return new ChatClientAgentRunResponse(chatResponse); + return new ChatClientAgentRunResponse(chatResponse) + { + ContinuationToken = WrapContinuationToken(chatResponse.ContinuationToken) + }; } return this.RunCoreAsync(GetResponseAsync, CreateResponse, messages, thread, options, cancellationToken); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs index 981f1e3933..41867a90ba 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs @@ -57,7 +57,6 @@ public void ConstructorWithChatResponseRoundtrips() RawRepresentation = new object(), ResponseId = "responseId", Usage = new UsageDetails(), - ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }), }; AgentRunResponse response = new(chatResponse); @@ -67,7 +66,6 @@ public void ConstructorWithChatResponseRoundtrips() Assert.Equal(chatResponse.ResponseId, response.ResponseId); Assert.Same(chatResponse, response.RawRepresentation as ChatResponse); Assert.Same(chatResponse.Usage, response.Usage); - Assert.Equivalent(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }), response.ContinuationToken); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateTests.cs index 32b7acd673..63bada4edb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateTests.cs @@ -42,7 +42,6 @@ public void ConstructorWithChatResponseUpdateRoundtrips() RawRepresentation = new object(), ResponseId = "responseId", Role = ChatRole.Assistant, - ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }), }; AgentRunResponseUpdate response = new(chatResponseUpdate); @@ -54,7 +53,6 @@ public void ConstructorWithChatResponseUpdateRoundtrips() Assert.Same(chatResponseUpdate, response.RawRepresentation as ChatResponseUpdate); Assert.Equal(chatResponseUpdate.ResponseId, response.ResponseId); Assert.Equal(chatResponseUpdate.Role, response.Role); - Assert.Same(chatResponseUpdate.ContinuationToken, response.ContinuationToken); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentContinuationTokenTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentContinuationTokenTests.cs new file mode 100644 index 0000000000..c8a90f25c5 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentContinuationTokenTests.cs @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.UnitTests.ChatClient; + +public class ChatClientAgentContinuationTokenTests +{ + [Fact] + public void ToBytes_Roundtrip() + { + // Arrange + ResponseContinuationToken originalToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3, 4, 5 }); + + ChatClientAgentContinuationToken chatClientToken = new(originalToken) + { + InputMessages = + [ + new ChatMessage(ChatRole.User, "Hello!"), + new ChatMessage(ChatRole.User, "How are you?") + ], + ResponseUpdates = + [ + new ChatResponseUpdate(ChatRole.Assistant, "I'm fine, thank you."), + new ChatResponseUpdate(ChatRole.Assistant, "How can I assist you today?") + ] + }; + + // Act + ReadOnlyMemory bytes = chatClientToken.ToBytes(); + + ChatClientAgentContinuationToken tokenFromBytes = ChatClientAgentContinuationToken.FromToken(ResponseContinuationToken.FromBytes(bytes)); + + // Assert + Assert.NotNull(tokenFromBytes); + Assert.Equal(chatClientToken.ToBytes().ToArray(), tokenFromBytes.ToBytes().ToArray()); + + // Verify InnerToken + Assert.Equal(chatClientToken.InnerToken.ToBytes().ToArray(), tokenFromBytes.InnerToken.ToBytes().ToArray()); + + // Verify InputMessages + Assert.NotNull(tokenFromBytes.InputMessages); + Assert.Equal(chatClientToken.InputMessages.Count, tokenFromBytes.InputMessages.Count); + for (int i = 0; i < chatClientToken.InputMessages.Count; i++) + { + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Role, tokenFromBytes.InputMessages.ElementAt(i).Role); + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Text, tokenFromBytes.InputMessages.ElementAt(i).Text); + } + + // Verify ResponseUpdates + Assert.NotNull(tokenFromBytes.ResponseUpdates); + Assert.Equal(chatClientToken.ResponseUpdates.Count, tokenFromBytes.ResponseUpdates.Count); + for (int i = 0; i < chatClientToken.ResponseUpdates.Count; i++) + { + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Role, tokenFromBytes.ResponseUpdates.ElementAt(i).Role); + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Text, tokenFromBytes.ResponseUpdates.ElementAt(i).Text); + } + } + + [Fact] + public void Serialization_Roundtrip() + { + // Arrange + ResponseContinuationToken originalToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3, 4, 5 }); + + ChatClientAgentContinuationToken chatClientToken = new(originalToken) + { + InputMessages = + [ + new ChatMessage(ChatRole.User, "Hello!"), + new ChatMessage(ChatRole.User, "How are you?") + ], + ResponseUpdates = + [ + new ChatResponseUpdate(ChatRole.Assistant, "I'm fine, thank you."), + new ChatResponseUpdate(ChatRole.Assistant, "How can I assist you today?") + ] + }; + + // Act + string json = JsonSerializer.Serialize(chatClientToken, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + + ResponseContinuationToken? deserializedToken = (ResponseContinuationToken?)JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + + ChatClientAgentContinuationToken deserializedChatClientToken = ChatClientAgentContinuationToken.FromToken(deserializedToken!); + + // Assert + Assert.NotNull(deserializedChatClientToken); + Assert.Equal(chatClientToken.ToBytes().ToArray(), deserializedChatClientToken.ToBytes().ToArray()); + + // Verify InnerToken + Assert.Equal(chatClientToken.InnerToken.ToBytes().ToArray(), deserializedChatClientToken.InnerToken.ToBytes().ToArray()); + + // Verify InputMessages + Assert.NotNull(deserializedChatClientToken.InputMessages); + Assert.Equal(chatClientToken.InputMessages.Count, deserializedChatClientToken.InputMessages.Count); + for (int i = 0; i < chatClientToken.InputMessages.Count; i++) + { + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Role, deserializedChatClientToken.InputMessages.ElementAt(i).Role); + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Text, deserializedChatClientToken.InputMessages.ElementAt(i).Text); + } + + // Verify ResponseUpdates + Assert.NotNull(deserializedChatClientToken.ResponseUpdates); + Assert.Equal(chatClientToken.ResponseUpdates.Count, deserializedChatClientToken.ResponseUpdates.Count); + for (int i = 0; i < chatClientToken.ResponseUpdates.Count; i++) + { + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Role, deserializedChatClientToken.ResponseUpdates.ElementAt(i).Role); + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Text, deserializedChatClientToken.ResponseUpdates.ElementAt(i).Text); + } + } + + [Fact] + public void FromToken_WithChatClientAgentContinuationToken_ReturnsSameInstance() + { + // Arrange + ChatClientAgentContinuationToken originalToken = new(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3, 4, 5 })); + + // Act + ChatClientAgentContinuationToken fromToken = ChatClientAgentContinuationToken.FromToken(originalToken); + + // Assert + Assert.Same(originalToken, fromToken); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 583a0815ca..cb14fcbe7a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -18,10 +18,10 @@ public class ChatClientAgent_BackgroundResponsesTests [Theory] [InlineData(true)] [InlineData(false)] - public async Task RunAsyncPropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) + public async Task RunAsync_PropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) { // Arrange - var continuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })); ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -63,15 +63,15 @@ public async Task RunAsyncPropagatesBackgroundResponsesPropertiesToChatClientAsy // Assert Assert.NotNull(capturedChatOptions); Assert.True(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken.InnerToken, capturedChatOptions.ContinuationToken); } [Fact] - public async Task RunAsyncPrioritizesBackgroundResponsesPropertiesFromAgentRunOptionsOverOnesFromChatOptionsAsync() + public async Task RunAsync_WhenPropertiesSetInBothLocations_PrioritizesAgentRunOptionsOverChatOptionsAsync() { // Arrange - var continuationToken1 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); - var continuationToken2 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken1 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })); + var continuationToken2 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })); ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -104,13 +104,13 @@ public async Task RunAsyncPrioritizesBackgroundResponsesPropertiesFromAgentRunOp // Assert Assert.NotNull(capturedChatOptions); Assert.False(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken2, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken2.InnerToken, capturedChatOptions.ContinuationToken); } [Theory] [InlineData(true)] [InlineData(false)] - public async Task RunStreamingAsyncPropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) + public async Task RunStreamingAsync_PropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) { // Arrange ChatResponseUpdate[] returnUpdates = @@ -119,7 +119,7 @@ public async Task RunStreamingAsyncPropagatesBackgroundResponsesPropertiesToChat new ChatResponseUpdate(role: ChatRole.Assistant, content: "at?") { ConversationId = "conversation-id" }, ]; - var continuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] }; ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -164,11 +164,11 @@ public async Task RunStreamingAsyncPropagatesBackgroundResponsesPropertiesToChat Assert.NotNull(capturedChatOptions); Assert.True(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken.InnerToken, capturedChatOptions.ContinuationToken); } [Fact] - public async Task RunStreamingAsyncPrioritizesBackgroundResponsesPropertiesFromAgentRunOptionsOverOnesFromChatOptionsAsync() + public async Task RunStreamingAsync_WhenPropertiesSetInBothLocations_PrioritizesAgentRunOptionsOverChatOptionsAsync() { // Arrange ChatResponseUpdate[] returnUpdates = @@ -176,8 +176,8 @@ public async Task RunStreamingAsyncPrioritizesBackgroundResponsesPropertiesFromA new ChatResponseUpdate(role: ChatRole.Assistant, content: "wh") { ConversationId = "conversation-id" }, ]; - var continuationToken1 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); - var continuationToken2 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken1 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] }; + var continuationToken2 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] }; ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -212,11 +212,11 @@ public async Task RunStreamingAsyncPrioritizesBackgroundResponsesPropertiesFromA // Assert Assert.NotNull(capturedChatOptions); Assert.False(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken2, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken2.InnerToken, capturedChatOptions.ContinuationToken); } [Fact] - public async Task RunAsyncPropagatesContinuationTokenFromChatResponseToAgentRunResponseAsync() + public async Task RunAsync_WhenContinuationTokenReceivedFromChatResponse_WrapsContinuationTokenAsync() { // Arrange var continuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); @@ -237,11 +237,11 @@ public async Task RunAsyncPropagatesContinuationTokenFromChatResponseToAgentRunR var response = await agent.RunAsync([new(ChatRole.User, "hi")], thread, options: runOptions); // Assert - Assert.Same(continuationToken, response.ContinuationToken); + Assert.Same(continuationToken, (response.ContinuationToken as ChatClientAgentContinuationToken)?.InnerToken); } [Fact] - public async Task RunStreamingAsyncPropagatesContinuationTokensFromUpdatesAsync() + public async Task RunStreamingAsync_WhenContinuationTokenReceived_WrapsContinuationTokenAsync() { // Arrange var token1 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); @@ -272,19 +272,19 @@ public async Task RunStreamingAsyncPropagatesContinuationTokensFromUpdatesAsync( // Assert Assert.Equal(2, actualUpdates.Count); - Assert.Same(token1, actualUpdates[0].ContinuationToken); + Assert.Same(token1, (actualUpdates[0].ContinuationToken as ChatClientAgentContinuationToken)?.InnerToken); Assert.Null(actualUpdates[1].ContinuationToken); // last update has null token } [Fact] - public async Task RunAsyncThrowsWhenMessagesProvidedWithContinuationTokenAsync() + public async Task RunAsync_WhenMessagesProvidedWithContinuationToken_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); ChatClientAgent agent = new(mockChatClient.Object); - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() { ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) }; IEnumerable inputMessages = [new ChatMessage(ChatRole.User, "test message")]; @@ -301,14 +301,14 @@ public async Task RunAsyncThrowsWhenMessagesProvidedWithContinuationTokenAsync() } [Fact] - public async Task RunStreamingAsyncThrowsWhenMessagesProvidedWithContinuationTokenAsync() + public async Task RunStreamingAsync_WhenMessagesProvidedWithContinuationToken_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); ChatClientAgent agent = new(mockChatClient.Object); - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() { ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) }; IEnumerable inputMessages = [new ChatMessage(ChatRole.User, "test message")]; @@ -331,7 +331,7 @@ await Assert.ThrowsAsync(async () => } [Fact] - public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync() + public async Task RunAsync_WhenContinuationTokenProvided_SkipsThreadMessagePopulationAsync() { // Arrange List capturedMessages = []; @@ -371,7 +371,10 @@ public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync AIContextProvider = mockContextProvider.Object }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + }; // Act await agent.RunAsync([], thread, options: runOptions); @@ -393,7 +396,7 @@ public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync } [Fact] - public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync() + public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsThreadMessagePopulationAsync() { // Arrange List capturedMessages = []; @@ -433,14 +436,15 @@ public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationT AIContextProvider = mockContextProvider.Object }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] } + }; // Act - var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync()); + await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync(); // Assert - Assert.Equal("Streaming resumption is only supported when chat history is stored and managed by the underlying AI service.", exception.Message); - // With continuation token, thread message population should be skipped Assert.Empty(capturedMessages); @@ -456,7 +460,7 @@ public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationT } [Fact] - public async Task RunAsyncThrowsWhenNoThreadProvideForBackgroundResponsesAsync() + public async Task RunAsync_WhenNoThreadProvidedForBackgroundResponses_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); @@ -480,7 +484,7 @@ public async Task RunAsyncThrowsWhenNoThreadProvideForBackgroundResponsesAsync() } [Fact] - public async Task RunStreamingAsyncThrowsWhenNoThreadProvideForBackgroundResponsesAsync() + public async Task RunStreamingAsync_WhenNoThreadProvidedForBackgroundResponses_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); @@ -510,7 +514,7 @@ await Assert.ThrowsAsync(async () => } [Fact] - public async Task RunAsyncThrowsWhenContinuationTokenProvidedForInitialRunAsync() + public async Task RunAsync_WhenContinuationTokenProvidedForInitialRun_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); @@ -520,7 +524,7 @@ public async Task RunAsyncThrowsWhenContinuationTokenProvidedForInitialRunAsync( // Create a new thread with no ConversationId and no MessageStore (initial run state) ChatClientAgentThread thread = new(); - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() { ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) }; // Act & Assert var exception = await Assert.ThrowsAsync(() => agent.RunAsync(thread: thread, options: runOptions)); @@ -536,17 +540,24 @@ public async Task RunAsyncThrowsWhenContinuationTokenProvidedForInitialRunAsync( } [Fact] - public async Task RunStreamingAsyncThrowsWhenContinuationTokenProvidedForInitialRunAsync() + public async Task RunStreamingAsync_WhenContinuationTokenProvidedForInitialRun_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); ChatClientAgent agent = new(mockChatClient.Object); - // Create a new thread with no ConversationId and no MessageStore (initial run state) - ChatClientAgentThread thread = new(); + // Create a new thread with no ConversationId and no MessageStore imitating initial run + ChatClientAgentThread thread = new() + { + ConversationId = null, + MessageStore = null + }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + }; // Act & Assert var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread: thread, options: runOptions).ToListAsync()); @@ -562,74 +573,287 @@ public async Task RunStreamingAsyncThrowsWhenContinuationTokenProvidedForInitial } [Fact] - public async Task RunStreamingAsyncThrowsWhenContinuationTokenUsedWithClientSideManagedChatHistoryAsync() + public async Task RunStreamingAsync_WhenInputMessagesPresentInContinuationToken_ResumesStreamingAsync() { // Arrange + ChatResponseUpdate[] returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "continuation") { ConversationId = "conversation-id" }, + ]; + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); ChatClientAgent agent = new(mockChatClient.Object); - // Create a thread with a MessageStore - ChatClientAgentThread thread = new() + ChatClientAgentThread thread = new() { ConversationId = "conversation-id" }; + + AgentRunOptions runOptions = new() { - MessageStore = new InMemoryChatMessageStore(), // Setting a message store to skip checking the continuation token in the initial run - ConversationId = null, // No conversation ID to simulate client-side managed chat history + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + InputMessages = [new ChatMessage(ChatRole.User, "previous message")] + } }; - // Create run options with a continuation token - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + // Act + var updates = new List(); + await foreach (var update in agent.RunStreamingAsync(thread, options: runOptions)) + { + updates.Add(update); + } - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread: thread, options: runOptions).ToListAsync()); - Assert.Equal("Streaming resumption is only supported when chat history is stored and managed by the underlying AI service.", exception.Message); + // Assert + Assert.Single(updates); - // Verify that the IChatClient was never called due to early validation + // Verify that the IChatClient was called mockChatClient.Verify( c => c.GetStreamingResponseAsync( It.IsAny>(), It.IsAny(), It.IsAny()), - Times.Never); + Times.Once); } [Fact] - public async Task RunStreamingAsyncThrowsWhenContinuationTokenUsedWithAIContextProviderAsync() + public async Task RunStreamingAsync_WhenResponseUpdatesPresentInContinuationToken_ResumesStreamingAsync() { // Arrange + ChatResponseUpdate[] returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "continuation") { ConversationId = "conversation-id" }, + ]; + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); ChatClientAgent agent = new(mockChatClient.Object); - // Create a mock AIContextProvider + ChatClientAgentThread thread = new() { ConversationId = "conversation-id" }; + + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + ResponseUpdates = [new ChatResponseUpdate(ChatRole.Assistant, "previous update")] + } + }; + + // Act + var updates = new List(); + await foreach (var update in agent.RunStreamingAsync(thread, options: runOptions)) + { + updates.Add(update); + } + + // Assert + Assert.Single(updates); + + // Verify that the IChatClient was called + mockChatClient.Verify( + c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny()), + Times.Once); + } + + [Fact] + public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitialRunForContextProviderAndMessageStoreAsync() + { + // Arrange + ChatResponseUpdate[] returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "upon"), + new ChatResponseUpdate(role: ChatRole.Assistant, content: " a"), + new ChatResponseUpdate(role: ChatRole.Assistant, content: " time"), + ]; + + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); + + ChatClientAgent agent = new(mockChatClient.Object); + + List capturedMessagesAddedToStore = []; + var mockMessageStore = new Mock(); + mockMessageStore + .Setup(ms => ms.AddMessagesAsync(It.IsAny>(), It.IsAny())) + .Callback, CancellationToken>((messages, ct) => capturedMessagesAddedToStore.AddRange(messages)) + .Returns(Task.CompletedTask); + + AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); mockContextProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(new AIContext()); + .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Callback((context, ct) => capturedInvokedContext = context) + .Returns(new ValueTask()); + + ChatClientAgentThread thread = new() + { + MessageStore = mockMessageStore.Object, + AIContextProvider = mockContextProvider.Object + }; + + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + ResponseUpdates = [new ChatResponseUpdate(ChatRole.Assistant, "once ")] + } + }; + + // Act + await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync(); + + // Assert + mockMessageStore.Verify(ms => ms.AddMessagesAsync(It.IsAny>(), It.IsAny()), Times.Once); + Assert.Single(capturedMessagesAddedToStore); + Assert.Contains("once upon a time", capturedMessagesAddedToStore[0].Text); + + mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + Assert.NotNull(capturedInvokedContext?.ResponseMessages); + Assert.Single(capturedInvokedContext.ResponseMessages); + Assert.Contains("once upon a time", capturedInvokedContext.ResponseMessages.ElementAt(0).Text); + } + + [Fact] + public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromInitialRunForContextProviderAndMessageStoreAsync() + { + // Arrange + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(Array.Empty())); + + ChatClientAgent agent = new(mockChatClient.Object); + + List capturedMessagesAddedToStore = []; + var mockMessageStore = new Mock(); + mockMessageStore + .Setup(ms => ms.AddMessagesAsync(It.IsAny>(), It.IsAny())) + .Callback, CancellationToken>((messages, ct) => capturedMessagesAddedToStore.AddRange(messages)) + .Returns(Task.CompletedTask); + + AIContextProvider.InvokedContext? capturedInvokedContext = null; + var mockContextProvider = new Mock(); mockContextProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - // Create a thread with an AIContextProvider and conversation ID to simulate non-initial run ChatClientAgentThread thread = new() { - ConversationId = "existing-conversation-id", + MessageStore = mockMessageStore.Object, AIContextProvider = mockContextProvider.Object }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + InputMessages = [new ChatMessage(ChatRole.User, "Tell me a story")], + } + }; - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread: thread, options: runOptions).ToListAsync()); + // Act + await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync(); - Assert.Equal("Using context provider with streaming resumption is not supported.", exception.Message); + // Assert + mockMessageStore.Verify(ms => ms.AddMessagesAsync(It.IsAny>(), It.IsAny()), Times.Once); + Assert.Single(capturedMessagesAddedToStore); + Assert.Contains("Tell me a story", capturedMessagesAddedToStore[0].Text); + + mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + Assert.NotNull(capturedInvokedContext?.RequestMessages); + Assert.Single(capturedInvokedContext.RequestMessages); + Assert.Contains("Tell me a story", capturedInvokedContext.RequestMessages.ElementAt(0).Text); + } - // Verify that the IChatClient was never called due to early validation - mockChatClient.Verify( - c => c.GetStreamingResponseAsync( + [Fact] + public async Task RunStreamingAsync_WhenResumingStreaming_SavesInputMessagesAndUpdatesInContinuationTokenAsync() + { + // Arrange + List returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "Once") { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + new ChatResponseUpdate(role: ChatRole.Assistant, content: " upon") { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + new ChatResponseUpdate(role: ChatRole.Assistant, content: " a") { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + new ChatResponseUpdate(role: ChatRole.Assistant, content: " time"){ ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + ]; + + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( It.IsAny>(), It.IsAny(), - It.IsAny()), - Times.Never); + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); + + ChatClientAgent agent = new(mockChatClient.Object); + + ChatClientAgentThread thread = new() { }; + + List capturedContinuationTokens = []; + + ChatMessage userMessage = new(ChatRole.User, "Tell me a story"); + + // Act + + // Do the initial run + await foreach (var update in agent.RunStreamingAsync(userMessage, thread)) + { + capturedContinuationTokens.Add(Assert.IsType(update.ContinuationToken)); + break; + } + + // Now resume the run using the captured continuation token + returnUpdates.RemoveAt(0); // remove the first mock update as it was already processed + var options = new AgentRunOptions { ContinuationToken = capturedContinuationTokens[0] }; + await foreach (var update in agent.RunStreamingAsync(thread, options: options)) + { + capturedContinuationTokens.Add(Assert.IsType(update.ContinuationToken)); + } + + // Assert + Assert.Equal(4, capturedContinuationTokens.Count); + + // Verify that the first continuation token has the initial input and first update + Assert.NotNull(capturedContinuationTokens[0].InputMessages); + Assert.Single(capturedContinuationTokens[0].InputMessages!); + Assert.Equal("Tell me a story", capturedContinuationTokens[0].InputMessages!.Last().Text); + Assert.NotNull(capturedContinuationTokens[0].ResponseUpdates); + Assert.Single(capturedContinuationTokens[0].ResponseUpdates!); + Assert.Equal("Once", capturedContinuationTokens[0].ResponseUpdates!.Last().Text); + + // Verify the last continuation token has the input and all updates + var lastToken = capturedContinuationTokens[^1]; + Assert.NotNull(lastToken.InputMessages); + Assert.Single(lastToken.InputMessages!); + Assert.Equal("Tell me a story", lastToken.InputMessages!.Last().Text); + Assert.NotNull(lastToken.ResponseUpdates); + Assert.Equal(4, lastToken.ResponseUpdates!.Count); + Assert.Equal("Once", lastToken.ResponseUpdates!.ElementAt(0).Text); + Assert.Equal(" upon", lastToken.ResponseUpdates!.ElementAt(1).Text); + Assert.Equal(" a", lastToken.ResponseUpdates!.ElementAt(2).Text); + Assert.Equal(" time", lastToken.ResponseUpdates!.ElementAt(3).Text); } private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values)