From 776b909d06a263492cdb8c1bbfe4ea7b8499ab18 Mon Sep 17 00:00:00 2001 From: Anush Date: Wed, 10 Jan 2024 00:32:36 +0530 Subject: [PATCH] feat: Batch update points (#16) * feat: batch update points * chore: update filter param doc comments * chore: upadate namedVectors() to accept Map * chore: typo fix QdrantClient.java * Apply suggestions from code review Co-authored-by: Russ Cam * chore: review updates * chore: simplify examples README.md * docs: simplify examples README.md * docs: Updated doc URLREADME.md --------- Co-authored-by: Russ Cam --- README.md | 69 +++++++++--------- build.gradle | 4 +- gradle.properties | 2 +- .../java/io/qdrant/client/QdrantClient.java | 71 ++++++++++++++++++- .../java/io/qdrant/client/VectorsFactory.java | 7 +- .../java/io/qdrant/client/PointsTest.java | 30 ++++++++ 6 files changed, 142 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 51c562b..46c52d5 100644 --- a/README.md +++ b/README.md @@ -34,25 +34,26 @@ To install the library, add the following lines to your build config file. io.qdrant client - 1.7.0 + 1.7.1 ``` -#### Scala SBT +#### SBT ```sbt -libraryDependencies += "io.qdrant" % "client" % "1.7.0" +libraryDependencies += "io.qdrant" % "client" % "1.7.1" ``` #### Gradle ```gradle -implementation 'io.qdrant:client:1.7.0' +implementation 'io.qdrant:client:1.7.1' ``` ## 📖 Documentation -- [`QdrantClient` Reference](https://qdrant.github.io/java-client/io/qdrant/client/QdrantClient.html#constructor-detail) +- [JavaDoc Reference](https://qdrant.github.io/java-client/) +- Usage examples are available throughout the [Qdrant documentation](https://qdrant.tech/documentation/quick-start/) ## 🔌 Getting started @@ -125,22 +126,27 @@ Insert vectors into a collection // import static convenience methods import static io.qdrant.client.PointIdFactory.id; import static io.qdrant.client.ValueFactory.value; -import static io.qdrant.client.VectorsFactory.vector; - -Random random = new Random(); -List points = IntStream.range(1, 101) - .mapToObj(i -> PointStruct.newBuilder() - .setId(id(i)) - .setVectors(vector(IntStream.range(1, 101) - .mapToObj(v -> random.nextFloat()) - .collect(Collectors.toList()))) - .putAllPayload(ImmutableMap.of( - "color", value("red"), - "rand_number", value(i % 10)) - ) - .build() - ) - .collect(Collectors.toList()); +import static io.qdrant.client.VectorsFactory.vectors; + +List points = + List.of( + PointStruct.newBuilder() + .setId(id(1)) + .setVectors(vectors(0.32f, 0.52f, 0.21f, 0.52f)) + .putAllPayload( + Map.of( + "color", value("red"), + "rand_number", value(32))) + .build(), + PointStruct.newBuilder() + .setId(id(2)) + .setVectors(vectors(0.42f, 0.52f, 0.67f, 0.632f)) + .putAllPayload( + Map.of( + "color", value("black"), + "rand_number", value(53), + "extra_field", value(true))) + .build()); UpdateResult updateResult = client.upsertAsync("my_collection", points).get(); ``` @@ -148,16 +154,15 @@ UpdateResult updateResult = client.upsertAsync("my_collection", points).get(); Search for similar vectors ```java -List queryVector = IntStream.range(1, 101) - .mapToObj(v -> random.nextFloat()) - .collect(Collectors.toList()); - -List points = client.searchAsync(SearchPoints.newBuilder() - .setCollectionName("my_collection") - .addAllVector(queryVector) - .setLimit(5) - .build() -).get(); +List anush = + client + .searchAsync( + SearchPoints.newBuilder() + .setCollectionName("my_collection") + .addAllVector(List.of(0.6235f, 0.123f, 0.532f, 0.123f)) + .setLimit(5) + .build()) + .get(); ``` Search for similar vectors with filtering condition @@ -168,7 +173,7 @@ import static io.qdrant.client.ConditionFactory.range; List points = client.searchAsync(SearchPoints.newBuilder() .setCollectionName("my_collection") - .addAllVector(queryVector) + .addAllVector(List.of(0.6235f, 0.123f, 0.532f, 0.123f)) .setFilter(Filter.newBuilder() .addMust(range("rand_number", Range.newBuilder().setGte(3).build())) .build()) diff --git a/build.gradle b/build.gradle index 5ca962e..ab70027 100644 --- a/build.gradle +++ b/build.gradle @@ -19,8 +19,8 @@ plugins { id 'maven-publish' id 'com.google.protobuf' version '0.9.4' - id "net.ltgt.errorprone" version '3.1.0' - id 'io.github.gradle-nexus.publish-plugin' version "1.3.0" + id 'net.ltgt.errorprone' version '3.1.0' + id 'io.github.gradle-nexus.publish-plugin' version '1.3.0' } group = 'io.qdrant' diff --git a/gradle.properties b/gradle.properties index a609339..b5b794c 100644 --- a/gradle.properties +++ b/gradle.properties @@ -5,4 +5,4 @@ qdrantProtosVersion=v1.7.0 qdrantVersion=v1.7.0 # The version of the client to generate -packageVersion=1.7.0 \ No newline at end of file +packageVersion=1.7.1 \ No newline at end of file diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index a8dcae4..9c057ff 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -39,6 +39,9 @@ import static io.qdrant.client.grpc.Points.DiscoverBatchResponse; import static io.qdrant.client.grpc.Points.DiscoverPoints; import static io.qdrant.client.grpc.Points.DiscoverResponse; +import static io.qdrant.client.grpc.Points.PointsUpdateOperation; +import static io.qdrant.client.grpc.Points.UpdateBatchPoints; +import static io.qdrant.client.grpc.Points.UpdateBatchResponse; import static io.qdrant.client.grpc.Collections.GetCollectionInfoRequest; import static io.qdrant.client.grpc.Collections.GetCollectionInfoResponse; import static io.qdrant.client.grpc.Collections.ListAliasesRequest; @@ -1551,7 +1554,7 @@ public ListenableFuture overwritePayloadAsync( } /** - * Overwrites the payload for the given ids. + * Overwrites the payload for the filtered points. * * @param collectionName The name of the collection. * @param payload New payload values @@ -1696,7 +1699,7 @@ public ListenableFuture deletePayloadAsync( } /** - * Delete specified key payload for the given ids. + * Delete specified key payload for the filtered points. * * @param collectionName The name of the collection. * @param keys List of keys to delete. @@ -1832,7 +1835,7 @@ public ListenableFuture clearPayloadAsync( } /** - * Removes all payload for the given ids. + * Removes all payload for the filtered points. * * @param collectionName The name of the collection. * @param filter A filter selecting the points for which to remove the payload. @@ -2204,6 +2207,68 @@ public ListenableFuture> recommendBatchAsync( MoreExecutors.directExecutor()); } + /** + * Performs a batch update of points. + * + * @param collectionName The name of the collection. + * @param operations The list of point update operations. + * + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> batchUpdateAsync(String collectionName, List operations) { + return batchUpdateAsync(collectionName, operations, null, null, null); + } + + /** + * Performs a batch update of points. + * + * @param collectionName The name of the collection. + * @param operations The list of point update operations. + * @param wait Whether to wait until the changes have been applied. Defaults to true. + * @param ordering Write ordering guarantees. + * @param timeout The timeout for the call. + * + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> batchUpdateAsync( + String collectionName, + List operations, + @Nullable Boolean wait, + @Nullable WriteOrdering ordering, + @Nullable Duration timeout) { + + UpdateBatchPoints.Builder requestBuilder = UpdateBatchPoints.newBuilder() + .setCollectionName(collectionName) + .addAllOperations(operations) + .setWait(wait == null || wait); + + if (ordering != null) { + requestBuilder.setOrdering(ordering); + } + return batchUpdateAsync(requestBuilder.build(), timeout); + } + + + /** + * Performs a batch update of points. + * + * @param request The update batch request. + * @param timeout The timeout for the call. + * + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> batchUpdateAsync(UpdateBatchPoints request, @Nullable Duration timeout) { + String collectionName = request.getCollectionName(); + Preconditions.checkArgument(!collectionName.isEmpty(), "Collection name must not be empty"); + logger.debug("Batch update points on '{}'", collectionName); + ListenableFuture future = getPoints(timeout).updateBatch(request); + addLogFailureCallback(future, "Batch update points"); + return Futures.transform( + future, + UpdateBatchResponse::getResultList, + MoreExecutors.directExecutor()); + } + /** * Look for the points which are closer to stored positive examples and at the same time further to negative * examples, grouped by a given field diff --git a/src/main/java/io/qdrant/client/VectorsFactory.java b/src/main/java/io/qdrant/client/VectorsFactory.java index 6fe9864..bd3e285 100644 --- a/src/main/java/io/qdrant/client/VectorsFactory.java +++ b/src/main/java/io/qdrant/client/VectorsFactory.java @@ -7,6 +7,7 @@ import static io.qdrant.client.VectorFactory.vector; import static io.qdrant.client.grpc.Points.NamedVectors; +import io.qdrant.client.grpc.Points.Vector; import static io.qdrant.client.grpc.Points.Vectors; /** @@ -18,13 +19,13 @@ private VectorsFactory() { /** * Creates named vectors - * @param values A map of vector names to values + * @param values A map of vector names to {@link Vector} * @return a new instance of {@link Vectors} */ - public static Vectors namedVectors(Map> values) { + public static Vectors namedVectors(Map values) { return Vectors.newBuilder() .setVectors(NamedVectors.newBuilder() - .putAllVectors(Maps.transformValues(values, v -> vector(v))) + .putAllVectors(values) ) .build(); } diff --git a/src/test/java/io/qdrant/client/PointsTest.java b/src/test/java/io/qdrant/client/PointsTest.java index 75cb540..7dba74b 100644 --- a/src/test/java/io/qdrant/client/PointsTest.java +++ b/src/test/java/io/qdrant/client/PointsTest.java @@ -14,6 +14,13 @@ import org.testcontainers.shaded.com.google.common.collect.ImmutableSet; import io.qdrant.client.container.QdrantContainer; import io.qdrant.client.grpc.Points.DiscoverPoints; +import io.qdrant.client.grpc.Points.PointVectors; +import io.qdrant.client.grpc.Points.PointsIdsList; +import io.qdrant.client.grpc.Points.PointsSelector; +import io.qdrant.client.grpc.Points.PointsUpdateOperation; +import io.qdrant.client.grpc.Points.UpdateBatchResponse; +import io.qdrant.client.grpc.Points.PointsUpdateOperation.ClearPayload; +import io.qdrant.client.grpc.Points.PointsUpdateOperation.UpdateVectors; import java.util.List; import java.util.concurrent.ExecutionException; @@ -49,6 +56,7 @@ import static io.qdrant.client.TargetVectorFactory.targetVector; import static io.qdrant.client.ValueFactory.value; import static io.qdrant.client.VectorFactory.vector; +import static io.qdrant.client.VectorsFactory.vectors; @Testcontainers class PointsTest { @@ -540,6 +548,28 @@ public void delete_by_filter() throws ExecutionException, InterruptedException { assertEquals(0, points.size()); } + @Test + public void batchPointUpdate() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List operations = List.of( + PointsUpdateOperation.newBuilder() + .setClearPayload(ClearPayload.newBuilder().setPoints( + PointsSelector.newBuilder().setPoints(PointsIdsList.newBuilder().addIds(id(9)))) + .build()) + .build(), + PointsUpdateOperation.newBuilder() + .setUpdateVectors(UpdateVectors.newBuilder() + .addPoints(PointVectors.newBuilder() + .setId(id(9)) + .setVectors(vectors(0.6f, 0.7f)))) + .build()); + + List response = client.batchUpdateAsync(testName, operations).get(); + + response.forEach(result -> assertEquals(UpdateStatus.Completed, result.getStatus())); + } + private void createAndSeedCollection(String collectionName) throws ExecutionException, InterruptedException { CreateCollection request = CreateCollection.newBuilder() .setCollectionName(collectionName)