Skip to content

Commit

Permalink
feat: Batch update points (#16)
Browse files Browse the repository at this point in the history
* feat: batch update points

* chore: update filter param doc comments

* chore: upadate namedVectors() to accept Map<String, Vector>

* chore: typo fix QdrantClient.java

* Apply suggestions from code review

Co-authored-by: Russ Cam <[email protected]>

* chore: review updates

* chore: simplify examples README.md

* docs: simplify examples README.md

* docs: Updated doc URLREADME.md

---------

Co-authored-by: Russ Cam <[email protected]>
  • Loading branch information
Anush008 and russcam authored Jan 9, 2024
1 parent b416d22 commit 776b909
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 41 deletions.
69 changes: 37 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,26 @@ To install the library, add the following lines to your build config file.
<dependency>
<groupId>io.qdrant</groupId>
<artifactId>client</artifactId>
<version>1.7.0</version>
<version>1.7.1</version>
</dependency>
```

#### Scala SBT
#### SBT

```sbt
libraryDependencies += "io.qdrant" % "client" % "1.7.0"
libraryDependencies += "io.qdrant" % "client" % "1.7.1"
```

#### Gradle

```gradle
implementation 'io.qdrant:client:1.7.0'
implementation 'io.qdrant:client:1.7.1'
```

## 📖 Documentation

- [`QdrantClient` Reference](https://qdrant.github.io/java-client/io/qdrant/client/QdrantClient.html#constructor-detail)
- [JavaDoc Reference](https://qdrant.github.io/java-client/)
- Usage examples are available throughout the [Qdrant documentation](https://qdrant.tech/documentation/quick-start/)

## 🔌 Getting started

Expand Down Expand Up @@ -125,39 +126,43 @@ Insert vectors into a collection
// import static convenience methods
import static io.qdrant.client.PointIdFactory.id;
import static io.qdrant.client.ValueFactory.value;
import static io.qdrant.client.VectorsFactory.vector;

Random random = new Random();
List<PointStruct> points = IntStream.range(1, 101)
.mapToObj(i -> PointStruct.newBuilder()
.setId(id(i))
.setVectors(vector(IntStream.range(1, 101)
.mapToObj(v -> random.nextFloat())
.collect(Collectors.toList())))
.putAllPayload(ImmutableMap.of(
"color", value("red"),
"rand_number", value(i % 10))
)
.build()
)
.collect(Collectors.toList());
import static io.qdrant.client.VectorsFactory.vectors;

List<PointStruct> points =
List.of(
PointStruct.newBuilder()
.setId(id(1))
.setVectors(vectors(0.32f, 0.52f, 0.21f, 0.52f))
.putAllPayload(
Map.of(
"color", value("red"),
"rand_number", value(32)))
.build(),
PointStruct.newBuilder()
.setId(id(2))
.setVectors(vectors(0.42f, 0.52f, 0.67f, 0.632f))
.putAllPayload(
Map.of(
"color", value("black"),
"rand_number", value(53),
"extra_field", value(true)))
.build());

UpdateResult updateResult = client.upsertAsync("my_collection", points).get();
```

Search for similar vectors

```java
List<Float> queryVector = IntStream.range(1, 101)
.mapToObj(v -> random.nextFloat())
.collect(Collectors.toList());

List<ScoredPoint> points = client.searchAsync(SearchPoints.newBuilder()
.setCollectionName("my_collection")
.addAllVector(queryVector)
.setLimit(5)
.build()
).get();
List<ScoredPoint> anush =
client
.searchAsync(
SearchPoints.newBuilder()
.setCollectionName("my_collection")
.addAllVector(List.of(0.6235f, 0.123f, 0.532f, 0.123f))
.setLimit(5)
.build())
.get();
```

Search for similar vectors with filtering condition
Expand All @@ -168,7 +173,7 @@ import static io.qdrant.client.ConditionFactory.range;

List<ScoredPoint> points = client.searchAsync(SearchPoints.newBuilder()
.setCollectionName("my_collection")
.addAllVector(queryVector)
.addAllVector(List.of(0.6235f, 0.123f, 0.532f, 0.123f))
.setFilter(Filter.newBuilder()
.addMust(range("rand_number", Range.newBuilder().setGte(3).build()))
.build())
Expand Down
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ plugins {
id 'maven-publish'

id 'com.google.protobuf' version '0.9.4'
id "net.ltgt.errorprone" version '3.1.0'
id 'io.github.gradle-nexus.publish-plugin' version "1.3.0"
id 'net.ltgt.errorprone' version '3.1.0'
id 'io.github.gradle-nexus.publish-plugin' version '1.3.0'
}

group = 'io.qdrant'
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ qdrantProtosVersion=v1.7.0
qdrantVersion=v1.7.0

# The version of the client to generate
packageVersion=1.7.0
packageVersion=1.7.1
71 changes: 68 additions & 3 deletions src/main/java/io/qdrant/client/QdrantClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
import static io.qdrant.client.grpc.Points.DiscoverBatchResponse;
import static io.qdrant.client.grpc.Points.DiscoverPoints;
import static io.qdrant.client.grpc.Points.DiscoverResponse;
import static io.qdrant.client.grpc.Points.PointsUpdateOperation;
import static io.qdrant.client.grpc.Points.UpdateBatchPoints;
import static io.qdrant.client.grpc.Points.UpdateBatchResponse;
import static io.qdrant.client.grpc.Collections.GetCollectionInfoRequest;
import static io.qdrant.client.grpc.Collections.GetCollectionInfoResponse;
import static io.qdrant.client.grpc.Collections.ListAliasesRequest;
Expand Down Expand Up @@ -1551,7 +1554,7 @@ public ListenableFuture<UpdateResult> overwritePayloadAsync(
}

/**
* Overwrites the payload for the given ids.
* Overwrites the payload for the filtered points.
*
* @param collectionName The name of the collection.
* @param payload New payload values
Expand Down Expand Up @@ -1696,7 +1699,7 @@ public ListenableFuture<UpdateResult> deletePayloadAsync(
}

/**
* Delete specified key payload for the given ids.
* Delete specified key payload for the filtered points.
*
* @param collectionName The name of the collection.
* @param keys List of keys to delete.
Expand Down Expand Up @@ -1832,7 +1835,7 @@ public ListenableFuture<UpdateResult> clearPayloadAsync(
}

/**
* Removes all payload for the given ids.
* Removes all payload for the filtered points.
*
* @param collectionName The name of the collection.
* @param filter A filter selecting the points for which to remove the payload.
Expand Down Expand Up @@ -2204,6 +2207,68 @@ public ListenableFuture<List<BatchResult>> recommendBatchAsync(
MoreExecutors.directExecutor());
}

/**
* Performs a batch update of points.
*
* @param collectionName The name of the collection.
* @param operations The list of point update operations.
*
* @return a new instance of {@link ListenableFuture}
*/
public ListenableFuture<List<UpdateResult>> batchUpdateAsync(String collectionName, List<PointsUpdateOperation> operations) {
return batchUpdateAsync(collectionName, operations, null, null, null);
}

/**
* Performs a batch update of points.
*
* @param collectionName The name of the collection.
* @param operations The list of point update operations.
* @param wait Whether to wait until the changes have been applied. Defaults to <code>true</code>.
* @param ordering Write ordering guarantees.
* @param timeout The timeout for the call.
*
* @return a new instance of {@link ListenableFuture}
*/
public ListenableFuture<List<UpdateResult>> batchUpdateAsync(
String collectionName,
List<PointsUpdateOperation> operations,
@Nullable Boolean wait,
@Nullable WriteOrdering ordering,
@Nullable Duration timeout) {

UpdateBatchPoints.Builder requestBuilder = UpdateBatchPoints.newBuilder()
.setCollectionName(collectionName)
.addAllOperations(operations)
.setWait(wait == null || wait);

if (ordering != null) {
requestBuilder.setOrdering(ordering);
}
return batchUpdateAsync(requestBuilder.build(), timeout);
}


/**
* Performs a batch update of points.
*
* @param request The update batch request.
* @param timeout The timeout for the call.
*
* @return a new instance of {@link ListenableFuture}
*/
public ListenableFuture<List<UpdateResult>> batchUpdateAsync(UpdateBatchPoints request, @Nullable Duration timeout) {
String collectionName = request.getCollectionName();
Preconditions.checkArgument(!collectionName.isEmpty(), "Collection name must not be empty");
logger.debug("Batch update points on '{}'", collectionName);
ListenableFuture<UpdateBatchResponse> future = getPoints(timeout).updateBatch(request);
addLogFailureCallback(future, "Batch update points");
return Futures.transform(
future,
UpdateBatchResponse::getResultList,
MoreExecutors.directExecutor());
}

/**
* Look for the points which are closer to stored positive examples and at the same time further to negative
* examples, grouped by a given field
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/io/qdrant/client/VectorsFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static io.qdrant.client.VectorFactory.vector;
import static io.qdrant.client.grpc.Points.NamedVectors;
import io.qdrant.client.grpc.Points.Vector;
import static io.qdrant.client.grpc.Points.Vectors;

/**
Expand All @@ -18,13 +19,13 @@ private VectorsFactory() {

/**
* Creates named vectors
* @param values A map of vector names to values
* @param values A map of vector names to {@link Vector}
* @return a new instance of {@link Vectors}
*/
public static Vectors namedVectors(Map<String, List<Float>> values) {
public static Vectors namedVectors(Map<String, Vector> values) {
return Vectors.newBuilder()
.setVectors(NamedVectors.newBuilder()
.putAllVectors(Maps.transformValues(values, v -> vector(v)))
.putAllVectors(values)
)
.build();
}
Expand Down
30 changes: 30 additions & 0 deletions src/test/java/io/qdrant/client/PointsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
import org.testcontainers.shaded.com.google.common.collect.ImmutableSet;
import io.qdrant.client.container.QdrantContainer;
import io.qdrant.client.grpc.Points.DiscoverPoints;
import io.qdrant.client.grpc.Points.PointVectors;
import io.qdrant.client.grpc.Points.PointsIdsList;
import io.qdrant.client.grpc.Points.PointsSelector;
import io.qdrant.client.grpc.Points.PointsUpdateOperation;
import io.qdrant.client.grpc.Points.UpdateBatchResponse;
import io.qdrant.client.grpc.Points.PointsUpdateOperation.ClearPayload;
import io.qdrant.client.grpc.Points.PointsUpdateOperation.UpdateVectors;

import java.util.List;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -49,6 +56,7 @@
import static io.qdrant.client.TargetVectorFactory.targetVector;
import static io.qdrant.client.ValueFactory.value;
import static io.qdrant.client.VectorFactory.vector;
import static io.qdrant.client.VectorsFactory.vectors;

@Testcontainers
class PointsTest {
Expand Down Expand Up @@ -540,6 +548,28 @@ public void delete_by_filter() throws ExecutionException, InterruptedException {
assertEquals(0, points.size());
}

@Test
public void batchPointUpdate() throws ExecutionException, InterruptedException {
createAndSeedCollection(testName);

List<PointsUpdateOperation> operations = List.of(
PointsUpdateOperation.newBuilder()
.setClearPayload(ClearPayload.newBuilder().setPoints(
PointsSelector.newBuilder().setPoints(PointsIdsList.newBuilder().addIds(id(9))))
.build())
.build(),
PointsUpdateOperation.newBuilder()
.setUpdateVectors(UpdateVectors.newBuilder()
.addPoints(PointVectors.newBuilder()
.setId(id(9))
.setVectors(vectors(0.6f, 0.7f))))
.build());

List<UpdateResult> response = client.batchUpdateAsync(testName, operations).get();

response.forEach(result -> assertEquals(UpdateStatus.Completed, result.getStatus()));
}

private void createAndSeedCollection(String collectionName) throws ExecutionException, InterruptedException {
CreateCollection request = CreateCollection.newBuilder()
.setCollectionName(collectionName)
Expand Down

0 comments on commit 776b909

Please sign in to comment.