diff --git a/core/pom.xml b/core/pom.xml index 78ab5bbc..3ecd1a93 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -94,6 +94,10 @@ org.slf4j slf4j-api + + io.vavr + vavr + com.sap.cloud.sdk.cloudplatform @@ -127,6 +131,11 @@ assertj-core test + + org.mockito + mockito-core + test + diff --git a/core/src/main/java/com/sap/ai/sdk/core/AiCoreDeployment.java b/core/src/main/java/com/sap/ai/sdk/core/AiCoreDeployment.java new file mode 100644 index 00000000..b0620591 --- /dev/null +++ b/core/src/main/java/com/sap/ai/sdk/core/AiCoreDeployment.java @@ -0,0 +1,103 @@ +package com.sap.ai.sdk.core; + +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; +import com.sap.cloud.sdk.cloudplatform.connectivity.DestinationProperty; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; +import java.util.function.Function; +import javax.annotation.Nonnull; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +/** Connectivity convenience methods for AI Core with deployment. */ +@RequiredArgsConstructor(access = AccessLevel.PROTECTED) +public class AiCoreDeployment implements AiCoreDestination { + + private static final String AI_RESOURCE_GROUP = "URL.headers.AI-Resource-Group"; + + // the delegating AI Core Service instance + @Nonnull private final AiCoreService service; + + // the deployment id handler to be used, based on instance + @Nonnull private final Function deploymentId; + + // the resource group, "default" if null + @Getter(AccessLevel.PROTECTED) + @Nonnull + private final String resourceGroup; + + /** + * Default constructor with "default" resource group. + * + * @param service The AI Core service. + * @param deploymentId The deployment id handler. + */ + public AiCoreDeployment( + @Nonnull final AiCoreService service, + @Nonnull final Function deploymentId) { + this(service, deploymentId, "default"); + } + + @Nonnull + @Override + public Destination destination() { + final var dest = service.baseDestinationHandler.apply(service); + final var builder = service.builderHandler.apply(service, dest); + destinationSetUrl(builder, dest); + destinationSetHeaders(builder); + return builder.build(); + } + + @Nonnull + @Override + public ApiClient client() { + final var destination = destination(); + return service.clientHandler.apply(service, destination); + } + + /** + * Update and set the URL for the destination. + * + * @param builder The destination builder. + * @param dest The original destination reference. + */ + protected void destinationSetUrl( + @Nonnull final DefaultHttpDestination.Builder builder, @Nonnull final Destination dest) { + String uri = dest.get(DestinationProperty.URI).get(); + if (!uri.endsWith("/")) { + uri = uri + "/"; + } + builder.uri(uri + "v2/inference/deployments/%s/".formatted(getDeploymentId())); + } + + /** + * Update and set the default request headers for the destination. + * + * @param builder The destination builder. + */ + protected void destinationSetHeaders(@Nonnull final DefaultHttpDestination.Builder builder) { + builder.property(AI_RESOURCE_GROUP, getResourceGroup()); + } + + /** + * Set the resource group. + * + * @param resourceGroup The resource group. + * @return A new instance of the AI Core service. + */ + @Nonnull + public AiCoreDeployment withResourceGroup(@Nonnull final String resourceGroup) { + return new AiCoreDeployment(service, deploymentId, resourceGroup); + } + + /** + * Get the deployment id. + * + * @return The deployment id. + */ + @Nonnull + protected String getDeploymentId() { + return deploymentId.apply(this); + } +} diff --git a/core/src/main/java/com/sap/ai/sdk/core/AiCoreDestination.java b/core/src/main/java/com/sap/ai/sdk/core/AiCoreDestination.java new file mode 100644 index 00000000..a61bd494 --- /dev/null +++ b/core/src/main/java/com/sap/ai/sdk/core/AiCoreDestination.java @@ -0,0 +1,28 @@ +package com.sap.ai.sdk.core; + +import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; +import javax.annotation.Nonnull; + +/** Container for an API client and destination. */ +@FunctionalInterface +public interface AiCoreDestination { + /** + * Get the destination. + * + * @return the destination + */ + @Nonnull + Destination destination(); + + /** + * Get the API client. + * + * @return the API client + */ + @Nonnull + default ApiClient client() { + final var destination = destination(); + return new ApiClient(destination); + } +} diff --git a/core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java b/core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java new file mode 100644 index 00000000..c063b3c7 --- /dev/null +++ b/core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java @@ -0,0 +1,181 @@ +package com.sap.ai.sdk.core; + +import static com.sap.ai.sdk.core.DestinationResolver.AI_CLIENT_TYPE_KEY; +import static com.sap.ai.sdk.core.DestinationResolver.AI_CLIENT_TYPE_VALUE; + +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.PropertyAccessor; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.google.common.collect.Iterables; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; +import com.sap.cloud.sdk.cloudplatform.connectivity.DestinationProperty; +import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException; +import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; +import java.util.NoSuchElementException; +import java.util.function.BiFunction; +import java.util.function.Function; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.client.BufferingClientHttpRequestFactory; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.web.client.RestTemplate; + +/** Connectivity convenience methods for AI Core. */ +@Slf4j +@RequiredArgsConstructor +public class AiCoreService implements AiCoreDestination { + + final Function baseDestinationHandler; + final BiFunction clientHandler; + final BiFunction builderHandler; + + private static final DeploymentCache DEPLOYMENT_CACHE = new DeploymentCache(); + + /** The default constructor. */ + public AiCoreService() { + this( + AiCoreService::getBaseDestination, + AiCoreService::getApiClient, + AiCoreService::getDestinationBuilder); + } + + @Nonnull + @Override + public ApiClient client() { + final var destination = destination(); + return clientHandler.apply(this, destination); + } + + @Nonnull + @Override + public Destination destination() { + final var dest = baseDestinationHandler.apply(this); + return builderHandler.apply(this, dest).build(); + } + + /** + * Set a specific base destination. + * + * @param destination The destination to be used for AI Core service calls. + * @return A new instance of the AI Core Service based on the provided destination. + */ + @Nonnull + public AiCoreService withDestination(@Nonnull final Destination destination) { + return new AiCoreService((service) -> destination, clientHandler, builderHandler); + } + + /** + * Set a specific deployment by id. + * + * @param deploymentId The deployment id to be used for AI Core service calls. + * @return A new instance of the AI Core Deployment. + */ + @Nonnull + public AiCoreDeployment forDeployment(@Nonnull final String deploymentId) { + return new AiCoreDeployment(this, obj -> deploymentId); + } + + /** + * Set a specific deployment by model name. If there are multiple deployments of the same model, + * the first one is returned. + * + * @param modelName The model name to be used for AI Core service calls. + * @return A new instance of the AI Core Deployment. + * @throws NoSuchElementException if no running deployment is found for the model. + */ + @Nonnull + public AiCoreDeployment forDeploymentByModel(@Nonnull final String modelName) + throws NoSuchElementException { + return new AiCoreDeployment( + this, + obj -> + DEPLOYMENT_CACHE.getDeploymentIdByModel( + this.client(), obj.getResourceGroup(), modelName)); + } + + /** + * Set a specific deployment by scenario id. If there are multiple deployments of the same model, + * the first one is returned. + * + * @param scenarioId The scenario id to be used for AI Core service calls. + * @return A new instance of the AI Core Deployment. + * @throws NoSuchElementException if no running deployment is found for the scenario. + */ + @Nonnull + public AiCoreDeployment forDeploymentByScenario(@Nonnull final String scenarioId) + throws NoSuchElementException { + return new AiCoreDeployment( + this, + obj -> + DEPLOYMENT_CACHE.getDeploymentIdByScenario( + this.client(), obj.getResourceGroup(), scenarioId)); + } + + /** + * Get a destination using the default service binding loading logic. + * + * @return The destination. + * @throws DestinationAccessException If the destination cannot be accessed. + * @throws DestinationNotFoundException If the destination cannot be found. + */ + @Nonnull + protected Destination getBaseDestination() + throws DestinationAccessException, DestinationNotFoundException { + final var serviceKey = System.getenv("AICORE_SERVICE_KEY"); + return DestinationResolver.getDestination(serviceKey); + } + + /** + * Get the destination builder with adjustments for AI Core. + * + * @param destination The destination. + * @return The destination builder. + */ + @Nonnull + protected DefaultHttpDestination.Builder getDestinationBuilder( + @Nonnull final Destination destination) { + final var builder = DefaultHttpDestination.fromDestination(destination); + String uri = destination.get(DestinationProperty.URI).get(); + if (!uri.endsWith("/")) { + uri = uri + "/"; + } + builder.uri(uri + "v2/").property(AI_CLIENT_TYPE_KEY, AI_CLIENT_TYPE_VALUE); + return builder; + } + + /** + * Get a destination using the default service binding loading logic. + * + * @return The destination. + * @throws DestinationAccessException If the destination cannot be accessed. + * @throws DestinationNotFoundException If the destination cannot be found. + */ + @SuppressWarnings("UnstableApiUsage") + @Nonnull + protected ApiClient getApiClient(@Nonnull final Destination destination) { + final var objectMapper = + new Jackson2ObjectMapperBuilder() + .modules(new JavaTimeModule()) + .visibility(PropertyAccessor.GETTER, JsonAutoDetect.Visibility.NONE) + .visibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.NONE) + .serializationInclusion(JsonInclude.Include.NON_NULL) // THIS STOPS `null` serialization + .build(); + + final var httpRequestFactory = new HttpComponentsClientHttpRequestFactory(); + httpRequestFactory.setHttpClient(ApacheHttpClient5Accessor.getHttpClient(destination)); + + final var rt = new RestTemplate(); + Iterables.filter(rt.getMessageConverters(), MappingJackson2HttpMessageConverter.class) + .forEach(converter -> converter.setObjectMapper(objectMapper)); + rt.setRequestFactory(new BufferingClientHttpRequestFactory(httpRequestFactory)); + + return new ApiClient(rt).setBasePath(destination.asHttp().getUri().toString()); + } +} diff --git a/core/src/main/java/com/sap/ai/sdk/core/Core.java b/core/src/main/java/com/sap/ai/sdk/core/Core.java deleted file mode 100644 index e1d83d32..00000000 --- a/core/src/main/java/com/sap/ai/sdk/core/Core.java +++ /dev/null @@ -1,300 +0,0 @@ -package com.sap.ai.sdk.core; - -import static com.sap.cloud.sdk.cloudplatform.connectivity.OnBehalfOf.TECHNICAL_USER_PROVIDER; - -import com.fasterxml.jackson.annotation.JsonAutoDetect; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.PropertyAccessor; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import com.sap.ai.sdk.core.client.DeploymentApi; -import com.sap.ai.sdk.core.client.model.AiDeployment; -import com.sap.cloud.environment.servicebinding.api.DefaultServiceBindingAccessor; -import com.sap.cloud.environment.servicebinding.api.DefaultServiceBindingBuilder; -import com.sap.cloud.environment.servicebinding.api.ServiceBindingAccessor; -import com.sap.cloud.environment.servicebinding.api.ServiceBindingMerger; -import com.sap.cloud.environment.servicebinding.api.ServiceIdentifier; -import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; -import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; -import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; -import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; -import com.sap.cloud.sdk.cloudplatform.connectivity.ServiceBindingDestinationLoader; -import com.sap.cloud.sdk.cloudplatform.connectivity.ServiceBindingDestinationOptions; -import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import lombok.extern.slf4j.Slf4j; -import org.springframework.http.client.BufferingClientHttpRequestFactory; -import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; -import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; -import org.springframework.web.client.RestTemplate; - -/** Connectivity convenience methods for AI Core. */ -@Slf4j -public class Core { - - /** - * Requires an AI Core service binding. - * - * @param resourceGroup the resource group. - * @return a generic Orchestration ApiClient. - */ - @Nonnull - public static ApiClient getOrchestrationClient(@Nonnull final String resourceGroup) { - return getClient( - getDestinationForDeployment(getOrchestrationDeployment(resourceGroup), resourceGroup)); - } - - /** - * Get the deployment id from the scenario id. If there are multiple deployments of the same - * scenario id, the first one is returned. - * - * @param resourceGroup the resource group. - * @return the deployment id - * @throws NoSuchElementException if no deployment is found for the scenario id. - */ - private static String getOrchestrationDeployment(@Nonnull final String resourceGroup) - throws NoSuchElementException { - final var deployments = new DeploymentApi(getClient(getDestination())).query(resourceGroup); - - return deployments.getResources().stream() - .filter(deployment -> "orchestration".equals(deployment.getScenarioId())) - .map(AiDeployment::getId) - .findFirst() - .orElseThrow( - () -> new NoSuchElementException("No deployment found with scenario id orchestration")); - } - - /** - * Requires an AI Core service binding OR a service key in the environment variable {@code - * AICORE_SERVICE_KEY}. - * - * @return a generic AI Core ApiClient. - */ - @Nonnull - public static ApiClient getClient() { - return getClient(getDestination()); - } - - /** - * Get a generic AI Core ApiClient for testing purposes. - * - * @param destination The destination to use. - * @return a generic AI Core ApiClient. - */ - @Nonnull - @SuppressWarnings("UnstableApiUsage") - public static ApiClient getClient(@Nonnull final Destination destination) { - final var objectMapper = - new Jackson2ObjectMapperBuilder() - .modules(new JavaTimeModule()) - .visibility(PropertyAccessor.GETTER, JsonAutoDetect.Visibility.NONE) - .visibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.NONE) - .serializationInclusion(JsonInclude.Include.NON_NULL) // THIS STOPS `null` serialization - .build(); - - final var httpRequestFactory = new HttpComponentsClientHttpRequestFactory(); - httpRequestFactory.setHttpClient(ApacheHttpClient5Accessor.getHttpClient(destination)); - - final var restTemplate = new RestTemplate(); - restTemplate.getMessageConverters().stream() - .filter(MappingJackson2HttpMessageConverter.class::isInstance) - .map(MappingJackson2HttpMessageConverter.class::cast) - .forEach(converter -> converter.setObjectMapper(objectMapper)); - restTemplate.setRequestFactory(new BufferingClientHttpRequestFactory(httpRequestFactory)); - - return new ApiClient(restTemplate).setBasePath(destination.asHttp().getUri().toString()); - } - - /** - * Get a destination pointing to the AI Core service. - * - *

Requires an AI Core service binding OR a service key in the environment variable {@code - * AICORE_SERVICE_KEY}. - * - * @return a destination pointing to the AI Core service. - */ - @Nonnull - public static Destination getDestination() { - final var serviceKey = System.getenv("AICORE_SERVICE_KEY"); - return getDestination(serviceKey); - } - - /** - * For testing only - * - *

Get a destination pointing to the AI Core service. - * - * @param serviceKey The service key in JSON format. - * @return a destination pointing to the AI Core service. - */ - static HttpDestination getDestination(@Nullable final String serviceKey) { - final var serviceKeyPresent = serviceKey != null; - final var aiCoreBindingPresent = - DefaultServiceBindingAccessor.getInstance().getServiceBindings().stream() - .anyMatch( - serviceBinding -> - ServiceIdentifier.AI_CORE.equals( - serviceBinding.getServiceIdentifier().orElse(null))); - - if (!aiCoreBindingPresent && serviceKeyPresent) { - addServiceBinding(serviceKey); - } - - // get a destination pointing to the AI Core service - final var opts = - ServiceBindingDestinationOptions.forService(ServiceIdentifier.AI_CORE) - .onBehalfOf(TECHNICAL_USER_PROVIDER) - .build(); - var destination = ServiceBindingDestinationLoader.defaultLoaderChain().getDestination(opts); - - destination = - DefaultHttpDestination.fromDestination(destination) - // append the /v2 path here, so we don't have to do it in every request when using the - // generated code this is actually necessary, because the generated code assumes this - // path to be present on the destination - .uri(destination.getUri().resolve("/v2")) - .header("AI-Client-Type", "AI SDK Java") - .build(); - return destination; - } - - /** - * Set the AI Core service key as the service binding. This is used for local testing. - * - * @param serviceKey The service key in JSON format. - * @throws AiCoreCredentialsInvalidException if the JSON service key cannot be parsed. - */ - private static void addServiceBinding(@Nonnull final String serviceKey) { - log.info( - """ - Found a service key in environment variable "AICORE_SERVICE_KEY". - Using a service key is recommended for local testing only. - Bind the AI Core service to the application for productive usage."""); - - var credentials = new HashMap(); - try { - credentials = new ObjectMapper().readValue(serviceKey, new TypeReference<>() {}); - } catch (JsonProcessingException e) { - throw new AiCoreCredentialsInvalidException( - "Error in parsing service key from the \"AICORE_SERVICE_KEY\" environment variable.", e); - } - - final var binding = - new DefaultServiceBindingBuilder() - .withServiceIdentifier(ServiceIdentifier.AI_CORE) - .withCredentials(credentials) - .build(); - final ServiceBindingAccessor accessor = DefaultServiceBindingAccessor.getInstance(); - final var newAccessor = - new ServiceBindingMerger( - List.of(accessor, () -> List.of(binding)), ServiceBindingMerger.KEEP_EVERYTHING); - DefaultServiceBindingAccessor.setInstance(newAccessor); - } - - /** Exception thrown when the JSON AI Core service key is invalid. */ - static class AiCoreCredentialsInvalidException extends RuntimeException { - public AiCoreCredentialsInvalidException( - @Nonnull final String message, @Nonnull final Throwable cause) { - super(message, cause); - } - } - - /** - * Get a destination pointing to the inference endpoint of a deployment on AI Core. Requires an - * AI Core service binding. - * - * @param deploymentId The deployment id. - * @param resourceGroup The resource group. - * @return a destination that can be used for inference calls. - */ - @Nonnull - public static Destination getDestinationForDeployment( - @Nonnull final String deploymentId, @Nonnull final String resourceGroup) { - final var destination = getDestination().asHttp(); - final DefaultHttpDestination.Builder builder = - DefaultHttpDestination.fromDestination(destination) - .uri( - destination - .getUri() - .resolve("/v2/inference/deployments/%s/".formatted(deploymentId))); - - builder.header("AI-Resource-Group", resourceGroup); - - return builder.build(); - } - - /** - * Get a destination pointing to the inference endpoint of a deployment on AI Core. Requires an - * AI Core service binding. - * - * @param modelName The name of the foundation model that is used by a deployment. - * @param resourceGroup The resource group. - * @return a destination that can be used for inference calls. - */ - @Nonnull - public static Destination getDestinationForModel( - @Nonnull final String modelName, @Nonnull final String resourceGroup) { - return getDestinationForDeployment( - getDeploymentForModel(modelName, resourceGroup), resourceGroup); - } - - /** - * Get the deployment id from the model name. If there are multiple deployments of the same model, - * the first one is returned. - * - * @param modelName the model name. - * @param resourceGroup the resource group. - * @return the deployment id - * @throws NoSuchElementException if no deployment is found for the model name. - */ - private static String getDeploymentForModel( - @Nonnull final String modelName, @Nonnull final String resourceGroup) - throws NoSuchElementException { - final var deployments = new DeploymentApi(getClient()).query(resourceGroup); - - return deployments.getResources().stream() - .filter(deployment -> isDeploymentOfModel(modelName, deployment)) - .map(AiDeployment::getId) - .findFirst() - .orElseThrow( - () -> new NoSuchElementException("No deployment found with model name " + modelName)); - } - - /** This exists because getBackendDetails() is broken */ - private static boolean isDeploymentOfModel( - @Nonnull final String modelName, @Nonnull final AiDeployment deployment) { - final var deploymentDetails = deployment.getDetails(); - // The AI Core specification doesn't mention that this is nullable, but it can be. - // Remove this check when the specification is fixed. - if (deploymentDetails == null) { - return false; - } - final var resources = deploymentDetails.getResources(); - if (resources == null) { - return false; - } - Object detailsObject = resources.getBackendDetails(); - // workaround for AIWDF-2124 - if (detailsObject == null) { - if (!resources.getCustomFieldNames().contains("backend_details")) { - return false; - } - detailsObject = resources.getCustomField("backend_details"); - } - - if (detailsObject instanceof Map details - && details.get("model") instanceof Map model - && model.get("name") instanceof String name) { - return modelName.equals(name); - } - return false; - } -} diff --git a/core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java b/core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java new file mode 100644 index 00000000..aa54ceaf --- /dev/null +++ b/core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java @@ -0,0 +1,169 @@ +package com.sap.ai.sdk.core; + +import com.sap.ai.sdk.core.client.DeploymentApi; +import com.sap.ai.sdk.core.client.model.AiDeployment; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; +import com.sap.cloud.sdk.services.openapi.core.OpenApiRequestException; +import java.util.HashSet; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; +import javax.annotation.Nonnull; +import lombok.extern.slf4j.Slf4j; + +/** + * Cache for deployment IDs. This class is used to get the deployment id for the orchestration + * scenario or for a model. + */ +@Slf4j +class DeploymentCache { + + /** Cache for deployment ids. The key is the model name and the value is the deployment id. */ + protected final Set CACHE = new HashSet<>(); + + /** + * Remove all entries from the cache then load all deployments into the cache. + * + *

Call this whenever a deployment is deleted. + * + * @param client the API client to query deployments. + * @param resourceGroup the resource group, usually "default". + */ + public void resetCache(@Nonnull final ApiClient client, @Nonnull final String resourceGroup) { + clearCache(); + loadCache(client, resourceGroup); + } + + /** + * Remove all entries from the cache. + * + *

Call {@link #resetCache} whenever a deployment is deleted. + */ + protected void clearCache() { + CACHE.clear(); + } + + /** + * Load all deployments into the cache + * + *

Call {@link #resetCache} whenever a deployment is deleted. + * + * @param client the API client to query deployments. + * @param resourceGroup the resource group, usually "default". + */ + protected void loadCache(@Nonnull final ApiClient client, @Nonnull final String resourceGroup) { + try { + final var deployments = new DeploymentApi(client).query(resourceGroup).getResources(); + CACHE.addAll(deployments); + } catch (final OpenApiRequestException e) { + log.error("Failed to load deployments into cache", e); + } + } + + /** + * Get the deployment id from the foundation model name. If there are multiple deployments of the + * same model, the first one is returned. + * + * @param client the API client to maybe reset the cache if the deployment is not found. + * @param resourceGroup the resource group, usually "default". + * @param modelName the name of the foundation model. + * @return the deployment id. + * @throws NoSuchElementException if no running deployment is found for the model. + */ + @Nonnull + public String getDeploymentIdByModel( + @Nonnull final ApiClient client, + @Nonnull final String resourceGroup, + @Nonnull final String modelName) + throws NoSuchElementException { + return getDeploymentIdByModel(modelName) + .orElseGet( + () -> { + resetCache(client, resourceGroup); + return getDeploymentIdByModel(modelName) + .orElseThrow( + () -> + new NoSuchElementException( + "No running deployment found for model: " + modelName)); + }); + } + + private Optional getDeploymentIdByModel(@Nonnull final String modelName) { + return CACHE.stream() + .filter(deployment -> isDeploymentOfModel(modelName, deployment)) + .findFirst() + .map(AiDeployment::getId); + } + + /** + * Get the deployment id from the scenario id. If there are multiple deployments of the * same + * model, the first one is returned. + * + * @param client the API client to maybe reset the cache if the deployment is not found. + * @param resourceGroup the resource group, usually "default". + * @param scenarioId the scenario id, can be "orchestration". + * @return the deployment id. + * @throws NoSuchElementException if no running deployment is found for the scenario. + */ + @Nonnull + public String getDeploymentIdByScenario( + @Nonnull final ApiClient client, + @Nonnull final String resourceGroup, + @Nonnull final String scenarioId) + throws NoSuchElementException { + return getDeploymentIdByScenario(scenarioId) + .orElseGet( + () -> { + resetCache(client, resourceGroup); + return getDeploymentIdByScenario(scenarioId) + .orElseThrow( + () -> + new NoSuchElementException( + "No running deployment found for scenario: " + scenarioId)); + }); + } + + private Optional getDeploymentIdByScenario(@Nonnull final String scenarioId) { + return CACHE.stream() + .filter(deployment -> scenarioId.equals(deployment.getScenarioId())) + .findFirst() + .map(AiDeployment::getId); + } + + /** + * This exists because getBackendDetails() is broken + * + * @param modelName The model name. + * @param deployment The deployment. + * @return true if the deployment is of the model. + */ + protected static boolean isDeploymentOfModel( + @Nonnull final String modelName, @Nonnull final AiDeployment deployment) { + final var deploymentDetails = deployment.getDetails(); + // The AI Core specification doesn't mention that this is nullable, but it can be. + // Remove this check when the specification is fixed. + if (deploymentDetails == null) { + return false; + } + final var resources = deploymentDetails.getResources(); + if (resources == null) { + return false; + } + Object detailsObject = resources.getBackendDetails(); + // workaround for AIWDF-2124 + if (detailsObject == null) { + if (!resources.getCustomFieldNames().contains("backend_details")) { + return false; + } + detailsObject = resources.getCustomField("backend_details"); + } + + if (detailsObject instanceof Map details + && details.get("model") instanceof Map model + && model.get("name") instanceof String name) { + return modelName.equals(name); + } + return false; + } +} diff --git a/core/src/main/java/com/sap/ai/sdk/core/DestinationResolver.java b/core/src/main/java/com/sap/ai/sdk/core/DestinationResolver.java new file mode 100644 index 00000000..3e2eebd8 --- /dev/null +++ b/core/src/main/java/com/sap/ai/sdk/core/DestinationResolver.java @@ -0,0 +1,116 @@ +package com.sap.ai.sdk.core; + +import static com.google.common.collect.Iterables.tryFind; +import static com.sap.cloud.sdk.cloudplatform.connectivity.OnBehalfOf.TECHNICAL_USER_PROVIDER; +import static com.sap.cloud.sdk.cloudplatform.connectivity.ServiceBindingDestinationOptions.forService; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sap.cloud.environment.servicebinding.api.DefaultServiceBindingAccessor; +import com.sap.cloud.environment.servicebinding.api.DefaultServiceBindingBuilder; +import com.sap.cloud.environment.servicebinding.api.ServiceBindingAccessor; +import com.sap.cloud.environment.servicebinding.api.ServiceBindingMerger; +import com.sap.cloud.environment.servicebinding.api.ServiceIdentifier; +import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; +import com.sap.cloud.sdk.cloudplatform.connectivity.ServiceBindingDestinationLoader; +import java.util.HashMap; +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; + +/** Utility class to resolve the destination pointing to the AI Core service. */ +@Slf4j +class DestinationResolver { + static final String AI_CLIENT_TYPE_KEY = "URL.headers.AI-Client-Type"; + static final String AI_CLIENT_TYPE_VALUE = "AI SDK Java"; + + @Getter(AccessLevel.PROTECTED) + @Nonnull + private static ServiceBindingAccessor accessor = DefaultServiceBindingAccessor.getInstance(); + + /** + * For testing only + * + *

Get a destination pointing to the AI Core service. + * + * @param serviceKey The service key in JSON format. + * @return a destination pointing to the AI Core service. + */ + @SuppressWarnings("UnstableApiUsage") + static HttpDestination getDestination(@Nullable final String serviceKey) { + final Predicate aiCore = Optional.of(ServiceIdentifier.AI_CORE)::equals; + final var serviceBindings = accessor.getServiceBindings(); + final var aiCoreBinding = tryFind(serviceBindings, b -> aiCore.test(b.getServiceIdentifier())); + + final var serviceKeyPresent = serviceKey != null; + if (!aiCoreBinding.isPresent() && serviceKeyPresent) { + addServiceBinding(serviceKey); + } + + // get a destination pointing to the AI Core service + final var opts = + (aiCoreBinding.isPresent() + ? forService(aiCoreBinding.get()) + : forService(ServiceIdentifier.AI_CORE)) + .onBehalfOf(TECHNICAL_USER_PROVIDER) + .build(); + + return ServiceBindingDestinationLoader.defaultLoaderChain().getDestination(opts); + } + + /** + * Set the AI Core service key as the service binding. This is used for local testing. + * + * @param serviceKey The service key in JSON format. + * @throws AiCoreCredentialsInvalidException if the JSON service key cannot be parsed. + */ + private static void addServiceBinding(@Nonnull final String serviceKey) { + log.info( + """ + Found a service key in environment variable "AICORE_SERVICE_KEY". + Using a service key is recommended for local testing only. + Bind the AI Core service to the application for productive usage."""); + + var credentials = new HashMap(); + try { + credentials = new ObjectMapper().readValue(serviceKey, new TypeReference<>() {}); + } catch (JsonProcessingException e) { + throw new AiCoreCredentialsInvalidException( + "Error in parsing service key from the \"AICORE_SERVICE_KEY\" environment variable.", e); + } + + final var binding = + new DefaultServiceBindingBuilder() + .withServiceIdentifier(ServiceIdentifier.AI_CORE) + .withCredentials(credentials) + .build(); + final var newAccessor = + new ServiceBindingMerger( + List.of(accessor, () -> List.of(binding)), ServiceBindingMerger.KEEP_EVERYTHING); + DefaultServiceBindingAccessor.setInstance(newAccessor); + } + + /** Exception thrown when the JSON AI Core service key is invalid. */ + static class AiCoreCredentialsInvalidException extends RuntimeException { + public AiCoreCredentialsInvalidException( + @Nonnull final String message, @Nonnull final Throwable cause) { + super(message, cause); + } + } + + /** + * For testing set the accessor to be used for service binding resolution. + * + * @param accessor The accessor to be used for service binding resolution. + */ + static void setAccessor(@Nullable final ServiceBindingAccessor accessor) { + DestinationResolver.accessor = + accessor == null ? DefaultServiceBindingAccessor.getInstance() : accessor; + } +} diff --git a/core/src/test/java/com/sap/ai/sdk/core/AiCoreServiceTest.java b/core/src/test/java/com/sap/ai/sdk/core/AiCoreServiceTest.java new file mode 100644 index 00000000..7a275818 --- /dev/null +++ b/core/src/test/java/com/sap/ai/sdk/core/AiCoreServiceTest.java @@ -0,0 +1,127 @@ +package com.sap.ai.sdk.core; + +import static com.sap.ai.sdk.core.DestinationResolver.AI_CLIENT_TYPE_KEY; +import static com.sap.ai.sdk.core.DestinationResolver.AI_CLIENT_TYPE_VALUE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.STRING; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.sap.cloud.environment.servicebinding.api.DefaultServiceBinding; +import com.sap.cloud.environment.servicebinding.api.ServiceBindingAccessor; +import com.sap.cloud.environment.servicebinding.api.ServiceIdentifier; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; +import com.sap.cloud.sdk.cloudplatform.connectivity.DestinationProperty; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import javax.annotation.Nonnull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +public class AiCoreServiceTest { + + // setup + private static final Map URLS = Map.of("AI_API_URL", "https://srv"); + private static final Map CREDENTIALS = + Map.of("clientid", "id", "clientsecret", "pw", "url", "https://auth", "serviceurls", URLS); + private static final DefaultServiceBinding BINDING = + DefaultServiceBinding.builder() + .copy(Map.of()) + .withServiceIdentifier(ServiceIdentifier.AI_CORE) + .withCredentials(CREDENTIALS) + .build(); + + @AfterEach + void tearDown() { + DestinationResolver.setAccessor(null); + } + + @Test + void testLazyEvaluation() { + // setup + final var accessor = mock(ServiceBindingAccessor.class); + DestinationResolver.setAccessor(accessor); + + // execution without errors + new AiCoreService(); + + // verification + verify(accessor, never()).getServiceBindings(); + } + + @Test + void testDefaultCase() { + // setup + final var accessor = mock(ServiceBindingAccessor.class); + DestinationResolver.setAccessor(accessor); + doReturn(List.of(BINDING)).when(accessor).getServiceBindings(); + + // execution without errors + final var core = new AiCoreService(); + final var destination = core.destination(); + final var client = core.client(); + + // verification + assertThat(destination.get(DestinationProperty.URI)).contains("https://srv/v2/"); + assertThat(destination.get(DestinationProperty.AUTH_TYPE)).isEmpty(); + assertThat(destination.get(DestinationProperty.NAME)).singleElement(STRING).contains("aicore"); + assertThat(destination.get(AI_CLIENT_TYPE_KEY)).contains(AI_CLIENT_TYPE_VALUE); + assertThat(client.getBasePath()).isEqualTo("https://srv/v2/"); + verify(accessor, times(2)).getServiceBindings(); + } + + @Test + void testBaseDestination() { + // setup + DestinationResolver.setAccessor(Collections::emptyList); + + // execution without errors + final var customDestination = DefaultHttpDestination.builder("https://foo.bar").build(); + final var core = new AiCoreService().withDestination(customDestination); + final var destination = core.destination(); + final var client = core.client(); + + // verification + assertThat(destination.get(DestinationProperty.URI)).contains("https://foo.bar/v2/"); + assertThat(destination.get(DestinationProperty.AUTH_TYPE)).isEmpty(); + assertThat(destination.get(DestinationProperty.NAME)).isEmpty(); + assertThat(destination.get(AI_CLIENT_TYPE_KEY)).contains(AI_CLIENT_TYPE_VALUE); + assertThat(client.getBasePath()).isEqualTo("https://foo.bar/v2/"); + } + + @Test + void testCustomization() { + final var customService = + new AiCoreService() { + @Nonnull + @Override + protected Destination getBaseDestination() { + return DefaultHttpDestination.builder("https://ai").build(); + } + + @Nonnull + @Override + protected ApiClient getApiClient(@Nonnull Destination destination) { + return new ApiClient().setBasePath("https://fizz.buzz").setUserAgent("SAP"); + } + }; + + final var customServiceForDeployment = + customService.forDeployment("deployment").withResourceGroup("group"); + + final var client = customServiceForDeployment.client(); + assertThat(client.getBasePath()).isEqualTo("https://fizz.buzz"); + + final var destination = customServiceForDeployment.destination().asHttp(); + assertThat(destination.getUri()).hasToString("https://ai/v2/inference/deployments/deployment/"); + + final var resourceGroup = customServiceForDeployment.getResourceGroup(); + assertThat(resourceGroup).isEqualTo("group"); + } +} diff --git a/core/src/test/java/com/sap/ai/sdk/core/CacheTest.java b/core/src/test/java/com/sap/ai/sdk/core/CacheTest.java new file mode 100644 index 00000000..bd2850ac --- /dev/null +++ b/core/src/test/java/com/sap/ai/sdk/core/CacheTest.java @@ -0,0 +1,144 @@ +package com.sap.ai.sdk.core; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; + +import com.sap.ai.sdk.core.client.WireMockTestServer; +import org.apache.hc.core5.http.HttpStatus; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class CacheTest extends WireMockTestServer { + + private final DeploymentCache cacheUnderTest = new DeploymentCache(); + + @BeforeEach + void setupCache() { + wireMockServer.resetRequests(); + } + + private static void stubGPT4() { + wireMockServer.stubFor( + get(urlPathEqualTo("/v2/lm/deployments")) + .withHeader("AI-Resource-Group", equalTo("default")) + .willReturn( + aResponse() + .withStatus(HttpStatus.SC_OK) + .withHeader("content-type", "application/json") + .withBody( + """ + { + "count": 1, + "resources": [ + { + "configurationId": "7652a231-ba9b-4fcc-b473-2c355cb21b61", + "configurationName": "gpt-4-32k", + "createdAt": "2024-04-17T15:19:53Z", + "deploymentUrl": "https://api.ai.intprod-eu12.eu-central-1.aws.ml.hana.ondemand.com/v2/inference/deployments/d19b998f347341aa", + "details": { + "resources": { + "backend_details": { + "model": { + "name": "gpt-4-32k", + "version": "latest" + } + } + }, + "scaling": { + "backend_details": {} + } + }, + "id": "d19b998f347341aa", + "lastOperation": "CREATE", + "latestRunningConfigurationId": "7652a231-ba9b-4fcc-b473-2c355cb21b61", + "modifiedAt": "2024-05-07T13:05:45Z", + "scenarioId": "foundation-models", + "startTime": "2024-04-17T15:21:15Z", + "status": "RUNNING", + "submissionTime": "2024-04-17T15:20:11Z", + "targetStatus": "RUNNING" + } + ] + } + """))); + } + + private static void stubEmpty() { + wireMockServer.stubFor( + get(urlPathEqualTo("/v2/lm/deployments")) + .withHeader("AI-Resource-Group", equalTo("default")) + .willReturn( + aResponse() + .withStatus(HttpStatus.SC_OK) + .withHeader("content-type", "application/json") + .withBody( + """ + { + "count": 0, + "resources": [] + } + """))); + } + + /** + * The user creates a deployment. + * + *

The user uses the OpenAI client and specifies only the name "foo". + * + *

The user repeatedly uses the API in the same way + * + *

Simple case, the deployment should be served from cache as much as possible + */ + @Test + void newDeployment() { + stubGPT4(); + cacheUnderTest.loadCache(client, "default"); + + cacheUnderTest.getDeploymentIdByModel(client, "default", "gpt-4-32k"); + wireMockServer.verify(1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); + + cacheUnderTest.getDeploymentIdByModel(client, "default", "gpt-4-32k"); + wireMockServer.verify(1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); + } + + @Test + void clearCache() { + stubGPT4(); + cacheUnderTest.loadCache(client, "default"); + + cacheUnderTest.getDeploymentIdByModel(client, "default", "gpt-4-32k"); + wireMockServer.verify(1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); + + cacheUnderTest.clearCache(); + + cacheUnderTest.getDeploymentIdByModel(client, "default", "gpt-4-32k"); + // the deployment is not in the cache anymore, so we need to query it again + wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); + } + + /** + * The user creates a deployment after starting with an empty cache. + * + *

The user uses the OpenAI client and specifies only the name "foo". + * + *

The user repeatedly uses the API in the same way + * + *

Simple case, the deployment should be served from cache as much as possible + */ + @Test + void newDeploymentAfterReset() { + stubEmpty(); + cacheUnderTest.loadCache(client, "default"); + stubGPT4(); + + cacheUnderTest.getDeploymentIdByModel(client, "default", "gpt-4-32k"); + // 1 reset empty and 1 cache miss + wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); + + cacheUnderTest.getDeploymentIdByModel(client, "default", "gpt-4-32k"); + wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); + } +} diff --git a/core/src/test/java/com/sap/ai/sdk/core/CoreTest.java b/core/src/test/java/com/sap/ai/sdk/core/DestinationResolverTest.java similarity index 77% rename from core/src/test/java/com/sap/ai/sdk/core/CoreTest.java rename to core/src/test/java/com/sap/ai/sdk/core/DestinationResolverTest.java index 021d05aa..c6c8391a 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/CoreTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/DestinationResolverTest.java @@ -1,6 +1,5 @@ package com.sap.ai.sdk.core; -import static com.sap.ai.sdk.core.Core.getDestination; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -8,12 +7,12 @@ import lombok.SneakyThrows; import org.junit.jupiter.api.Test; -public class CoreTest { +public class DestinationResolverTest { @Test @SneakyThrows void getDestinationWithoutEnvVarFailsLocally() { - assertThatThrownBy(() -> getDestination(null)) + assertThatThrownBy(() -> DestinationResolver.getDestination(null)) .isExactlyInstanceOf(DestinationAccessException.class) .hasMessage("Could not find any matching service bindings for service identifier 'aicore'"); } @@ -21,8 +20,8 @@ void getDestinationWithoutEnvVarFailsLocally() { @Test @SneakyThrows void getDestinationWithBrokenEnvVarFailsLocally() { - assertThatThrownBy(() -> getDestination("")) - .isExactlyInstanceOf(Core.AiCoreCredentialsInvalidException.class) + assertThatThrownBy(() -> DestinationResolver.getDestination("")) + .isExactlyInstanceOf(DestinationResolver.AiCoreCredentialsInvalidException.class) .hasMessage( "Error in parsing service key from the \"AICORE_SERVICE_KEY\" environment variable."); } @@ -44,7 +43,7 @@ void getDestinationWithEnvVarSucceedsLocally() { } } """; - var result = getDestination(AICORE_SERVICE_KEY).asHttp(); - assertThat(result.getUri()).hasToString("https://api.ai.core/v2"); + var result = DestinationResolver.getDestination(AICORE_SERVICE_KEY).asHttp(); + assertThat(result.getUri()).hasToString("https://api.ai.core"); } } diff --git a/core/src/test/java/com/sap/ai/sdk/core/client/ArtifactUnitTest.java b/core/src/test/java/com/sap/ai/sdk/core/client/ArtifactUnitTest.java index a0dc8c7c..dcfbf505 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/client/ArtifactUnitTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/client/ArtifactUnitTest.java @@ -7,7 +7,6 @@ import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static com.sap.ai.sdk.core.Core.getClient; import static org.assertj.core.api.Assertions.assertThat; import com.sap.ai.sdk.core.client.model.AiArtifact; @@ -25,7 +24,7 @@ public class ArtifactUnitTest extends WireMockTestServer { @Test void getArtifacts() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/artifacts")) + get(urlPathEqualTo("/v2/lm/artifacts")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -50,7 +49,7 @@ void getArtifacts() { } """))); - final AiArtifactList artifactList = new ArtifactApi(getClient(destination)).query("default"); + final AiArtifactList artifactList = new ArtifactApi(client).query("default"); assertThat(artifactList).isNotNull(); assertThat(artifactList.getCount()).isEqualTo(1); @@ -71,7 +70,7 @@ void getArtifacts() { @Test void postArtifact() { wireMockServer.stubFor( - post(urlPathEqualTo("/lm/artifacts")) + post(urlPathEqualTo("/v2/lm/artifacts")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -95,7 +94,7 @@ void postArtifact() { .scenarioId("foundation-models") .description("dataset for aicore training"); final AiArtifactCreationResponse artifact = - new ArtifactApi(getClient(destination)).create("default", artifactPostData); + new ArtifactApi(client).create("default", artifactPostData); assertThat(artifact).isNotNull(); assertThat(artifact.getId()).isEqualTo("1a84bb38-4a84-4d12-a5aa-300ae7d33fb4"); @@ -103,7 +102,7 @@ void postArtifact() { assertThat(artifact.getUrl()).isEqualTo("ai://default/spam/data"); wireMockServer.verify( - postRequestedFor(urlPathEqualTo("/lm/artifacts")) + postRequestedFor(urlPathEqualTo("/v2/lm/artifacts")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( @@ -122,7 +121,7 @@ void postArtifact() { @Test void getArtifactById() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/artifacts/777dea85-e9b1-4a7b-9bea-14769b977633")) + get(urlPathEqualTo("/v2/lm/artifacts/777dea85-e9b1-4a7b-9bea-14769b977633")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -142,8 +141,7 @@ void getArtifactById() { """))); final AiArtifact artifact = - new ArtifactApi(getClient(destination)) - .get("default", "777dea85-e9b1-4a7b-9bea-14769b977633"); + new ArtifactApi(client).get("default", "777dea85-e9b1-4a7b-9bea-14769b977633"); assertThat(artifact).isNotNull(); assertThat(artifact.getCreatedAt()).isEqualTo("2024-08-23T09:13:21Z"); @@ -160,7 +158,7 @@ void getArtifactById() { @Test void getArtifactCount() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/artifacts/$count")) + get(urlPathEqualTo("/v2/lm/artifacts/$count")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -170,7 +168,7 @@ void getArtifactCount() { 4 """))); - final int count = new ArtifactApi(getClient(destination)).count("default"); + final int count = new ArtifactApi(client).count("default"); assertThat(count).isEqualTo(4); } diff --git a/core/src/test/java/com/sap/ai/sdk/core/client/ConfigurationUnitTest.java b/core/src/test/java/com/sap/ai/sdk/core/client/ConfigurationUnitTest.java index a223b778..6f2a0a5c 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/client/ConfigurationUnitTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/client/ConfigurationUnitTest.java @@ -7,7 +7,6 @@ import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static com.sap.ai.sdk.core.Core.getClient; import static org.assertj.core.api.Assertions.assertThat; import com.sap.ai.sdk.core.client.model.AiArtifactArgumentBinding; @@ -27,7 +26,7 @@ public class ConfigurationUnitTest extends WireMockTestServer { @Test void getConfigurations() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/configurations")) + get(urlPathEqualTo("/v2/lm/configurations")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -60,8 +59,7 @@ void getConfigurations() { } """))); - final AiConfigurationList configurationList = - new ConfigurationApi(getClient(destination)).query("default"); + final AiConfigurationList configurationList = new ConfigurationApi(client).query("default"); assertThat(configurationList).isNotNull(); assertThat(configurationList.getCount()).isEqualTo(1); @@ -82,7 +80,7 @@ void getConfigurations() { @Test void postConfiguration() { wireMockServer.stubFor( - post(urlPathEqualTo("/lm/configurations")) + post(urlPathEqualTo("/v2/lm/configurations")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -107,14 +105,14 @@ void postConfiguration() { .scenarioId("foundation-models") .addInputArtifactBindingsItem(inputArtifactBindingsItem); final AiConfigurationCreationResponse configuration = - new ConfigurationApi(getClient(destination)).create("default", configurationBaseData); + new ConfigurationApi(client).create("default", configurationBaseData); assertThat(configuration).isNotNull(); assertThat(configuration.getId()).isEqualTo("f88e7581-ade7-45c6-94e9-807889b523ec"); assertThat(configuration.getMessage()).isEqualTo("Configuration created"); wireMockServer.verify( - postRequestedFor(urlPathEqualTo("/lm/configurations")) + postRequestedFor(urlPathEqualTo("/v2/lm/configurations")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( @@ -137,7 +135,7 @@ void postConfiguration() { @Test void getConfigurationCount() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/configurations/$count")) + get(urlPathEqualTo("/v2/lm/configurations/$count")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -147,7 +145,7 @@ void getConfigurationCount() { 3 """))); - final int configurationCount = new ConfigurationApi(getClient(destination)).count("default"); + final int configurationCount = new ConfigurationApi(client).count("default"); assertThat(configurationCount).isEqualTo(3); } @@ -155,7 +153,7 @@ void getConfigurationCount() { @Test void getConfigurationById() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/configurations/6ff6cb80-87db-45f0-b718-4e1d96e66332")) + get(urlPathEqualTo("/v2/lm/configurations/6ff6cb80-87db-45f0-b718-4e1d96e66332")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -185,8 +183,7 @@ void getConfigurationById() { """))); final AiConfiguration configuration = - new ConfigurationApi(getClient(destination)) - .get("default", "6ff6cb80-87db-45f0-b718-4e1d96e66332"); + new ConfigurationApi(client).get("default", "6ff6cb80-87db-45f0-b718-4e1d96e66332"); assertThat(configuration).isNotNull(); assertThat(configuration.getCreatedAt()).isEqualTo("2024-09-11T09:14:31Z"); diff --git a/core/src/test/java/com/sap/ai/sdk/core/client/DeploymentUnitTest.java b/core/src/test/java/com/sap/ai/sdk/core/client/DeploymentUnitTest.java index 5c617a65..a6804b23 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/client/DeploymentUnitTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/client/DeploymentUnitTest.java @@ -10,7 +10,6 @@ import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static com.sap.ai.sdk.core.Core.getClient; import static org.assertj.core.api.Assertions.assertThat; import com.sap.ai.sdk.core.client.model.AiDeployment; @@ -43,7 +42,7 @@ public class DeploymentUnitTest extends WireMockTestServer { @Test void getDeployments() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/deployments")) + get(urlPathEqualTo("/v2/lm/deployments")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -86,8 +85,7 @@ void getDeployments() { } """))); - final AiDeploymentList deploymentList = - new DeploymentApi(getClient(destination)).query("default"); + final AiDeploymentList deploymentList = new DeploymentApi(client).query("default"); assertThat(deploymentList).isNotNull(); assertThat(deploymentList.getCount()).isEqualTo(1); @@ -121,7 +119,7 @@ void getDeployments() { @Test void postDeployment() { wireMockServer.stubFor( - post(urlPathEqualTo("/lm/deployments")) + post(urlPathEqualTo("/v2/lm/deployments")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -141,7 +139,7 @@ void postDeployment() { AiDeploymentCreationRequest.create() .configurationId("7652a231-ba9b-4fcc-b473-2c355cb21b61"); final AiDeploymentCreationResponse deployment = - new DeploymentApi(getClient(destination)).create("default", deploymentCreationRequest); + new DeploymentApi(client).create("default", deploymentCreationRequest); assertThat(deployment).isNotNull(); assertThat(deployment.getDeploymentUrl()).isEmpty(); @@ -150,7 +148,7 @@ void postDeployment() { assertThat(deployment.getStatus()).isEqualTo(AiExecutionStatus.UNKNOWN_DEFAULT_OPEN_API); wireMockServer.verify( - postRequestedFor(urlPathEqualTo("/lm/deployments")) + postRequestedFor(urlPathEqualTo("/v2/lm/deployments")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( @@ -164,7 +162,7 @@ void postDeployment() { @Test void patchDeploymentStatus() { wireMockServer.stubFor( - patch(urlPathEqualTo("/lm/deployments/d19b998f347341aa")) + patch(urlPathEqualTo("/v2/lm/deployments/d19b998f347341aa")) .willReturn( aResponse() .withStatus(HttpStatus.SC_ACCEPTED) @@ -180,8 +178,7 @@ void patchDeploymentStatus() { final AiDeploymentModificationRequest configModification = AiDeploymentModificationRequest.create().targetStatus(AiDeploymentTargetStatus.STOPPED); final AiDeploymentModificationResponse deployment = - new DeploymentApi(getClient(destination)) - .modify("default", "d19b998f347341aa", configModification); + new DeploymentApi(client).modify("default", "d19b998f347341aa", configModification); assertThat(deployment).isNotNull(); assertThat(deployment.getId()).isEqualTo("d5b764fe55b3e87c"); @@ -189,7 +186,7 @@ void patchDeploymentStatus() { // verify that null fields are absent from the sent request wireMockServer.verify( - patchRequestedFor(urlPathEqualTo("/lm/deployments/d19b998f347341aa")) + patchRequestedFor(urlPathEqualTo("/v2/lm/deployments/d19b998f347341aa")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( @@ -203,7 +200,7 @@ void patchDeploymentStatus() { @Test void deleteDeployment() { wireMockServer.stubFor( - delete(urlPathEqualTo("/lm/deployments/d5b764fe55b3e87c")) + delete(urlPathEqualTo("/v2/lm/deployments/d5b764fe55b3e87c")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -219,7 +216,7 @@ void deleteDeployment() { """))); final AiDeploymentDeletionResponse deployment = - new DeploymentApi(getClient(destination)).delete("default", "d5b764fe55b3e87c"); + new DeploymentApi(client).delete("default", "d5b764fe55b3e87c"); assertThat(deployment).isNotNull(); assertThat(deployment.getId()).isEqualTo("d5b764fe55b3e87c"); @@ -230,7 +227,7 @@ void deleteDeployment() { @Test void getDeploymentById() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/deployments/db1d64d9f06be467")) + get(urlPathEqualTo("/v2/lm/deployments/db1d64d9f06be467")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -264,7 +261,7 @@ void getDeploymentById() { """))); final AiDeploymentResponseWithDetails deployment = - new DeploymentApi(getClient(destination)).get("default", "db1d64d9f06be467"); + new DeploymentApi(client).get("default", "db1d64d9f06be467"); assertThat(deployment).isNotNull(); assertThat(deployment.getConfigurationId()).isEqualTo("dd80625e-ad86-426a-b1a7-1494c083428f"); @@ -294,7 +291,7 @@ void getDeploymentById() { @Test void patchDeploymentConfiguration() { wireMockServer.stubFor( - patch(urlPathEqualTo("/lm/deployments/d03050a2ab7055cc")) + patch(urlPathEqualTo("/v2/lm/deployments/d03050a2ab7055cc")) .willReturn( aResponse() .withStatus(HttpStatus.SC_ACCEPTED) @@ -311,8 +308,7 @@ void patchDeploymentConfiguration() { AiDeploymentModificationRequest.create() .configurationId("6ff6cb80-87db-45f0-b718-4e1d96e66332"); final AiDeploymentModificationResponse deployment = - new DeploymentApi(getClient(destination)) - .modify("default", "d03050a2ab7055cc", configModification); + new DeploymentApi(client).modify("default", "d03050a2ab7055cc", configModification); assertThat(deployment).isNotNull(); assertThat(deployment.getId()).isEqualTo("d03050a2ab7055cc"); @@ -320,7 +316,7 @@ void patchDeploymentConfiguration() { // verify that null fields are absent from the sent request wireMockServer.verify( - patchRequestedFor(urlPathEqualTo("/lm/deployments/d03050a2ab7055cc")) + patchRequestedFor(urlPathEqualTo("/v2/lm/deployments/d03050a2ab7055cc")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( @@ -334,7 +330,7 @@ void patchDeploymentConfiguration() { @Test void getDeploymentCount() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/deployments/$count")) + get(urlPathEqualTo("/v2/lm/deployments/$count")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -344,7 +340,7 @@ void getDeploymentCount() { 1 """))); - final int count = new DeploymentApi(getClient(destination)).count("default"); + final int count = new DeploymentApi(client).count("default"); assertThat(count).isEqualTo(1); } @@ -352,7 +348,7 @@ void getDeploymentCount() { @Test void getDeploymentLogs() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/deployments/d19b998f347341aa/logs")) + get(urlPathEqualTo("/v2/lm/deployments/d19b998f347341aa/logs")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -377,7 +373,7 @@ void getDeploymentLogs() { // `Ai-Resource-Group` header needs explicit inclusion as kubesubmitV4DeploymentsGetLogs missed // to include the header on the request. final RTALogCommonResponse logs = - new DeploymentApi(getClient(destination).addDefaultHeader("Ai-Resource-Group", "default")) + new DeploymentApi(client.addDefaultHeader("Ai-Resource-Group", "default")) .getLogs("d19b998f347341aa"); assertThat(logs).isNotNull(); @@ -399,7 +395,7 @@ void getDeploymentLogs() { @Test void patchBulkDeployments() { wireMockServer.stubFor( - patch(urlPathEqualTo("/lm/deployments")) + patch(urlPathEqualTo("/v2/lm/deployments")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -427,7 +423,7 @@ void patchBulkDeployments() { AiDeploymentModificationRequestWithIdentifier.TargetStatusEnum .STOPPED))); final AiDeploymentBulkModificationResponse bulkModificationResponse = - new DeploymentApi(getClient(destination)).batchModify("default", bulkModificationRequest); + new DeploymentApi(client).batchModify("default", bulkModificationRequest); assertThat(bulkModificationResponse).isNotNull(); assertThat(bulkModificationResponse.getDeployments()).hasSize(1); @@ -440,7 +436,7 @@ void patchBulkDeployments() { .isEqualTo("Deployment modification scheduled"); wireMockServer.verify( - patchRequestedFor(urlPathEqualTo("/lm/deployments")) + patchRequestedFor(urlPathEqualTo("/v2/lm/deployments")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( diff --git a/core/src/test/java/com/sap/ai/sdk/core/client/ExecutionUnitTest.java b/core/src/test/java/com/sap/ai/sdk/core/client/ExecutionUnitTest.java index 0a260b54..53e6c9ad 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/client/ExecutionUnitTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/client/ExecutionUnitTest.java @@ -10,7 +10,6 @@ import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static com.sap.ai.sdk.core.Core.getClient; import static org.assertj.core.api.Assertions.assertThat; import com.sap.ai.sdk.core.client.model.AiArtifact; @@ -41,7 +40,7 @@ public class ExecutionUnitTest extends WireMockTestServer { @Test void getExecutions() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/executions")) + get(urlPathEqualTo("/v2/lm/executions")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -83,7 +82,7 @@ void getExecutions() { } """))); - final AiExecutionList executionList = new ExecutionApi(getClient(destination)).query("default"); + final AiExecutionList executionList = new ExecutionApi(client).query("default"); assertThat(executionList).isNotNull(); assertThat(executionList.getCount()).isEqualTo(1); @@ -120,7 +119,7 @@ void getExecutions() { @Test void postExecution() { wireMockServer.stubFor( - post(urlPathEqualTo("/lm/executions")) + post(urlPathEqualTo("/v2/lm/executions")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -138,7 +137,7 @@ void postExecution() { final AiEnactmentCreationRequest enactmentCreationRequest = AiEnactmentCreationRequest.create().configurationId("e0a9eb2e-9ea1-43bf-aff5-7660db166676"); final AiExecutionCreationResponse execution = - new ExecutionApi(getClient(destination)).create("default", enactmentCreationRequest); + new ExecutionApi(client).create("default", enactmentCreationRequest); assertThat(execution).isNotNull(); assertThat(execution.getId()).isEqualTo("eab289226fe981da"); @@ -146,7 +145,7 @@ void postExecution() { assertThat(execution.getCustomField("url")).isEqualTo("ai://default/eab289226fe981da"); wireMockServer.verify( - postRequestedFor(urlPathEqualTo("/lm/executions")) + postRequestedFor(urlPathEqualTo("/v2/lm/executions")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( @@ -160,7 +159,7 @@ void postExecution() { @Test void getExecutionById() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/executions/e529e8bd58740bc9")) + get(urlPathEqualTo("/v2/lm/executions/e529e8bd58740bc9")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -198,7 +197,7 @@ void getExecutionById() { """))); final AiExecutionResponseWithDetails execution = - new ExecutionApi(getClient(destination)).get("default", "e529e8bd58740bc9"); + new ExecutionApi(client).get("default", "e529e8bd58740bc9"); assertThat(execution).isNotNull(); assertThat(execution.getCompletionTime()).isEqualTo("2024-09-09T19:10:58Z"); @@ -233,7 +232,7 @@ void getExecutionById() { @Test void deleteExecution() { wireMockServer.stubFor( - delete(urlPathEqualTo("/lm/executions/e529e8bd58740bc9")) + delete(urlPathEqualTo("/v2/lm/executions/e529e8bd58740bc9")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -249,7 +248,7 @@ void deleteExecution() { """))); final AiExecutionDeletionResponse execution = - new ExecutionApi(getClient(destination)).delete("default", "e529e8bd58740bc9"); + new ExecutionApi(client).delete("default", "e529e8bd58740bc9"); assertThat(execution).isNotNull(); assertThat(execution.getId()).isEqualTo("e529e8bd58740bc9"); @@ -261,7 +260,7 @@ void deleteExecution() { @Test void patchExecution() { wireMockServer.stubFor( - patch(urlPathEqualTo("/lm/executions/eec3c6ea18bac6da")) + patch(urlPathEqualTo("/v2/lm/executions/eec3c6ea18bac6da")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -279,7 +278,7 @@ void patchExecution() { AiExecutionModificationRequest.create() .targetStatus(AiExecutionModificationRequest.TargetStatusEnum.STOPPED); final AiExecutionModificationResponse aiExecutionModificationResponse = - new ExecutionApi(getClient(destination)) + new ExecutionApi(client) .modify("default", "eec3c6ea18bac6da", aiExecutionModificationRequest); assertThat(aiExecutionModificationResponse).isNotNull(); @@ -288,7 +287,7 @@ void patchExecution() { .isEqualTo("Execution modification scheduled"); wireMockServer.verify( - patchRequestedFor(urlPathEqualTo("/lm/executions/eec3c6ea18bac6da")) + patchRequestedFor(urlPathEqualTo("/v2/lm/executions/eec3c6ea18bac6da")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody(equalToJson("{\"targetStatus\":\"STOPPED\"}"))); } @@ -296,7 +295,7 @@ void patchExecution() { @Test void getExecutionCount() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/executions/$count")) + get(urlPathEqualTo("/v2/lm/executions/$count")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -306,7 +305,7 @@ void getExecutionCount() { 1 """))); - final int count = new ExecutionApi(getClient(destination)).count("default"); + final int count = new ExecutionApi(client).count("default"); assertThat(count).isEqualTo(1); } @@ -314,7 +313,7 @@ void getExecutionCount() { @Test void getExecutionLogs() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/executions/ee467bea5af28adb/logs")) + get(urlPathEqualTo("/v2/lm/executions/ee467bea5af28adb/logs")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -337,7 +336,7 @@ void getExecutionLogs() { """))); final RTALogCommonResponse logResponse = - new ExecutionApi(getClient(destination).addDefaultHeader("AI-Resource-Group", "default")) + new ExecutionApi(client.addDefaultHeader("AI-Resource-Group", "default")) .getLogs("ee467bea5af28adb"); assertThat(logResponse).isNotNull(); @@ -358,7 +357,7 @@ void getExecutionLogs() { @Test void patchBulkExecutions() { wireMockServer.stubFor( - patch(urlPathEqualTo("/lm/executions")) + patch(urlPathEqualTo("/v2/lm/executions")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -386,8 +385,7 @@ void patchBulkExecutions() { AiExecutionModificationRequestWithIdentifier.TargetStatusEnum .STOPPED))); final AiExecutionBulkModificationResponse executionBulkModificationResponse = - new ExecutionApi(getClient(destination)) - .batchModify("default", executionBulkModificationRequest); + new ExecutionApi(client).batchModify("default", executionBulkModificationRequest); assertThat(executionBulkModificationResponse).isNotNull(); assertThat(executionBulkModificationResponse.getExecutions().size()).isEqualTo(1); @@ -400,7 +398,7 @@ void patchBulkExecutions() { .isEqualTo("Execution modification scheduled"); wireMockServer.verify( - patchRequestedFor(urlPathEqualTo("/lm/executions")) + patchRequestedFor(urlPathEqualTo("/v2/lm/executions")) .withHeader("AI-Resource-Group", equalTo("default")) .withRequestBody( equalToJson( diff --git a/core/src/test/java/com/sap/ai/sdk/core/client/ScenarioUnitTest.java b/core/src/test/java/com/sap/ai/sdk/core/client/ScenarioUnitTest.java index e536b04f..9ec203d4 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/client/ScenarioUnitTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/client/ScenarioUnitTest.java @@ -4,7 +4,6 @@ import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; import static com.github.tomakehurst.wiremock.client.WireMock.get; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static com.sap.ai.sdk.core.Core.getClient; import static org.assertj.core.api.Assertions.assertThat; import com.sap.ai.sdk.core.client.model.AiModelBaseData; @@ -26,7 +25,7 @@ public class ScenarioUnitTest extends WireMockTestServer { @Test void getScenarios() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/scenarios")) + get(urlPathEqualTo("/v2/lm/scenarios")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -54,7 +53,7 @@ void getScenarios() { } """))); - final AiScenarioList scenarioList = new ScenarioApi(getClient(destination)).query("default"); + final AiScenarioList scenarioList = new ScenarioApi(client).query("default"); assertThat(scenarioList).isNotNull(); assertThat(scenarioList.getCount()).isEqualTo(1); @@ -72,7 +71,7 @@ void getScenarios() { @Test void getScenarioVersions() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/scenarios/foundation-models/versions")) + get(urlPathEqualTo("/v2/lm/scenarios/foundation-models/versions")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -94,7 +93,7 @@ void getScenarioVersions() { """))); final AiVersionList versionList = - new ScenarioApi(getClient(destination)).queryVersions("default", "foundation-models"); + new ScenarioApi(client).queryVersions("default", "foundation-models"); assertThat(versionList).isNotNull(); assertThat(versionList.getCount()).isEqualTo(1); @@ -111,7 +110,7 @@ void getScenarioVersions() { @Test void getScenarioById() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/scenarios/foundation-models")) + get(urlPathEqualTo("/v2/lm/scenarios/foundation-models")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -134,8 +133,7 @@ void getScenarioById() { } """))); - final AiScenario scenario = - new ScenarioApi(getClient(destination)).get("default", "foundation-models"); + final AiScenario scenario = new ScenarioApi(client).get("default", "foundation-models"); assertThat(scenario).isNotNull(); assertThat(scenario.getCreatedAt()).isEqualTo("2023-11-03T14:02:46+00:00"); @@ -152,7 +150,7 @@ void getScenarioById() { @Test void getScenarioModels() { wireMockServer.stubFor( - get(urlPathEqualTo("/lm/scenarios/foundation-models/models")) + get(urlPathEqualTo("/v2/lm/scenarios/foundation-models/models")) .withHeader("AI-Resource-Group", equalTo("default")) .willReturn( aResponse() @@ -181,7 +179,7 @@ void getScenarioModels() { """))); final AiModelList scenarioList = - new ScenarioApi(getClient(destination)).queryModels("foundation-models", "default"); + new ScenarioApi(client).queryModels("foundation-models", "default"); assertThat(scenarioList).isNotNull(); assertThat(scenarioList.getCount()).isEqualTo(1); diff --git a/core/src/test/java/com/sap/ai/sdk/core/client/WireMockTestServer.java b/core/src/test/java/com/sap/ai/sdk/core/client/WireMockTestServer.java index 816e1f02..cadbf181 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/client/WireMockTestServer.java +++ b/core/src/test/java/com/sap/ai/sdk/core/client/WireMockTestServer.java @@ -4,24 +4,27 @@ import com.github.tomakehurst.wiremock.WireMockServer; import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import com.sap.ai.sdk.core.AiCoreService; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; -import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; /** Test server for all unit tests. */ -abstract class WireMockTestServer { +public abstract class WireMockTestServer { private static final WireMockConfiguration WIREMOCK_CONFIGURATION = wireMockConfig().dynamicPort(); - static WireMockServer wireMockServer; - static Destination destination; + public static WireMockServer wireMockServer; + public static ApiClient client; @BeforeAll static void setup() { wireMockServer = new WireMockServer(WIREMOCK_CONFIGURATION); wireMockServer.start(); - destination = DefaultHttpDestination.builder(wireMockServer.baseUrl()).build(); + + final var destination = DefaultHttpDestination.builder(wireMockServer.baseUrl()).build(); + client = new AiCoreService().withDestination(destination).client(); } // Reset WireMock before each test to ensure clean state diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java index 819d01f2..25a5d922 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java @@ -6,7 +6,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import com.sap.ai.sdk.core.Core; +import com.sap.ai.sdk.core.AiCoreService; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; @@ -58,7 +58,11 @@ public final class OpenAiClient { */ @Nonnull public static OpenAiClient forModel(@Nonnull final OpenAiModel foundationModel) { - final var destination = Core.getDestinationForModel(foundationModel.model(), "default"); + final var destination = + new AiCoreService() + .forDeploymentByModel(foundationModel.model()) + .withResourceGroup("default") + .destination(); final var client = new OpenAiClient(destination); return client.withApiVersion(DEFAULT_API_VERSION); } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java index 860684d2..f810db42 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java @@ -1,14 +1,16 @@ package com.sap.ai.sdk.orchestration.client; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.get; import static com.github.tomakehurst.wiremock.client.WireMock.jsonResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.verify; -import static com.sap.ai.sdk.core.Core.getClient; import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_0; import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_4; import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST; @@ -17,6 +19,7 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.sap.ai.sdk.core.AiCoreService; import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; import com.sap.ai.sdk.orchestration.client.model.AzureThreshold; import com.sap.ai.sdk.orchestration.client.model.ChatMessage; @@ -78,15 +81,39 @@ public class OrchestrationUnitTest { @BeforeEach void setup(WireMockRuntimeInfo server) { + stubFor( + get(urlPathEqualTo("/v2/lm/deployments")) + .withHeader("AI-Resource-Group", equalTo("my-resource-group")) + .withHeader("AI-Client-Type", equalTo("AI SDK Java")) + .willReturn( + okJson( + """ + { + "resources": [ + { + "id": "abcdef0123456789", + "scenarioId": "orchestration" + } + ] + } + """))); + final DefaultHttpDestination destination = DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); - client = new OrchestrationCompletionApi(getClient(destination)); + + final var apiClient = + new AiCoreService() + .withDestination(destination) + .forDeploymentByScenario("orchestration") + .withResourceGroup("my-resource-group") + .client(); + client = new OrchestrationCompletionApi(apiClient); } @Test void templating() throws IOException { stubFor( - post(urlPathEqualTo("/completion")) + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) .willReturn( aResponse() .withBodyFile("templatingResponse.json") @@ -141,14 +168,16 @@ void templating() throws IOException { // verify that null fields are absent from the sent request try (var requestInputStream = TEST_FILE_LOADER.apply("templatingRequest.json")) { final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); + verify( + postRequestedFor(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) + .withRequestBody(equalToJson(request))); } } @Test void templatingBadRequest() { stubFor( - post(urlPathEqualTo("/completion")) + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) .willReturn( jsonResponse( """ @@ -222,7 +251,7 @@ void templatingBadRequest() { @Test void filteringLoose() throws IOException { stubFor( - post(urlPathEqualTo("/completion")) + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) .willReturn( aResponse() .withBodyFile("filteringLooseResponse.json") @@ -236,7 +265,9 @@ void filteringLoose() throws IOException { // verify that null fields are absent from the sent request try (var requestInputStream = TEST_FILE_LOADER.apply("filteringLooseRequest.json")) { final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); + verify( + postRequestedFor(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) + .withRequestBody(equalToJson(request))); } } @@ -245,7 +276,9 @@ void filteringStrict() { final String response = """ {"request_id": "bf6d6792-7adf-4d3c-9368-a73615af8c5a", "code": 400, "message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "location": "Input Filter", "module_results": {"templating": [{"role": "user", "content": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}], "input_filtering": {"message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "data": {"original_service_response": {"Hate": 0, "SelfHarm": 0, "Sexual": 0, "Violence": 2}, "checked_text": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}}}}"""; - stubFor(post(urlPathEqualTo("/completion")).willReturn(jsonResponse(response, SC_BAD_REQUEST))); + stubFor( + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) + .willReturn(jsonResponse(response, SC_BAD_REQUEST))); final var config = FILTERING_CONFIG.apply(NUMBER_0); @@ -257,7 +290,7 @@ void filteringStrict() { @Test void messagesHistory() throws IOException { stubFor( - post(urlPathEqualTo("/completion")) + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) .willReturn( aResponse() .withBodyFile("templatingResponse.json") @@ -282,7 +315,9 @@ void messagesHistory() throws IOException { // verify that the history is sent correctly try (var requestInputStream = TEST_FILE_LOADER.apply("messagesHistoryRequest.json")) { final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); + verify( + postRequestedFor(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) + .withRequestBody(equalToJson(request))); } } @@ -316,7 +351,7 @@ void messagesHistory() throws IOException { @Test void maskingAnonymization() throws IOException { stubFor( - post(urlPathEqualTo("/completion")) + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) .willReturn( aResponse() .withBodyFile("maskingResponse.json") @@ -336,7 +371,9 @@ void maskingAnonymization() throws IOException { // verify that the request is sent correctly try (var requestInputStream = TEST_FILE_LOADER.apply("maskingRequest.json")) { final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); + verify( + postRequestedFor(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) + .withRequestBody(equalToJson(request))); } } } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/Application.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/Application.java index 5378b970..a78dbbba 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/Application.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/Application.java @@ -1,14 +1,20 @@ package com.sap.ai.sdk.app; +import com.sap.ai.sdk.core.AiCoreService; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.web.servlet.ServletComponentScan; import org.springframework.context.annotation.ComponentScan; +/** Main class to start the Spring Boot application. */ @SpringBootApplication @ComponentScan({"com.sap.cloud.sdk", "com.sap.ai.sdk.app"}) @ServletComponentScan({"com.sap.cloud.sdk", "com.sap.ai.sdk.app"}) -class Application { +public class Application { + /** The API client connected to the AI Core service. */ + public static final ApiClient API_CLIENT = new AiCoreService().client(); + /** * Main method to start the Spring Boot application. * diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/ConfigurationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/ConfigurationController.java new file mode 100644 index 00000000..6afe81b9 --- /dev/null +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/ConfigurationController.java @@ -0,0 +1,26 @@ +package com.sap.ai.sdk.app.controllers; + +import static com.sap.ai.sdk.app.Application.API_CLIENT; + +import com.sap.ai.sdk.core.client.ConfigurationApi; +import com.sap.ai.sdk.core.client.model.AiConfigurationList; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +/** Endpoint for Configuration operations */ +@SuppressWarnings("unused") // debug class that doesn't need to be tested +@RestController +public class ConfigurationController { + + private static final ConfigurationApi API = new ConfigurationApi(API_CLIENT); + + /** + * Get the list of configurations. + * + * @return the list of configurations + */ + @GetMapping("/configurations") + AiConfigurationList getConfigurations() { + return API.query("default"); + } +} diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java index 5e7a511d..99abb794 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java @@ -1,6 +1,6 @@ package com.sap.ai.sdk.app.controllers; -import static com.sap.ai.sdk.core.Core.getClient; +import static com.sap.ai.sdk.app.Application.API_CLIENT; import com.sap.ai.sdk.core.client.ConfigurationApi; import com.sap.ai.sdk.core.client.DeploymentApi; @@ -31,7 +31,7 @@ @RequestMapping("/deployments") class DeploymentController { - private static final DeploymentApi API = new DeploymentApi(getClient()); + private static final DeploymentApi API = new DeploymentApi(API_CLIENT); /** * Create and delete a deployment with the Java specific configuration ID @@ -154,7 +154,7 @@ public AiDeploymentCreationResponse createConfigAndDeploy(final OpenAiModel mode .addParameterBindingsItem(modelVersion); final AiConfigurationCreationResponse configuration = - new ConfigurationApi(getClient()).create("default", configurationBaseData); + new ConfigurationApi(API_CLIENT).create("default", configurationBaseData); // Create a deployment from the configuration final var deploymentCreationRequest = diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index 68ce41a5..eaeea1f7 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -1,7 +1,6 @@ package com.sap.ai.sdk.app.controllers; -import static com.sap.ai.sdk.core.Core.getOrchestrationClient; - +import com.sap.ai.sdk.core.AiCoreService; import com.sap.ai.sdk.orchestration.client.OrchestrationCompletionApi; import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; import com.sap.ai.sdk.orchestration.client.model.AzureThreshold; @@ -37,7 +36,8 @@ class OrchestrationController { private static final OrchestrationCompletionApi API = - new OrchestrationCompletionApi(getOrchestrationClient("default")); + new OrchestrationCompletionApi( + new AiCoreService().forDeploymentByScenario("orchestration").client()); static final String MODEL = "gpt-35-turbo"; diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/ScenarioController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/ScenarioController.java index 214f2c93..1e6df585 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/ScenarioController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/ScenarioController.java @@ -1,6 +1,6 @@ package com.sap.ai.sdk.app.controllers; -import static com.sap.ai.sdk.core.Core.getClient; +import static com.sap.ai.sdk.app.Application.API_CLIENT; import com.sap.ai.sdk.core.client.ScenarioApi; import com.sap.ai.sdk.core.client.model.AiModelList; @@ -14,7 +14,7 @@ @SuppressWarnings("unused") // debug method that doesn't need to be tested public class ScenarioController { - private static final ScenarioApi API = new ScenarioApi(getClient()); + private static final ScenarioApi API = new ScenarioApi(API_CLIENT); /** * Get the list of available scenarios diff --git a/sample-code/spring-app/src/main/resources/static/index.html b/sample-code/spring-app/src/main/resources/static/index.html index 936abc9f..0fe1127f 100644 --- a/sample-code/spring-app/src/main/resources/static/index.html +++ b/sample-code/spring-app/src/main/resources/static/index.html @@ -57,6 +57,12 @@

Endpoints

  • /models All available foundation models in this region
  • +
  • Orchestration