Skip to content

Commit

Permalink
Simple OpenAI API (#47)
Browse files Browse the repository at this point in the history
* OpenAI API proposal

* Added convenience getContent

* Applied Alex's suggestions

* Added tests

* Update foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java

Co-authored-by: Matthias Kuhr <[email protected]>

* Throw in getContent

* Added withSystemPrompt(String) and addMessages(...)

* Updated README

* PMD

---------

Co-authored-by: Matthias Kuhr <[email protected]>
  • Loading branch information
CharlesDuboisSAP and MatKuhr authored Sep 17, 2024
1 parent 64875a3 commit 95d0328
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 65 deletions.
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,29 @@ See [an example pom in our Spring Boot application](e2e-test-app/pom.xml)

### Simple chat completion

```java
final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(GPT_35_TURBO)
.withSystemPrompt("You are a helpful AI")
.chatCompletion("Hello World! Why is this phrase so famous?");

final String resultMessage = result.getContent();
```

### Message history

```java
final var systemMessage =
new OpenAiChatSystemMessage().setContent("You are a helpful assistant");
final var userMessage =
new OpenAiChatUserMessage().addText("Hello World! Why is this phrase so famous?");
final var request =
new OpenAiChatCompletionParameters().setMessages(List.of(systemMessage, userMessage));
new OpenAiChatCompletionParameters().addMessages(systemMessage, userMessage);

final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);

final String resultMessage = result.getChoices().get(0).getMessage().getContent();
final String resultMessage = result.getContent();
```

See [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java)
Expand All @@ -213,14 +224,10 @@ This is a blocking example for streaming and printing directly to the console:
```java
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);

// try-with-resources on stream ensures the connection will be closed
try( Stream<String> stream = client.streamChatCompletion(request)) {
try( Stream<String> stream = client.streamChatCompletion(msg)) {
stream.forEach(deltaString -> {
System.out.print(deltaString);
System.out.flush();
Expand All @@ -239,7 +246,7 @@ String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));
.addMessages(new OpenAiChatUserMessage().addText(msg));

OpenAiChatCompletionOutput totalOutput = new OpenAiChatCompletionOutput();
OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ class OpenAiController {
@GetMapping("/chatCompletion")
@Nonnull
public static OpenAiChatCompletionOutput chatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest")));

return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion("Who is the prettiest");
}

/**
Expand All @@ -59,8 +55,7 @@ public static OpenAiChatCompletionOutput chatCompletion() {
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas() {
final var msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));
new OpenAiChatCompletionParameters().addMessages(new OpenAiChatUserMessage().addText(msg));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request);

Expand Down Expand Up @@ -103,15 +98,11 @@ private static String objectToJson(@Nonnull final Object obj) {
@GetMapping("/streamChatCompletion")
@Nonnull
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(
List.of(
new OpenAiChatUserMessage()
.addText(
"Can you give me the first 100 numbers of the Fibonacci sequence?")));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletion(request);
final var stream =
OpenAiClient.forModel(GPT_35_TURBO)
.withSystemPrompt("Be a good, honest AI and answer the following question:")
.streamChatCompletion(
"Can you give me the first 100 numbers of the Fibonacci sequence?");

final var emitter = new ResponseBodyEmitter();

Expand Down Expand Up @@ -150,13 +141,12 @@ private static void send(
public static OpenAiChatCompletionOutput chatCompletionImage() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(
List.of(
new OpenAiChatUserMessage()
.addText("Describe the following image.")
.addImage(
"https://upload.wikimedia.org/wikipedia/commons/thumb/5/59/SAP_2011_logo.svg/440px-SAP_2011_logo.svg.png",
ImageDetailLevel.HIGH)));
.addMessages(
new OpenAiChatUserMessage()
.addText("Describe the following image.")
.addImage(
"https://upload.wikimedia.org/wikipedia/commons/thumb/5/59/SAP_2011_logo.svg/440px-SAP_2011_logo.svg.png",
ImageDetailLevel.HIGH));

return OpenAiClient.forModel(GPT_4O).chatCompletion(request);
}
Expand All @@ -180,7 +170,7 @@ public static OpenAiChatCompletionOutput chatCompletionTools() {
final var tool = new OpenAiChatCompletionTool().setType(FUNCTION).setFunction(function);
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(question)))
.addMessages(new OpenAiChatUserMessage().addText(question))
.setTools(List.of(tool))
.setToolChoiceFunction("fibonacci");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -36,7 +35,7 @@ void chatCompletionImage() {
void streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest?")));
.addMessages(new OpenAiChatUserMessage().addText("Who is the prettiest?"));

final var totalOutput = new OpenAiChatCompletionOutput();
final var filledDeltaCount = new AtomicInteger(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.StreamedDelta;
Expand All @@ -34,6 +36,7 @@
public final class OpenAiClient {
private static final String DEFAULT_API_VERSION = "2024-02-01";
static final ObjectMapper JACKSON;
private String systemPrompt = null;

static {
JACKSON =
Expand Down Expand Up @@ -95,6 +98,36 @@ public static OpenAiClient withCustomDestination(@Nonnull final Destination dest
return new OpenAiClient(destination);
}

/**
* Add a system prompt before user prompts.
*
* @param systemPrompt the system prompt
* @return the client
*/
@Nonnull
public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) {
this.systemPrompt = systemPrompt;
return this;
}

/**
* Generate a completion for the given user prompt.
*
* @param prompt a text message.
* @return the completion output
* @throws OpenAiClientException if the request fails
*/
@Nonnull
public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters();
if (systemPrompt != null) {
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt));
}
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt));
return chatCompletion(parameters);
}

/**
* Generate a completion for the given prompt.
*
Expand All @@ -105,19 +138,25 @@ public static OpenAiClient withCustomDestination(@Nonnull final Destination dest
@Nonnull
public OpenAiChatCompletionOutput chatCompletion(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
warnIfUnsupportedUsage();
return execute("/chat/completions", parameters, OpenAiChatCompletionOutput.class);
}

/**
* Generate a completion for the given prompt.
*
* @param parameters the prompt, including messages and other parameters.
* @param prompt a text message.
* @return A stream of message deltas
* @throws OpenAiClientException if the request fails or if the finish reason is content_filter
*/
@Nonnull
public Stream<String> streamChatCompletion(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
public Stream<String> streamChatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters();
if (systemPrompt != null) {
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt));
}
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt));
return streamChatCompletionDeltas(parameters)
.peek(OpenAiClient::throwOnContentFilter)
.map(OpenAiChatCompletionDelta::getDeltaContent);
Expand All @@ -140,10 +179,18 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt
@Nonnull
public Stream<OpenAiChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
warnIfUnsupportedUsage();
parameters.enableStreaming();
return executeStream("/chat/completions", parameters, OpenAiChatCompletionDelta.class);
}

private void warnIfUnsupportedUsage() {
if (systemPrompt != null) {
log.warn(
"Previously set messages will be ignored, set it as an argument of this method instead.");
}
}

/**
* Get a vector representation of a given input that can be easily consumed by machine learning
* models and algorithms.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.sap.ai.sdk.foundationmodels.openai.model;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClientException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nonnull;
import lombok.EqualsAndHashCode;
import lombok.Getter;
Expand All @@ -28,6 +30,25 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput
@Getter(onMethod_ = @Nonnull)
private String systemFingerprint;

/**
* Get the message content from the output.
*
* <p>Note: If there are multiple choices only the first one is returned
*
* @return the message content or empty string.
* @throws OpenAiClientException if the content filter filtered the output.
*/
@Nonnull
public String getContent() throws OpenAiClientException {
if (getChoices().isEmpty()) {
return "";
}
if ("content_filter".equals(getChoices().get(0).getFinishReason())) {
throw new OpenAiClientException("Content filter filtered the output.");
}
return Objects.requireNonNullElse(getChoices().get(0).getMessage().getContent(), "");
}

/**
* Add a streamed delta to the total output.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonValue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
Expand All @@ -25,7 +26,6 @@
public class OpenAiChatCompletionParameters extends OpenAiCompletionParameters {
/** A list of messages comprising the conversation so far. */
@JsonProperty("messages")
@Setter(onParam_ = @Nonnull)
private List<OpenAiChatMessage> messages;

/**
Expand Down Expand Up @@ -197,4 +197,19 @@ private record Function(@JsonProperty("name") @Nonnull String name) {}
public OpenAiChatCompletionParameters setStop(@Nullable final String... values) {
return (OpenAiChatCompletionParameters) super.setStop(values);
}

/**
* Add messages to the conversation.
*
* @param messages The messages to add.
* @return this instance for chaining.
*/
@Nonnull
public OpenAiChatCompletionParameters addMessages(@Nonnull final OpenAiChatMessage... messages) {
if (this.messages == null) {
this.messages = new ArrayList<>();
}
this.messages.addAll(Arrays.asList(messages));
return this;
}
}
Loading

0 comments on commit 95d0328

Please sign in to comment.