From d5c8e27e0c4b3d1ee3254c882af6cb3696c46dd1 Mon Sep 17 00:00:00 2001 From: Mehdi Hadeli Date: Sat, 30 Nov 2024 12:04:00 +0100 Subject: [PATCH] =?UTF-8?q?refactor:=20=E2=99=BB=EF=B8=8F=20enhance=20cach?= =?UTF-8?q?e-models=20and=20model-options=20loading=20(#20)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/AIAssist/Commands/CodeAssistCommand.cs | 54 ++++++-- .../DependencyInjectionExtensions.cs | 34 +++-- .../EmbeddingCodeAssist.cs | 4 +- .../TreeSitterCodeAssistSummary.cs | 4 +- src/AIAssist/Services/LLMClientManager.cs | 15 ++- .../SpectreConsole/ColorTheme.cs | 2 + .../Contracts/ISpectreUtilities.cs | 74 ++++++++-- .../Markdown/SpectreMarkdownBlockRendering.cs | 2 +- .../SpectreMarkdownInlineRendering.cs | 2 +- .../SpectreConsole/SpectreUtilities.cs | 126 ++++++++++++++---- .../SpectreConsole/Themes/dracula.json | 1 + src/Clients/AnthropicClient.cs | 23 ++-- src/Clients/AzureClient.cs | 59 ++++---- src/Clients/CacheModels.cs | 125 ++++++++--------- .../Converters/AIProviderTypeConverter.cs | 53 -------- .../Converters/CodeAssistTypeConverter.cs | 43 ------ .../Converters/CodeDiffTypeConverter.cs | 49 ------- src/Clients/Converters/ModelTypeConverter.cs | 46 ------- src/Clients/Converters/RoleTypeConverter.cs | 1 + src/Clients/LLMs/models_information_list.json | 8 +- src/Clients/Models/Model.cs | 24 +++- src/Clients/Models/ModelInformation.cs | 13 +- src/Clients/Models/ModelOption.cs | 8 +- src/Clients/OllamaClient.cs | 45 +++---- src/Clients/OpenAiClient.cs | 45 +++---- 25 files changed, 420 insertions(+), 440 deletions(-) delete mode 100644 src/Clients/Converters/AIProviderTypeConverter.cs delete mode 100644 src/Clients/Converters/CodeAssistTypeConverter.cs delete mode 100644 src/Clients/Converters/CodeDiffTypeConverter.cs delete mode 100644 src/Clients/Converters/ModelTypeConverter.cs diff --git a/src/AIAssist/Commands/CodeAssistCommand.cs b/src/AIAssist/Commands/CodeAssistCommand.cs index fe30c6e..c1be460 100644 --- a/src/AIAssist/Commands/CodeAssistCommand.cs +++ b/src/AIAssist/Commands/CodeAssistCommand.cs @@ -30,7 +30,7 @@ IOptions appOptions private readonly AppOptions _appOptions = appOptions.Value; private readonly Model _chatModel = cacheModels.GetModel(llmOptions.Value.ChatModel) - ?? throw new KeyNotFoundException($"Model '{llmOptions.Value.ChatModel}' not found in the ModelCache."); + ?? throw new ArgumentNullException($"Model '{llmOptions.Value.ChatModel}' not found in the ModelCache."); private readonly Model? _embeddingModel = cacheModels.GetModel(llmOptions.Value.EmbeddingsModel); private static bool _running = true; @@ -122,10 +122,22 @@ public override async Task ExecuteAsync(CommandContext context, Settings se SetupOptions(settings); spectreUtilities.SummaryTextLine("Code assist mode is activated!"); - spectreUtilities.SummaryTextLine( - $"Chat model: {_chatModel.Name} | Embedding model: {_embeddingModel?.Name ?? "-"} | CodeAssistType: {_chatModel.ModelOption.CodeAssistType} | CodeDiffType: {_chatModel.ModelOption.CodeDiffType}" + spectreUtilities.NormalText("Chat model: "); + spectreUtilities.HighlightTextLine(_chatModel.Name); + + spectreUtilities.NormalText("Embedding model: "); + spectreUtilities.HighlightTextLine(_embeddingModel?.Name ?? "-"); + + spectreUtilities.NormalText("CodeAssistType: "); + spectreUtilities.HighlightTextLine(_chatModel.CodeAssistType.ToString()); + + spectreUtilities.NormalText("CodeDiffType: "); + spectreUtilities.HighlightTextLine(_chatModel.CodeDiffType.ToString()); + + spectreUtilities.NormalTextLine( + "Please 'Ctrl+H' to see all available commands in the code assist mode.", + decoration: Decoration.Bold ); - spectreUtilities.SummaryTextLine("Please 'Ctrl+H' to see all available commands in the code assist mode."); spectreUtilities.WriteRule(); await AnsiConsole @@ -189,42 +201,42 @@ private void SetupOptions(Settings settings) if (!string.IsNullOrEmpty(settings.ChatModelApiKey)) { - _chatModel.ModelOption.ApiKey = settings.ChatModelApiKey.Trim(); + _chatModel.ApiKey = settings.ChatModelApiKey.Trim(); } if (!string.IsNullOrEmpty(settings.ChatApiVersion)) { - _chatModel.ModelOption.ApiVersion = settings.ChatApiVersion.Trim(); + _chatModel.ApiVersion = settings.ChatApiVersion.Trim(); } if (!string.IsNullOrEmpty(settings.ChatDeploymentId)) { - _chatModel.ModelOption.DeploymentId = settings.ChatDeploymentId.Trim(); + _chatModel.DeploymentId = settings.ChatDeploymentId.Trim(); } if (!string.IsNullOrEmpty(settings.ChatBaseAddress)) { - _chatModel.ModelOption.BaseAddress = settings.ChatBaseAddress.Trim(); + _chatModel.BaseAddress = settings.ChatBaseAddress.Trim(); } if (!string.IsNullOrEmpty(settings.EmbeddingsModelApiKey) && _embeddingModel is not null) { - _embeddingModel.ModelOption.ApiKey = settings.EmbeddingsModelApiKey.Trim(); + _embeddingModel.ApiKey = settings.EmbeddingsModelApiKey.Trim(); } if (!string.IsNullOrEmpty(settings.EmbeddingsApiVersion) && _embeddingModel is not null) { - _embeddingModel.ModelOption.ApiVersion = settings.EmbeddingsApiVersion.Trim(); + _embeddingModel.ApiVersion = settings.EmbeddingsApiVersion.Trim(); } if (!string.IsNullOrEmpty(settings.EmbeddingsDeploymentId) && _embeddingModel is not null) { - _embeddingModel.ModelOption.DeploymentId = settings.EmbeddingsDeploymentId.Trim(); + _embeddingModel.DeploymentId = settings.EmbeddingsDeploymentId.Trim(); } if (!string.IsNullOrEmpty(settings.EmbeddingsBaseAddress) && _embeddingModel is not null) { - _embeddingModel.ModelOption.BaseAddress = settings.EmbeddingsBaseAddress.Trim(); + _embeddingModel.BaseAddress = settings.EmbeddingsBaseAddress.Trim(); } _appOptions.ContextWorkingDirectory = !string.IsNullOrEmpty(settings.ContextWorkingDirectory) @@ -246,21 +258,37 @@ private void SetupOptions(Settings settings) if (settings.CodeDiffType is not null) { _llmOptions.CodeDiffType = settings.CodeDiffType.Value; + _chatModel.CodeDiffType = settings.CodeDiffType.Value; + + if (_embeddingModel != null) + _embeddingModel.CodeDiffType = settings.CodeDiffType.Value; } if (settings.CodeAssistType is not null) { _llmOptions.CodeAssistType = settings.CodeAssistType.Value; + _chatModel.CodeAssistType = settings.CodeAssistType.Value; + + if (_embeddingModel != null) + _embeddingModel.CodeAssistType = settings.CodeAssistType.Value; } - if (settings.Threshold is not null && _embeddingModel is not null) + if (settings.Threshold is not null) { _llmOptions.Threshold = settings.Threshold.Value; + _chatModel.Threshold = settings.Threshold.Value; + + if (_embeddingModel != null) + _embeddingModel.Threshold = settings.Threshold.Value; } if (settings.Temperature is not null) { _llmOptions.Temperature = settings.Temperature.Value; + _chatModel.Temperature = settings.Temperature.Value; + + if (_embeddingModel != null) + _embeddingModel.Temperature = settings.Temperature.Value; } } } diff --git a/src/AIAssist/Extensions/DependencyInjectionExtensions.cs b/src/AIAssist/Extensions/DependencyInjectionExtensions.cs index 0a70674..ee07906 100644 --- a/src/AIAssist/Extensions/DependencyInjectionExtensions.cs +++ b/src/AIAssist/Extensions/DependencyInjectionExtensions.cs @@ -212,7 +212,9 @@ private static void AddCodeAssistDependencies(HostApplicationBuilder builder) var chatModel = cacheModels.GetModel(llmOptions.Value.ChatModel); - ICodeAssist codeAssist = factory.Create(chatModel.ModelOption.CodeAssistType); + ArgumentNullException.ThrowIfNull(chatModel); + + ICodeAssist codeAssist = factory.Create(chatModel.CodeAssistType); return new CodeAssistantManager(codeAssist, codeDiffManager); }); @@ -284,17 +286,19 @@ private static void AddClientDependencies(HostApplicationBuilder builder) var options = sp.GetRequiredService>().Value; var policyOptions = sp.GetRequiredService>().Value; - var cacheModels = sp.GetRequiredService(); ArgumentException.ThrowIfNullOrEmpty(options.ChatModel); + + var cacheModels = sp.GetRequiredService(); var chatModel = cacheModels.GetModel(options.ChatModel); + ArgumentNullException.ThrowIfNull(chatModel); client.Timeout = TimeSpan.FromSeconds(policyOptions.TimeoutSeconds); var chatApiKey = Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatModelApiKey) - ?? chatModel.ModelOption.ApiKey; + ?? chatModel.ApiKey; - switch (chatModel.ModelInformation.AIProvider) + switch (chatModel.AIProvider) { case AIProvider.Openai: { @@ -303,7 +307,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder) var baseAddress = Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatBaseAddress) - ?? chatModel.ModelOption.BaseAddress + ?? chatModel.BaseAddress ?? "https://api.openai.com"; client.BaseAddress = new Uri(baseAddress.Trim()); @@ -320,7 +324,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder) var baseAddress = Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatBaseAddress) - ?? chatModel.ModelOption.BaseAddress; + ?? chatModel.BaseAddress; ArgumentException.ThrowIfNullOrEmpty(baseAddress); client.BaseAddress = new Uri(baseAddress.Trim()); @@ -332,7 +336,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder) { var baseAddress = Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatBaseAddress) - ?? chatModel.ModelOption.BaseAddress + ?? chatModel.BaseAddress ?? "http://localhost:11434"; // https://github.com/ollama/ollama/blob/main/docs/api.md @@ -359,15 +363,17 @@ private static void AddClientDependencies(HostApplicationBuilder builder) var cacheModels = sp.GetRequiredService(); ArgumentException.ThrowIfNullOrEmpty(options.EmbeddingsModel); + var embeddingModel = cacheModels.GetModel(options.EmbeddingsModel); + ArgumentNullException.ThrowIfNull(embeddingModel); client.Timeout = TimeSpan.FromSeconds(policyOptions.TimeoutSeconds); var embeddingsApiKey = Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsModelApiKey) - ?? embeddingModel.ModelOption.ApiKey; + ?? embeddingModel.ApiKey; - switch (embeddingModel.ModelInformation.AIProvider) + switch (embeddingModel.AIProvider) { case AIProvider.Openai: { @@ -376,7 +382,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder) var baseAddress = Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsBaseAddress) - ?? embeddingModel.ModelOption.BaseAddress + ?? embeddingModel.BaseAddress ?? "https://api.openai.com"; client.BaseAddress = new Uri(baseAddress.Trim()); @@ -393,7 +399,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder) var baseAddress = Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsBaseAddress) - ?? embeddingModel.ModelOption.BaseAddress; + ?? embeddingModel.BaseAddress; ArgumentException.ThrowIfNullOrEmpty(baseAddress); client.BaseAddress = new Uri(baseAddress.Trim()); @@ -405,7 +411,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder) { var baseAddress = Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsBaseAddress) - ?? embeddingModel.ModelOption.BaseAddress + ?? embeddingModel.BaseAddress ?? "http://localhost:11434"; // https://github.com/ollama/ollama/blob/main/docs/api.md @@ -505,7 +511,9 @@ private static void AddCodeDiffDependency(HostApplicationBuilder builder) var cacheModels = sp.GetRequiredService(); var chatModel = cacheModels.GetModel(options.Value.ChatModel); - var codeDiffParser = factory.Create(chatModel.ModelOption.CodeDiffType); + ArgumentNullException.ThrowIfNull(chatModel); + + var codeDiffParser = factory.Create(chatModel.CodeDiffType); var codeDiffUpdater = sp.GetRequiredService(); diff --git a/src/AIAssist/Services/CodeAssistStrategies/EmbeddingCodeAssist.cs b/src/AIAssist/Services/CodeAssistStrategies/EmbeddingCodeAssist.cs index 469f063..d3fde7f 100644 --- a/src/AIAssist/Services/CodeAssistStrategies/EmbeddingCodeAssist.cs +++ b/src/AIAssist/Services/CodeAssistStrategies/EmbeddingCodeAssist.cs @@ -85,8 +85,8 @@ public Task> GetCodeTreeContents(IList? codeFiles) var systemPrompt = promptManager.GetSystemPrompt( embeddingOriginalTreeCodes, - llmClientManager.ChatModel.ModelOption.CodeAssistType, - llmClientManager.ChatModel.ModelOption.CodeDiffType + llmClientManager.ChatModel.CodeAssistType, + llmClientManager.ChatModel.CodeDiffType ); // Generate a response from the language model (e.g., OpenAI or Llama) diff --git a/src/AIAssist/Services/CodeAssistStrategies/TreeSitterCodeAssistSummary.cs b/src/AIAssist/Services/CodeAssistStrategies/TreeSitterCodeAssistSummary.cs index debe278..0abdd49 100644 --- a/src/AIAssist/Services/CodeAssistStrategies/TreeSitterCodeAssistSummary.cs +++ b/src/AIAssist/Services/CodeAssistStrategies/TreeSitterCodeAssistSummary.cs @@ -50,8 +50,8 @@ public Task> GetCodeTreeContents(IList? codeFiles) var systemPrompt = promptManager.GetSystemPrompt( summaryTreeCodes, - llmClientManager.ChatModel.ModelOption.CodeAssistType, - llmClientManager.ChatModel.ModelOption.CodeDiffType + llmClientManager.ChatModel.CodeAssistType, + llmClientManager.ChatModel.CodeDiffType ); // Generate a response from the language model (e.g., OpenAI or Llama) diff --git a/src/AIAssist/Services/LLMClientManager.cs b/src/AIAssist/Services/LLMClientManager.cs index 4782d0e..9b9e8c8 100644 --- a/src/AIAssist/Services/LLMClientManager.cs +++ b/src/AIAssist/Services/LLMClientManager.cs @@ -29,12 +29,14 @@ ICacheModels cacheModels _tokenizer = tokenizer; EmbeddingModel = cacheModels.GetModel(llmOptions.Value.EmbeddingsModel); - ChatModel = cacheModels.GetModel(llmOptions.Value.ChatModel); - EmbeddingThreshold = EmbeddingModel.ModelOption.Threshold; + ChatModel = + cacheModels.GetModel(llmOptions.Value.ChatModel) + ?? throw new ArgumentNullException($"Model '{llmOptions.Value.ChatModel}' not found in the CacheModels."); + EmbeddingThreshold = EmbeddingModel?.Threshold ?? 0.2m; } public Model ChatModel { get; } - public Model EmbeddingModel { get; } + public Model? EmbeddingModel { get; } public decimal EmbeddingThreshold { get; } public async IAsyncEnumerable GetCompletionStreamAsync( @@ -50,7 +52,7 @@ ICacheModels cacheModels var chatItems = chatSession.GetChatItemsFromHistory(); - var llmClientStratgey = _clientFactory.CreateClient(ChatModel.ModelInformation.AIProvider); + var llmClientStratgey = _clientFactory.CreateClient(ChatModel.AIProvider); var chatCompletionResponseStreams = llmClientStratgey.GetCompletionStreamAsync( new ChatCompletionRequest(chatItems.Select(x => new ChatCompletionRequestItem(x.Role, x.Prompt))), @@ -94,14 +96,15 @@ public async Task GetEmbeddingAsync( CancellationToken cancellationToken = default ) { - var llmClientStratgey = _clientFactory.CreateClient(EmbeddingModel.ModelInformation.AIProvider); + ArgumentNullException.ThrowIfNull(EmbeddingModel); + var llmClientStratgey = _clientFactory.CreateClient(EmbeddingModel.AIProvider); var embeddingResponse = await llmClientStratgey.GetEmbeddingAsync(inputs, path, cancellationToken); // in embedding output tokens and its cost is 0 var inputTokens = embeddingResponse?.TokenUsage?.InputTokens ?? await _tokenizer.GetTokenCount(string.Concat(inputs)); - var cost = inputTokens * EmbeddingModel.ModelInformation.InputCostPerToken; + var cost = inputTokens * EmbeddingModel.InputCostPerToken; return new GetEmbeddingResult(embeddingResponse?.Embeddings ?? new List>(), inputTokens, cost); } diff --git a/src/BuildingBlocks/SpectreConsole/ColorTheme.cs b/src/BuildingBlocks/SpectreConsole/ColorTheme.cs index 4095602..779c420 100644 --- a/src/BuildingBlocks/SpectreConsole/ColorTheme.cs +++ b/src/BuildingBlocks/SpectreConsole/ColorTheme.cs @@ -7,6 +7,8 @@ public class ColorTheme { public string Name { get; set; } = default!; + public string? Foreground { get; set; } = default!; + [JsonPropertyName("console")] public ConsoleStyle ConsoleStyle { get; set; } = default!; diff --git a/src/BuildingBlocks/SpectreConsole/Contracts/ISpectreUtilities.cs b/src/BuildingBlocks/SpectreConsole/Contracts/ISpectreUtilities.cs index a4b33a4..fb5f4bd 100644 --- a/src/BuildingBlocks/SpectreConsole/Contracts/ISpectreUtilities.cs +++ b/src/BuildingBlocks/SpectreConsole/Contracts/ISpectreUtilities.cs @@ -6,18 +6,68 @@ public interface ISpectreUtilities { bool ConfirmationPrompt(string message); string? UserPrompt(string? promptMessage = null); - void InformationTextLine(string message, Justify? justify = null, Overflow? overflow = null); - void InformationText(string message, Justify? justify = null, Overflow? overflow = null); - public void SummaryTextLine(string message, Justify? justify = null, Overflow? overflow = null); - public void SummaryText(string message, Justify? justify = null, Overflow? overflow = null); - public void HighlightTextLine(string message, Justify? justify = null, Overflow? overflow = null); - public void HighlightText(string message, Justify? justify = null, Overflow? overflow = null); - void NormalTextLine(string message, Justify? justify = null, Overflow? overflow = null); - void NormalText(string message, Justify? justify = null, Overflow? overflow = null); - void WarningTextLine(string message, Justify? justify = null, Overflow? overflow = null); - void WarningText(string message, Justify? justify = null, Overflow? overflow = null); - void ErrorTextLine(string message, Justify? justify = null, Overflow? overflow = null); - void SuccessTextLine(string message, Justify? justify = null, Overflow? overflow = null); + void InformationTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + void InformationText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + public void SummaryTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + public void SummaryText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + public void HighlightTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + public void HighlightText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + void NormalTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + void NormalText(string message, Justify? justify = null, Overflow? overflow = null, Decoration? decoration = null); + void WarningTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + void WarningText(string message, Justify? justify = null, Overflow? overflow = null, Decoration? decoration = null); + void ErrorTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); + void SuccessTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ); void WriteCursor(); void WriteRule(); void Exception(string errorMessage, Exception ex); diff --git a/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownBlockRendering.cs b/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownBlockRendering.cs index 6cdce6c..3808203 100644 --- a/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownBlockRendering.cs +++ b/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownBlockRendering.cs @@ -272,7 +272,7 @@ private string CreateStringStyle(StyleBase styleBase) var bold = styleBase.Bold ? "bold" : "default"; var underline = styleBase.Underline ? "underline" : "default"; - return $"{styleBase.Foreground ?? "default"} on {styleBase.Background ?? "default"} {italic} {bold} {underline}"; + return $"{styleBase.Foreground ?? _colorTheme.Foreground ?? "default"} on {styleBase.Background ?? "default"} {italic} {bold} {underline}"; } public void Dispose() diff --git a/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownInlineRendering.cs b/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownInlineRendering.cs index 5e3eee1..0ef9b38 100644 --- a/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownInlineRendering.cs +++ b/src/BuildingBlocks/SpectreConsole/Markdown/SpectreMarkdownInlineRendering.cs @@ -120,7 +120,7 @@ private string CreateStringStyle(StyleBase styleBase) var style = $"{ - styleBase.Foreground ?? "default" + styleBase.Foreground ?? colorTheme.Foreground ?? "default" } on { styleBase.Background ?? "default" } { diff --git a/src/BuildingBlocks/SpectreConsole/SpectreUtilities.cs b/src/BuildingBlocks/SpectreConsole/SpectreUtilities.cs index ade21e9..fab3c2d 100644 --- a/src/BuildingBlocks/SpectreConsole/SpectreUtilities.cs +++ b/src/BuildingBlocks/SpectreConsole/SpectreUtilities.cs @@ -29,16 +29,29 @@ public bool ConfirmationPrompt(string message) return input; } - public void InformationTextLine(string message, Justify? justify = null, Overflow? overflow = null) + public void InformationTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { - InformationText(message, justify: justify, overflow: overflow); + InformationText(message, justify: justify, overflow: overflow, decoration: decoration); console.WriteLine(); } - public void InformationText(string message, Justify? justify = null, Overflow? overflow = null) + public void InformationText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { console.Write( - new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Information)}]{message}[/]") + new Markup( + $"[{CreateStringStyle(theme.ConsoleStyle.Information)}]{message}[/]", + new Style(decoration: decoration) + ) { Overflow = overflow, Justification = justify, @@ -46,16 +59,29 @@ public void InformationText(string message, Justify? justify = null, Overflow? o ); } - public void SummaryTextLine(string message, Justify? justify = null, Overflow? overflow = null) + public void SummaryTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { - SummaryText(message, justify: justify, overflow: overflow); + SummaryText(message, justify: justify, overflow: overflow, decoration: decoration); console.WriteLine(); } - public void SummaryText(string message, Justify? justify = null, Overflow? overflow = null) + public void SummaryText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { console.Write( - new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Summary)}]{message}[/]") + new Markup( + $"[{CreateStringStyle(theme.ConsoleStyle.Summary)}]{message}[/]", + new Style(decoration: decoration) + ) { Overflow = overflow, Justification = justify, @@ -63,16 +89,29 @@ public void SummaryText(string message, Justify? justify = null, Overflow? overf ); } - public void HighlightTextLine(string message, Justify? justify = null, Overflow? overflow = null) + public void HighlightTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { - HighlightText(message, justify: justify, overflow: overflow); + HighlightText(message, justify: justify, overflow: overflow, decoration: decoration); console.WriteLine(); } - public void HighlightText(string message, Justify? justify = null, Overflow? overflow = null) + public void HighlightText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { console.Write( - new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Highlight)}]{message}[/]") + new Markup( + $"[{CreateStringStyle(theme.ConsoleStyle.Highlight)}]{message}[/]", + new Style(decoration: decoration) + ) { Overflow = overflow, Justification = justify, @@ -80,16 +119,26 @@ public void HighlightText(string message, Justify? justify = null, Overflow? ove ); } - public void NormalTextLine(string message, Justify? justify = null, Overflow? overflow = null) + public void NormalTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { - NormalText(message, justify: justify, overflow: overflow); + NormalText(message, justify: justify, overflow: overflow, decoration: decoration); console.WriteLine(); } - public void NormalText(string message, Justify? justify = null, Overflow? overflow = null) + public void NormalText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { console.Write( - new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Text)}]{message}[/]") + new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Text)}]{message}[/]", new Style(decoration: decoration)) { Overflow = overflow, Justification = justify, @@ -97,16 +146,29 @@ public void NormalText(string message, Justify? justify = null, Overflow? overfl ); } - public void WarningTextLine(string message, Justify? justify = null, Overflow? overflow = null) + public void WarningTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { - WarningText(message, justify: justify, overflow: overflow); + WarningText(message, justify: justify, overflow: overflow, decoration: decoration); console.WriteLine(); } - public void WarningText(string message, Justify? justify = null, Overflow? overflow = null) + public void WarningText( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { console.Write( - new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Warning)}]{message}[/]") + new Markup( + $"[{CreateStringStyle(theme.ConsoleStyle.Warning)}]{message}[/]", + new Style(decoration: decoration) + ) { Overflow = overflow, Justification = justify, @@ -114,10 +176,18 @@ public void WarningText(string message, Justify? justify = null, Overflow? overf ); } - public void ErrorTextLine(string message, Justify? justify = null, Overflow? overflow = null) + public void ErrorTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { console.Write( - new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Error)}]{message}[/]" + Environment.NewLine) + new Markup( + $"[{CreateStringStyle(theme.ConsoleStyle.Error)}]{message}[/]" + Environment.NewLine, + new Style(decoration: decoration) + ) { Overflow = overflow, Justification = justify, @@ -125,10 +195,18 @@ public void ErrorTextLine(string message, Justify? justify = null, Overflow? ove ); } - public void SuccessTextLine(string message, Justify? justify = null, Overflow? overflow = null) + public void SuccessTextLine( + string message, + Justify? justify = null, + Overflow? overflow = null, + Decoration? decoration = null + ) { console.Write( - new Markup($"[{CreateStringStyle(theme.ConsoleStyle.Success)}]{message}[/]" + Environment.NewLine) + new Markup( + $"[{CreateStringStyle(theme.ConsoleStyle.Success)}]{message}[/]" + Environment.NewLine, + new Style(decoration: decoration) + ) { Overflow = overflow, Justification = justify, diff --git a/src/BuildingBlocks/SpectreConsole/Themes/dracula.json b/src/BuildingBlocks/SpectreConsole/Themes/dracula.json index c891b13..f7147fd 100644 --- a/src/BuildingBlocks/SpectreConsole/Themes/dracula.json +++ b/src/BuildingBlocks/SpectreConsole/Themes/dracula.json @@ -1,6 +1,7 @@ { "$schema": "./spectre_console_schema.json", "name": "dracula", + "foreground": "#f8f8f2", "console": { "prompt": { "foreground": "#bd93f9" diff --git a/src/Clients/AnthropicClient.cs b/src/Clients/AnthropicClient.cs index a6c6581..ccdbc21 100644 --- a/src/Clients/AnthropicClient.cs +++ b/src/Clients/AnthropicClient.cs @@ -51,7 +51,7 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - temperature = _chatModel.ModelOption.Temperature, + temperature = _chatModel.Temperature, }; var client = httpClientFactory.CreateClient("llm_chat_client"); @@ -80,8 +80,8 @@ AsyncPolicyWrap combinedPolicy var inputTokens = completionResponse.Usage?.InputTokens ?? 0; var outTokens = completionResponse.Usage?.OutputTokens ?? 0; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -107,7 +107,7 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - temperature = _chatModel.ModelOption.Temperature, + temperature = _chatModel.Temperature, stream = true, }; @@ -165,8 +165,8 @@ AsyncPolicyWrap combinedPolicy // we have the usage in the last chunk and done state var inputTokens = completionStreamResponse.Usage?.InputTokens ?? 0; var outTokens = completionStreamResponse.Usage?.OutputTokens ?? 0; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -241,17 +241,14 @@ private async Task ValidateChatMaxInputToken(ChatCompletionRequest chatCompletio string.Concat(chatCompletionRequest.Items.Select(x => x.Prompt)) ); - if ( - _chatModel.ModelInformation.MaxInputTokens > 0 - && inputTokenCount > _chatModel.ModelInformation.MaxInputTokens - ) + if (_chatModel.MaxInputTokens > 0 && inputTokenCount > _chatModel.MaxInputTokens) { throw new AnthropicException( new AnthropicError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.ModelInformation.MaxInputTokens.FormatCommas()}", + $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.MaxInputTokens.FormatCommas()}", }, HttpStatusCode.BadRequest ); @@ -260,14 +257,14 @@ private async Task ValidateChatMaxInputToken(ChatCompletionRequest chatCompletio private void ValidateChatMaxToken(int maxTokenCount) { - if (_chatModel.ModelInformation.MaxTokens > 0 && maxTokenCount > _chatModel.ModelInformation.MaxTokens) + if (_chatModel.MaxTokens > 0 && maxTokenCount > _chatModel.MaxTokens) { throw new AnthropicException( new AnthropicError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.ModelInformation.MaxTokens.FormatCommas()}.", + $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.MaxTokens.FormatCommas()}.", }, HttpStatusCode.BadRequest ); diff --git a/src/Clients/AzureClient.cs b/src/Clients/AzureClient.cs index 313cf6d..94cbd39 100644 --- a/src/Clients/AzureClient.cs +++ b/src/Clients/AzureClient.cs @@ -53,18 +53,17 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - temperature = _chatModel.ModelOption.Temperature, + temperature = _chatModel.Temperature, }; var client = httpClientFactory.CreateClient("llm_chat_client"); var apiVersion = - Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatApiVersion) - ?? _chatModel.ModelOption.ApiVersion; + Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatApiVersion) ?? _chatModel.ApiVersion; var deploymentId = Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatDeploymentId) - ?? _chatModel.ModelOption.DeploymentId; + ?? _chatModel.DeploymentId; ArgumentException.ThrowIfNullOrEmpty(apiVersion); ArgumentException.ThrowIfNullOrEmpty(deploymentId); @@ -101,8 +100,8 @@ AsyncPolicyWrap combinedPolicy var inputTokens = completionResponse.Usage?.PromptTokens ?? 0; var outTokens = completionResponse.Usage?.CompletionTokens ?? 0; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -128,7 +127,7 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - temperature = _chatModel.ModelOption.Temperature, + temperature = _chatModel.Temperature, stream = true, // https://cookbook.openai.com/examples/how_to_stream_completions#4-how-to-get-token-usage-data-for-streamed-chat-completion-response stream_options = new { include_usage = true }, @@ -137,12 +136,11 @@ AsyncPolicyWrap combinedPolicy var client = httpClientFactory.CreateClient("llm_chat_client"); var apiVersion = - Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatApiVersion) - ?? _chatModel.ModelOption.ApiVersion; + Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatApiVersion) ?? _chatModel.ApiVersion; var deploymentId = Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatDeploymentId) - ?? _chatModel.ModelOption.DeploymentId; + ?? _chatModel.DeploymentId; ArgumentException.ThrowIfNullOrEmpty(apiVersion); ArgumentException.ThrowIfNullOrEmpty(deploymentId); @@ -218,8 +216,8 @@ AsyncPolicyWrap combinedPolicy // Capture the `usage` data from the final chunk and after done var inputTokens = completionStreamResponse.Usage?.PromptTokens ?? 0; var outTokens = completionStreamResponse.Usage?.CompletionTokens ?? 0; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -254,18 +252,18 @@ AsyncPolicyWrap combinedPolicy { input = inputs, model = _embeddingModel.Name.Trim(), - dimensions = _embeddingModel.ModelInformation.EmbeddingDimensions, + dimensions = _embeddingModel.EmbeddingDimensions, }; var client = httpClientFactory.CreateClient("llm_embeddings_client"); var apiVersion = Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsApiVersion) - ?? _embeddingModel.ModelOption.ApiVersion; + ?? _embeddingModel.ApiVersion; var deploymentId = Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsDeploymentId) - ?? _embeddingModel.ModelOption.DeploymentId; + ?? _embeddingModel.DeploymentId; ArgumentException.ThrowIfNullOrEmpty(apiVersion); ArgumentException.ThrowIfNullOrEmpty(deploymentId); @@ -298,8 +296,8 @@ AsyncPolicyWrap combinedPolicy var inputTokens = embeddingResponse.Usage?.PromptTokens ?? 0; var outTokens = embeddingResponse.Usage?.CompletionTokens ?? 0; - var inputCostPerToken = _embeddingModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _embeddingModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _embeddingModel.InputCostPerToken; + var outputCostPerToken = _embeddingModel.OutputCostPerToken; ValidateEmbeddingMaxToken(inputTokens + outTokens, path); @@ -344,17 +342,14 @@ private async Task ValidateChatMaxInputToken(ChatCompletionRequest chatCompletio string.Concat(chatCompletionRequest.Items.Select(x => x.Prompt), false) ); - if ( - _chatModel.ModelInformation.MaxInputTokens > 0 - && inputTokenCount > _chatModel.ModelInformation.MaxInputTokens - ) + if (_chatModel.MaxInputTokens > 0 && inputTokenCount > _chatModel.MaxInputTokens) { throw new OpenAIException( new OpenAIError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.ModelInformation.MaxInputTokens.FormatCommas()}.", + $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.MaxInputTokens.FormatCommas()}.", }, HttpStatusCode.BadRequest ); @@ -365,10 +360,8 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n { var inputTokenCount = await tokenizer.GetTokenCount(input); - if ( - _embeddingModel.ModelInformation.MaxInputTokens > 0 - && inputTokenCount > _embeddingModel.ModelInformation.MaxInputTokens - ) + ArgumentNullException.ThrowIfNull(_embeddingModel); + if (_embeddingModel.MaxInputTokens > 0 && inputTokenCount > _embeddingModel.MaxInputTokens) { var moreInfo = path is not null ? $"if file '{ @@ -381,7 +374,7 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"embedding {path} 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_embeddingModel.ModelInformation.MaxInputTokens.FormatCommas()}. {moreInfo}", + $"embedding {path} 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_embeddingModel.MaxInputTokens.FormatCommas()}. {moreInfo}", }, HttpStatusCode.BadRequest ); @@ -390,14 +383,14 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n private void ValidateChatMaxToken(int maxTokenCount) { - if (_chatModel.ModelInformation.MaxTokens > 0 && maxTokenCount > _chatModel.ModelInformation.MaxTokens) + if (_chatModel.MaxTokens > 0 && maxTokenCount > _chatModel.MaxTokens) { throw new OpenAIException( new OpenAIError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.ModelInformation.MaxTokens.FormatCommas()}.", + $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.MaxTokens.FormatCommas()}.", }, HttpStatusCode.BadRequest ); @@ -406,17 +399,15 @@ private void ValidateChatMaxToken(int maxTokenCount) private void ValidateEmbeddingMaxToken(int maxTokenCount, string? path) { - if ( - _embeddingModel.ModelInformation.MaxTokens > 0 - && maxTokenCount > _embeddingModel.ModelInformation.MaxTokens - ) + ArgumentNullException.ThrowIfNull(_embeddingModel); + if (_embeddingModel.MaxTokens > 0 && maxTokenCount > _embeddingModel.MaxTokens) { throw new OpenAIException( new OpenAIError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"embedding {path} 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_embeddingModel.ModelInformation.MaxTokens.FormatCommas()}.", + $"embedding {path} 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_embeddingModel.MaxTokens.FormatCommas()}.", }, HttpStatusCode.BadRequest ); diff --git a/src/Clients/CacheModels.cs b/src/Clients/CacheModels.cs index 3540703..683ed39 100644 --- a/src/Clients/CacheModels.cs +++ b/src/Clients/CacheModels.cs @@ -51,10 +51,11 @@ IOptions llmOptions private bool TryGetModelWithFallback(string modelName, out Model model) { var parts = modelName.Split('/'); + if ( parts.Length == 2 && _models.TryGetValue(parts[1], out var fallbackModel) - && fallbackModel.ModelInformation.AIProvider.ToString() == parts[0] + && fallbackModel.AIProvider.ToString() == parts[0] ) { model = fallbackModel; @@ -104,66 +105,68 @@ private void InitCache() { Name = GetName(originalName), OriginalName = originalName, - ModelOption = new ModelOption - { - CodeAssistType = - overrideModelOption?.CodeAssistType - ?? _llmOptions.CodeAssistType - ?? predefinedModelOption?.CodeAssistType - ?? CodeAssistType.Embedding, - CodeDiffType = - overrideModelOption?.CodeDiffType - ?? _llmOptions.CodeDiffType - ?? predefinedModelOption?.CodeDiffType - ?? CodeDiffType.CodeBlockDiff, - Threshold = - overrideModelOption?.Threshold - ?? _llmOptions.Threshold - ?? predefinedModelOption?.Threshold - ?? 0.4m, - Temperature = - overrideModelOption?.Temperature - ?? _llmOptions.Temperature - ?? predefinedModelOption?.Temperature - ?? 0.2m, - ApiVersion = overrideModelOption?.ApiVersion ?? predefinedModelOption?.ApiVersion, - BaseAddress = overrideModelOption?.BaseAddress ?? predefinedModelOption?.BaseAddress, - DeploymentId = overrideModelOption?.DeploymentId ?? predefinedModelOption?.DeploymentId, - }, - ModelInformation = new ModelInformation - { - AIProvider = overrideModelInformation?.AIProvider ?? predefinedModelInformation.AIProvider, - ModelType = overrideModelInformation?.ModelType ?? predefinedModelInformation.ModelType, - MaxTokens = overrideModelInformation?.MaxTokens ?? predefinedModelInformation.MaxTokens, - MaxInputTokens = - overrideModelInformation?.MaxInputTokens ?? predefinedModelInformation.MaxInputTokens, - MaxOutputTokens = - overrideModelInformation?.MaxOutputTokens ?? predefinedModelInformation.MaxOutputTokens, - InputCostPerToken = - overrideModelInformation?.InputCostPerToken ?? predefinedModelInformation.InputCostPerToken, - OutputCostPerToken = - overrideModelInformation?.OutputCostPerToken ?? predefinedModelInformation.OutputCostPerToken, - OutputVectorSize = - overrideModelInformation?.OutputVectorSize ?? predefinedModelInformation.OutputVectorSize, - Enabled = overrideModelInformation?.Enabled ?? predefinedModelInformation.Enabled, - SupportsFunctionCalling = - overrideModelInformation?.SupportsFunctionCalling - ?? predefinedModelInformation.SupportsFunctionCalling, - SupportsParallelFunctionCalling = - overrideModelInformation?.SupportsParallelFunctionCalling - ?? predefinedModelInformation.SupportsParallelFunctionCalling, - SupportsVision = - overrideModelInformation?.SupportsVision ?? predefinedModelInformation.SupportsVision, - EmbeddingDimensions = - overrideModelInformation?.EmbeddingDimensions ?? predefinedModelInformation.EmbeddingDimensions, - SupportsAudioInput = - overrideModelInformation?.SupportsAudioInput ?? predefinedModelInformation.SupportsAudioInput, - SupportsAudioOutput = - overrideModelInformation?.SupportsAudioOutput ?? predefinedModelInformation.SupportsAudioOutput, - SupportsPromptCaching = - overrideModelInformation?.SupportsPromptCaching - ?? predefinedModelInformation.SupportsPromptCaching, - }, + + // Model Options + CodeAssistType = + overrideModelOption?.CodeAssistType + ?? _llmOptions.CodeAssistType + ?? predefinedModelOption?.CodeAssistType + ?? CodeAssistType.Embedding, + CodeDiffType = + overrideModelOption?.CodeDiffType + ?? _llmOptions.CodeDiffType + ?? predefinedModelOption?.CodeDiffType + ?? CodeDiffType.CodeBlockDiff, + Threshold = + overrideModelOption?.Threshold ?? _llmOptions.Threshold ?? predefinedModelOption?.Threshold ?? 0.4m, + Temperature = + overrideModelOption?.Temperature + ?? _llmOptions.Temperature + ?? predefinedModelOption?.Temperature + ?? 0.2m, + ApiVersion = overrideModelOption?.ApiVersion ?? predefinedModelOption?.ApiVersion, + BaseAddress = overrideModelOption?.BaseAddress ?? predefinedModelOption?.BaseAddress, + DeploymentId = overrideModelOption?.DeploymentId ?? predefinedModelOption?.DeploymentId, + + // Model Information + AIProvider = + overrideModelInformation?.AIProvider + ?? predefinedModelInformation.AIProvider + ?? throw new ArgumentException($"AI Provider not set for model {originalName}."), + ModelType = + overrideModelInformation?.ModelType + ?? predefinedModelInformation.ModelType + ?? throw new ArgumentException($"Model Type not set for model {originalName}."), + MaxTokens = + overrideModelInformation?.MaxTokens + ?? predefinedModelInformation.MaxTokens + ?? throw new ArgumentException($"Max tokens not set for model {originalName}."), + MaxInputTokens = + overrideModelInformation?.MaxInputTokens + ?? predefinedModelInformation.MaxInputTokens + ?? throw new ArgumentException($"Max input tokens not set for model {originalName}."), + MaxOutputTokens = + overrideModelInformation?.MaxOutputTokens + ?? predefinedModelInformation.MaxOutputTokens + ?? throw new ArgumentException($"Max output tokens not set for model {originalName}."), + InputCostPerToken = + overrideModelInformation?.InputCostPerToken ?? predefinedModelInformation.InputCostPerToken, + OutputCostPerToken = + overrideModelInformation?.OutputCostPerToken ?? predefinedModelInformation.OutputCostPerToken, + OutputVectorSize = + overrideModelInformation?.OutputVectorSize ?? predefinedModelInformation.OutputVectorSize, + Enabled = overrideModelInformation?.Enabled ?? predefinedModelInformation.Enabled, + SupportsFunctionCalling = + overrideModelInformation?.SupportsFunctionCalling + ?? predefinedModelInformation.SupportsFunctionCalling, + SupportsParallelFunctionCalling = + overrideModelInformation?.SupportsParallelFunctionCalling + ?? predefinedModelInformation.SupportsParallelFunctionCalling, + SupportsVision = overrideModelInformation?.SupportsVision ?? predefinedModelInformation.SupportsVision, + EmbeddingDimensions = + overrideModelInformation?.EmbeddingDimensions ?? predefinedModelInformation.EmbeddingDimensions, + SupportsPromptCaching = + overrideModelInformation?.SupportsPromptCaching ?? predefinedModelInformation.SupportsPromptCaching, }; _models[originalName] = model; diff --git a/src/Clients/Converters/AIProviderTypeConverter.cs b/src/Clients/Converters/AIProviderTypeConverter.cs deleted file mode 100644 index 6feeda6..0000000 --- a/src/Clients/Converters/AIProviderTypeConverter.cs +++ /dev/null @@ -1,53 +0,0 @@ -using System.Text.Json; -using System.Text.Json.Serialization; -using Clients.Models; -using Humanizer; - -namespace Clients.Converters; - -public class AIProviderTypeConverter : JsonConverter -{ - public override AIProvider Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - // Read the string value from JSON and convert it to snake_case - string aiProviderValue = reader.GetString() ?? string.Empty; - var snakeCaseValue = aiProviderValue.Underscore(); - - // Define snake_case mappings for each enum value - var openAI = AIProvider.Openai.ToString().Underscore(); - var ollama = AIProvider.Ollama.ToString().Underscore(); - var azure = AIProvider.Azure.ToString().Underscore(); - var anthropic = AIProvider.Anthropic.ToString().Underscore(); - - // Convert snake_case string to AIProvider enum value - return snakeCaseValue switch - { - var type when type == openAI => AIProvider.Openai, - var type when type == ollama => AIProvider.Ollama, - var type when type == azure => AIProvider.Azure, - var type when type == anthropic => AIProvider.Anthropic, - _ => throw new JsonException($"Unknown AIProvider type: {aiProviderValue}"), - }; - } - - public override void Write(Utf8JsonWriter writer, AIProvider value, JsonSerializerOptions options) - { - // Define snake_case strings for each enum value - string openAI = AIProvider.Openai.ToString().Underscore(); - string ollama = AIProvider.Ollama.ToString().Underscore(); - string azure = AIProvider.Azure.ToString().Underscore(); - string anthropic = AIProvider.Anthropic.ToString().Underscore(); - - // Convert AIProvider enum to corresponding snake_case string - string aiProviderString = value switch - { - AIProvider.Openai => openAI, - AIProvider.Ollama => ollama, - AIProvider.Azure => azure, - AIProvider.Anthropic => anthropic, - _ => throw new JsonException($"Unknown AIProvider type: {value}"), - }; - - writer.WriteStringValue(aiProviderString); - } -} diff --git a/src/Clients/Converters/CodeAssistTypeConverter.cs b/src/Clients/Converters/CodeAssistTypeConverter.cs deleted file mode 100644 index 139f38b..0000000 --- a/src/Clients/Converters/CodeAssistTypeConverter.cs +++ /dev/null @@ -1,43 +0,0 @@ -using System.Text.Json; -using System.Text.Json.Serialization; -using Clients.Models; -using Humanizer; - -namespace Clients.Converters; - - -// public class CodeAssistTypeConverter : JsonConverter -// { -// public override CodeAssistType Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) -// { -// string modelType = reader.GetString() ?? string.Empty; -// -// var snakeModelType = modelType.Underscore(); -// var embedding = CodeAssistType.Embedding.ToString().Underscore(); -// var summary = CodeAssistType.Summary.ToString().Underscore(); -// -// // Convert snake_case string to CodeDiff enum -// return snakeModelType switch -// { -// var type when type == embedding => CodeAssistType.Embedding, -// var type when type == summary => CodeAssistType.Summary, -// _ => throw new JsonException($"Unknown CodeAssistType: {modelType}"), -// }; -// } -// -// public override void Write(Utf8JsonWriter writer, CodeAssistType value, JsonSerializerOptions options) -// { -// string embedding = CodeAssistType.Embedding.ToString().Underscore(); -// string summary = CodeAssistType.Summary.ToString().Underscore(); -// -// // Convert CodeDiffType enum back to snake_case string -// string modelTypeString = value switch -// { -// CodeAssistType.Embedding => embedding, -// CodeAssistType.Summary => summary, -// _ => throw new JsonException($"Unknown CodeAssistType value: {value}"), -// }; -// -// writer.WriteStringValue(modelTypeString); -// } -// } diff --git a/src/Clients/Converters/CodeDiffTypeConverter.cs b/src/Clients/Converters/CodeDiffTypeConverter.cs deleted file mode 100644 index 6455263..0000000 --- a/src/Clients/Converters/CodeDiffTypeConverter.cs +++ /dev/null @@ -1,49 +0,0 @@ -using System.Text.Json; -using System.Text.Json.Serialization; -using Clients.Models; -using Humanizer; - -namespace Clients.Converters; - -// -// public class CodeDiffTypeConverter : JsonConverter -// { -// public CodeDiffTypeConverter() { } -// -// public override CodeDiffType Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) -// { -// string modelType = reader.GetString() ?? string.Empty; -// -// var snakeModelType = modelType.Underscore(); -// var codeBlock = CodeDiffType.CodeBlockDiff.ToString().Underscore(); -// var unifiedDiff = CodeDiffType.UnifiedDiff.ToString().Underscore(); -// var mergeConflict = CodeDiffType.MergeConflictDiff.ToString().Underscore(); -// -// // Convert snake_case string to CodeDiff enum -// return snakeModelType switch -// { -// var type when type == codeBlock => CodeDiffType.CodeBlockDiff, -// var type when type == unifiedDiff => CodeDiffType.UnifiedDiff, -// var type when type == mergeConflict => CodeDiffType.MergeConflictDiff, -// _ => throw new JsonException($"Unknown CodeDiffType: {modelType}"), -// }; -// } -// -// public override void Write(Utf8JsonWriter writer, CodeDiffType value, JsonSerializerOptions options) -// { -// string codeBlock = CodeDiffType.CodeBlockDiff.ToString().Underscore(); -// string unifiedDiff = CodeDiffType.UnifiedDiff.ToString().Underscore(); -// string mergeConflict = CodeDiffType.MergeConflictDiff.ToString().Underscore(); -// -// // Convert CodeDiffType enum back to snake_case string -// string modelTypeString = value switch -// { -// CodeDiffType.CodeBlockDiff => codeBlock, -// CodeDiffType.UnifiedDiff => unifiedDiff, -// CodeDiffType.MergeConflictDiff => mergeConflict, -// _ => throw new JsonException($"Unknown CodeDiffType value: {value}"), -// }; -// -// writer.WriteStringValue(modelTypeString); -// } -// } diff --git a/src/Clients/Converters/ModelTypeConverter.cs b/src/Clients/Converters/ModelTypeConverter.cs deleted file mode 100644 index 546dcc7..0000000 --- a/src/Clients/Converters/ModelTypeConverter.cs +++ /dev/null @@ -1,46 +0,0 @@ -using System.Globalization; -using System.Text.Json; -using System.Text.Json.Serialization; -using Clients.Models; -using Humanizer; - -namespace Clients.Converters; - -public class ModelTypeConverter : JsonConverter -{ - public override ModelType Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - // Read the string value from JSON and convert it to snake_case - string modelType = reader.GetString() ?? string.Empty; - var snakeCaseType = modelType.Underscore(); - - // Define snake_case mappings for each enum value - var chat = ModelType.Chat.ToString().Underscore(); - var embedding = ModelType.Embedding.ToString().Underscore(); - - // Convert snake_case string to ModelType enum - return snakeCaseType switch - { - var type when type == chat => ModelType.Chat, - var type when type == embedding => ModelType.Embedding, - _ => throw new JsonException($"Unknown model type: {modelType}"), - }; - } - - public override void Write(Utf8JsonWriter writer, ModelType value, JsonSerializerOptions options) - { - // Define snake_case strings for each enum value - string chat = ModelType.Chat.ToString().Underscore(); - string embedding = ModelType.Embedding.ToString().Underscore(); - - // Convert ModelType enum to corresponding snake_case string - string modelTypeString = value switch - { - ModelType.Chat => chat, - ModelType.Embedding => embedding, - _ => throw new JsonException($"Unknown model type: {value}"), - }; - - writer.WriteStringValue(modelTypeString); - } -} diff --git a/src/Clients/Converters/RoleTypeConverter.cs b/src/Clients/Converters/RoleTypeConverter.cs index b0a218d..6692094 100644 --- a/src/Clients/Converters/RoleTypeConverter.cs +++ b/src/Clients/Converters/RoleTypeConverter.cs @@ -5,6 +5,7 @@ namespace Clients.Converters; +// we use RoleTypeConverter when Role type particpate in a model and we use the model inside of serialization mechanism nor binding configuration because they are not serialization based public class RoleTypeConverter : JsonConverter { public override RoleType Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) diff --git a/src/Clients/LLMs/models_information_list.json b/src/Clients/LLMs/models_information_list.json index eafa543..a8cd5ef 100644 --- a/src/Clients/LLMs/models_information_list.json +++ b/src/Clients/LLMs/models_information_list.json @@ -14,6 +14,7 @@ "text-embedding-3-large": { "MaxTokens": 8191, "MaxInputTokens": 8191, + "MaxOutputTokens": 8191, "OutputVectorSize": 3072, "InputCostPerToken": 0.00000013, "OutputCostPerToken": 0.000000, @@ -24,6 +25,7 @@ "text-embedding-3-small": { "MaxTokens": 8191, "MaxInputTokens": 8191, + "MaxOutputTokens": 8191, "OutputVectorSize": 1536, "InputCostPerToken": 0.00000002, "OutputCostPerToken": 0.000000, @@ -33,8 +35,8 @@ }, "azure/gpt-4o": { "MaxTokens": 8192, - "MaxInputTokens": 128000, "MaxOutputTokens": 16384, + "MaxInputTokens": 128000, "InputCostPerToken": 0.000005, "OutputCostPerToken": 0.000015, "AIProvider": "Azure", @@ -45,6 +47,7 @@ "azure/text-embedding-3-large": { "MaxTokens": 8191, "MaxInputTokens": 8191, + "MaxOutputTokens": 8191, "InputCostPerToken": 0.00000013, "OutputCostPerToken": 0.000000, "AIProvider": "Azure", @@ -54,6 +57,7 @@ "azure/text-embedding-3-small": { "MaxTokens": 8191, "MaxInputTokens": 8191, + "MaxOutputTokens": 8191, "InputCostPerToken": 0.00000002, "OutputCostPerToken": 0.000000, "AIProvider": "Azure", @@ -159,6 +163,7 @@ "ollama/nomic-embed-text": { "MaxTokens": 8192, "MaxInputTokens": 8192, + "MaxOutputTokens": 8192, "InputCostPerToken": 0.0, "OutputCostPerToken": 0.0, "AIProvider": "Ollama", @@ -168,6 +173,7 @@ "ollama/mxbai-embed-large": { "MaxTokens": 8192, "MaxInputTokens": 8192, + "MaxOutputTokens": 8192, "InputCostPerToken": 0.0, "OutputCostPerToken": 0.0, "AIProvider": "Ollama", diff --git a/src/Clients/Models/Model.cs b/src/Clients/Models/Model.cs index f9a2329..5adcaa0 100644 --- a/src/Clients/Models/Model.cs +++ b/src/Clients/Models/Model.cs @@ -11,6 +11,26 @@ public class Model /// Model name with an AI provider type with '/' prefix /// public string OriginalName { get; set; } = default!; - public ModelInformation ModelInformation { get; set; } = default!; - public ModelOption ModelOption { get; set; } = default!; + public AIProvider AIProvider { get; set; } + public ModelType ModelType { get; set; } + public CodeDiffType CodeDiffType { get; set; } + public CodeAssistType CodeAssistType { get; set; } + public decimal Threshold { get; set; } + public decimal Temperature { get; set; } + public string? BaseAddress { get; set; } + public string? ApiVersion { get; set; } + public string? DeploymentId { get; set; } + public string? ApiKey { get; set; } + public int MaxTokens { get; set; } + public int MaxInputTokens { get; set; } + public int MaxOutputTokens { get; set; } + public decimal InputCostPerToken { get; set; } + public decimal OutputCostPerToken { get; set; } + public int? OutputVectorSize { get; set; } + public bool SupportsFunctionCalling { get; set; } + public bool SupportsParallelFunctionCalling { get; set; } + public bool SupportsVision { get; set; } + public int? EmbeddingDimensions { get; set; } + public bool SupportsPromptCaching { get; set; } + public bool Enabled { get; set; } = true; } diff --git a/src/Clients/Models/ModelInformation.cs b/src/Clients/Models/ModelInformation.cs index 349b9f9..fa57e91 100644 --- a/src/Clients/Models/ModelInformation.cs +++ b/src/Clients/Models/ModelInformation.cs @@ -6,11 +6,11 @@ namespace Clients.Models; public class ModelInformation { - public AIProvider AIProvider { get; set; } - public ModelType ModelType { get; set; } - public int MaxTokens { get; set; } - public int MaxInputTokens { get; set; } - public int MaxOutputTokens { get; set; } + public AIProvider? AIProvider { get; set; } + public ModelType? ModelType { get; set; } + public int? MaxTokens { get; set; } + public int? MaxInputTokens { get; set; } + public int? MaxOutputTokens { get; set; } public decimal InputCostPerToken { get; set; } public decimal OutputCostPerToken { get; set; } public int? OutputVectorSize { get; set; } @@ -18,9 +18,6 @@ public class ModelInformation public bool SupportsParallelFunctionCalling { get; set; } public bool SupportsVision { get; set; } public int? EmbeddingDimensions { get; set; } - public bool SupportsAudioInput { get; set; } - public bool SupportsAudioOutput { get; set; } - public bool SupportsPromptCaching { get; set; } public bool Enabled { get; set; } = true; } diff --git a/src/Clients/Models/ModelOption.cs b/src/Clients/Models/ModelOption.cs index e38786b..7867299 100644 --- a/src/Clients/Models/ModelOption.cs +++ b/src/Clients/Models/ModelOption.cs @@ -2,10 +2,10 @@ namespace Clients.Models; public class ModelOption { - public CodeDiffType CodeDiffType { get; set; } - public CodeAssistType CodeAssistType { get; set; } - public decimal Threshold { get; set; } - public decimal Temperature { get; set; } + public CodeDiffType? CodeDiffType { get; set; } + public CodeAssistType? CodeAssistType { get; set; } + public decimal? Threshold { get; set; } + public decimal? Temperature { get; set; } public string? BaseAddress { get; set; } public string? ApiVersion { get; set; } public string? DeploymentId { get; set; } diff --git a/src/Clients/OllamaClient.cs b/src/Clients/OllamaClient.cs index ae5c354..da71b55 100644 --- a/src/Clients/OllamaClient.cs +++ b/src/Clients/OllamaClient.cs @@ -53,7 +53,7 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - options = new { temperature = _chatModel.ModelOption.Temperature }, + options = new { temperature = _chatModel.Temperature }, keep_alive = "30m", stream = false, }; @@ -79,8 +79,8 @@ AsyncPolicyWrap combinedPolicy var inputTokens = completionResponse.PromptEvalCount; var outTokens = completionResponse.EvalCount; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -109,7 +109,7 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - options = new { temperature = _chatModel.ModelOption.Temperature }, + options = new { temperature = _chatModel.Temperature }, stream = true, keep_alive = "30m", }; @@ -158,8 +158,8 @@ AsyncPolicyWrap combinedPolicy // https://github.com/ollama/ollama/blob/main/docs/api.md#response-9 var inputTokens = completionStreamResponse.PromptEvalCount; var outTokens = completionStreamResponse.EvalCount; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -199,7 +199,7 @@ AsyncPolicyWrap combinedPolicy { input = inputs, model = _embeddingModel.Name, - options = new { temperature = _embeddingModel.ModelOption.Temperature }, + options = new { temperature = _embeddingModel.Temperature }, keep_alive = "30m", }; @@ -225,8 +225,8 @@ AsyncPolicyWrap combinedPolicy var inputTokens = embeddingResponse.PromptEvalCount; var outTokens = embeddingResponse.EvalCount; - var inputCostPerToken = _embeddingModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _embeddingModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _embeddingModel.InputCostPerToken; + var outputCostPerToken = _embeddingModel.OutputCostPerToken; ValidateEmbeddingMaxToken(inputTokens + outTokens, path); @@ -262,13 +262,10 @@ private async Task ValidateChatMaxInputToken(ChatCompletionRequest chatCompletio string.Concat(chatCompletionRequest.Items.Select(x => x.Prompt)) ); - if ( - _chatModel.ModelInformation.MaxInputTokens > 0 - && inputTokenCount > _chatModel.ModelInformation.MaxInputTokens - ) + if (_chatModel.MaxInputTokens > 0 && inputTokenCount > _chatModel.MaxInputTokens) { throw new OllamaException( - $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.ModelInformation.MaxInputTokens.FormatCommas()}.", + $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.MaxInputTokens.FormatCommas()}.", HttpStatusCode.BadRequest ); } @@ -278,10 +275,8 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n { var inputTokenCount = await tokenizer.GetTokenCount(input); - if ( - _embeddingModel.ModelInformation.MaxInputTokens > 0 - && inputTokenCount > _embeddingModel.ModelInformation.MaxInputTokens - ) + ArgumentNullException.ThrowIfNull(_embeddingModel); + if (_embeddingModel.MaxInputTokens > 0 && inputTokenCount > _embeddingModel.MaxInputTokens) { var moreInfo = path is not null ? $"if file '{ @@ -290,7 +285,7 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n : ""; throw new OllamaException( - $"embedding {path} 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_embeddingModel.ModelInformation.MaxInputTokens.FormatCommas()}. {moreInfo}", + $"embedding {path} 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_embeddingModel.MaxInputTokens.FormatCommas()}. {moreInfo}", HttpStatusCode.BadRequest ); } @@ -298,10 +293,10 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n private void ValidateChatMaxToken(int maxTokenCount) { - if (_chatModel.ModelInformation.MaxTokens > 0 && maxTokenCount > _chatModel.ModelInformation.MaxTokens) + if (_chatModel.MaxTokens > 0 && maxTokenCount > _chatModel.MaxTokens) { throw new OllamaException( - $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.ModelInformation.MaxTokens.FormatCommas()}.", + $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.MaxTokens.FormatCommas()}.", HttpStatusCode.BadRequest ); } @@ -309,13 +304,11 @@ private void ValidateChatMaxToken(int maxTokenCount) private void ValidateEmbeddingMaxToken(int maxTokenCount, string? path) { - if ( - _embeddingModel.ModelInformation.MaxTokens > 0 - && maxTokenCount > _embeddingModel.ModelInformation.MaxTokens - ) + ArgumentNullException.ThrowIfNull(_embeddingModel); + if (_embeddingModel.MaxTokens > 0 && maxTokenCount > _embeddingModel.MaxTokens) { throw new OllamaException( - $"embedding {path} 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_embeddingModel.ModelInformation.MaxTokens.FormatCommas()}.", + $"embedding {path} 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_embeddingModel.MaxTokens.FormatCommas()}.", HttpStatusCode.BadRequest ); } diff --git a/src/Clients/OpenAiClient.cs b/src/Clients/OpenAiClient.cs index 710bf1a..6313887 100644 --- a/src/Clients/OpenAiClient.cs +++ b/src/Clients/OpenAiClient.cs @@ -52,7 +52,7 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - temperature = _chatModel.ModelOption.Temperature, + temperature = _chatModel.Temperature, }; var client = httpClientFactory.CreateClient("llm_chat_client"); @@ -83,8 +83,8 @@ AsyncPolicyWrap combinedPolicy var inputTokens = completionResponse.Usage?.PromptTokens ?? 0; var outTokens = completionResponse.Usage?.CompletionTokens ?? 0; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -110,7 +110,7 @@ AsyncPolicyWrap combinedPolicy role = x.Role.Humanize(LetterCasing.LowerCase), content = x.Prompt, }), - temperature = _chatModel.ModelOption.Temperature, + temperature = _chatModel.Temperature, stream = true, // https://cookbook.openai.com/examples/how_to_stream_completions#4-how-to-get-token-usage-data-for-streamed-chat-completion-response stream_options = new { include_usage = true }, @@ -183,8 +183,8 @@ AsyncPolicyWrap combinedPolicy // Capture the `usage` data from the final chunk and after done var inputTokens = completionStreamResponse.Usage?.PromptTokens ?? 0; var outTokens = completionStreamResponse.Usage?.CompletionTokens ?? 0; - var inputCostPerToken = _chatModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _chatModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _chatModel.InputCostPerToken; + var outputCostPerToken = _chatModel.OutputCostPerToken; ValidateChatMaxToken(inputTokens + outTokens); @@ -220,7 +220,7 @@ AsyncPolicyWrap combinedPolicy { input = inputs, model = _embeddingModel.Name.Trim(), - dimensions = _embeddingModel.ModelInformation.EmbeddingDimensions, + dimensions = _embeddingModel.EmbeddingDimensions, }; var client = httpClientFactory.CreateClient("llm_embeddings_client"); @@ -247,8 +247,8 @@ AsyncPolicyWrap combinedPolicy var inputTokens = embeddingResponse.Usage?.PromptTokens ?? 0; var outTokens = embeddingResponse.Usage?.CompletionTokens ?? 0; - var inputCostPerToken = _embeddingModel.ModelInformation.InputCostPerToken; - var outputCostPerToken = _embeddingModel.ModelInformation.OutputCostPerToken; + var inputCostPerToken = _embeddingModel.InputCostPerToken; + var outputCostPerToken = _embeddingModel.OutputCostPerToken; ValidateEmbeddingMaxToken(inputTokens + outTokens, path); @@ -293,17 +293,14 @@ private async Task ValidateChatMaxInputToken(ChatCompletionRequest chatCompletio string.Concat(chatCompletionRequest.Items.Select(x => x.Prompt)) ); - if ( - _chatModel.ModelInformation.MaxInputTokens > 0 - && inputTokenCount > _chatModel.ModelInformation.MaxInputTokens - ) + if (_chatModel.MaxInputTokens > 0 && inputTokenCount > _chatModel.MaxInputTokens) { throw new OpenAIException( new OpenAIError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.ModelInformation.MaxInputTokens.FormatCommas()}", + $"current chat 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_chatModel.MaxInputTokens.FormatCommas()}", }, HttpStatusCode.BadRequest ); @@ -314,10 +311,8 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n { var inputTokenCount = await tokenizer.GetTokenCount(input); - if ( - _embeddingModel.ModelInformation.MaxInputTokens > 0 - && inputTokenCount > _embeddingModel.ModelInformation.MaxInputTokens - ) + ArgumentNullException.ThrowIfNull(_embeddingModel); + if (_embeddingModel.MaxInputTokens > 0 && inputTokenCount > _embeddingModel.MaxInputTokens) { var moreInfo = path is not null ? $"if file '{ @@ -330,7 +325,7 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"embedding {path} 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_embeddingModel.ModelInformation.MaxInputTokens.FormatCommas()}. {moreInfo}", + $"embedding {path} 'max_input_token' count: {inputTokenCount.FormatCommas()} is larger than configured 'max_input_token' count: {_embeddingModel.MaxInputTokens.FormatCommas()}. {moreInfo}", }, HttpStatusCode.BadRequest ); @@ -339,14 +334,14 @@ private async Task ValidateEmbeddingMaxInputToken(string input, string? path = n private void ValidateChatMaxToken(int maxTokenCount) { - if (_chatModel.ModelInformation.MaxTokens > 0 && maxTokenCount > _chatModel.ModelInformation.MaxTokens) + if (_chatModel.MaxTokens > 0 && maxTokenCount > _chatModel.MaxTokens) { throw new OpenAIException( new OpenAIError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.ModelInformation.MaxTokens.FormatCommas()}.", + $"current chat 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.MaxTokens.FormatCommas()}.", }, HttpStatusCode.BadRequest ); @@ -355,17 +350,15 @@ private void ValidateChatMaxToken(int maxTokenCount) private void ValidateEmbeddingMaxToken(int maxTokenCount, string? path) { - if ( - _embeddingModel.ModelInformation.MaxTokens > 0 - && maxTokenCount > _embeddingModel.ModelInformation.MaxTokens - ) + ArgumentNullException.ThrowIfNull(_embeddingModel); + if (_embeddingModel.MaxTokens > 0 && maxTokenCount > _embeddingModel.MaxTokens) { throw new OpenAIException( new OpenAIError { StatusCode = (int)HttpStatusCode.BadRequest, Message = - $"embedding {path} 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.ModelInformation.MaxTokens.FormatCommas()}.", + $"embedding {path} 'max_token' count: {maxTokenCount.FormatCommas()} is larger than configured 'max_token' count: {_chatModel.MaxTokens.FormatCommas()}.", }, HttpStatusCode.BadRequest );