diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs
index f3a964b4c..2d6bf1785 100644
--- a/LLama.Examples/Examples/ChatChineseGB2312.cs
+++ b/LLama.Examples/Examples/ChatChineseGB2312.cs
@@ -49,7 +49,7 @@ public static async Task Run()
else
{
var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json");
- ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
+ IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();
session = new ChatSession(executor, chatHistory);
}
@@ -105,7 +105,7 @@ in session.RegenerateAssistantMessageAsync(
await foreach (
var text
in session.ChatAsync(
- new ChatHistory.Message(AuthorRole.User, userInput),
+ new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs
index b46c92e49..93f5c394d 100644
--- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs
+++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs
@@ -21,7 +21,7 @@ public static async Task Run()
var executor = new InteractiveExecutor(context);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
- ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
+ IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();
ChatSession session = new(executor, chatHistory);
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
@@ -46,7 +46,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
- new ChatHistory.Message(AuthorRole.User, userInput),
+ new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs
index 31b6a7718..4e96acdef 100644
--- a/LLama.Examples/Examples/ChatSessionWithHistory.cs
+++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs
@@ -31,7 +31,7 @@ public static async Task Run()
else
{
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
- ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
+ IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();
session = new ChatSession(executor, chatHistory);
}
@@ -92,7 +92,7 @@ in session.RegenerateAssistantMessageAsync(
await foreach (
var text
in session.ChatAsync(
- new ChatHistory.Message(AuthorRole.User, userInput),
+ new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs
index 923f78f67..c8bac7f90 100644
--- a/LLama.Examples/Examples/ChatSessionWithRestart.cs
+++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs
@@ -19,8 +19,8 @@ public static async Task Run()
var executor = new InteractiveExecutor(context);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
- ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
- ChatSession prototypeSession =
+ IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();
+ ChatSession prototypeSession =
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);
prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:" },
@@ -50,10 +50,10 @@ public static async Task Run()
while (userInput != "exit")
{
// Load the session state from the reset state
- if(userInput == "reset")
+ if (userInput == "reset")
{
session.LoadSession(resetState);
- Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}");
+ Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.SessionChatHistory)}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session reset.");
}
@@ -75,10 +75,10 @@ public static async Task Run()
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Provide assistant input: ");
-
+
Console.ForegroundColor = ConsoleColor.Green;
string assistantInputOverride = Console.ReadLine() ?? "";
-
+
await session.AddAndProcessUserMessage(userInputOverride);
await session.AddAndProcessAssistantMessage(assistantInputOverride);
@@ -90,7 +90,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
- new ChatHistory.Message(AuthorRole.User, userInput),
+ new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs
index de3314130..080e9544e 100644
--- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs
+++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs
@@ -19,7 +19,7 @@ public static async Task Run()
var executor = new InteractiveExecutor(context);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
- ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
+ IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();
ChatSession session = new(executor, chatHistory);
@@ -41,7 +41,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
- new ChatHistory.Message(AuthorRole.User, userInput),
+ new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs
index fded50e03..50530011c 100644
--- a/LLama.Examples/Examples/LoadAndSaveSession.cs
+++ b/LLama.Examples/Examples/LoadAndSaveSession.cs
@@ -33,7 +33,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
- new ChatHistory.Message(AuthorRole.User, prompt),
+ new Message(AuthorRole.User, prompt),
new InferenceParams()
{
Temperature = 0.6f,
diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
index f1a0ebcb6..4cb803b81 100644
--- a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
+++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
@@ -10,7 +10,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
public class HistoryTransform : DefaultHistoryTransform
{
///
- public override string HistoryToText(global::LLama.Common.ChatHistory history)
+ public string HistoryToText(global::LLama.Common.ChatHistory history)
{
return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
}
diff --git a/LLama.WebAPI/Controllers/ChatController.cs b/LLama.WebAPI/Controllers/ChatController.cs
index 9643ccf80..012a6137e 100644
--- a/LLama.WebAPI/Controllers/ChatController.cs
+++ b/LLama.WebAPI/Controllers/ChatController.cs
@@ -43,7 +43,7 @@ public async Task SendHistory([FromBody] HistoryInput input, [FromServic
{
var history = new ChatHistory();
- var messages = input.Messages.Select(m => new ChatHistory.Message(Enum.Parse(m.Role), m.Content));
+ var messages = input.Messages.Select(m => new Message(Enum.Parse(m.Role), m.Content));
history.Messages.AddRange(messages);
diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs
index ae2401c90..0691ce18e 100644
--- a/LLama.WebAPI/Services/StatefulChatService.cs
+++ b/LLama.WebAPI/Services/StatefulChatService.cs
@@ -28,7 +28,7 @@ public StatefulChatService(IConfiguration configuration, ILogger Send(SendMessageInput input)
}
_logger.LogInformation("Input: {text}", input.Text);
var outputs = _session.ChatAsync(
- new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
+ new Common.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
@@ -74,7 +74,7 @@ public async IAsyncEnumerable SendStream(SendMessageInput input)
_logger.LogInformation(input.Text);
var outputs = _session.ChatAsync(
- new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!)
+ new Common.Message(Common.AuthorRole.User, input.Text!)
, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs
index 3520c29b0..799ac21d5 100644
--- a/LLama.WebAPI/Services/StatelessChatService.cs
+++ b/LLama.WebAPI/Services/StatelessChatService.cs
@@ -48,10 +48,9 @@ public async Task SendAsync(ChatHistory history)
}
public class HistoryTransform : DefaultHistoryTransform
{
- public override string HistoryToText(ChatHistory history)
+ public override string HistoryToText(IChatHistory history)
{
return base.HistoryToText(history) + "\n Assistant:";
}
-
}
}
diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs
index 9644b3e1d..a5ff28cc9 100644
--- a/LLama/Abstractions/IHistoryTransform.cs
+++ b/LLama/Abstractions/IHistoryTransform.cs
@@ -1,4 +1,5 @@
using LLama.Common;
+using System;
using System.Text.Json.Serialization;
namespace LLama.Abstractions
@@ -14,15 +15,16 @@ public interface IHistoryTransform
///
/// The ChatHistory instance
///
- string HistoryToText(ChatHistory history);
-
+ string HistoryToText(IChatHistory history);
+
///
/// Converts plain text to a ChatHistory instance.
///
/// The role for the author.
/// The chat history as plain text.
+ /// The type of the chat history.
/// The updated history.
- ChatHistory TextToHistory(AuthorRole role, string text);
+ IChatHistory TextToHistory(AuthorRole role, string text, Type type);
///
/// Copy the transform.
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 0a5accc5e..db18cbc04 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -52,7 +52,7 @@ public class ChatSession
///
/// The chat history for this session.
///
- public ChatHistory History { get; private set; } = new();
+ public IChatHistory SessionChatHistory { get; private set; } = new ChatHistory();
///
/// The history transform used in this session.
@@ -73,17 +73,17 @@ public class ChatSession
/// Create a new chat session and preprocess history.
///
/// The executor for this session
- /// History for this session
+ /// History for this session
///
public static async Task InitializeSessionFromHistoryAsync(
- ILLamaExecutor executor, ChatHistory history)
+ ILLamaExecutor executor, IChatHistory chatHistory)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
}
- var session = new ChatSession(executor, history);
- await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
+ var session = new ChatSession(executor, chatHistory);
+ await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(chatHistory));
return session;
}
@@ -107,10 +107,10 @@ public ChatSession(ILLamaExecutor executor)
///
///
///
- public ChatSession(ILLamaExecutor executor, ChatHistory history)
+ public ChatSession(ILLamaExecutor executor, IChatHistory history)
: this(executor)
{
- History = history;
+ SessionChatHistory = history;
}
///
@@ -165,10 +165,10 @@ public SessionState GetSessionState()
{
var executorState = ((StatefulExecutorBase)Executor).GetStateData();
return new SessionState(
- executorState.PastTokensCount > 0
+ executorState.PastTokensCount > 0
? Executor.Context.GetState() : null,
executorState,
- History,
+ SessionChatHistory,
InputTransformPipeline,
OutputTransform,
HistoryTransform);
@@ -198,7 +198,7 @@ public void LoadSession(SessionState state, bool loadTransforms = true)
{
Executor.Context.LoadState(state.ContextState);
}
- History = new ChatHistory(state.History);
+ SessionChatHistory = state.SessionChatHistory;
if (loadTransforms)
{
InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
@@ -216,12 +216,12 @@ public void LoadSession(SessionState state, bool loadTransforms = true)
///
public void LoadSession(string path, bool loadTransforms = true)
{
- var state = SessionState.Load(path);
+ var state = SessionState.Load(path, this.SessionChatHistory.GetType());
// Handle non-polymorphic serialization of executor state
if (state.ExecutorState is null)
{
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
- ((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
+ ((StatefulExecutorBase)Executor).LoadState(filename: executorPath);
}
LoadSession(state, loadTransforms);
}
@@ -231,10 +231,10 @@ public void LoadSession(string path, bool loadTransforms = true)
///
///
///
- public ChatSession AddMessage(ChatHistory.Message message)
+ public ChatSession AddMessage(Message message)
{
// If current message is a system message, only allow the history to be empty
- if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
+ if (message.AuthorRole == AuthorRole.System && SessionChatHistory.Messages.Count > 0)
{
throw new ArgumentException("Cannot add a system message after another message", nameof(message));
}
@@ -243,7 +243,7 @@ public ChatSession AddMessage(ChatHistory.Message message)
// or the previous message to be a system message or assistant message.
if (message.AuthorRole == AuthorRole.User)
{
- ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
+ Message? lastMessage = SessionChatHistory.Messages.LastOrDefault();
if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
{
throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
@@ -254,7 +254,7 @@ public ChatSession AddMessage(ChatHistory.Message message)
// the previous message must be a user message.
if (message.AuthorRole == AuthorRole.Assistant)
{
- ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
+ Message? lastMessage = SessionChatHistory.Messages.LastOrDefault();
if (lastMessage is null
|| lastMessage.AuthorRole != AuthorRole.User)
{
@@ -262,7 +262,7 @@ public ChatSession AddMessage(ChatHistory.Message message)
}
}
- History.AddMessage(message.AuthorRole, message.Content);
+ SessionChatHistory.AddMessage(message.AuthorRole, message.Content);
return this;
}
@@ -272,7 +272,7 @@ public ChatSession AddMessage(ChatHistory.Message message)
///
///
public ChatSession AddSystemMessage(string content)
- => AddMessage(new ChatHistory.Message(AuthorRole.System, content));
+ => AddMessage(new Message(AuthorRole.System, content));
///
/// Add an assistant message to the chat history.
@@ -280,7 +280,7 @@ public ChatSession AddSystemMessage(string content)
///
///
public ChatSession AddAssistantMessage(string content)
- => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
+ => AddMessage(new Message(AuthorRole.Assistant, content));
///
/// Add a user message to the chat history.
@@ -288,7 +288,7 @@ public ChatSession AddAssistantMessage(string content)
///
///
public ChatSession AddUserMessage(string content)
- => AddMessage(new ChatHistory.Message(AuthorRole.User, content));
+ => AddMessage(new Message(AuthorRole.User, content));
///
/// Remove the last message from the chat history.
@@ -296,7 +296,7 @@ public ChatSession AddUserMessage(string content)
///
public ChatSession RemoveLastMessage()
{
- History.Messages.RemoveAt(History.Messages.Count - 1);
+ SessionChatHistory.Messages.RemoveAt(SessionChatHistory.Messages.Count - 1);
return this;
}
@@ -305,7 +305,7 @@ public ChatSession RemoveLastMessage()
///
///
///
- public async Task AddAndProcessMessage(ChatHistory.Message message)
+ public async Task AddAndProcessMessage(Message message)
{
if (Executor is not StatefulExecutorBase statefulExecutor)
{
@@ -329,19 +329,19 @@ public async Task AddAndProcessMessage(ChatHistory.Message message)
/// Compute KV cache for the system message and add it to the chat history.
///
public Task AddAndProcessSystemMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
+ => AddAndProcessMessage(new Message(AuthorRole.System, content));
///
/// Compute KV cache for the user message and add it to the chat history.
///
public Task AddAndProcessUserMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
+ => AddAndProcessMessage(new Message(AuthorRole.User, content));
///
/// Compute KV cache for the assistant message and add it to the chat history.
///
public Task AddAndProcessAssistantMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
+ => AddAndProcessMessage(new Message(AuthorRole.Assistant, content));
///
/// Replace a user message with a new message and remove all messages after the new message.
@@ -351,8 +351,8 @@ public Task AddAndProcessAssistantMessage(string content)
///
///
public ChatSession ReplaceUserMessage(
- ChatHistory.Message oldMessage,
- ChatHistory.Message newMessage)
+ Message oldMessage,
+ Message newMessage)
{
if (oldMessage.AuthorRole != AuthorRole.User)
{
@@ -364,16 +364,16 @@ public ChatSession ReplaceUserMessage(
throw new ArgumentException("New message must be a user message", nameof(newMessage));
}
- int index = History.Messages.IndexOf(oldMessage);
+ int index = SessionChatHistory.Messages.IndexOf(oldMessage);
if (index == -1)
{
throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
}
- History.Messages[index] = newMessage;
+ SessionChatHistory.Messages[index] = newMessage;
// Remove all message after the new message
- History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
+ SessionChatHistory.Messages.RemoveRange(index + 1, SessionChatHistory.Messages.Count - index - 1);
return this;
}
@@ -388,7 +388,7 @@ public ChatSession ReplaceUserMessage(
///
///
public async IAsyncEnumerable ChatAsync(
- ChatHistory.Message message,
+ Message message,
bool applyInputTransformPipeline,
IInferenceParams? inferenceParams = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -424,14 +424,14 @@ public async IAsyncEnumerable ChatAsync(
// If the session history was added as part of new chat session history,
// convert the complete history includsing system message and manually added history
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
- prompt = HistoryTransform.HistoryToText(History);
+ prompt = HistoryTransform.HistoryToText(SessionChatHistory);
}
else
{
// If the session was restored from a previous session,
// convert only the current message to the prompt with the prompt template
// specified in the HistoryTransform class implementation that is provided.
- ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
+ IChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content, SessionChatHistory.GetType());
prompt = HistoryTransform.HistoryToText(singleMessageHistory);
}
@@ -460,7 +460,7 @@ in ChatAsyncInternal(
///
///
public IAsyncEnumerable ChatAsync(
- ChatHistory.Message message,
+ Message message,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
@@ -481,16 +481,16 @@ public IAsyncEnumerable ChatAsync(
///
///
public IAsyncEnumerable ChatAsync(
- ChatHistory history,
+ IChatHistory history,
bool applyInputTransformPipeline,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
- ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
+ Message lastMessage = history.Messages.LastOrDefault()
?? throw new ArgumentException("History must contain at least one message", nameof(history));
foreach (
- ChatHistory.Message message
+ Message message
in history.Messages.Take(history.Messages.Count - 1))
{
// Apply input transform pipeline
@@ -523,7 +523,7 @@ var inputTransform
///
///
public IAsyncEnumerable ChatAsync(
- ChatHistory history,
+ IChatHistory history,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
@@ -546,7 +546,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Make sure the last message is an assistant message (reponse from the LLM).
- ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
+ Message? lastAssistantMessage = SessionChatHistory.Messages.LastOrDefault();
if (lastAssistantMessage is null
|| lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
@@ -558,7 +558,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync(
RemoveLastMessage();
// Get the last user message.
- ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
+ Message? lastUserMessage = SessionChatHistory.Messages.LastOrDefault();
if (lastUserMessage is null
|| lastUserMessage.AuthorRole != AuthorRole.User)
@@ -629,29 +629,29 @@ public record SessionState
/// The history transform used in this session.
///
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
-
+
///
/// The the chat history messages for this session.
///
- public ChatHistory.Message[] History { get; set; } = Array.Empty();
+ public IChatHistory SessionChatHistory;
///
/// Create a new session state.
///
///
///
- ///
+ ///
///
///
///
public SessionState(
- State? contextState, ExecutorBaseState executorState,
- ChatHistory history, List inputTransformPipeline,
+ State? contextState, ExecutorBaseState executorState,
+ IChatHistory chatHistory, List inputTransformPipeline,
ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
{
ContextState = contextState;
ExecutorState = executorState;
- History = history.Messages.ToArray();
+ SessionChatHistory = chatHistory;
InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray();
OutputTransform = outputTransform.Clone();
HistoryTransform = historyTransform.Clone();
@@ -686,7 +686,7 @@ public void Save(string path)
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
- File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());
+ File.WriteAllText(historyFilepath, ChatHistorySerializer.ToJson(new ChatHistory(SessionChatHistory.Messages.ToArray())));
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));
@@ -704,7 +704,7 @@ public void Save(string path)
///
///
/// Throws when session state is incorrect
- public static SessionState Load(string path)
+ public static SessionState Load(string path, Type type)
{
if (string.IsNullOrWhiteSpace(path))
{
@@ -717,7 +717,7 @@ public static SessionState Load(string path)
}
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
- var contextState = File.Exists(modelStateFilePath) ?
+ var contextState = File.Exists(modelStateFilePath) ?
State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
: null;
@@ -726,14 +726,14 @@ public static SessionState Load(string path)
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
string historyJson = File.ReadAllText(historyFilepath);
- var history = ChatHistory.FromJson(historyJson)
+ var history = ChatHistorySerializer.FromJson(historyJson, type)
?? throw new ArgumentException("History file is invalid", nameof(path));
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
ITextTransform[] inputTransforms;
try
{
- inputTransforms = File.Exists(inputTransformFilepath) ?
+ inputTransforms = File.Exists(inputTransformFilepath) ?
(JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath))
?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
: Array.Empty();
@@ -744,11 +744,10 @@ public static SessionState Load(string path)
}
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
-
ITextStreamTransform outputTransform;
try
{
- outputTransform = File.Exists(outputTransformFilepath) ?
+ outputTransform = File.Exists(outputTransformFilepath) ?
(JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath))
?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
: new LLamaTransforms.EmptyTextOutputStreamTransform();
@@ -762,7 +761,7 @@ public static SessionState Load(string path)
IHistoryTransform historyTransform;
try
{
- historyTransform = File.Exists(historyTransformFilepath) ?
+ historyTransform = File.Exists(historyTransformFilepath) ?
(JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath))
?? throw new ArgumentException("History transform file is invalid", nameof(path)))
: new LLamaTransforms.DefaultHistoryTransform();
diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs
index c22cc7c06..052d1b538 100644
--- a/LLama/Common/ChatHistory.cs
+++ b/LLama/Common/ChatHistory.cs
@@ -1,4 +1,5 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -33,41 +34,60 @@ public enum AuthorRole
// copy from semantic-kernel
///
- /// The chat history class
+ /// The message class
///
- public class ChatHistory
+ public class Message
{
- private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };
+ ///
+ /// Role of the message author, e.g. user/assistant/system
+ ///
+ [JsonConverter(typeof(JsonStringEnumConverter))]
+ [JsonPropertyName("author_role")]
+ public AuthorRole AuthorRole { get; set; }
///
- /// Chat message representation
+ /// Message content
///
- public class Message
+ [JsonPropertyName("content")]
+ public string Content { get; set; }
+
+ ///
+ /// Create a new instance
+ ///
+ /// Role of message author
+ /// Message content
+ public Message(AuthorRole authorRole, string content)
{
- ///
- /// Role of the message author, e.g. user/assistant/system
- ///
- [JsonConverter(typeof(JsonStringEnumConverter))]
- [JsonPropertyName("author_role")]
- public AuthorRole AuthorRole { get; set; }
-
- ///
- /// Message content
- ///
- [JsonPropertyName("content")]
- public string Content { get; set; }
-
- ///
- /// Create a new instance
- ///
- /// Role of message author
- /// Message content
- public Message(AuthorRole authorRole, string content)
- {
- this.AuthorRole = authorRole;
- this.Content = content;
- }
+ this.AuthorRole = authorRole;
+ this.Content = content;
}
+ }
+
+ ///
+ /// Interface for chat history
+ ///
+ public interface IChatHistory
+ {
+ ///
+ /// List of messages in the chat
+ ///
+ List Messages { get; set; }
+
+ ///
+ /// Add a message to the chat history
+ ///
+ /// Role of the message author
+ /// Message content
+ void AddMessage(AuthorRole authorRole, string content);
+ }
+
+ // copy from semantic-kernel
+ ///
+ /// The chat history class
+ ///
+ public class ChatHistory : IChatHistory
+ {
+ private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };
///
/// List of messages in the chat
@@ -99,24 +119,33 @@ public void AddMessage(AuthorRole authorRole, string content)
{
this.Messages.Add(new Message(authorRole, content));
}
+ }
+
+ ///
+ /// Serializer for chat history
+ ///
+ public class ChatHistorySerializer
+ {
+ private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };
///
/// Serialize the chat history to JSON
///
///
- public string ToJson()
+ public static string ToJson(IChatHistory chatHistory)
{
- return JsonSerializer.Serialize(this, _jsonOptions);
+ return JsonSerializer.Serialize(chatHistory, _jsonOptions);
}
///
/// Deserialize a chat history from JSON
///
///
+ ///
///
- public static ChatHistory? FromJson(string json)
+ public static IChatHistory? FromJson(string json, Type type)
{
- return JsonSerializer.Deserialize(json);
+ return JsonSerializer.Deserialize(json, type) as IChatHistory;
}
}
}
diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs
index d74d9ddaf..6ed86e777 100644
--- a/LLama/LLamaTransforms.cs
+++ b/LLama/LLamaTransforms.cs
@@ -1,5 +1,6 @@
using LLama.Abstractions;
using LLama.Common;
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -44,7 +45,7 @@ public class DefaultHistoryTransform : IHistoryTransform
///
///
///
- public DefaultHistoryTransform(string? userName = null, string? assistantName = null,
+ public DefaultHistoryTransform(string? userName = null, string? assistantName = null,
string? systemName = null, string? unknownName = null, bool isInstructMode = false)
{
_userName = userName ?? defaultUserName;
@@ -61,7 +62,7 @@ public IHistoryTransform Clone()
}
///
- public virtual string HistoryToText(ChatHistory history)
+ public virtual string HistoryToText(IChatHistory history)
{
StringBuilder sb = new();
foreach (var message in history.Messages)
@@ -87,9 +88,9 @@ public virtual string HistoryToText(ChatHistory history)
}
///
- public virtual ChatHistory TextToHistory(AuthorRole role, string text)
+ public virtual IChatHistory TextToHistory(AuthorRole role, string text, Type type)
{
- ChatHistory history = new ChatHistory();
+ IChatHistory history = (IChatHistory)(Activator.CreateInstance(type) ?? new ChatHistory());
history.AddMessage(role, TrimNamesFromText(text, role));
return history;
}