diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index f3a964b4c..2d6bf1785 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -49,7 +49,7 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } @@ -105,7 +105,7 @@ in session.RegenerateAssistantMessageAsync( await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index b46c92e49..93f5c394d 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -21,7 +21,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( @@ -46,7 +46,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 31b6a7718..4e96acdef 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -31,7 +31,7 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } @@ -92,7 +92,7 @@ in session.RegenerateAssistantMessageAsync( await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 923f78f67..c8bac7f90 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -19,8 +19,8 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); - ChatSession prototypeSession = + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); + ChatSession prototypeSession = await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( new string[] { "User:", "Assistant:" }, @@ -50,10 +50,10 @@ public static async Task Run() while (userInput != "exit") { // Load the session state from the reset state - if(userInput == "reset") + if (userInput == "reset") { session.LoadSession(resetState); - Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}"); + Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.SessionChatHistory)}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session reset."); } @@ -75,10 +75,10 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Provide assistant input: "); - + Console.ForegroundColor = ConsoleColor.Green; string assistantInputOverride = Console.ReadLine() ?? ""; - + await session.AddAndProcessUserMessage(userInputOverride); await session.AddAndProcessAssistantMessage(assistantInputOverride); @@ -90,7 +90,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index de3314130..080e9544e 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -19,7 +19,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); @@ -41,7 +41,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs index fded50e03..50530011c 100644 --- a/LLama.Examples/Examples/LoadAndSaveSession.cs +++ b/LLama.Examples/Examples/LoadAndSaveSession.cs @@ -33,7 +33,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, prompt), + new Message(AuthorRole.User, prompt), new InferenceParams() { Temperature = 0.6f, diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs index f1a0ebcb6..4cb803b81 100644 --- a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs +++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs @@ -10,7 +10,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; public class HistoryTransform : DefaultHistoryTransform { /// - public override string HistoryToText(global::LLama.Common.ChatHistory history) + public string HistoryToText(global::LLama.Common.ChatHistory history) { return base.HistoryToText(history) + $"{AuthorRole.Assistant}: "; } diff --git a/LLama.WebAPI/Controllers/ChatController.cs b/LLama.WebAPI/Controllers/ChatController.cs index 9643ccf80..012a6137e 100644 --- a/LLama.WebAPI/Controllers/ChatController.cs +++ b/LLama.WebAPI/Controllers/ChatController.cs @@ -43,7 +43,7 @@ public async Task SendHistory([FromBody] HistoryInput input, [FromServic { var history = new ChatHistory(); - var messages = input.Messages.Select(m => new ChatHistory.Message(Enum.Parse(m.Role), m.Content)); + var messages = input.Messages.Select(m => new Message(Enum.Parse(m.Role), m.Content)); history.Messages.AddRange(messages); diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index ae2401c90..0691ce18e 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -28,7 +28,7 @@ public StatefulChatService(IConfiguration configuration, ILogger Send(SendMessageInput input) } _logger.LogInformation("Input: {text}", input.Text); var outputs = _session.ChatAsync( - new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), + new Common.Message(Common.AuthorRole.User, input.Text), new Common.InferenceParams() { RepeatPenalty = 1.0f, @@ -74,7 +74,7 @@ public async IAsyncEnumerable SendStream(SendMessageInput input) _logger.LogInformation(input.Text); var outputs = _session.ChatAsync( - new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!) + new Common.Message(Common.AuthorRole.User, input.Text!) , new Common.InferenceParams() { RepeatPenalty = 1.0f, diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs index 3520c29b0..799ac21d5 100644 --- a/LLama.WebAPI/Services/StatelessChatService.cs +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -48,10 +48,9 @@ public async Task SendAsync(ChatHistory history) } public class HistoryTransform : DefaultHistoryTransform { - public override string HistoryToText(ChatHistory history) + public override string HistoryToText(IChatHistory history) { return base.HistoryToText(history) + "\n Assistant:"; } - } } diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index 9644b3e1d..a5ff28cc9 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -1,4 +1,5 @@ using LLama.Common; +using System; using System.Text.Json.Serialization; namespace LLama.Abstractions @@ -14,15 +15,16 @@ public interface IHistoryTransform /// /// The ChatHistory instance /// - string HistoryToText(ChatHistory history); - + string HistoryToText(IChatHistory history); + /// /// Converts plain text to a ChatHistory instance. /// /// The role for the author. /// The chat history as plain text. + /// The type of the chat history. /// The updated history. - ChatHistory TextToHistory(AuthorRole role, string text); + IChatHistory TextToHistory(AuthorRole role, string text, Type type); /// /// Copy the transform. diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 0a5accc5e..db18cbc04 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -52,7 +52,7 @@ public class ChatSession /// /// The chat history for this session. /// - public ChatHistory History { get; private set; } = new(); + public IChatHistory SessionChatHistory { get; private set; } = new ChatHistory(); /// /// The history transform used in this session. @@ -73,17 +73,17 @@ public class ChatSession /// Create a new chat session and preprocess history. /// /// The executor for this session - /// History for this session + /// History for this session /// public static async Task InitializeSessionFromHistoryAsync( - ILLamaExecutor executor, ChatHistory history) + ILLamaExecutor executor, IChatHistory chatHistory) { if (executor is not StatefulExecutorBase statefulExecutor) { throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } - var session = new ChatSession(executor, history); - await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); + var session = new ChatSession(executor, chatHistory); + await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(chatHistory)); return session; } @@ -107,10 +107,10 @@ public ChatSession(ILLamaExecutor executor) /// /// /// - public ChatSession(ILLamaExecutor executor, ChatHistory history) + public ChatSession(ILLamaExecutor executor, IChatHistory history) : this(executor) { - History = history; + SessionChatHistory = history; } /// @@ -165,10 +165,10 @@ public SessionState GetSessionState() { var executorState = ((StatefulExecutorBase)Executor).GetStateData(); return new SessionState( - executorState.PastTokensCount > 0 + executorState.PastTokensCount > 0 ? Executor.Context.GetState() : null, executorState, - History, + SessionChatHistory, InputTransformPipeline, OutputTransform, HistoryTransform); @@ -198,7 +198,7 @@ public void LoadSession(SessionState state, bool loadTransforms = true) { Executor.Context.LoadState(state.ContextState); } - History = new ChatHistory(state.History); + SessionChatHistory = state.SessionChatHistory; if (loadTransforms) { InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); @@ -216,12 +216,12 @@ public void LoadSession(SessionState state, bool loadTransforms = true) /// public void LoadSession(string path, bool loadTransforms = true) { - var state = SessionState.Load(path); + var state = SessionState.Load(path, this.SessionChatHistory.GetType()); // Handle non-polymorphic serialization of executor state if (state.ExecutorState is null) { var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); - ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); + ((StatefulExecutorBase)Executor).LoadState(filename: executorPath); } LoadSession(state, loadTransforms); } @@ -231,10 +231,10 @@ public void LoadSession(string path, bool loadTransforms = true) /// /// /// - public ChatSession AddMessage(ChatHistory.Message message) + public ChatSession AddMessage(Message message) { // If current message is a system message, only allow the history to be empty - if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) + if (message.AuthorRole == AuthorRole.System && SessionChatHistory.Messages.Count > 0) { throw new ArgumentException("Cannot add a system message after another message", nameof(message)); } @@ -243,7 +243,7 @@ public ChatSession AddMessage(ChatHistory.Message message) // or the previous message to be a system message or assistant message. if (message.AuthorRole == AuthorRole.User) { - ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + Message? lastMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) { throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); @@ -254,7 +254,7 @@ public ChatSession AddMessage(ChatHistory.Message message) // the previous message must be a user message. if (message.AuthorRole == AuthorRole.Assistant) { - ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + Message? lastMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastMessage is null || lastMessage.AuthorRole != AuthorRole.User) { @@ -262,7 +262,7 @@ public ChatSession AddMessage(ChatHistory.Message message) } } - History.AddMessage(message.AuthorRole, message.Content); + SessionChatHistory.AddMessage(message.AuthorRole, message.Content); return this; } @@ -272,7 +272,7 @@ public ChatSession AddMessage(ChatHistory.Message message) /// /// public ChatSession AddSystemMessage(string content) - => AddMessage(new ChatHistory.Message(AuthorRole.System, content)); + => AddMessage(new Message(AuthorRole.System, content)); /// /// Add an assistant message to the chat history. @@ -280,7 +280,7 @@ public ChatSession AddSystemMessage(string content) /// /// public ChatSession AddAssistantMessage(string content) - => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + => AddMessage(new Message(AuthorRole.Assistant, content)); /// /// Add a user message to the chat history. @@ -288,7 +288,7 @@ public ChatSession AddAssistantMessage(string content) /// /// public ChatSession AddUserMessage(string content) - => AddMessage(new ChatHistory.Message(AuthorRole.User, content)); + => AddMessage(new Message(AuthorRole.User, content)); /// /// Remove the last message from the chat history. @@ -296,7 +296,7 @@ public ChatSession AddUserMessage(string content) /// public ChatSession RemoveLastMessage() { - History.Messages.RemoveAt(History.Messages.Count - 1); + SessionChatHistory.Messages.RemoveAt(SessionChatHistory.Messages.Count - 1); return this; } @@ -305,7 +305,7 @@ public ChatSession RemoveLastMessage() /// /// /// - public async Task AddAndProcessMessage(ChatHistory.Message message) + public async Task AddAndProcessMessage(Message message) { if (Executor is not StatefulExecutorBase statefulExecutor) { @@ -329,19 +329,19 @@ public async Task AddAndProcessMessage(ChatHistory.Message message) /// Compute KV cache for the system message and add it to the chat history. /// public Task AddAndProcessSystemMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + => AddAndProcessMessage(new Message(AuthorRole.System, content)); /// /// Compute KV cache for the user message and add it to the chat history. /// public Task AddAndProcessUserMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + => AddAndProcessMessage(new Message(AuthorRole.User, content)); /// /// Compute KV cache for the assistant message and add it to the chat history. /// public Task AddAndProcessAssistantMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + => AddAndProcessMessage(new Message(AuthorRole.Assistant, content)); /// /// Replace a user message with a new message and remove all messages after the new message. @@ -351,8 +351,8 @@ public Task AddAndProcessAssistantMessage(string content) /// /// public ChatSession ReplaceUserMessage( - ChatHistory.Message oldMessage, - ChatHistory.Message newMessage) + Message oldMessage, + Message newMessage) { if (oldMessage.AuthorRole != AuthorRole.User) { @@ -364,16 +364,16 @@ public ChatSession ReplaceUserMessage( throw new ArgumentException("New message must be a user message", nameof(newMessage)); } - int index = History.Messages.IndexOf(oldMessage); + int index = SessionChatHistory.Messages.IndexOf(oldMessage); if (index == -1) { throw new ArgumentException("Old message does not exist in history", nameof(oldMessage)); } - History.Messages[index] = newMessage; + SessionChatHistory.Messages[index] = newMessage; // Remove all message after the new message - History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1); + SessionChatHistory.Messages.RemoveRange(index + 1, SessionChatHistory.Messages.Count - index - 1); return this; } @@ -388,7 +388,7 @@ public ChatSession ReplaceUserMessage( /// /// public async IAsyncEnumerable ChatAsync( - ChatHistory.Message message, + Message message, bool applyInputTransformPipeline, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) @@ -424,14 +424,14 @@ public async IAsyncEnumerable ChatAsync( // If the session history was added as part of new chat session history, // convert the complete history includsing system message and manually added history // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. - prompt = HistoryTransform.HistoryToText(History); + prompt = HistoryTransform.HistoryToText(SessionChatHistory); } else { // If the session was restored from a previous session, // convert only the current message to the prompt with the prompt template // specified in the HistoryTransform class implementation that is provided. - ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); + IChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content, SessionChatHistory.GetType()); prompt = HistoryTransform.HistoryToText(singleMessageHistory); } @@ -460,7 +460,7 @@ in ChatAsyncInternal( /// /// public IAsyncEnumerable ChatAsync( - ChatHistory.Message message, + Message message, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { @@ -481,16 +481,16 @@ public IAsyncEnumerable ChatAsync( /// /// public IAsyncEnumerable ChatAsync( - ChatHistory history, + IChatHistory history, bool applyInputTransformPipeline, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { - ChatHistory.Message lastMessage = history.Messages.LastOrDefault() + Message lastMessage = history.Messages.LastOrDefault() ?? throw new ArgumentException("History must contain at least one message", nameof(history)); foreach ( - ChatHistory.Message message + Message message in history.Messages.Take(history.Messages.Count - 1)) { // Apply input transform pipeline @@ -523,7 +523,7 @@ var inputTransform /// /// public IAsyncEnumerable ChatAsync( - ChatHistory history, + IChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { @@ -546,7 +546,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Make sure the last message is an assistant message (reponse from the LLM). - ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault(); + Message? lastAssistantMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastAssistantMessage is null || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) @@ -558,7 +558,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync( RemoveLastMessage(); // Get the last user message. - ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault(); + Message? lastUserMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastUserMessage is null || lastUserMessage.AuthorRole != AuthorRole.User) @@ -629,29 +629,29 @@ public record SessionState /// The history transform used in this session. /// public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); - + /// /// The the chat history messages for this session. /// - public ChatHistory.Message[] History { get; set; } = Array.Empty(); + public IChatHistory SessionChatHistory; /// /// Create a new session state. /// /// /// - /// + /// /// /// /// public SessionState( - State? contextState, ExecutorBaseState executorState, - ChatHistory history, List inputTransformPipeline, + State? contextState, ExecutorBaseState executorState, + IChatHistory chatHistory, List inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform) { ContextState = contextState; ExecutorState = executorState; - History = history.Messages.ToArray(); + SessionChatHistory = chatHistory; InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray(); OutputTransform = outputTransform.Clone(); HistoryTransform = historyTransform.Clone(); @@ -686,7 +686,7 @@ public void Save(string path) File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); - File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); + File.WriteAllText(historyFilepath, ChatHistorySerializer.ToJson(new ChatHistory(SessionChatHistory.Messages.ToArray()))); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); @@ -704,7 +704,7 @@ public void Save(string path) /// /// /// Throws when session state is incorrect - public static SessionState Load(string path) + public static SessionState Load(string path, Type type) { if (string.IsNullOrWhiteSpace(path)) { @@ -717,7 +717,7 @@ public static SessionState Load(string path) } string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); - var contextState = File.Exists(modelStateFilePath) ? + var contextState = File.Exists(modelStateFilePath) ? State.FromByteArray(File.ReadAllBytes(modelStateFilePath)) : null; @@ -726,14 +726,14 @@ public static SessionState Load(string path) string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); string historyJson = File.ReadAllText(historyFilepath); - var history = ChatHistory.FromJson(historyJson) + var history = ChatHistorySerializer.FromJson(historyJson, type) ?? throw new ArgumentException("History file is invalid", nameof(path)); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); ITextTransform[] inputTransforms; try { - inputTransforms = File.Exists(inputTransformFilepath) ? + inputTransforms = File.Exists(inputTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath)) ?? throw new ArgumentException("Input transform file is invalid", nameof(path))) : Array.Empty(); @@ -744,11 +744,10 @@ public static SessionState Load(string path) } string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); - ITextStreamTransform outputTransform; try { - outputTransform = File.Exists(outputTransformFilepath) ? + outputTransform = File.Exists(outputTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath)) ?? throw new ArgumentException("Output transform file is invalid", nameof(path))) : new LLamaTransforms.EmptyTextOutputStreamTransform(); @@ -762,7 +761,7 @@ public static SessionState Load(string path) IHistoryTransform historyTransform; try { - historyTransform = File.Exists(historyTransformFilepath) ? + historyTransform = File.Exists(historyTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) ?? throw new ArgumentException("History transform file is invalid", nameof(path))) : new LLamaTransforms.DefaultHistoryTransform(); diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index c22cc7c06..052d1b538 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; @@ -33,41 +34,60 @@ public enum AuthorRole // copy from semantic-kernel /// - /// The chat history class + /// The message class /// - public class ChatHistory + public class Message { - private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; + /// + /// Role of the message author, e.g. user/assistant/system + /// + [JsonConverter(typeof(JsonStringEnumConverter))] + [JsonPropertyName("author_role")] + public AuthorRole AuthorRole { get; set; } /// - /// Chat message representation + /// Message content /// - public class Message + [JsonPropertyName("content")] + public string Content { get; set; } + + /// + /// Create a new instance + /// + /// Role of message author + /// Message content + public Message(AuthorRole authorRole, string content) { - /// - /// Role of the message author, e.g. user/assistant/system - /// - [JsonConverter(typeof(JsonStringEnumConverter))] - [JsonPropertyName("author_role")] - public AuthorRole AuthorRole { get; set; } - - /// - /// Message content - /// - [JsonPropertyName("content")] - public string Content { get; set; } - - /// - /// Create a new instance - /// - /// Role of message author - /// Message content - public Message(AuthorRole authorRole, string content) - { - this.AuthorRole = authorRole; - this.Content = content; - } + this.AuthorRole = authorRole; + this.Content = content; } + } + + /// + /// Interface for chat history + /// + public interface IChatHistory + { + /// + /// List of messages in the chat + /// + List Messages { get; set; } + + /// + /// Add a message to the chat history + /// + /// Role of the message author + /// Message content + void AddMessage(AuthorRole authorRole, string content); + } + + // copy from semantic-kernel + /// + /// The chat history class + /// + public class ChatHistory : IChatHistory + { + private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; /// /// List of messages in the chat @@ -99,24 +119,33 @@ public void AddMessage(AuthorRole authorRole, string content) { this.Messages.Add(new Message(authorRole, content)); } + } + + /// + /// Serializer for chat history + /// + public class ChatHistorySerializer + { + private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; /// /// Serialize the chat history to JSON /// /// - public string ToJson() + public static string ToJson(IChatHistory chatHistory) { - return JsonSerializer.Serialize(this, _jsonOptions); + return JsonSerializer.Serialize(chatHistory, _jsonOptions); } /// /// Deserialize a chat history from JSON /// /// + /// /// - public static ChatHistory? FromJson(string json) + public static IChatHistory? FromJson(string json, Type type) { - return JsonSerializer.Deserialize(json); + return JsonSerializer.Deserialize(json, type) as IChatHistory; } } } diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index d74d9ddaf..6ed86e777 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -1,5 +1,6 @@ using LLama.Abstractions; using LLama.Common; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -44,7 +45,7 @@ public class DefaultHistoryTransform : IHistoryTransform /// /// /// - public DefaultHistoryTransform(string? userName = null, string? assistantName = null, + public DefaultHistoryTransform(string? userName = null, string? assistantName = null, string? systemName = null, string? unknownName = null, bool isInstructMode = false) { _userName = userName ?? defaultUserName; @@ -61,7 +62,7 @@ public IHistoryTransform Clone() } /// - public virtual string HistoryToText(ChatHistory history) + public virtual string HistoryToText(IChatHistory history) { StringBuilder sb = new(); foreach (var message in history.Messages) @@ -87,9 +88,9 @@ public virtual string HistoryToText(ChatHistory history) } /// - public virtual ChatHistory TextToHistory(AuthorRole role, string text) + public virtual IChatHistory TextToHistory(AuthorRole role, string text, Type type) { - ChatHistory history = new ChatHistory(); + IChatHistory history = (IChatHistory)(Activator.CreateInstance(type) ?? new ChatHistory()); history.AddMessage(role, TrimNamesFromText(text, role)); return history; }