diff --git a/samples/ChatGpt.TelegramBotExample/Helpers.cs b/samples/ChatGpt.TelegramBotExample/Helpers.cs index b8f2d21..67c66c3 100644 --- a/samples/ChatGpt.TelegramBotExample/Helpers.cs +++ b/samples/ChatGpt.TelegramBotExample/Helpers.cs @@ -1,4 +1,6 @@ -namespace ChatGpt.TelegramBotExample; +using OpenAI.ChatGpt; + +namespace ChatGpt.TelegramBotExample; public static class Helpers { diff --git a/src/Directory.Build.props b/src/Directory.Build.props index aeac60a..9728ba1 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -1,6 +1,6 @@ - 4.0.0 + 4.0.0-alpha net6.0;net7.0;net8.0 enable enable diff --git a/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs b/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs index 1ae1d07..25d58b3 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs @@ -45,7 +45,8 @@ private static void ThrowUnkownProviderException(string provider) } /// - public Task GetChatCompletions(UserOrSystemMessage dialog, + public Task GetChatCompletions( + UserOrSystemMessage dialog, int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, @@ -57,7 +58,8 @@ public Task GetChatCompletions(UserOrSystemMessage dialog, } /// - public Task GetChatCompletions(IEnumerable messages, + public Task GetChatCompletions( + IEnumerable messages, int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, @@ -69,7 +71,8 @@ public Task GetChatCompletions(IEnumerable messag } /// - public Task GetChatCompletionsRaw(IEnumerable messages, + public Task GetChatCompletionsRaw( + IEnumerable messages, int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, @@ -81,7 +84,8 @@ public Task GetChatCompletionsRaw(IEnumerable - public IAsyncEnumerable StreamChatCompletions(IEnumerable messages, + public IAsyncEnumerable StreamChatCompletions( + IEnumerable messages, int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, @@ -93,27 +97,32 @@ public IAsyncEnumerable StreamChatCompletions(IEnumerable - public IAsyncEnumerable StreamChatCompletions(UserOrSystemMessage messages, + public IAsyncEnumerable StreamChatCompletions( + UserOrSystemMessage messages, int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, Action? requestModifier = null, CancellationToken cancellationToken = default) { - return _client.StreamChatCompletions(messages, maxTokens, model, temperature, user, jsonMode, seed, - requestModifier, cancellationToken); + return _client.StreamChatCompletions( + messages, maxTokens, model, temperature, user, jsonMode, seed, requestModifier, cancellationToken); } /// - public IAsyncEnumerable StreamChatCompletions(ChatCompletionRequest request, + public IAsyncEnumerable StreamChatCompletions( + ChatCompletionRequest request, CancellationToken cancellationToken = default) { return _client.StreamChatCompletions(request, cancellationToken); } /// - public IAsyncEnumerable StreamChatCompletionsRaw(ChatCompletionRequest request, + public IAsyncEnumerable StreamChatCompletionsRaw( + ChatCompletionRequest request, CancellationToken cancellationToken = default) { return _client.StreamChatCompletionsRaw(request, cancellationToken); } + + internal IAiClient GetInnerClient() => _client; } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt.AspNetCore/AiClientStartupValidationBackgroundService.cs b/src/OpenAI.ChatGpt.AspNetCore/AiClientStartupValidationBackgroundService.cs index 0b0bfd8..8ed0f4c 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/AiClientStartupValidationBackgroundService.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/AiClientStartupValidationBackgroundService.cs @@ -4,11 +4,9 @@ namespace OpenAI.ChatGpt.AspNetCore; internal class AiClientStartupValidationBackgroundService : BackgroundService { - private readonly AiClientFromConfiguration _aiClient; - public AiClientStartupValidationBackgroundService(AiClientFromConfiguration aiClient) + public AiClientStartupValidationBackgroundService(AiClientFromConfiguration _) { - _aiClient = aiClient ?? throw new ArgumentNullException(nameof(aiClient)); } protected override Task ExecuteAsync(CancellationToken stoppingToken) => Task.CompletedTask; diff --git a/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs b/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs index 4c0d3d5..0eff924 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs @@ -1,5 +1,5 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; namespace OpenAI.ChatGpt.AspNetCore.Extensions; @@ -14,12 +14,16 @@ public static class ServiceCollectionExtensions public static IServiceCollection AddChatGptInMemoryIntegration( this IServiceCollection services, + IConfiguration configuration, bool injectInMemoryChatService = true, - string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault, + string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault, + string azureOpenAiCredentialsConfigSectionPath = AzureOpenAiCredentialsConfigSectionPathDefault, + string openRouterCredentialsConfigSectionPath = OpenRouterCredentialsConfigSectionPathDefault, bool validateAiClientProviderOnStart = true) { ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configuration); if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath)) { throw new ArgumentException("Value cannot be null or whitespace.", @@ -32,6 +36,17 @@ public static IServiceCollection AddChatGptInMemoryIntegration( nameof(completionsConfigSectionPath)); } + if (string.IsNullOrWhiteSpace(azureOpenAiCredentialsConfigSectionPath)) + { + throw new ArgumentException("Value cannot be null or whitespace.", + nameof(azureOpenAiCredentialsConfigSectionPath)); + } + if (string.IsNullOrWhiteSpace(openRouterCredentialsConfigSectionPath)) + { + throw new ArgumentException("Value cannot be null or whitespace.", + nameof(openRouterCredentialsConfigSectionPath)); + } + services.AddSingleton(); if (injectInMemoryChatService) { @@ -39,8 +54,11 @@ public static IServiceCollection AddChatGptInMemoryIntegration( } return services.AddChatGptIntegrationCore( - credentialsConfigSectionPath: credentialsConfigSectionPath, + configuration, completionsConfigSectionPath: completionsConfigSectionPath, + credentialsConfigSectionPath: credentialsConfigSectionPath, + azureOpenAiCredentialsConfigSectionPath, + openRouterCredentialsConfigSectionPath, validateAiClientProviderOnStart: validateAiClientProviderOnStart ); } @@ -69,14 +87,16 @@ private static ChatService CreateChatService(IServiceProvider provider) } public static IServiceCollection AddChatGptIntegrationCore(this IServiceCollection services, - string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault, + IConfiguration configuration, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault, + string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault, string azureOpenAiCredentialsConfigSectionPath = AzureOpenAiCredentialsConfigSectionPathDefault, string openRouterCredentialsConfigSectionPath = OpenRouterCredentialsConfigSectionPathDefault, ServiceLifetime gptFactoryLifetime = ServiceLifetime.Scoped, bool validateAiClientProviderOnStart = true) { ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configuration); if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath)) { throw new ArgumentException("Value cannot be null or whitespace.", @@ -89,12 +109,55 @@ public static IServiceCollection AddChatGptIntegrationCore(this IServiceCollecti nameof(completionsConfigSectionPath)); } + if (string.IsNullOrWhiteSpace(azureOpenAiCredentialsConfigSectionPath)) + { + throw new ArgumentException("Value cannot be null or whitespace.", + nameof(azureOpenAiCredentialsConfigSectionPath)); + } + if (string.IsNullOrWhiteSpace(openRouterCredentialsConfigSectionPath)) + { + throw new ArgumentException("Value cannot be null or whitespace.", + nameof(openRouterCredentialsConfigSectionPath)); + } + services.AddOptions() + .BindConfiguration(completionsConfigSectionPath) + .Configure(_ => { }) //make optional + .ValidateDataAnnotations() + .ValidateOnStart(); + + services.AddSingleton(); + services.Add(new ServiceDescriptor(typeof(ChatGPTFactory), typeof(ChatGPTFactory), gptFactoryLifetime)); + + services.AddAiClient(configuration, credentialsConfigSectionPath, azureOpenAiCredentialsConfigSectionPath, openRouterCredentialsConfigSectionPath, validateAiClientProviderOnStart); + + return services; + } + + internal static void AddAiClient( + this IServiceCollection services, + IConfiguration configuration, + string credentialsConfigSectionPath, + string azureOpenAiCredentialsConfigSectionPath, + string openRouterCredentialsConfigSectionPath, + bool validateAiClientProviderOnStart) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configuration); + if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(credentialsConfigSectionPath)); + if (string.IsNullOrWhiteSpace(azureOpenAiCredentialsConfigSectionPath)) + throw new ArgumentException("Value cannot be null or whitespace.", + nameof(azureOpenAiCredentialsConfigSectionPath)); + if (string.IsNullOrWhiteSpace(openRouterCredentialsConfigSectionPath)) + throw new ArgumentException("Value cannot be null or whitespace.", + nameof(openRouterCredentialsConfigSectionPath)); + services.AddOptions() .BindConfiguration(credentialsConfigSectionPath) .Configure(_ => { }) //make optional .ValidateDataAnnotations() - .ValidateOnStart(); + .ValidateOnStart(); services.AddOptions() .BindConfiguration(azureOpenAiCredentialsConfigSectionPath) .Configure(_ => { }) //make optional @@ -105,22 +168,15 @@ public static IServiceCollection AddChatGptIntegrationCore(this IServiceCollecti .Configure(_ => { }) //make optional .ValidateDataAnnotations() .ValidateOnStart(); - - services.AddOptions() - .BindConfiguration(completionsConfigSectionPath) - .Configure(_ => { }) //make optional - .ValidateDataAnnotations() - .ValidateOnStart(); - - services.AddSingleton(); - services.Add(new ServiceDescriptor(typeof(ChatGPTFactory), typeof(ChatGPTFactory), gptFactoryLifetime)); services.AddHttpClient(nameof(OpenAiClient)); services.AddHttpClient(nameof(AzureOpenAiClient)); services.AddHttpClient(nameof(OpenRouterClient)); services.AddSingleton(); + services.AddSingleton(); #pragma warning disable CS0618 // Type or member is obsolete + // will be removed in 5.0 services.AddSingleton(); #pragma warning restore CS0618 // Type or member is obsolete @@ -128,7 +184,5 @@ public static IServiceCollection AddChatGptIntegrationCore(this IServiceCollecti { services.AddHostedService(); } - - return services; } } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs b/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs index 781c4cd..8efe406 100644 --- a/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs +++ b/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs @@ -1,4 +1,5 @@ using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using static OpenAI.ChatGpt.AspNetCore.Extensions.ServiceCollectionExtensions; @@ -9,10 +10,12 @@ public static class ServiceCollectionExtensions /// /// Adds the implementation using Entity Framework Core. /// - public static IServiceCollection AddChatGptEntityFrameworkIntegration(this IServiceCollection services, + public static IServiceCollection AddChatGptEntityFrameworkIntegration( + this IServiceCollection services, Action optionsAction, - string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault, + IConfiguration configuration, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault, + string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault, string azureOpenAiCredentialsConfigSectionPath = AzureOpenAiCredentialsConfigSectionPathDefault, string openRouterCredentialsConfigSectionPath = OpenRouterCredentialsConfigSectionPathDefault, ServiceLifetime serviceLifetime = ServiceLifetime.Scoped, @@ -20,6 +23,7 @@ public static IServiceCollection AddChatGptEntityFrameworkIntegration(this IServ { ArgumentNullException.ThrowIfNull(services); ArgumentNullException.ThrowIfNull(optionsAction); + ArgumentNullException.ThrowIfNull(configuration); if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath)) { throw new ArgumentException("Value cannot be null or whitespace.", @@ -48,8 +52,10 @@ public static IServiceCollection AddChatGptEntityFrameworkIntegration(this IServ throw new ArgumentOutOfRangeException(nameof(serviceLifetime), serviceLifetime, null); } - return services.AddChatGptIntegrationCore(credentialsConfigSectionPath: credentialsConfigSectionPath, + return services.AddChatGptIntegrationCore( + configuration, completionsConfigSectionPath: completionsConfigSectionPath, + credentialsConfigSectionPath: credentialsConfigSectionPath, azureOpenAiCredentialsConfigSectionPath: azureOpenAiCredentialsConfigSectionPath, openRouterCredentialsConfigSectionPath: openRouterCredentialsConfigSectionPath, serviceLifetime, diff --git a/src/OpenAI.ChatGpt/AzureOpenAiClient.cs b/src/OpenAI.ChatGpt/AzureOpenAiClient.cs index ecf5503..d9b25b1 100644 --- a/src/OpenAI.ChatGpt/AzureOpenAiClient.cs +++ b/src/OpenAI.ChatGpt/AzureOpenAiClient.cs @@ -48,12 +48,14 @@ internal static void SetupHttpClient(HttpClient httpClient, string endpointUrl, httpClient.DefaultRequestHeaders.Add("api-key", azureKey); } - public AzureOpenAiClient(HttpClient httpClient, string apiVersion) : base(httpClient) + public AzureOpenAiClient(HttpClient httpClient, string apiVersion) + : base(httpClient, validateAuthorizationHeader: false, validateBaseAddress: true) { _apiVersion = apiVersion ?? throw new ArgumentNullException(nameof(apiVersion)); } - public AzureOpenAiClient(HttpClient httpClient) : base(httpClient) + public AzureOpenAiClient(HttpClient httpClient) + : base(httpClient, validateAuthorizationHeader: false, validateBaseAddress: true) { _apiVersion = DefaultApiVersion; } diff --git a/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs b/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs index 7eca7e2..b9a7bdb 100644 --- a/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs +++ b/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs @@ -14,7 +14,6 @@ public class ChatGPTConfig }; private int? _maxTokens; - private string? _model; private float? _temperature; /// @@ -75,7 +74,7 @@ public int? MaxTokens { if (value is { } maxTokens) { - if (_model is { } model) + if (Model is { } model) { ChatCompletionModels.EnsureMaxTokensIsSupported(model, maxTokens); } @@ -93,11 +92,7 @@ public int? MaxTokens /// ID of the model to use. One of: /// Maps to: /// - public string? Model - { - get => _model; - set => _model = value; - } + public string? Model { get; set; } /// /// What sampling temperature to use, between 0 and 2. @@ -161,7 +156,7 @@ internal void ModifyRequest(ChatCompletionRequest request) (not null, null) => baseConfig, _ => new ChatGPTConfig() { - _model = config._model ?? baseConfig._model, + Model = config.Model ?? baseConfig.Model, _maxTokens = config._maxTokens ?? baseConfig._maxTokens, _temperature = config._temperature ?? baseConfig._temperature, PassUserIdToOpenAiRequests = config.PassUserIdToOpenAiRequests ?? diff --git a/src/OpenAI.ChatGpt/OpenAiClient.cs b/src/OpenAI.ChatGpt/OpenAiClient.cs index 7800363..e98513d 100644 --- a/src/OpenAI.ChatGpt/OpenAiClient.cs +++ b/src/OpenAI.ChatGpt/OpenAiClient.cs @@ -64,10 +64,18 @@ public OpenAiClient(string apiKey, string? host = DefaultHost) /// Indicates that OpenAI API key is not set in /// .. header. /// - public OpenAiClient(HttpClient httpClient) + public OpenAiClient(HttpClient httpClient) + : this(httpClient, true, true) + { + } + + internal OpenAiClient( + HttpClient httpClient, + bool validateAuthorizationHeader, + bool validateBaseAddress) { HttpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); - ValidateHttpClient(httpClient); + ValidateHttpClient(httpClient, validateAuthorizationHeader, validateBaseAddress); IsHttpClientInjected = true; } @@ -105,10 +113,13 @@ private static Uri ValidateHost(string? host) return uri; } - private static void ValidateHttpClient(HttpClient httpClient) + private static void ValidateHttpClient( + HttpClient httpClient, + bool validateAuthorizationHeader, + bool validateBaseAddress) { ArgumentNullException.ThrowIfNull(httpClient); - if (httpClient.DefaultRequestHeaders.Authorization is null) + if (validateAuthorizationHeader && httpClient.DefaultRequestHeaders.Authorization is null) { throw new ArgumentException( "HttpClient must have an Authorization header set. " + @@ -117,21 +128,24 @@ private static void ValidateHttpClient(HttpClient httpClient) ); } - if (httpClient.BaseAddress is null) - { - throw new ArgumentException( - "HttpClient must have a BaseAddress set." + - "It should be set to OpenAI's API endpoint.", - nameof(httpClient) - ); - } - if(!httpClient.BaseAddress.AbsoluteUri.EndsWith("/")) + if (validateBaseAddress) { - throw new ArgumentException( - "HttpClient's BaseAddress must end with a slash." + - " It should be set to OpenAI's API endpoint.", - nameof(httpClient) - ); + if (httpClient.BaseAddress is null) + { + throw new ArgumentException( + "HttpClient must have a BaseAddress set." + + "It should be set to AI service API endpoint.", + nameof(httpClient) + ); + } + + if (!httpClient.BaseAddress.AbsoluteUri.EndsWith("/")) + { + throw new ArgumentException( + "HttpClient's BaseAddress must end with a slash.", + nameof(httpClient) + ); + } } } diff --git a/src/OpenAI.ChatGpt/OpenRouterClient.cs b/src/OpenAI.ChatGpt/OpenRouterClient.cs index 2baa68b..58c4068 100644 --- a/src/OpenAI.ChatGpt/OpenRouterClient.cs +++ b/src/OpenAI.ChatGpt/OpenRouterClient.cs @@ -22,7 +22,8 @@ public OpenRouterClient(string apiKey, string? host = DefaultHost) { } - public OpenRouterClient(HttpClient httpClient) : base(httpClient) + public OpenRouterClient(HttpClient httpClient) + : base(httpClient, validateAuthorizationHeader: true, validateBaseAddress: true) { } } \ No newline at end of file diff --git a/tests/OpenAI.ChatGpt.IntegrationTests/ChatGptEntityFrameworkIntegrationTests.cs b/tests/OpenAI.ChatGpt.IntegrationTests/ChatGptEntityFrameworkIntegrationTests.cs deleted file mode 100644 index 0ae3854..0000000 --- a/tests/OpenAI.ChatGpt.IntegrationTests/ChatGptEntityFrameworkIntegrationTests.cs +++ /dev/null @@ -1,52 +0,0 @@ -using Microsoft.EntityFrameworkCore; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.DependencyInjection; -using static OpenAI.ChatGpt.AspNetCore.Extensions.ServiceCollectionExtensions; - -namespace OpenAI.ChatGpt.IntegrationTests; - -[Collection("OpenAiTestCollection")] //to prevent parallel execution -public class ChatGptEntityFrameworkIntegrationTests -{ - [Fact] - public async void AddChatGptEntityFrameworkIntegration_works() - { - // Arrange - var services = CreateServiceCollection(); - - // Act - services.AddChatGptEntityFrameworkIntegration( - options => options.UseSqlite(connectionString: "Data Source=chatgpt.db")); - - // Assert - await using var provider = services.BuildServiceProvider(); - - var storage = provider.GetRequiredService(); - storage.Should().BeOfType(); - - _ = provider.GetRequiredService(); - - var dbContext = provider.GetRequiredService(); - await dbContext.Database.EnsureDeletedAsync(); - } - - private static ServiceCollection CreateServiceCollection() - { - var services = new ServiceCollection(); - services.AddSingleton(CreateConfiguration()); - - return services; - - IConfiguration CreateConfiguration() - { - var builder = new ConfigurationBuilder() - .AddInMemoryCollection(new Dictionary() - { - { $"{OpenAiCredentialsConfigSectionPathDefault}:{nameof(OpenAICredentials.ApiKey)}", "test-api-key" }, - { ChatGPTConfigSectionPathDefault, ""}, - }); - - return builder.Build(); - } - } -} \ No newline at end of file diff --git a/tests/OpenAI.ChatGpt.UnitTests/ChatGptEntityFrameworkIntegrationTests.cs b/tests/OpenAI.ChatGpt.UnitTests/ChatGptEntityFrameworkIntegrationTests.cs new file mode 100644 index 0000000..97b0d68 --- /dev/null +++ b/tests/OpenAI.ChatGpt.UnitTests/ChatGptEntityFrameworkIntegrationTests.cs @@ -0,0 +1,68 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using static OpenAI.ChatGpt.AspNetCore.Extensions.ServiceCollectionExtensions; + +namespace OpenAI.ChatGpt.UnitTests; + +public class ChatGptEntityFrameworkIntegrationTests +{ + [Fact] + public async void AddChatGptEntityFrameworkIntegration_works() + { + // Arrange + var configuration = CreateConfiguration(); + var services = CreateServiceCollection(configuration); + + // Act + services.AddChatGptEntityFrameworkIntegration( + options => options.UseInMemoryDatabase("chats"), + configuration + ); + + // Assert + await using var provider = services.BuildServiceProvider(); + + var storage = provider.GetRequiredService(); + storage.Should().BeOfType(); + + _ = provider.GetRequiredService(); + var client = provider.GetRequiredService(); + AssertAiClientOfType(client); + + var dbContext = provider.GetRequiredService(); + await dbContext.Database.EnsureDeletedAsync(); + } + + private static ServiceCollection CreateServiceCollection(IConfiguration configuration) + { + var services = new ServiceCollection(); + services.AddSingleton(configuration); + + return services; + } + + private IConfiguration CreateConfiguration() + { + var builder = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary() + { + { $"{OpenAiCredentialsConfigSectionPathDefault}:{nameof(OpenAICredentials.ApiKey)}", "test-api-key" }, + { ChatGPTConfigSectionPathDefault, ""}, + }); + + return builder.Build(); + } + + private void AssertAiClientOfType(IAiClient client) + { + if (client is AiClientFromConfiguration aiClientEx) + { + aiClientEx.GetInnerClient().Should().BeOfType(); + } + else + { + client.Should().BeOfType(); + } + } +} \ No newline at end of file diff --git a/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs index 0d38b34..2d0eca4 100644 --- a/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs +++ b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs @@ -12,33 +12,34 @@ public class ChatGptServicesIntegrationTests public void AddChatGptCoreIntegration_added_expected_services() { // Arrange - var services = CreateServiceCollection(); + var configuration = CreateConfiguration(); + var services = CreateServiceCollection(configuration); var initialServiceCount = services.Count; // Act - services.AddChatGptIntegrationCore(); + services.AddChatGptIntegrationCore(configuration); // Assert services.Count.Should().BeGreaterThan(initialServiceCount); using var provider = services.BuildServiceProvider(); - provider.GetRequiredService>(); - provider.GetRequiredService>(); + _ = provider.GetRequiredService>(); + _ = provider.GetRequiredService>(); - provider.GetRequiredService(); - provider.GetRequiredService(); - provider.GetRequiredService(); + _ = provider.GetRequiredService(); + _ = provider.GetRequiredService(); } [Fact] public async void AddChatGptInMemoryIntegration_works() { // Arrange - var services = CreateServiceCollection(); + var configuration = CreateConfiguration(); + var services = CreateServiceCollection(configuration); // Act - services.AddChatGptInMemoryIntegration(); + services.AddChatGptInMemoryIntegration(configuration); // Assert await using var provider = services.BuildServiceProvider(); @@ -55,10 +56,11 @@ public async void AddChatGptInMemoryIntegration_works() public async void AddChatGptInMemoryIntegration_with_Chat_injection_works() { // Arrange - var services = CreateServiceCollection(); + var configuration = CreateConfiguration(); + var services = CreateServiceCollection(configuration); // Act - services.AddChatGptInMemoryIntegration(injectInMemoryChatService: true); + services.AddChatGptInMemoryIntegration(configuration, injectInMemoryChatService: true); // Assert await using var provider = services.BuildServiceProvider(); @@ -73,11 +75,13 @@ public async void AddChatGptInMemoryIntegration_with_Chat_injection_works() public async void AddChatGptEntityFrameworkIntegration_works() { // Arrange - var services = CreateServiceCollection(); + var configuration = CreateConfiguration(); + var services = CreateServiceCollection(configuration); // Act services.AddChatGptEntityFrameworkIntegration( - options => options.UseInMemoryDatabase("ChatGptInMemoryDb")); + options => options.UseInMemoryDatabase("ChatGptInMemoryDb"), + configuration); // Assert await using var provider = services.BuildServiceProvider(); @@ -90,21 +94,21 @@ public async void AddChatGptEntityFrameworkIntegration_works() await factory.Create("test-user-id", ensureStorageCreated: true); } - private static ServiceCollection CreateServiceCollection() + private static ServiceCollection CreateServiceCollection(IConfiguration configuration) { var services = new ServiceCollection(); - services.AddSingleton(CreateConfiguration()); + services.AddSingleton(configuration); return services; - - IConfiguration CreateConfiguration() - { - var builder = new ConfigurationBuilder() - .AddInMemoryCollection(new Dictionary() - { - { $"{OpenAiCredentialsConfigSectionPathDefault}:{nameof(OpenAICredentials.ApiKey)}", "test-api-key" }, - { ChatGPTConfigSectionPathDefault, ""}, - }); - return builder.Build(); - } + } + + private IConfiguration CreateConfiguration() + { + var builder = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary() + { + { $"{OpenAiCredentialsConfigSectionPathDefault}:{nameof(OpenAICredentials.ApiKey)}", "test-api-key" }, + { ChatGPTConfigSectionPathDefault, ""}, + }); + return builder.Build(); } } \ No newline at end of file diff --git a/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/DifferentClientsIntegrationsTests.cs b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/DifferentClientsIntegrationsTests.cs new file mode 100644 index 0000000..19e2193 --- /dev/null +++ b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/DifferentClientsIntegrationsTests.cs @@ -0,0 +1,68 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using static OpenAI.ChatGpt.AspNetCore.Extensions.ServiceCollectionExtensions; + +namespace OpenAI.ChatGpt.UnitTests.DependencyInjectionTests; + +public class DifferentClientsIntegrationsTests +{ + [Fact] + public async void AddAzureOpenAiClient_succeeded() + { + // Arrange + var configuration = CreateConfiguration(); + var services = CreateServiceCollection(configuration); + + // Act + services.AddChatGptEntityFrameworkIntegration( + options => options.UseInMemoryDatabase("chats"), + configuration + ); + + // Assert + await using var provider = services.BuildServiceProvider(); + + var storage = provider.GetRequiredService(); + storage.Should().BeOfType(); + + _ = provider.GetRequiredService(); + var client = provider.GetRequiredService(); + AssertAiClientOfType(client); + } + + private static ServiceCollection CreateServiceCollection(IConfiguration configuration) + { + var services = new ServiceCollection(); + services.AddSingleton(configuration); + + return services; + } + + private IConfiguration CreateConfiguration() + { + var builder = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary() + { + { "AIProvider", "azure_openai"}, + { $"{AzureOpenAiCredentialsConfigSectionPathDefault}:{nameof(AzureOpenAICredentials.ApiKey)}", "test-api-key" }, + { $"{AzureOpenAiCredentialsConfigSectionPathDefault}:{nameof(AzureOpenAICredentials.ApiHost)}", "https://endopoint.openai.azure.com/" }, + { $"{AzureOpenAiCredentialsConfigSectionPathDefault}:{nameof(AzureOpenAICredentials.DeploymentName)}", "deployment" }, + { ChatGPTConfigSectionPathDefault, ""}, + }); + + return builder.Build(); + } + + private void AssertAiClientOfType(IAiClient client) + { + if (client is AiClientFromConfiguration aiClientEx) + { + aiClientEx.GetInnerClient().Should().BeOfType(); + } + else + { + client.Should().BeOfType(); + } + } +} \ No newline at end of file