Skip to content

Commit

Permalink
Add embed endpoint (#153)
Browse files Browse the repository at this point in the history
## Problem

Add embed endpoint for Java SDK to allow users to create embeddings for
text data such as passage or query using a specified model. More details
on inference api can be found
[here](https://docs.pinecone.io/guides/inference/understanding-inference).

## Solution

Users can now call embed endpoint with the following parameters:
1. `String model`: Accepts a string from the specified
[models](https://docs.pinecone.io/models/overview) .
2. `Map<String, Object> parameters`: Accepts `input_type` and `truncate`
as keys with their corresponding values in a Map. The values are
expected to be scalar. Please note that the default value of `truncate`
is set to `END` if not specified.
3. `List<String> inputs`: The list must be of size atleast 1.

As a part of this change, I have added a `getInferenceClient()` in
`Pinecone` class to follow a similar pattern with other SDKs. The
underlying client uses `OkHTTPClient` for REST calls. The method
`getInferenceClient()` returns an instance of `Inference` class which I
added as a wrapper to the Inference API. So far, this wrapper contains
`embed()` endpoint only, with the plan of adding `rerank()` soon.
Lastly, I have added docstrings and updated README with an example of
the embed endpoint.

## Type of Change

- [X] New feature (non-breaking change which adds functionality)

## Test Plan

Added integration tests.
  • Loading branch information
rohanshah18 authored Sep 25, 2024
1 parent 27ea488 commit 6f43b93
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 5 deletions.
48 changes: 44 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ The `Pinecone` class is your main entry point into the Pinecone Java SDK. You ca
your `apiKey`, either by passing it as an argument in your code or by setting it as an environment variable called
`PINECONE_API_KEY`.

Note: for pod-based indexes, you will also need an `environment` variable. You can set pass this as an argument in
your code or set it as an environment variable called `PINECONE_ENVIRONMENT`.
This internally instantiates a single shared OkHttpClient instance, which is used for both control plane and inference
operations. Note that the OkHttpClient performs best when you create a single `OkHttpClient` instance and reuse it
for all of your HTTP calls. This is because each client holds its own connection pool and thread pools. Reusing
connections and threads reduces latency and saves memory. Conversely, creating a client for each request wastes
resources on idle pools. More details on the OkHttpClient can be found [here](https://github.com/square/okhttp/blob/f2771425cb714a5b0b27238bd081b2516b4d640f/okhttp/src/main/kotlin/okhttp3/OkHttpClient.kt#L54).

```java
import io.pinecone.clients.Pinecone;
Expand Down Expand Up @@ -542,7 +545,44 @@ Pinecone pinecone = new Pinecone.Builder("PINECONE_API_KEY").build();
pinecone.deleteCollection("example-collection");
```

## Inference

The Pinecone SDK now supports creating embeddings via the [Inference API](https://docs.pinecone.io/guides/inference/understanding-inference).

```java
import io.pinecone.clients.Pinecone;
import org.openapitools.control.client.ApiException;
import org.openapitools.control.client.model.Embedding;
import org.openapitools.control.client.model.EmbeddingsList;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
...

Pinecone pinecone = new Pinecone.Builder("PINECONE_API_KEY").build();
Inference inference = pinecone.getInferenceClient();

// Prepare input sentences to be embedded
List<String> inputs = new ArrayList<>();
inputs.add("The quick brown fox jumps over the lazy dog.");
inputs.add("Lorem ipsum");

// Specify the embedding model and parameters
String embeddingModel = "multilingual-e5-large";

Map<String, Object> parameters = new HashMap<>();
parameters.put("input_type", "query");
parameters.put("truncate", "END");

// Generate embeddings for the input data
EmbeddingsList embeddings = inference.embed(embeddingModel, parameters, inputs);

// Get embedded data
List<Embedding> embeddedData = embeddings.getData();
```

## Examples

- The data and control plane operation examples can be found in `io/pinecone/integration` folder.
- A full end-to-end Semantic Search example can be found in the [Java Examples](https://github.com/pinecone-io/java-examples/tree/main) repo on Github.
- The data and control plane operation examples can be found in `io/pinecone/integration` folder.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.pinecone.integration.inference;

import io.pinecone.clients.Inference;
import io.pinecone.clients.Pinecone;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.openapitools.control.client.ApiException;
import org.openapitools.control.client.model.EmbeddingsList;

import java.util.*;

import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.assertNotNull;

public class EmbedTest {

private static final Pinecone pinecone = new Pinecone
.Builder(System.getenv("PINECONE_API_KEY"))
.withSourceTag("pinecone_test")
.build();
private static final Inference inference = pinecone.getInferenceClient();

@Test
public void testGenerateEmbeddings() throws ApiException {
List<String> inputs = new ArrayList<>(1);
inputs.add("The quick brown fox jumps over the lazy dog.");
inputs.add("Lorem ipsum");

String embeddingModel = "multilingual-e5-large";

Map<String, Object> parameters = new HashMap<>();
parameters.put("input_type", "query");
parameters.put("truncate", "END");
EmbeddingsList embeddings = inference.embed(embeddingModel, parameters, inputs);

assertNotNull(embeddings, "Expected embedding to be not null");
Assertions.assertEquals(embeddingModel, embeddings.getModel());
Assertions.assertEquals(1024, embeddings.getData().get(0).getValues().size());
Assertions.assertEquals(2, embeddings.getData().size());
}

@Test
public void testGenerateEmbeddingsInvalidInputs() throws ApiException {
String embeddingModel = "multilingual-e5-large";
List<String> inputs = new ArrayList<>();
Map<String, Object> parameters = new HashMap<>();
parameters.put("input_type", "query");
parameters.put("truncate", "END");

Exception exception = assertThrows(Exception.class, () -> {
inference.embed(embeddingModel, parameters, inputs);
});

Assertions.assertTrue(exception.getMessage().contains("Must specify at least one input"));
}
}
68 changes: 68 additions & 0 deletions src/main/java/io/pinecone/clients/Inference.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.pinecone.clients;

import org.openapitools.control.client.ApiClient;
import org.openapitools.control.client.ApiException;
import org.openapitools.control.client.api.InferenceApi;
import org.openapitools.control.client.model.EmbedRequest;
import org.openapitools.control.client.model.EmbedRequestInputsInner;
import org.openapitools.control.client.model.EmbedRequestParameters;
import org.openapitools.control.client.model.EmbeddingsList;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* The Inference class provides methods to interact with Pinecone's inference API through the Java SDK. It allows users
* to send input data to generate embeddings using a specified model.
* <p>
* This class utilizes the {@link InferenceApi} to make API calls to the Pinecone inference service.
*
*/

public class Inference {

private final InferenceApi inferenceApi;

/**
* Constructs an instance of {@link Inference} class.
*
* @param apiClient The ApiClient object used to configure the API connection.
*/
public Inference(ApiClient apiClient) {
inferenceApi = new InferenceApi(apiClient);
}

/**
* Sends input data and parameters to the embedding model and returns a list of embeddings.
*
* @param model The embedding model to use.
* @param parameters A map containing model-specific parameters.
* @param inputs A list of input strings to generate embeddings for.
* @return EmbeddingsList containing the embeddings for the provided inputs.
* @throws ApiException If the API call fails, an ApiException is thrown.
*/
public EmbeddingsList embed(String model, Map<String, Object> parameters, List<String> inputs) throws ApiException {
EmbedRequestParameters embedRequestParameters = new EmbedRequestParameters();
parameters.forEach(embedRequestParameters::putAdditionalProperty);

EmbedRequest embedRequest = new EmbedRequest()
.model(model)
.parameters(embedRequestParameters)
.inputs(convertToEmbedInputs(inputs));

return inferenceApi.embed(embedRequest);
}

/**
* Converts a list of input strings to EmbedRequestInputsInner objects.
*
* @param inputs A list of input strings.
* @return A list of EmbedRequestInputsInner objects containing the input data.
*/
private List<EmbedRequestInputsInner> convertToEmbedInputs(List<String> inputs) {
return inputs.stream()
.map(input -> new EmbedRequestInputsInner().text(input))
.collect(Collectors.toList());
}
}
19 changes: 18 additions & 1 deletion src/main/java/io/pinecone/clients/Pinecone.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

/**
* The Pinecone class is the main entry point for interacting with Pinecone via the Java SDK.
* It is used to create, delete, and manage your indexes and collections.
* It is used to create, delete, and manage your indexes and collections, along with the inference api.
* Note that the Pinecone class instantiates a single shared {@link OkHttpClient} instance,
* which is used for both control plane and inference operations.The OkHttpClient performs best when you create a single
* `OkHttpClient` instance and reuse it for all of your HTTP calls. This is because each client holds its own connection
* pool and thread pools. Reusing connections and threads reduces latency and saves memory. Conversely, creating a
* client for each request wastes resources on idle pools.
* <p>
* To instantiate the Pinecone class, use the {@link Pinecone.Builder} class to pass
* an API key and any other optional configuration.
Expand Down Expand Up @@ -871,6 +876,18 @@ public AsyncIndex getAsyncIndexConnection(String indexName) throws PineconeValid
return new AsyncIndex(connection, indexName);
}

/**
* A method to create and return a new instance of the {@link Inference} client.
* <p>
* This method initializes the Inference client using the current ApiClient
* from the {@link ManageIndexesApi}. The {@link Inference} client can then be used
* to interact with Pinecone's inference API.
* @return A new {@link Inference} client instance.
*/
public Inference getInferenceClient() {
return new Inference(manageIndexesApi.getApiClient());
}

PineconeConnection getConnection(String indexName) {
return connectionsMap.computeIfAbsent(indexName, key -> new PineconeConnection(config));
}
Expand Down

0 comments on commit 6f43b93

Please sign in to comment.