diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/IAgent.cs b/dotnet/src/Microsoft.AutoGen/Contracts/IAgent.cs index 7595594d345c..81bdd3f327c3 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/IAgent.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/IAgent.cs @@ -6,7 +6,7 @@ namespace Microsoft.AutoGen.Contracts; /// /// Represents an agent within the runtime that can process messages, maintain state, and be closed when no longer needed. /// -public interface IAgent : ISaveState +public interface IAgent : ISaveState { /// /// Gets the unique identifier of the agent. diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs index 59bc3180fb70..1ca767b0f827 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs @@ -8,7 +8,7 @@ namespace Microsoft.AutoGen.Contracts; /// /// Defines the runtime environment for agents, managing message sending, subscriptions, agent resolution, and state persistence. /// -public interface IAgentRuntime : ISaveState +public interface IAgentRuntime : ISaveState { /// /// Sends a message to an agent and gets a response. diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs index 3a4c556063f8..b6fcdc67bda8 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs @@ -9,8 +9,7 @@ namespace Microsoft.AutoGen.Contracts; /// Defines a contract for saving and loading the state of an object. /// The state must be JSON serializable. /// -/// The type of the object implementing this interface. -public interface ISaveState +public interface ISaveState { /// /// Saves the current state of the object. @@ -20,7 +19,10 @@ public interface ISaveState /// containing the saved state. The structure of the state is implementation-defined /// but must be JSON serializable. /// - public ValueTask SaveStateAsync(); + public virtual ValueTask SaveStateAsync() + { + return new ValueTask(JsonDocument.Parse("{}").RootElement); + } /// /// Loads a previously saved state into the object. @@ -30,6 +32,9 @@ public interface ISaveState /// is implementation-defined but must be JSON serializable. /// /// A task representing the asynchronous operation. - public ValueTask LoadStateAsync(JsonElement state); + public virtual ValueTask LoadStateAsync(JsonElement state) + { + return ValueTask.CompletedTask; + } } diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/ISaveStateMixin.cs b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveStateMixin.cs new file mode 100644 index 000000000000..f5828a6c6868 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveStateMixin.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ISaveStateMixin.cs + +using System.Text.Json; + +namespace Microsoft.AutoGen.Contracts; + +/// +/// Defines a contract for saving and loading the state of an object. +/// The state must be JSON serializable. +/// +/// The type of the object implementing this interface. +/// +public interface ISaveStateMixin : ISaveState +{ + /// + /// Saves the current state of the object. + /// + /// + /// A task representing the asynchronous operation, returning a dictionary + /// containing the saved state. The structure of the state is implementation-defined + /// but must be JSON serializable. + /// + async ValueTask ISaveState.SaveStateAsync() + { + var state = await SaveStateImpl(); + return JsonSerializer.SerializeToElement(state); + } + + /// + /// Loads a previously saved state into the object. + /// + /// + /// A dictionary representing the saved state. The structure of the state + /// is implementation-defined but must be JSON serializable. + /// + /// A task representing the asynchronous operation. + ValueTask ISaveState.LoadStateAsync(JsonElement state) + { + // Throw if failed to deserialize + var stateObject = JsonSerializer.Deserialize(state) ?? throw new InvalidDataException(); + return LoadStateImpl(stateObject); + } + + protected ValueTask SaveStateImpl(); + + protected ValueTask LoadStateImpl(T state); +} diff --git a/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs b/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs index 457c2cc6adf3..221ecfde9c91 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Reflection; -using System.Text.Json; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Logging; @@ -93,15 +92,6 @@ private Dictionary ReflectInvokers() return null; } - public virtual ValueTask SaveStateAsync() - { - return ValueTask.FromResult(JsonDocument.Parse("{}").RootElement); - } - public virtual ValueTask LoadStateAsync(JsonElement state) - { - return ValueTask.CompletedTask; - } - public ValueTask SendMessageAsync(object message, AgentId recepient, string? messageId = null, CancellationToken cancellationToken = default) { return this.Runtime.SendMessageAsync(message, recepient, sender: this.Id, messageId: messageId, cancellationToken: cancellationToken); diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs index 805fbc87102b..4952c4e7f56c 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentTests.cs +using System.Text.Json; using FluentAssertions; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Logging; @@ -146,4 +147,61 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => Assert.True(agent.ReceivedItems.Count == 1); } + + public class AgentState + { + public required string Name { get; set; } + public required int Value { get; set; } + } + + public class StateAgent(AgentId id, + IAgentRuntime runtime, + AgentState state, + Logger? logger = null) : BaseAgent(id, runtime, "Test Agent", logger), + ISaveStateMixin + + { + ValueTask ISaveStateMixin.SaveStateImpl() + { + return ValueTask.FromResult(_state); + } + + ValueTask ISaveStateMixin.LoadStateImpl(AgentState state) + { + _state = state; + return ValueTask.CompletedTask; + } + + private AgentState _state = state; + } + + [Fact] + public async Task StateMixinTest() + { + var runtime = new InProcessRuntime(); + await runtime.StartAsync(); + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + return ValueTask.FromResult(new StateAgent(id, runtime, new AgentState { Name = "TestAgent", Value = 5 })); + }); + + var agentId = new AgentId("MyAgent", "default"); + + // Get the state + var state1 = await runtime.SaveAgentStateAsync(agentId); + + Assert.Equal("TestAgent", state1.GetProperty("Name").GetString()); + Assert.Equal(5, state1.GetProperty("Value").GetInt32()); + + // Change the state + var newState = new AgentState { Name = "TestAgent", Value = 100 }; + var jsonState = JsonSerializer.SerializeToElement(newState); + await runtime.LoadAgentStateAsync(agentId, jsonState); + + // Get the state + var state2 = await runtime.SaveAgentStateAsync(agentId); + + Assert.Equal("TestAgent", state2.GetProperty("Name").GetString()); + Assert.Equal(100, state2.GetProperty("Value").GetInt32()); + } } diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs index 6c195194617e..712c7c56bea9 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs @@ -75,7 +75,7 @@ public SubscribedAgent(AgentId id, } [TypeSubscription("TestTopic")] -public class SubscribedSaveLoadAgent : TestAgent +public class SubscribedSaveLoadAgent : TestAgent, ISaveState { public SubscribedSaveLoadAgent(AgentId id, IAgentRuntime runtime, @@ -83,13 +83,13 @@ public SubscribedSaveLoadAgent(AgentId id, { } - public override ValueTask SaveStateAsync() + ValueTask ISaveState.SaveStateAsync() { var jsonDoc = JsonSerializer.SerializeToElement(_receivedMessages); return ValueTask.FromResult(jsonDoc); } - public override ValueTask LoadStateAsync(JsonElement state) + ValueTask ISaveState.LoadStateAsync(JsonElement state) { _receivedMessages = JsonSerializer.Deserialize>(state) ?? throw new InvalidOperationException("Failed to deserialize state"); return ValueTask.CompletedTask;