diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index ae1daff1f00..64ff3656980 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; import java.util.List; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; @@ -33,8 +34,12 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.ssl.SslBundle; +import org.springframework.boot.web.client.ClientHttpRequestFactories; +import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -60,8 +65,7 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { - var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), - chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + var openAiApi = openAiApi(commonProperties, chatProperties, restClientBuilder, responseErrorHandler); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); @@ -78,23 +82,22 @@ public OpenAiEmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties co OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { - var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), - embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + var openAiApi = openAiApi(commonProperties, embeddingProperties, restClientBuilder, responseErrorHandler); return new OpenAiEmbeddingClient(openAiApi, embeddingProperties.getMetadataMode(), embeddingProperties.getOptions(), retryTemplate); } - private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, - RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + private OpenAiApi openAiApi(OpenAiConnectionProperties commonProperties, + T specificProperties, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { - String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; - Assert.hasText(resolvedBaseUrl, "OpenAI base URL must be set"); + OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties, + specificProperties); + RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder, + overridenCommonProperties); - String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; - Assert.hasText(resolvedApiKey, "OpenAI API key must be set"); - - return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler); + return new OpenAiApi(overridenCommonProperties.getBaseUrl(), overridenCommonProperties.getApiKey(), + overrideRestClientBuilder, responseErrorHandler); } @Bean @@ -105,16 +108,13 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { - String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey() - : commonProperties.getApiKey(); - - String baseUrl = StringUtils.hasText(imageProperties.getBaseUrl()) ? imageProperties.getBaseUrl() - : commonProperties.getBaseUrl(); - - Assert.hasText(apiKey, "OpenAI API key must be set"); - Assert.hasText(baseUrl, "OpenAI base URL must be set"); + OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties, + imageProperties); + RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder, + overridenCommonProperties); - var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler); + var openAiImageApi = new OpenAiImageApi(overridenCommonProperties.getBaseUrl(), + overridenCommonProperties.getApiKey(), overrideRestClientBuilder, responseErrorHandler); return new OpenAiImageClient(openAiImageApi, imageProperties.getOptions(), retryTemplate); } @@ -122,24 +122,18 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp @Bean @ConditionalOnMissingBean public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConnectionProperties commonProperties, - OpenAiAudioTranscriptionProperties transcriptionProperties, RetryTemplate retryTemplate, - ResponseErrorHandler responseErrorHandler) { - - String apiKey = StringUtils.hasText(transcriptionProperties.getApiKey()) ? transcriptionProperties.getApiKey() - : commonProperties.getApiKey(); - - String baseUrl = StringUtils.hasText(transcriptionProperties.getBaseUrl()) - ? transcriptionProperties.getBaseUrl() : commonProperties.getBaseUrl(); - - Assert.hasText(apiKey, "OpenAI API key must be set"); - Assert.hasText(baseUrl, "OpenAI base URL must be set"); + OpenAiAudioTranscriptionProperties transcriptionProperties, RestClient.Builder restClientBuilder, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { - var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler); + OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties, + transcriptionProperties); + RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder, + overridenCommonProperties); - OpenAiAudioTranscriptionClient openAiChatClient = new OpenAiAudioTranscriptionClient(openAiAudioApi, - transcriptionProperties.getOptions(), retryTemplate); + var openAiAudioApi = new OpenAiAudioApi(overridenCommonProperties.getBaseUrl(), + overridenCommonProperties.getApiKey(), overrideRestClientBuilder, responseErrorHandler); - return openAiChatClient; + return new OpenAiAudioTranscriptionClient(openAiAudioApi, transcriptionProperties.getOptions(), retryTemplate); } @Bean @@ -150,4 +144,37 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex return manager; } + private static OpenAiConnectionProperties checkAndOverrideProperties( + OpenAiConnectionProperties commonProperties, T specificProperties) { + + String apiKey = StringUtils.hasText(specificProperties.getApiKey()) ? specificProperties.getApiKey() + : commonProperties.getApiKey(); + + String baseUrl = StringUtils.hasText(specificProperties.getBaseUrl()) ? specificProperties.getBaseUrl() + : commonProperties.getBaseUrl(); + + Duration readTimeout = specificProperties.getReadTimeout() != null ? specificProperties.getReadTimeout() + : commonProperties.getReadTimeout(); + + Assert.hasText(apiKey, "OpenAI API key must be set"); + Assert.hasText(baseUrl, "OpenAI base URL must be set"); + Assert.notNull(readTimeout, "OpenAI base read timeout must be set"); + + OpenAiConnectionProperties overridenCommonProperties = new OpenAiConnectionProperties(); + overridenCommonProperties.setApiKey(apiKey); + overridenCommonProperties.setBaseUrl(baseUrl); + overridenCommonProperties.setReadTimeout(readTimeout); + + return overridenCommonProperties; + + } + + private static RestClient.Builder overrideRestClientBuilder(RestClient.Builder restClientBuilder, + OpenAiConnectionProperties overridenCommonProperties) { + ClientHttpRequestFactorySettings requestFactorySettings = new ClientHttpRequestFactorySettings( + Duration.ofHours(1l), overridenCommonProperties.getReadTimeout(), SslBundle.of(null)); + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(requestFactorySettings); + return restClientBuilder.clone().requestFactory(requestFactory); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java index b065deb53a5..64acd4e5ccf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java @@ -15,6 +15,8 @@ */ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; + import org.springframework.boot.context.properties.ConfigurationProperties; @ConfigurationProperties(OpenAiConnectionProperties.CONFIG_PREFIX) @@ -24,8 +26,11 @@ public class OpenAiConnectionProperties extends OpenAiParentProperties { public static final String DEFAULT_BASE_URL = "https://api.openai.com"; + public static final Duration DEFAULT_READ_TIMEOUT = Duration.ofMinutes(1); + public OpenAiConnectionProperties() { super.setBaseUrl(DEFAULT_BASE_URL); + super.setReadTimeout(DEFAULT_READ_TIMEOUT); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java index e4ba3a470da..fde567dd72f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java @@ -15,6 +15,8 @@ */ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; + /** * Internal parent properties for the OpenAI properties. * @@ -27,6 +29,8 @@ class OpenAiParentProperties { private String baseUrl; + private Duration readTimeout; + public String getApiKey() { return apiKey; } @@ -43,4 +47,12 @@ public void setBaseUrl(String baseUrl) { this.baseUrl = baseUrl; } + public Duration getReadTimeout() { + return readTimeout; + } + + public void setReadTimeout(Duration readTimeout) { + this.readTimeout = readTimeout; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index e14a6a562de..b467b8e8240 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; import org.junit.jupiter.api.Test; import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; @@ -50,6 +51,7 @@ public void chatProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on @@ -61,9 +63,11 @@ public void chatProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(chatProperties.getApiKey()).isNull(); assertThat(chatProperties.getBaseUrl()).isNull(); + assertThat(chatProperties.getReadTimeout()).isNull(); assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); @@ -104,8 +108,10 @@ public void chatOverrideConnectionProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.chat.base-url=TEST_BASE_URL2", "spring.ai.openai.chat.api-key=456", + "spring.ai.openai.chat.read-timeout=5m", "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on @@ -117,9 +123,11 @@ public void chatOverrideConnectionProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(chatProperties.getApiKey()).isEqualTo("456"); assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + assertThat(chatProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5)); assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); @@ -162,6 +170,7 @@ public void embeddingProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, @@ -172,9 +181,11 @@ public void embeddingProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getApiKey()).isNull(); assertThat(embeddingProperties.getBaseUrl()).isNull(); + assertThat(embeddingProperties.getReadTimeout()).isNull(); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); }); @@ -187,8 +198,10 @@ public void embeddingOverrideConnectionProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.embedding.base-url=TEST_BASE_URL2", "spring.ai.openai.embedding.api-key=456", + "spring.ai.openai.embedding.read-timeout=5m", "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, @@ -199,9 +212,11 @@ public void embeddingOverrideConnectionProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getApiKey()).isEqualTo("456"); assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + assertThat(embeddingProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5)); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); }); @@ -213,6 +228,7 @@ public void imageProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.image.options.model=MODEL_XYZ", "spring.ai.openai.image.options.n=3") // @formatter:on @@ -224,9 +240,11 @@ public void imageProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(imageProperties.getApiKey()).isNull(); assertThat(imageProperties.getBaseUrl()).isNull(); + assertThat(imageProperties.getReadTimeout()).isNull(); assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(imageProperties.getOptions().getN()).isEqualTo(3); @@ -239,8 +257,10 @@ public void imageOverrideConnectionProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.image.base-url=TEST_BASE_URL2", "spring.ai.openai.image.api-key=456", + "spring.ai.openai.image.read-timeout=5m", "spring.ai.openai.image.options.model=MODEL_XYZ", "spring.ai.openai.image.options.n=3") // @formatter:on @@ -252,9 +272,11 @@ public void imageOverrideConnectionProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(imageProperties.getApiKey()).isEqualTo("456"); assertThat(imageProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + assertThat(imageProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5)); assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(imageProperties.getOptions().getN()).isEqualTo(3); @@ -268,6 +290,7 @@ public void chatOptionsTest() { // @formatter:off "spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.frequencyPenalty=-1.5", @@ -322,6 +345,7 @@ public void chatOptionsTest() { assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002"); @@ -395,6 +419,7 @@ public void embeddingOptionsTest() { // @formatter:off "spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.embedding.options.model=MODEL_XYZ", "spring.ai.openai.embedding.options.encodingFormat=MyEncodingFormat", @@ -409,6 +434,7 @@ public void embeddingOptionsTest() { assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(embeddingProperties.getOptions().getEncodingFormat()).isEqualTo("MyEncodingFormat"); @@ -422,6 +448,7 @@ public void imageOptionsTest() { // @formatter:off "spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.image.options.n=3", "spring.ai.openai.image.options.model=MODEL_XYZ", @@ -442,6 +469,7 @@ public void imageOptionsTest() { assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(imageProperties.getOptions().getN()).isEqualTo(3); assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");