From 9fab81a571ec9c88132e7ba1db584481e6b4b45c Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 24 Apr 2025 17:46:27 +0200 Subject: [PATCH 1/6] Polishing. Add dynamic projection benchmark. --- pom.xml | 10 +-------- .../RepositoryQueryMethodBenchmarks.java | 7 ++++++ .../data/jpa/benchmark/model/PersonDto.java | 22 +++++++++++++++++++ .../repository/PersonRepository.java | 3 +++ 4 files changed, 33 insertions(+), 9 deletions(-) create mode 100644 spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/model/PersonDto.java diff --git a/pom.xml b/pom.xml index 2b0f8c1740..5ccd77f57f 100755 --- a/pom.xml +++ b/pom.xml @@ -56,17 +56,9 @@ jmh - - - com.github.mp911de.microbenchmark-runner - microbenchmark-runner-junit5 - 0.4.0.RELEASE - test - - - jitpack.io + jitpack https://jitpack.io diff --git a/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/RepositoryQueryMethodBenchmarks.java b/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/RepositoryQueryMethodBenchmarks.java index f49d658a00..0f20652d65 100644 --- a/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/RepositoryQueryMethodBenchmarks.java +++ b/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/RepositoryQueryMethodBenchmarks.java @@ -42,6 +42,7 @@ import org.springframework.data.domain.Sort; import org.springframework.data.jpa.benchmark.model.Person; +import org.springframework.data.jpa.benchmark.model.PersonDto; import org.springframework.data.jpa.benchmark.model.Profile; import org.springframework.data.jpa.benchmark.repository.PersonRepository; import org.springframework.data.jpa.repository.support.JpaRepositoryFactory; @@ -195,6 +196,12 @@ public List stringBasedQueryDynamicSort(BenchmarkParameters parameters) Sort.by(COLUMN_PERSON_FIRSTNAME)); } + @Benchmark + public List stringBasedQueryDynamicSortAndProjection(BenchmarkParameters parameters) { + return parameters.repositoryProxy.findAllWithAnnotatedQueryByFirstname(PERSON_FIRSTNAME, + Sort.by(COLUMN_PERSON_FIRSTNAME), PersonDto.class); + } + @Benchmark public List stringBasedNativeQuery(BenchmarkParameters parameters) { return parameters.repositoryProxy.findAllWithNativeQueryByFirstname(PERSON_FIRSTNAME); diff --git a/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/model/PersonDto.java b/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/model/PersonDto.java new file mode 100644 index 0000000000..6241e6a439 --- /dev/null +++ b/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/model/PersonDto.java @@ -0,0 +1,22 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.benchmark.model; + +/** + * @author Mark Paluch + */ +public record PersonDto(String firstname, String lastname) { +} diff --git a/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/repository/PersonRepository.java b/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/repository/PersonRepository.java index 491ab736a8..81950ab3fa 100644 --- a/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/repository/PersonRepository.java +++ b/spring-data-jpa/src/jmh/java/org/springframework/data/jpa/benchmark/repository/PersonRepository.java @@ -38,6 +38,9 @@ public interface PersonRepository extends ListCrudRepository { @Query("SELECT p FROM org.springframework.data.jpa.benchmark.model.Person p WHERE p.firstname = ?1") List findAllWithAnnotatedQueryByFirstname(String firstname, Sort sort); + @Query("SELECT p FROM org.springframework.data.jpa.benchmark.model.Person p WHERE p.firstname = ?1") + List findAllWithAnnotatedQueryByFirstname(String firstname, Sort sort, Class projection); + @Query(value = "SELECT * FROM person WHERE firstname = ?1", nativeQuery = true) List findAllWithNativeQueryByFirstname(String firstname); From 838aea9ee592bd23d8dcecdfba198b97dcc28dd6 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 10 Apr 2025 11:50:32 +0200 Subject: [PATCH 2/6] Prepare issue branch. --- pom.xml | 4 ++-- spring-data-envers/pom.xml | 4 ++-- spring-data-jpa-distribution/pom.xml | 2 +- spring-data-jpa/pom.xml | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pom.xml b/pom.xml index 5ccd77f57f..bb7111cc06 100755 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-jpa-parent - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-SNAPSHOT pom Spring Data JPA Parent @@ -38,7 +38,7 @@ 5.0 9.1.0 42.7.4 - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-RESULT-SNAPSHOT 0.10.3 org.hibernate diff --git a/spring-data-envers/pom.xml b/spring-data-envers/pom.xml index 43c08369f6..6811c403cb 100755 --- a/spring-data-envers/pom.xml +++ b/spring-data-envers/pom.xml @@ -5,12 +5,12 @@ org.springframework.data spring-data-envers - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-SNAPSHOT org.springframework.data spring-data-jpa-parent - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-SNAPSHOT ../pom.xml diff --git a/spring-data-jpa-distribution/pom.xml b/spring-data-jpa-distribution/pom.xml index af5244a230..56b52d93f1 100644 --- a/spring-data-jpa-distribution/pom.xml +++ b/spring-data-jpa-distribution/pom.xml @@ -14,7 +14,7 @@ org.springframework.data spring-data-jpa-parent - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-SNAPSHOT ../pom.xml diff --git a/spring-data-jpa/pom.xml b/spring-data-jpa/pom.xml index 1cc6674063..ed609134b9 100644 --- a/spring-data-jpa/pom.xml +++ b/spring-data-jpa/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-jpa - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-SNAPSHOT Spring Data JPA Spring Data module for JPA repositories. @@ -16,7 +16,7 @@ org.springframework.data spring-data-jpa-parent - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-SNAPSHOT ../pom.xml From 3c4f4b1040413052c7b5359f174037e70c6f47f9 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 11 Apr 2025 14:44:56 +0200 Subject: [PATCH 3/6] Explore returning Search Results. --- spring-data-jpa/pom.xml | 13 + .../data/jpa/convert/VectorConverters.java | 61 +++ .../jpa/repository/aot/QueriesFactory.java | 3 +- .../repository/query/AbstractJpaQuery.java | 18 +- .../query/JpaCountQueryCreator.java | 2 +- .../query/JpaKeysetScrollQueryCreator.java | 4 +- .../query/JpaParametersParameterAccessor.java | 27 ++ .../jpa/repository/query/JpaQueryCreator.java | 210 ++++++++++- .../repository/query/JpaQueryExecution.java | 70 ++++ .../repository/query/JpqlQueryBuilder.java | 136 +++++-- .../repository/query/ParameterBinding.java | 26 +- .../query/ParameterMetadataProvider.java | 218 ++++------- .../repository/query/PartTreeJpaQuery.java | 9 +- .../repository/PgVectorIntegrationTests.java | 348 ++++++++++++++++++ .../MySqlStoredProcedureIntegrationTests.java | 9 +- ...stgresStoredProcedureIntegrationTests.java | 3 +- ...ProcedureNullHandlingIntegrationTests.java | 10 +- .../query/AbstractJpaQueryTests.java | 2 +- .../query/JpaQueryCreatorTests.java | 21 +- .../query/JpqlQueryBuilderUnitTests.java | 10 + .../ParameterMetadataProviderUnitTests.java | 21 -- .../TestcontainerConfigSupport.java} | 36 +- .../src/test/resources/scripts/pgvector.sql | 7 + 23 files changed, 1031 insertions(+), 233 deletions(-) create mode 100644 spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java create mode 100644 spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java rename spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/{procedures/StoredProcedureConfigSupport.java => support/TestcontainerConfigSupport.java} (74%) create mode 100644 spring-data-jpa/src/test/resources/scripts/pgvector.sql diff --git a/spring-data-jpa/pom.xml b/spring-data-jpa/pom.xml index ed609134b9..757087ec2b 100644 --- a/spring-data-jpa/pom.xml +++ b/spring-data-jpa/pom.xml @@ -88,6 +88,12 @@ true + + org.springframework + spring-test + test + + org.junit.platform junit-platform-launcher @@ -183,6 +189,13 @@ + + ${hibernate.groupId}.orm + hibernate-vector + ${hibernate} + true + + ${hibernate.groupId}.orm hibernate-jpamodelgen diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java new file mode 100644 index 0000000000..d6cf432340 --- /dev/null +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java @@ -0,0 +1,61 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.convert; + +import jakarta.persistence.AttributeConverter; +import jakarta.persistence.Converter; + +import org.jspecify.annotations.Nullable; + +import org.springframework.data.domain.Vector; + +/** + * JPA {@link Converter} for {@link Vector} types. + * + * @author Mark Paluch + * @since 4.0 + */ +public class VectorConverters { + + @Converter(autoApply = true) + public static class VectorAsFloatArrayConverter implements AttributeConverter<@Nullable Vector, @Nullable float[]> { + + @Override + public @Nullable float[] convertToDatabaseColumn(@Nullable Vector vector) { + return vector == null ? null : vector.toFloatArray(); + } + + @Override + public @Nullable Vector convertToEntityAttribute(@Nullable float[] floats) { + return floats == null ? null : Vector.of(floats); + } + } + + @Converter(autoApply = true) + public static class VectorAsDoubleArrayConverter implements AttributeConverter<@Nullable Vector, @Nullable double[]> { + + @Override + public @Nullable double[] convertToDatabaseColumn(@Nullable Vector vector) { + return vector == null ? null : vector.toDoubleArray(); + } + + @Override + public @Nullable Vector convertToEntityAttribute(@Nullable double[] doubles) { + return doubles == null ? null : Vector.of(doubles); + } + } + +} diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/QueriesFactory.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/QueriesFactory.java index 05c49f1144..ee26bf0d06 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/QueriesFactory.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/QueriesFactory.java @@ -224,7 +224,8 @@ private AotQuery createQuery(PartTree partTree, ReturnedType returnedType, JpaPa ParameterMetadataProvider metadataProvider = new ParameterMetadataProvider(parameters, EscapeCharacter.DEFAULT, templates); - JpaQueryCreator queryCreator = new JpaQueryCreator(partTree, returnedType, metadataProvider, templates, metamodel); + JpaQueryCreator queryCreator = new JpaQueryCreator(partTree, false, returnedType, metadataProvider, templates, + metamodel); return StringAotQuery.jpqlQuery(queryCreator.createQuery(), metadataProvider.getBindings(), partTree.getResultLimit(), partTree.isDelete(), partTree.isExistsProjection()); diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java index 4e672ccc80..2718d18691 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java @@ -101,7 +101,7 @@ public AbstractJpaQuery(JpaQueryMethod method, EntityManager em) { return new StreamExecution(); } else if (method.isProcedureQuery()) { return new ProcedureExecution(method.isCollectionQuery()); - } else if (method.isCollectionQuery()) { + } else if (method.isCollectionQuery() || method.isSearchQuery()) { return new CollectionExecution(); } else if (method.isSliceQuery()) { return new SlicedExecution(); @@ -140,7 +140,9 @@ protected JpaMetamodel getMetamodel() { @Override public @Nullable Object execute(Object[] parameters) { - return doExecute(getExecution(), parameters); + + JpaParametersParameterAccessor accessor = obtainParameterAccessor(parameters); + return doExecute(getExecution(accessor), accessor); } /** @@ -148,9 +150,8 @@ protected JpaMetamodel getMetamodel() { * @param values * @return */ - private @Nullable Object doExecute(JpaQueryExecution execution, Object[] values) { + private @Nullable Object doExecute(JpaQueryExecution execution, JpaParametersParameterAccessor accessor) { - JpaParametersParameterAccessor accessor = obtainParameterAccessor(values); Object result = execution.execute(this, accessor); ResultProcessor withDynamicProjection = method.getResultProcessor().withDynamicProjection(accessor); @@ -167,10 +168,17 @@ private JpaParametersParameterAccessor obtainParameterAccessor(Object[] values) return new JpaParametersParameterAccessor(method.getParameters(), values); } - protected JpaQueryExecution getExecution() { + protected JpaQueryExecution getExecution(JpaParametersParameterAccessor accessor) { JpaQueryExecution execution = this.execution.getNullable(); + if (method.isSearchQuery()) { + + ReturnedType returnedType = method.getResultProcessor().withDynamicProjection(accessor).getReturnedType(); + return new JpaQueryExecution.SearchResultExecution(execution == null ? new SingleEntityExecution() : execution, + returnedType, accessor.getScoringFunction()); + } + if (execution != null) { return execution; } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaCountQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaCountQueryCreator.java index c0f5c49d73..b95e272b1c 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaCountQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaCountQueryCreator.java @@ -48,7 +48,7 @@ public class JpaCountQueryCreator extends JpaQueryCreator { public JpaCountQueryCreator(PartTree tree, ReturnedType returnedType, ParameterMetadataProvider provider, JpqlQueryTemplates templates, EntityManager em) { - super(tree, returnedType, provider, templates, em); + super(tree, returnedType, provider, templates, em.getMetamodel()); this.distinct = tree.isDistinct(); this.returnedType = returnedType; diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java index 776657b2af..e7252b510a 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + import org.jspecify.annotations.Nullable; import org.springframework.data.domain.KeysetScrollPosition; @@ -49,7 +51,7 @@ public JpaKeysetScrollQueryCreator(PartTree tree, ReturnedType type, ParameterMe JpqlQueryTemplates templates, JpaEntityInformation entityInformation, KeysetScrollPosition scrollPosition, EntityManager em) { - super(tree, type, provider, templates, em); + super(tree, type, provider, templates, em.getMetamodel()); this.entityInformation = entityInformation; this.scrollPosition = scrollPosition; diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java index 9d22c7bbb4..a436bd1fe6 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java @@ -17,6 +17,9 @@ import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.jpa.repository.query.JpaParameters.JpaParameter; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; @@ -68,4 +71,28 @@ protected Object potentiallyUnwrap(Object parameterValue) { return parameterValue; } + public ScoringFunction getScoringFunction() { + + Score score = getScore(); + if (score != null) { + return score.getFunction(); + } + + JpaParameters parameters = getParameters(); + if (parameters.hasScoreRangeParameter()) { + + Range range = getScoreRange(); + + if (range.getUpperBound().isBounded()) { + return range.getUpperBound().getValue().get().getFunction(); + } + + if (range.getLowerBound().isBounded()) { + return range.getLowerBound().getValue().get().getFunction(); + } + } + + return ScoringFunction.UNSPECIFIED; + } + } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java index c49baf6ff9..cd4670f431 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java @@ -28,13 +28,22 @@ import jakarta.persistence.metamodel.SingularAttribute; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.VectorScoringFunctions; + import org.springframework.data.domain.Sort; import org.springframework.data.jpa.domain.JpaSort; import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.ParameterPlaceholder; @@ -65,6 +74,18 @@ */ public class JpaQueryCreator extends AbstractQueryCreator implements JpqlQueryCreator { + private static final Map DISTANCE_FUNCTIONS = Map.of(VectorScoringFunctions.COSINE, + new DistanceFunction("cosine_distance", Sort.Direction.ASC), VectorScoringFunctions.EUCLIDEAN, + new DistanceFunction("euclidean_distance", Sort.Direction.ASC), VectorScoringFunctions.TAXICAB, + new DistanceFunction("taxicab_distance", Sort.Direction.ASC), VectorScoringFunctions.HAMMING, + new DistanceFunction("hamming_distance", Sort.Direction.ASC), VectorScoringFunctions.INNER_PRODUCT, + new DistanceFunction("negative_inner_product", Sort.Direction.DESC)); + + record DistanceFunction(String distanceFunction, Sort.Direction direction) { + + } + + private final boolean searchQuery; private final ReturnedType returnedType; private final ParameterMetadataProvider provider; private final JpqlQueryTemplates templates; @@ -79,21 +100,28 @@ public class JpaQueryCreator extends AbstractQueryCreator paths = new ArrayList<>(requiredSelection.size()); + List paths = new ArrayList<>(requiredSelection.size()); for (String selection : requiredSelection) { paths.add(JpqlUtils.toExpressionRecursively(metamodel, entity, entityType, PropertyPath.from(selection, returnedType.getDomainType()), true)); } + JpqlQueryBuilder.Expression distance = null; + if (searchQuery) { + distance = getDistanceExpression(); + } + if (useTupleQuery()) { + if (searchQuery) { + paths.add((distance != null ? distance : JpqlQueryBuilder.literal(0)).as("distance")); + } return selectStep.select(paths); } else { - return selectStep.instantiate(returnedType.getReturnedType(), paths); + + JpqlQueryBuilder.ConstructorExpression expression = new JpqlQueryBuilder.ConstructorExpression( + returnedType.getReturnedType().getName(), new JpqlQueryBuilder.Multiselect(entity, paths)); + + List selection = new ArrayList<>(2); + selection.add(expression); + + if (searchQuery) { + selection.add((distance != null ? distance : JpqlQueryBuilder.literal(0)).as("distance")); + } + + return selectStep.select(selection); + } + } + + if (searchQuery) { + + JpqlQueryBuilder.Expression distance = getDistanceExpression(); + + if (distance != null) { + return selectStep.select(new JpqlQueryBuilder.Multiselect(entity, + Arrays.asList(new JpqlQueryBuilder.EntitySelection(entity), distance.as("distance")))); } } @@ -287,6 +357,34 @@ private JpqlQueryBuilder.Select doSelect(Sort sort) { } } + @org.springframework.lang.Nullable + private JpqlQueryBuilder.Expression getDistanceExpression() { + + DistanceFunction distanceFunction = DISTANCE_FUNCTIONS.get(provider.getScoringFunction()); + + if (distanceFunction != null) { + JpqlQueryBuilder.PathExpression pas = JpqlUtils.toExpressionRecursively(metamodel, entity, entityType, + getVectorPath(), true); + return JpqlQueryBuilder.function(distanceFunction.distanceFunction(), pas, + placeholder(provider.getVectorBinding())); + } + + return null; + } + + PropertyPath getVectorPath() { + + for (PartTree.OrPart parts : tree) { + for (Part part : parts) { + if (part.getType() == NEAR || part.getType() == WITHIN) { + return part.getProperty(); + } + } + } + + throw new IllegalStateException("No vector path found"); + } + Collection getRequiredSelection(Sort sort, ReturnedType returnedType) { return returnedType.getInputProperties(); } @@ -419,11 +517,83 @@ public JpqlQueryBuilder.Predicate build() { where = JpqlQueryBuilder.where(entity, property); return type.equals(IS_NOT_EMPTY) ? where.isNotEmpty() : where.isEmpty(); + case WITHIN: + case NEAR: + PartTreeParameterBinding vector = provider.next(part); + PartTreeParameterBinding within = provider.next(part); + + if (within.getValue() instanceof Range r) { + + Range range = (Range) within.getValue(); + + if (range.getUpperBound().isBounded() || range.getUpperBound().isBounded()) { + + Range.Bound lower = range.getLowerBound(); + Range.Bound upper = range.getUpperBound(); + + String distanceFunction = getDistanceFunction(provider.getScoringFunction()); + JpqlQueryBuilder.Expression distance = JpqlQueryBuilder.function(distanceFunction, pas, + placeholder(vector)); + + JpqlQueryBuilder.Predicate lowerPredicate = null; + JpqlQueryBuilder.Predicate upperPredicate = null; + if (lower.isBounded()) { + + JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder + .expression("" + lower.getValue().get().getValue()); + + where = JpqlQueryBuilder.where(distance); + + lowerPredicate = lower.isInclusive() ? where.gte(distanceValue) : where.gt(distanceValue); + } + + if (upper.isBounded()) { + + JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder + .expression("" + upper.getValue().get().getValue()); + + where = JpqlQueryBuilder.where(distance); + + upperPredicate = upper.isInclusive() ? where.lte(distanceValue) : where.lt(distanceValue); + } + + if (lowerPredicate != null && upperPredicate != null) { + return lowerPredicate.and(upperPredicate); + } else if (lowerPredicate != null) { + return lowerPredicate; + } else if (upperPredicate != null) { + return upperPredicate; + } + } + } + + if (within.getValue() instanceof Score score) { + + String distanceFunction = getDistanceFunction(score.getFunction()); + JpqlQueryBuilder.Expression distanceValue = placeholder(within); + JpqlQueryBuilder.Expression distance = JpqlQueryBuilder.function(distanceFunction, pas, + placeholder(vector)); + + return score instanceof Similarity ? JpqlQueryBuilder.where(distance).lte(distanceValue) + : JpqlQueryBuilder.where(distance).gte(distanceValue); + } + default: throw new IllegalArgumentException("Unsupported keyword " + type); } } + private static String getDistanceFunction(ScoringFunction scoringFunction) { + + DistanceFunction distanceFunction = JpaQueryCreator.DISTANCE_FUNCTIONS.get(scoringFunction); + + if (distanceFunction == null) { + throw new IllegalArgumentException("Unsupported ScoringFunction: %s".formatted(scoringFunction.getName())); + } + + return distanceFunction.distanceFunction(); + } + /** * Applies an {@code UPPERCASE} conversion to the given {@link Expression} in case the underlying {@link Part} * requires ignoring case. diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java index 338a2204e8..931c6d3ddb 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java @@ -18,8 +18,10 @@ import jakarta.persistence.EntityManager; import jakarta.persistence.Query; import jakarta.persistence.StoredProcedureQuery; +import jakarta.persistence.Tuple; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; @@ -32,12 +34,17 @@ import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort; import org.springframework.data.jpa.provider.PersistenceProvider; import org.springframework.data.repository.core.support.SurroundingTransactionDetectorMethodInterceptor; +import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.support.PageableExecutionUtils; import org.springframework.data.util.CloseableIterator; import org.springframework.data.util.StreamUtils; @@ -123,6 +130,69 @@ protected Object doExecute(AbstractJpaQuery query, JpaParametersParameterAccesso } } + static class SearchResultExecution extends JpaQueryExecution { + + private final JpaQueryExecution delegate; + private final ReturnedType returnedType; + private final ScoringFunction function; + + SearchResultExecution(JpaQueryExecution delegate, ReturnedType returnedType, ScoringFunction function) { + this.delegate = delegate; + this.returnedType = returnedType; + this.function = function; + } + + @Override + protected @Nullable Object doExecute(AbstractJpaQuery query, JpaParametersParameterAccessor accessor) { + + Object result = delegate.execute(query, accessor); + + if (result instanceof Tuple || result instanceof Object[]) { + return map(result); + } + + if (result instanceof Collection c) { + + List> objects = new ArrayList<>(c.size()); + + for (Object o : c) { + objects.add(o instanceof Tuple || o instanceof Object[] ? map(o) : new SearchResult<>(o, 0)); + } + + return new SearchResults<>(objects); + } + + return result; + } + + private @Nullable SearchResult map(Object result) { + + if (result instanceof Tuple t) { + + Object value = returnedType.needsCustomConstruction() ? t : t.get(0); + try { + return new SearchResult<>(value, Score.of(t.get("distance", Number.class).doubleValue(), function)); + } catch (RuntimeException e) { + return new SearchResult<>(value, Score.of(0, function)); + } + } + + if (result instanceof Object[] objects) { + + Object value = returnedType.needsCustomConstruction() ? objects : objects[0]; + + try { + + return new SearchResult<>(value, Score.of(((Number) (objects[objects.length - 1])).doubleValue(), function)); + } catch (RuntimeException e) { + return new SearchResult<>(value, Score.of(0, function)); + } + } + + return null; + } + } + /** * Executes the query to return a {@link org.springframework.data.domain.Window} of entities. * diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java index 45c804e124..da450a560b 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java @@ -28,9 +28,9 @@ import java.util.Objects; import java.util.function.Supplier; -import org.springframework.data.domain.Sort; - import org.jspecify.annotations.Nullable; + +import org.springframework.data.domain.Sort; import org.springframework.data.mapping.PropertyPath; import org.springframework.data.util.Predicates; import org.springframework.lang.CheckReturnValue; @@ -124,15 +124,20 @@ public Select count() { } @Override - public Select instantiate(String resultType, Collection paths) { + public Select instantiate(String resultType, Collection paths) { return new Select(postProcess(new ConstructorExpression(resultType, new Multiselect(from, paths))), from); } @Override - public Select select(Collection paths) { + public Select select(Collection paths) { return new Select(postProcess(new Multiselect(from, paths)), from); } + @Override + public Select select(Selection selection) { + return new Select(postProcess(selection), from); + } + Selection postProcess(Selection selection) { return distinct ? new DistinctSelection(selection) : selection; } @@ -239,6 +244,17 @@ public static Expression parameter(ParameterPlaceholder placeholder) { return new ParameterExpression(placeholder); } + /** + * Create a new ordering expression. + * + * @param sortExpression + * @return + * @since 4.0 + */ + public static Expression orderBy(Expression sortExpression) { + return new OrderExpression(sortExpression, null, Sort.NullHandling.NATIVE); + } + /** * Create a new ordering expression. * @@ -247,7 +263,19 @@ public static Expression parameter(ParameterPlaceholder placeholder) { * @return */ public static Expression orderBy(Expression sortExpression, Sort.Order order) { - return new OrderExpression(sortExpression, order); + return new OrderExpression(sortExpression, order.getDirection(), order.getNullHandling()); + } + + /** + * Create a new ordering expression. + * + * @param sortExpression + * @param direction + * @return + * @since 4.0 + */ + public static Expression orderBy(Expression sortExpression, Sort.Direction direction) { + return new OrderExpression(sortExpression, direction, Sort.NullHandling.NATIVE); } /** @@ -431,7 +459,7 @@ public interface SelectStep { * @return */ @CheckReturnValue - default Select instantiate(Class resultType, Collection paths) { + default Select instantiate(Class resultType, Collection paths) { return instantiate(resultType.getName(), paths); } @@ -440,10 +468,10 @@ default Select instantiate(Class resultType, Collection paths); + Select instantiate(String resultType, Collection paths); /** * Specify a multi-select. @@ -452,7 +480,7 @@ default Select instantiate(Class resultType, Collection paths); + Select select(Collection paths); /** * Select a single attribute. @@ -465,9 +493,18 @@ default Select select(JpqlQueryBuilder.PathExpression path) { return select(List.of(path)); } + /** + * Select a single attribute. + * + * @param path + * @return + */ + @CheckReturnValue + Select select(Selection selection); + } - interface Selection { + public interface Selection { String render(RenderContext context); } @@ -530,7 +567,7 @@ static PathAndOrigin path(Origin origin, String path) { * * @param source */ - record EntitySelection(Entity source) implements Selection { + record EntitySelection(Entity source) implements Selection, Expression { @Override public String render(RenderContext context) { @@ -568,7 +605,7 @@ public String toString() { * @param resultType * @param multiselect */ - record ConstructorExpression(String resultType, Multiselect multiselect) implements Selection { + record ConstructorExpression(String resultType, Multiselect multiselect) implements Selection, Expression { @Override public String render(RenderContext context) { @@ -588,22 +625,22 @@ public String toString() { * @param source * @param paths */ - record Multiselect(Origin source, Collection paths) implements Selection { + record Multiselect(Origin source, Collection paths) implements Selection { @Override public String render(RenderContext context) { StringBuilder builder = new StringBuilder(); - for (PathExpression path : paths) { + for (Expression path : paths) { if (!builder.isEmpty()) { builder.append(", "); } builder.append(path.render(context)); - if (!context.isConstructorContext()) { - builder.append(" ").append(path.getPropertyPath().getSegment()); + if (!context.isConstructorContext() && path instanceof AliasedExpression ae) { + builder.append(" ").append(ae.getAlias()); } } @@ -677,6 +714,47 @@ public interface Expression { * @return */ String render(RenderContext context); + + default AliasedExpression as(String alias) { + + if (this instanceof DefaultAliasedExpression de) { + return new DefaultAliasedExpression(de.delegate, alias); + } + + return new DefaultAliasedExpression(this, alias); + } + } + + /** + * Aliased expression. + * + * @since 4.0 + */ + public interface AliasedExpression extends Expression { + + /** + * @return the expression alias. + */ + String getAlias(); + + } + + record DefaultAliasedExpression(Expression delegate, String alias) implements AliasedExpression { + + @Override + public String render(RenderContext context) { + return delegate.render(context); + } + + @Override + public String getAlias() { + return alias(); + } + + @Override + public String toString() { + return render(RenderContext.EMPTY); + } } /** @@ -812,7 +890,8 @@ public String toString() { } } - record OrderExpression(Expression sortExpression, Sort.Order order) implements Expression { + record OrderExpression(Expression sortExpression, @org.springframework.lang.Nullable Sort.Direction direction, + Sort.NullHandling nullHandling) implements Expression { @Override public String render(RenderContext context) { @@ -820,14 +899,17 @@ public String render(RenderContext context) { StringBuilder builder = new StringBuilder(); builder.append(sortExpression.render(context)); - builder.append(" "); - builder.append(order.isDescending() ? TOKEN_DESC : TOKEN_ASC); + if (direction != null) { - if (order.getNullHandling() == Sort.NullHandling.NULLS_FIRST) { - builder.append(" NULLS FIRST"); - } else if (order.getNullHandling() == Sort.NullHandling.NULLS_LAST) { - builder.append(" NULLS LAST"); + builder.append(" "); + builder.append(direction.isDescending() ? TOKEN_DESC : TOKEN_ASC); + + if (nullHandling == Sort.NullHandling.NULLS_FIRST) { + builder.append(" NULLS FIRST"); + } else if (nullHandling == Sort.NullHandling.NULLS_LAST) { + builder.append(" NULLS LAST"); + } } return builder.toString(); @@ -1395,7 +1477,8 @@ public String toString() { * @param origin * @param onTheJoin whether the path should target the join itself instead of matching {@link PropertyPath}. */ - record PathAndOrigin(PropertyPath path, Origin origin, boolean onTheJoin) implements PathExpression { + record PathAndOrigin(PropertyPath path, Origin origin, + boolean onTheJoin) implements PathExpression, AliasedExpression { @Override public PropertyPath getPropertyPath() { @@ -1411,6 +1494,11 @@ public String render(RenderContext context) { return context.getAlias(origin()); } } + + @Override + public String getAlias() { + return path().getSegment(); + } } /** diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java index 040e84a8ed..90e90f14fb 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java @@ -27,6 +27,9 @@ import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; import org.springframework.data.expression.ValueExpression; import org.springframework.data.jpa.provider.PersistenceProvider; import org.springframework.data.jpa.repository.support.JpqlQueryTemplates; @@ -160,6 +163,19 @@ public String toString() { * @param valueToBind value to prepare */ public @Nullable Object prepare(@Nullable Object valueToBind) { + + if (valueToBind instanceof Similarity similarity) { + return 1 - similarity.getValue(); + } + + if (valueToBind instanceof Score score) { + return score.getValue(); + } + + if (valueToBind instanceof Vector v) { + return v.getType() == Float.TYPE ? v.toFloatArray() : v.toDoubleArray(); + } + return valueToBind; } @@ -216,6 +232,7 @@ static class PartTreeParameterBinding extends ParameterBinding { private final Type type; private final boolean ignoreCase; private final boolean noWildcards; + private final @Nullable Object value; public PartTreeParameterBinding(BindingIdentifier identifier, ParameterOrigin origin, Class parameterType, Part part, @Nullable Object value, JpqlQueryTemplates templates, EscapeCharacter escape) { @@ -225,7 +242,7 @@ public PartTreeParameterBinding(BindingIdentifier identifier, ParameterOrigin or this.parameterType = parameterType; this.templates = templates; this.escape = escape; - + this.value = value; this.type = value == null && (Type.SIMPLE_PROPERTY.equals(part.getType()) || Type.NEGATING_SIMPLE_PROPERTY.equals(part.getType())) ? Type.IS_NULL @@ -241,9 +258,14 @@ public boolean isIsNullParameter() { return Type.IS_NULL.equals(type); } + public @Nullable Object getValue() { + return value; + } + @Override public @Nullable Object prepare(@Nullable Object value) { + value = super.prepare(value); if (value == null || parameterType == null) { return value; } @@ -389,7 +411,7 @@ public Type getType() { @Override public @Nullable Object prepare(@Nullable Object value) { - Object unwrapped = PersistenceProvider.unwrapTypedParameterValue(value); + Object unwrapped = PersistenceProvider.unwrapTypedParameterValue(super.prepare(value)); if (unwrapped == null) { return null; } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java index 72d43ab5bd..96507ba1cb 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java @@ -20,33 +20,27 @@ import jakarta.persistence.criteria.CriteriaBuilder; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Vector; import org.springframework.data.jpa.provider.PersistenceProvider; import org.springframework.data.jpa.repository.support.JpqlQueryTemplates; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; import org.springframework.data.repository.query.ParametersParameterAccessor; import org.springframework.data.repository.query.parser.Part; -import org.springframework.data.repository.query.parser.Part.IgnoreCaseType; -import org.springframework.data.repository.query.parser.Part.Type; import org.springframework.expression.Expression; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; -import org.springframework.util.CollectionUtils; -import org.springframework.util.ObjectUtils; /** - * Helper class to allow easy creation of {@link ParameterMetadata}s. + * Helper class to allow easy creation of {@link PartTreeParameterBinding}s. * * @author Oliver Gierke * @author Thomas Darimont @@ -60,9 +54,13 @@ */ public class ParameterMetadataProvider { + static final Object PLACEHOLDER = new Object(); + private final Iterator parameters; + private final @Nullable JpaParametersParameterAccessor accessor; private final List bindings; private final Set syntheticParameterNames = new LinkedHashSet<>(); + private @Nullable ParameterBinding vector; private final @Nullable Iterator bindableParameterValues; private final EscapeCharacter escape; private final JpqlQueryTemplates templates; @@ -79,7 +77,7 @@ public class ParameterMetadataProvider { */ public ParameterMetadataProvider(JpaParametersParameterAccessor accessor, EscapeCharacter escape, JpqlQueryTemplates templates) { - this(accessor.iterator(), accessor.getParameters(), escape, templates); + this(accessor.iterator(), accessor, accessor.getParameters(), escape, templates); } /** @@ -92,7 +90,7 @@ public ParameterMetadataProvider(JpaParametersParameterAccessor accessor, */ public ParameterMetadataProvider(JpaParameters parameters, EscapeCharacter escape, JpqlQueryTemplates templates) { - this(null, parameters, escape, templates); + this(null, null, parameters, escape, templates); } /** @@ -104,14 +102,15 @@ public ParameterMetadataProvider(JpaParameters parameters, EscapeCharacter escap * @param escape must not be {@literal null}. * @param templates must not be {@literal null}. */ - private ParameterMetadataProvider(@Nullable Iterator bindableParameterValues, JpaParameters parameters, + private ParameterMetadataProvider(@Nullable Iterator bindableParameterValues, + @Nullable JpaParametersParameterAccessor accessor, JpaParameters parameters, EscapeCharacter escape, JpqlQueryTemplates templates) { - Assert.notNull(parameters, "Parameters must not be null"); Assert.notNull(escape, "EscapeCharacter must not be null"); Assert.notNull(templates, "JpqlQueryTemplates must not be null"); this.jpaParameters = parameters; + this.accessor = accessor; this.parameters = parameters.getBindableParameters().iterator(); this.bindings = new ArrayList<>(); this.bindableParameterValues = bindableParameterValues; @@ -119,6 +118,10 @@ private ParameterMetadataProvider(@Nullable Iterator bindableParameterVa this.templates = templates; } + public JpaParameters getParameters() { + return this.jpaParameters; + } + /** * Returns all {@link ParameterBinding}s built. * @@ -132,7 +135,7 @@ public List getBindings() { * Builds a new {@link PartTreeParameterBinding} for given {@link Part} and the next {@link Parameter}. */ @SuppressWarnings("unchecked") - public PartTreeParameterBinding next(Part part) { + PartTreeParameterBinding next(Part part) { Assert.isTrue(parameters.hasNext(), () -> String.format("No parameter available for part %s", part)); @@ -144,12 +147,11 @@ public PartTreeParameterBinding next(Part part) { * Builds a new {@link PartTreeParameterBinding} of the given {@link Part} and type. Forwards the underlying * {@link Parameters} as well. * - * @param is the type parameter of the returned {@link ParameterMetadata}. + * @param is the type parameter of the returned {@link PartTreeParameterBinding}. * @param type must not be {@literal null}. * @return ParameterMetadata for the next parameter. */ - @SuppressWarnings("unchecked") - public PartTreeParameterBinding next(Part part, Class type) { + PartTreeParameterBinding next(Part part, Class type) { Parameter parameter = parameters.next(); Class typeToUse = ClassUtils.isAssignable(type, parameter.getType()) ? parameter.getType() : type; @@ -159,11 +161,11 @@ public PartTreeParameterBinding next(Part part, Class type) { /** * Builds a new {@link PartTreeParameterBinding} for the given type and name. * - * @param type parameter for the returned {@link ParameterMetadata}. + * @param type parameter for the returned {@link PartTreeParameterBinding}. * @param part must not be {@literal null}. * @param type must not be {@literal null}. - * @param parameter providing the name for the returned {@link ParameterMetadata}. - * @return a new {@link ParameterMetadata} for the given type and name. + * @param parameter providing the name for the returned {@link PartTreeParameterBinding}. + * @return a new {@link PartTreeParameterBinding} for the given type and name. */ private PartTreeParameterBinding next(Part part, Class type, Parameter parameter) { @@ -175,8 +177,7 @@ private PartTreeParameterBinding next(Part part, Class type, Parameter pa @SuppressWarnings("unchecked") Class reifiedType = Expression.class.equals(type) ? (Class) Object.class : type; - Object value = bindableParameterValues == null ? ParameterMetadata.PLACEHOLDER : bindableParameterValues.next(); - + Object value = bindableParameterValues == null ? PLACEHOLDER : bindableParameterValues.next(); int currentPosition = ++position; BindingIdentifier bindingIdentifier = parameter.getName().map(it -> BindingIdentifier.of(it, currentPosition)) @@ -187,158 +188,89 @@ private PartTreeParameterBinding next(Part part, Class type, Parameter pa PartTreeParameterBinding binding = new PartTreeParameterBinding(bindingIdentifier, methodParameter, reifiedType, part, value, templates, escape); + // PartTreeParameterBinding is more expressive than a potential ParameterBinding for Vector. bindings.add(binding); - return binding; - } + if (Vector.class.isAssignableFrom(parameter.getType())) { + this.vector = binding; + } - EscapeCharacter getEscape() { - return escape; + return binding; } - /** - * Builds a new synthetic {@link ParameterBinding} for the given value. - * - * @param nameHint - * @param value - * @param source - * @return a new {@link ParameterBinding} for the given value and source. - */ - public ParameterBinding nextSynthetic(String nameHint, Object value, Object source) { - - int currentPosition = ++position; - String bindingName = nameHint; - - if (!syntheticParameterNames.add(bindingName)) { + ScoringFunction getScoringFunction() { - bindingName = bindingName + "_" + currentPosition; - syntheticParameterNames.add(bindingName); + if (accessor != null) { + return accessor.getScoringFunction(); } - return new ParameterBinding(BindingIdentifier.of(bindingName, currentPosition), - ParameterOrigin.synthetic(value, source)); - } - - public JpaParameters getParameters() { - return this.jpaParameters; + return ScoringFunction.UNSPECIFIED; } - /** - * @author Oliver Gierke - * @author Thomas Darimont - * @author Andrey Kovalev - */ - public static class ParameterMetadata { - - static final Object PLACEHOLDER = new Object(); + ParameterBinding getVectorBinding() { - private final Class parameterType; - private final Type type; - private final int position; - private final JpqlQueryTemplates templates; - private final EscapeCharacter escape; - private final boolean ignoreCase; - private final boolean noWildcards; - - /** - * Creates a new {@link ParameterMetadata}. - */ - public ParameterMetadata(Class parameterType, Part part, @Nullable Object value, EscapeCharacter escape, - int position, JpqlQueryTemplates templates) { - - this.parameterType = parameterType; - this.position = position; - this.templates = templates; - this.type = value == null - && (Type.SIMPLE_PROPERTY.equals(part.getType()) || Type.NEGATING_SIMPLE_PROPERTY.equals(part.getType())) - ? Type.IS_NULL - : part.getType(); - this.ignoreCase = IgnoreCaseType.ALWAYS.equals(part.shouldIgnoreCase()); - this.noWildcards = part.getProperty().getLeafProperty().isCollection(); - this.escape = escape; + if (!getParameters().hasVectorParameter()) { + throw new IllegalStateException("Vector parameter not available"); } - public int getPosition() { - return position; + if (this.vector != null) { + return this.vector; } - public Class getParameterType() { - return parameterType; - } + int vectorIndex = getParameters().getVectorIndex(); - /** - * Returns whether the parameter shall be considered an {@literal IS NULL} parameter. - */ - public boolean isIsNullParameter() { - return Type.IS_NULL.equals(type); - } - - /** - * Prepares the object before it's actually bound to the {@link jakarta.persistence.Query;}. - * - * @param value can be {@literal null}. - */ - public @Nullable Object prepare(@Nullable Object value) { + BindingIdentifier bindingIdentifier = BindingIdentifier.of(vectorIndex + 1); - if (value == null || parameterType == null) { - return value; - } + /* identifier refers to bindable parameters, not _all_ parameters index */ + MethodInvocationArgument methodParameter = ParameterOrigin.ofParameter(bindingIdentifier); + ParameterBinding parameterBinding = new ParameterBinding(bindingIdentifier, methodParameter); - if (String.class.equals(parameterType) && !noWildcards) { + this.bindings.add(parameterBinding); - return switch (type) { - case STARTING_WITH -> String.format("%s%%", escape.escape(value.toString())); - case ENDING_WITH -> String.format("%%%s", escape.escape(value.toString())); - case CONTAINING, NOT_CONTAINING -> String.format("%%%s%%", escape.escape(value.toString())); - default -> value; - }; - } + return parameterBinding; + } - return Collection.class.isAssignableFrom(parameterType) // - ? potentiallyIgnoreCase(ignoreCase, toCollection(value)) // - : value; - } + private void maybeAdd(ParameterBinding parameterBinding) { - /** - * Returns the given argument as {@link Collection} which means it will return it as is if it's a - * {@link Collections}, turn an array into an {@link ArrayList} or simply wrap any other value into a single element - * {@link Collections}. - * - * @param value the value to be converted to a {@link Collection}. - * @return the object itself as a {@link Collection} or a {@link Collection} constructed from the value. - */ - private static @Nullable Collection toCollection(@Nullable Object value) { + boolean found = false; - if (value == null) { - return null; - } + for (ParameterBinding existing : bindings) { - if (value instanceof Collection collection) { - return collection.isEmpty() ? null : collection; + if (existing.isCompatibleWith(parameterBinding)) { + found = true; } + } - if (ObjectUtils.isArray(value)) { + if (!found) { + bindings.add(parameterBinding); + } + } - List collection = Arrays.asList(ObjectUtils.toObjectArray(value)); - return collection.isEmpty() ? null : collection; - } + EscapeCharacter getEscape() { + return escape; + } - return Collections.singleton(value); - } + /** + * Builds a new synthetic {@link ParameterBinding} for the given value. + * + * @param nameHint + * @param value + * @param source + * @return a new {@link ParameterBinding} for the given value and source. + */ + ParameterBinding nextSynthetic(String nameHint, Object value, Object source) { - @SuppressWarnings("unchecked") - private @Nullable Collection potentiallyIgnoreCase(boolean ignoreCase, @Nullable Collection collection) { + int currentPosition = ++position; + String bindingName = nameHint; - if (!ignoreCase || CollectionUtils.isEmpty(collection)) { - return collection; - } + if (!syntheticParameterNames.add(bindingName)) { - return ((Collection) collection).stream() // - .map(it -> it == null // - ? null // - : templates.ignoreCase(it)) // - .collect(Collectors.toList()); + bindingName = bindingName + "_" + currentPosition; + syntheticParameterNames.add(bindingName); } + return new ParameterBinding(BindingIdentifier.of(bindingName, currentPosition), + ParameterOrigin.synthetic(value, source)); } + } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java index 66dac47929..5990e1cb55 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java @@ -124,7 +124,7 @@ public TypedQuery doCreateCountQuery(JpaParametersParameterAccessor access } @Override - protected JpaQueryExecution getExecution() { + protected JpaQueryExecution getExecution(JpaParametersParameterAccessor accessor) { if (this.getQueryMethod().isScrollQuery()) { return new ScrollExecution(this.tree.getSort(), new ScrollDelegate<>(entityInformation)); @@ -134,7 +134,7 @@ protected JpaQueryExecution getExecution() { return new ExistsExecution(); } - return super.getExecution(); + return super.getExecution(accessor); } private static void validate(PartTree tree, JpaParameters parameters, String methodName) { @@ -297,9 +297,10 @@ protected JpqlQueryCreator createCreator(Sort sort, JpaParametersParameterAccess } JpqlQueryCreator creator = new CacheableJpqlQueryCreator(sort, - new JpaQueryCreator(tree, returnedType, provider, templates, em)); + new JpaQueryCreator(tree, getQueryMethod().isSearchQuery(), returnedType, provider, templates, + em.getMetamodel())); - if (accessor.getParameters().hasDynamicProjection()) { + if (accessor.getParameters().hasDynamicProjection() || getQueryMethod().isSearchQuery()) { return creator; } diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java new file mode 100644 index 0000000000..600652261b --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java @@ -0,0 +1,348 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.data.jpa.repository; + +import static org.assertj.core.api.Assertions.*; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.Table; + +import java.net.URL; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.type.SqlTypes; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan.Filter; +import org.springframework.context.annotation.FilterType; +import org.springframework.core.io.ClassPathResource; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; +import org.springframework.data.domain.VectorScoringFunctions; +import org.springframework.data.jpa.convert.VectorConverters; +import org.springframework.data.jpa.repository.config.EnableJpaRepositories; +import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; +import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; +import org.springframework.test.annotation.Rollback; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; +import org.springframework.transaction.annotation.EnableTransactionManagement; +import org.springframework.transaction.annotation.Transactional; + +import org.testcontainers.containers.PostgreSQLContainer; + +/** + * Testcase to verify {@link org.springframework.jdbc.object.StoredProcedure}s work with Postgres. + * + * @author Mark Paluch + */ +@Transactional +@Rollback(value = false) +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = PgVectorIntegrationTests.Config.class) +class PgVectorIntegrationTests { + + Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f); + + @Autowired VectorSearchRepository repository; + + @BeforeEach + void setUp() { + + WithVector w1 = new WithVector("de", "one", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f }); + WithVector w2 = new WithVector("de", "two", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f }); + WithVector w3 = new WithVector("en", "three", new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f }); + WithVector w4 = new WithVector("de", "four", + new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f }); + + repository.deleteAllInBatch(); + repository.saveAllAndFlush(Arrays.asList(w1, w2, w3, w4)); + } + + @ParameterizedTest + @MethodSource("scoringFunctions") + void shouldApplyVectorSearchWithDistance(VectorScoringFunctions functions) { + + SearchResults results = repository.searchTop2ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.of(0.1, functions)); + + assertThat(results).hasSize(2).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de"); + + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsExactlyInAnyOrder("two", "one"); + } + + static Set scoringFunctions() { + return EnumSet.of(VectorScoringFunctions.COSINE, VectorScoringFunctions.INNER_PRODUCT, + VectorScoringFunctions.EUCLIDEAN); + } + + @Test + void shouldRunStringQuery() { + + List results = repository.findAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(2, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(WithVector::getCountry).containsOnly("de", "de", "de"); + assertThat(results).extracting(WithVector::getDescription).containsSequence("two", "one", "four"); + } + + @Test + void shouldRunStringQueryWithDistance() { + + SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(2, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de", "de"); + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsSequence("two", "one", "four"); + + SearchResult result = results.getContent().get(0); + assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0); + assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE); + } + + @Test + void shouldApplyVectorSearchWithRange() { + + SearchResults results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR, + Score.between(0, 1, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de", "de"); + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsSequence("two", "one", "four"); + } + + @Test + void shouldApplyVectorSearchAndReturnList() { + + List results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(0, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(WithVector::getCountry).containsOnly("de", "de", "de"); + assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four"); + + } + + @Test + void shouldProjectVectorSearchAsInterface() { + + SearchResults results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de", + VECTOR, Score.of(0, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) + .containsSequence("two", "one", "four"); + } + + @Test + void shouldProjectVectorSearchAsDto() { + + SearchResults results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(0, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) + .containsSequence("two", "one", "four"); + } + + @Test + void shouldProjectVectorSearchDynamically() { + + SearchResults dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(0, VectorScoringFunctions.COSINE), DescriptionDto.class); + + assertThat(dtos).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) + .containsSequence("two", "one", "four"); + + SearchResults proxies = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(0, VectorScoringFunctions.COSINE), WithDescription.class); + + assertThat(proxies).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) + .containsSequence("two", "one", "four"); + } + + @Entity + @Table(name = "with_vector") + public static class WithVector { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) // + private Integer id; + + private String country; + private String description; + + @Column(name = "the_embedding") + @JdbcTypeCode(SqlTypes.VECTOR) + @Array(length = 5) private float[] embedding; + + public WithVector() {} + + public WithVector(String country, String description, float[] embedding) { + this.country = country; + this.description = description; + this.embedding = embedding; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getCountry() { + return country; + } + + public void setCountry(String country) { + this.country = country; + } + + public String getDescription() { + return description; + } + + public float[] getEmbedding() { + return embedding; + } + + public void setEmbedding(float[] embedding) { + this.embedding = embedding; + } + } + + interface WithDescription { + String getDescription(); + } + + static class DescriptionDto { + + private final String description; + + public DescriptionDto(String description) { + this.description = description; + } + + public String getDescription() { + return description; + } + } + + interface VectorSearchRepository extends JpaRepository { + + List findAllByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + + @Query(""" + SELECT w FROM org.springframework.data.jpa.repository.PgVectorIntegrationTests$WithVector w + WHERE w.country = ?1 + AND cosine_distance(w.embedding, :embedding) <= :distance + ORDER BY cosine_distance(w.embedding, :embedding) asc""") + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + + @Query(""" + SELECT w, cosine_distance(w.embedding, :embedding) as distance FROM org.springframework.data.jpa.repository.PgVectorIntegrationTests$WithVector w + WHERE w.country = ?1 + AND cosine_distance(w.embedding, :embedding) <= :distance + ORDER BY distance asc""") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + + SearchResults searchAllByCountryAndEmbeddingWithin(String country, Vector embedding, + Range distance); + + SearchResults searchTop2ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + + SearchResults searchInterfaceProjectionByCountryAndEmbeddingWithin(String country, + Vector embedding, Score distance); + + SearchResults searchDtoByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + + SearchResults searchDynamicByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, + Class projection); + + } + + @EnableJpaRepositories(considerNestedRepositories = true, + includeFilters = @Filter(type = FilterType.ASSIGNABLE_TYPE, classes = VectorSearchRepository.class)) + @EnableTransactionManagement + static class Config extends TestcontainerConfigSupport { + + public Config() { + super(PostgreSQLDialect.class, new ClassPathResource("scripts/pgvector.sql")); + } + + @Override + protected String getSchemaAction() { + return "none"; + } + + @Override + protected PersistenceManagedTypes getManagedTypes() { + return new PersistenceManagedTypes() { + @Override + public List getManagedClassNames() { + return List.of(WithVector.class.getName(), VectorConverters.VectorAsDoubleArrayConverter.class.getName(), + VectorConverters.VectorAsFloatArrayConverter.class.getName()); + } + + @Override + public List getManagedPackages() { + return List.of(); + } + + @Override + public @Nullable URL getPersistenceUnitRootUrl() { + return null; + } + }; + } + + @SuppressWarnings("resource") + @Bean(initMethod = "start", destroyMethod = "start") + public PostgreSQLContainer container() { + + return new PostgreSQLContainer<>("pgvector/pgvector:pg17") // + .withUsername("postgres").withReuse(true); + } + } +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java index 5366736fc9..5981f5d456 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java @@ -23,6 +23,7 @@ import jakarta.persistence.Id; import jakarta.persistence.NamedStoredProcedureQuery; +import java.util.Collection; import java.util.List; import java.util.Objects; @@ -38,6 +39,7 @@ import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.config.EnableJpaRepositories; import org.springframework.data.jpa.repository.query.Procedure; +import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.transaction.annotation.EnableTransactionManagement; @@ -223,12 +225,17 @@ public interface EmployeeRepositoryWithNoCursor extends JpaRepository getPackagesToScan() { + return List.of(getClass().getPackageName()); + } + @SuppressWarnings("resource") @Bean(initMethod = "start", destroyMethod = "stop") public MySQLContainer container() { diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java index a88e23f9a6..5b9a790082 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java @@ -42,6 +42,7 @@ import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.config.EnableJpaRepositories; import org.springframework.data.jpa.repository.query.Procedure; +import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; import org.springframework.data.jpa.util.DisabledOnHibernate; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -292,7 +293,7 @@ public interface EmployeeRepositoryWithRefCursor extends JpaRepository { @EnableJpaRepositories(considerNestedRepositories = true, includeFilters = @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = TestModelRepository.class)) @EnableTransactionManagement - static class Config extends StoredProcedureConfigSupport { + static class Config extends TestcontainerConfigSupport { public Config() { super(PostgreSQLDialect.class, new ClassPathResource("scripts/postgres-nullable-stored-procedures.sql")); } + @Override + protected Collection getPackagesToScan() { + return List.of(getClass().getPackageName()); + } + @SuppressWarnings("resource") @Bean(initMethod = "start", destroyMethod = "stop") public PostgreSQLContainer container() { diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/AbstractJpaQueryTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/AbstractJpaQueryTests.java index 8728e03229..059f090f5f 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/AbstractJpaQueryTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/AbstractJpaQueryTests.java @@ -221,7 +221,7 @@ class DummyJpaQuery extends AbstractJpaQuery { } @Override - protected JpaQueryExecution getExecution() { + protected JpaQueryExecution getExecution(JpaParametersParameterAccessor accessor) { return execution; } diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryCreatorTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryCreatorTests.java index 55e9f39122..582670223a 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryCreatorTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryCreatorTests.java @@ -40,8 +40,11 @@ import org.junit.jupiter.params.provider.ValueSource; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.jpa.repository.support.JpqlQueryTemplates; import org.springframework.data.jpa.util.TestMetaModel; import org.springframework.data.projection.ProjectionFactory; @@ -743,7 +746,8 @@ JpaQueryCreator queryCreator(PartTree tree, ReturnedType returnedType, Metamodel ParameterMetadataProvider parameterMetadataProvider = new ParameterMetadataProvider(parameterAccessor, EscapeCharacter.DEFAULT, templates); - return new JpaQueryCreator(tree, returnedType, parameterMetadataProvider, templates, entityManager); + return new JpaQueryCreator(tree, false, returnedType, parameterMetadataProvider, templates, + entityManager.getMetamodel()); } @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -979,6 +983,21 @@ public int bindingIndexFor(String placeholder) { public ParameterAccessor bindableParameters() { return new ParameterAccessor() { + @Override + public @Nullable Vector getVector() { + return null; + } + + @Override + public @Nullable Score getScore() { + return null; + } + + @Override + public @Nullable Range getScoreRange() { + return null; + } + @Override public @Nullable ScrollPosition getScrollPosition() { return null; diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilderUnitTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilderUnitTests.java index d2ac172373..46952dee71 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilderUnitTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilderUnitTests.java @@ -33,6 +33,7 @@ * Unit tests for {@link JpqlQueryBuilder}. * * @author Christoph Strobl + * @author Mark Paluch */ class JpqlQueryBuilderUnitTests { @@ -77,6 +78,15 @@ void literalExpressionRendersAsIs() { assertThat(expression.render(RenderContext.EMPTY)).isEqualTo("CONCAT(person.lastName, ‘, ’, person.firstName))"); } + @Test // GH- + void aliasedExpression() { + + // aliasing is contextual and happens during selection rendering. E.g. constructor expressions don't use aliases. + Expression expression = expression("CONCAT(person.lastName, ‘, ’, person.firstName)").as("concatted"); + assertThat(expression.render(RenderContext.EMPTY)) + .isEqualTo("CONCAT(person.lastName, ‘, ’, person.firstName)"); + } + @Test // GH-3588 void xxx() { diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderUnitTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderUnitTests.java index 4ad41bfd14..30895c94b3 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderUnitTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderUnitTests.java @@ -61,25 +61,4 @@ void errorMessageMentionsParametersWhenParametersAreExhausted() { .withMessageContaining("parameter"); } - @Test // GH-3137 - void returnAugmentedValueForStringExpressions() { - - when(part.getProperty().getLeafProperty().isCollection()).thenReturn(false); - when(part.getProperty().getType()).thenReturn((Class) String.class); - - assertThat(createParameterMetadata(Part.Type.STARTING_WITH).prepare("starting with")).isEqualTo("starting with%"); - assertThat(createParameterMetadata(Part.Type.ENDING_WITH).prepare("ending with")).isEqualTo("%ending with"); - assertThat(createParameterMetadata(Part.Type.CONTAINING).prepare("containing")).isEqualTo("%containing%"); - assertThat(createParameterMetadata(Part.Type.NOT_CONTAINING).prepare("not containing")) - .isEqualTo("%not containing%"); - assertThat(createParameterMetadata(Part.Type.LIKE).prepare("%like%")).isEqualTo("%like%"); - assertThat(createParameterMetadata(Part.Type.IS_NULL).prepare(null)).isEqualTo(null); - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private ParameterMetadataProvider.ParameterMetadata createParameterMetadata(Part.Type partType) { - - when(part.getType()).thenReturn(partType); - return new ParameterMetadataProvider.ParameterMetadata(part.getProperty().getType(), part, null, EscapeCharacter.DEFAULT, 1, JpqlQueryTemplates.LOWER); - } } diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/StoredProcedureConfigSupport.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/support/TestcontainerConfigSupport.java similarity index 74% rename from spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/StoredProcedureConfigSupport.java rename to spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/support/TestcontainerConfigSupport.java index 998245126b..505ad4b089 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/StoredProcedureConfigSupport.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/support/TestcontainerConfigSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2025 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.data.jpa.repository.procedures; +package org.springframework.data.jpa.repository.support; import jakarta.persistence.EntityManagerFactory; +import java.util.Collection; +import java.util.Collections; import java.util.Properties; import javax.sql.DataSource; @@ -29,6 +31,8 @@ import org.springframework.orm.jpa.AbstractEntityManagerFactoryBean; import org.springframework.orm.jpa.JpaTransactionManager; import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean; +import org.springframework.orm.jpa.persistenceunit.ManagedClassNameFilter; +import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter; import org.springframework.transaction.PlatformTransactionManager; @@ -39,12 +43,12 @@ * * @author Mark Paluch */ -class StoredProcedureConfigSupport { +public class TestcontainerConfigSupport { private final Class dialect; private final Resource initScript; - StoredProcedureConfigSupport(Class dialect, Resource initScript) { + protected TestcontainerConfigSupport(Class dialect, Resource initScript) { this.dialect = dialect; this.initScript = initScript; } @@ -67,16 +71,36 @@ AbstractEntityManagerFactoryBean entityManagerFactory(DataSource dataSource) { factoryBean.setDataSource(dataSource); factoryBean.setPersistenceUnitRootLocation("simple-persistence"); factoryBean.setJpaVendorAdapter(new HibernateJpaVendorAdapter()); - factoryBean.setPackagesToScan(this.getClass().getPackage().getName()); + + factoryBean.setManagedTypes(getManagedTypes()); + factoryBean.setPackagesToScan(getPackagesToScan().toArray(new String[0])); + factoryBean.setManagedClassNameFilter(getManagedClassNameFilter()); Properties properties = new Properties(); - properties.setProperty("hibernate.hbm2ddl.auto", "create"); + properties.setProperty("hibernate.hbm2ddl.auto", getSchemaAction()); properties.setProperty("hibernate.dialect", dialect.getCanonicalName()); + factoryBean.setJpaProperties(properties); return factoryBean; } + protected String getSchemaAction() { + return "create"; + } + + protected PersistenceManagedTypes getManagedTypes() { + return null; + } + + protected Collection getPackagesToScan() { + return Collections.emptyList(); + } + + protected ManagedClassNameFilter getManagedClassNameFilter() { + return className -> true; + } + @Bean PlatformTransactionManager transactionManager(EntityManagerFactory entityManagerFactory) { return new JpaTransactionManager(entityManagerFactory); diff --git a/spring-data-jpa/src/test/resources/scripts/pgvector.sql b/spring-data-jpa/src/test/resources/scripts/pgvector.sql new file mode 100644 index 0000000000..4057dd9528 --- /dev/null +++ b/spring-data-jpa/src/test/resources/scripts/pgvector.sql @@ -0,0 +1,7 @@ +CREATE EXTENSION IF NOT EXISTS vector; + +DROP TABLE IF EXISTS with_vector; + +CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10),the_embedding vector(5)); + +CREATE INDEX ON with_vector USING hnsw (the_embedding vector_l2_ops); From 55e6483feff3d12140b60bb1a268239185988e77 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 28 Apr 2025 18:02:26 +0200 Subject: [PATCH 4/6] Add SimilarityNormalizer. --- pom.xml | 8 + .../repository/query/AbstractJpaQuery.java | 2 +- .../query/JpaParametersParameterAccessor.java | 45 ++++- .../jpa/repository/query/JpaQueryCreator.java | 72 +++++--- .../repository/query/JpaQueryExecution.java | 24 ++- .../repository/query/ParameterBinding.java | 59 ++++++- .../query/ParameterMetadataProvider.java | 159 +++++++++++++++--- .../repository/query/PartTreeJpaQuery.java | 14 +- .../query/QueryParameterSetterFactory.java | 4 + .../query/SimilarityNormalizer.java | 126 ++++++++++++++ .../repository/PgVectorIntegrationTests.java | 60 +++++-- ...meterMetadataProviderIntegrationTests.java | 59 +++++++ .../query/SimilarityNormalizerUnitTests.java | 70 ++++++++ 13 files changed, 618 insertions(+), 84 deletions(-) create mode 100644 spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/SimilarityNormalizer.java create mode 100644 spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java diff --git a/pom.xml b/pom.xml index bb7111cc06..f5026ee712 100755 --- a/pom.xml +++ b/pom.xml @@ -56,6 +56,14 @@ jmh + + + com.github.mp911de.microbenchmark-runner + microbenchmark-runner-junit5 + 0.5.0.RELEASE + test + + jitpack diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java index 2718d18691..b7ac2b9127 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java @@ -176,7 +176,7 @@ protected JpaQueryExecution getExecution(JpaParametersParameterAccessor accessor ReturnedType returnedType = method.getResultProcessor().withDynamicProjection(accessor).getReturnedType(); return new JpaQueryExecution.SearchResultExecution(execution == null ? new SingleEntityExecution() : execution, - returnedType, accessor.getScoringFunction()); + returnedType, accessor.getScoringFunction(), accessor.normalizeSimilarity()); } if (execution != null) { diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java index a436bd1fe6..72b7156cc3 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java @@ -15,11 +15,16 @@ */ package org.springframework.data.jpa.repository.query; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; + import org.jspecify.annotations.Nullable; import org.springframework.data.domain.Range; import org.springframework.data.domain.Score; import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Similarity; import org.springframework.data.jpa.repository.query.JpaParameters.JpaParameter; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; @@ -71,11 +76,34 @@ protected Object potentiallyUnwrap(Object parameterValue) { return parameterValue; } + /** + * Returns the {@link ScoringFunction}. + * + * @return + */ public ScoringFunction getScoringFunction() { + return doWithScore(Score::getFunction, Score.class::isInstance, () -> ScoringFunction.UNSPECIFIED); + } + + /** + * Returns whether to normalize similarities (i.e. translate the database-specific score into {@link Similarity}). + * + * @return + */ + public boolean normalizeSimilarity() { + return doWithScore(it -> true, Similarity.class::isInstance, () -> false); + } + + /** + * Returns the {@link ScoringFunction}. + * + * @return + */ + public T doWithScore(Function function, Predicate scoreFilter, Supplier defaultValue) { Score score = getScore(); - if (score != null) { - return score.getFunction(); + if (score != null && scoreFilter.test(score)) { + return function.apply(score); } JpaParameters parameters = getParameters(); @@ -83,16 +111,19 @@ public ScoringFunction getScoringFunction() { Range range = getScoreRange(); - if (range.getUpperBound().isBounded()) { - return range.getUpperBound().getValue().get().getFunction(); + if (range != null && range.getLowerBound().isBounded() + && scoreFilter.test(range.getLowerBound().getValue().get())) { + return function.apply(range.getUpperBound().getValue().get()); } - if (range.getLowerBound().isBounded()) { - return range.getLowerBound().getValue().get().getFunction(); + if (range != null && range.getUpperBound().isBounded() + && scoreFilter.test(range.getUpperBound().getValue().get())) { + return function.apply(range.getUpperBound().getValue().get()); } + } - return ScoringFunction.UNSPECIFIED; + return defaultValue.get(); } } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java index cd4670f431..fc46b7ab7e 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java @@ -40,7 +40,6 @@ import org.springframework.data.domain.Range; import org.springframework.data.domain.Score; import org.springframework.data.domain.ScoringFunction; -import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Sort; import org.springframework.data.domain.VectorScoringFunctions; @@ -75,11 +74,14 @@ public class JpaQueryCreator extends AbstractQueryCreator implements JpqlQueryCreator { private static final Map DISTANCE_FUNCTIONS = Map.of(VectorScoringFunctions.COSINE, - new DistanceFunction("cosine_distance", Sort.Direction.ASC), VectorScoringFunctions.EUCLIDEAN, - new DistanceFunction("euclidean_distance", Sort.Direction.ASC), VectorScoringFunctions.TAXICAB, - new DistanceFunction("taxicab_distance", Sort.Direction.ASC), VectorScoringFunctions.HAMMING, - new DistanceFunction("hamming_distance", Sort.Direction.ASC), VectorScoringFunctions.INNER_PRODUCT, - new DistanceFunction("negative_inner_product", Sort.Direction.DESC)); + new DistanceFunction("cosine_distance", Sort.Direction.ASC), // + VectorScoringFunctions.EUCLIDEAN, new DistanceFunction("euclidean_distance", Sort.Direction.ASC), // + VectorScoringFunctions.TAXICAB, new DistanceFunction("taxicab_distance", Sort.Direction.ASC), // + VectorScoringFunctions.HAMMING, new DistanceFunction("hamming_distance", Sort.Direction.ASC), // + VectorScoringFunctions.INNER_PRODUCT, new DistanceFunction("negative_inner_product", Sort.Direction.ASC), // + + // TODO: Do we need both, dot and inner product? Aren't these the same in some sense? + VectorScoringFunctions.DOT, new DistanceFunction("negative_inner_product", Sort.Direction.ASC)); record DistanceFunction(String distanceFunction, Sort.Direction direction) { @@ -94,6 +96,7 @@ record DistanceFunction(String distanceFunction, Sort.Direction direction) { private final EntityType entityType; private final JpqlQueryBuilder.Entity entity; private final Metamodel metamodel; + private final SimilarityNormalizer similarityNormalizer; private final boolean useNamedParameters; /** @@ -147,6 +150,7 @@ public JpaQueryCreator(PartTree tree, boolean searchQuery, ReturnedType type, Pa this.entityType = metamodel.entity(type.getDomainType()); this.entity = JpqlQueryBuilder.entity(returnedType.getDomainType()); this.metamodel = metamodel; + this.similarityNormalizer = provider.getSimilarityNormalizer(); } Bindable getFrom() { @@ -405,7 +409,7 @@ JpqlQueryBuilder.Expression placeholder(ParameterBinding binding) { * @return */ private JpqlQueryBuilder.Predicate toPredicate(Part part) { - return new PredicateBuilder(part).build(); + return new PredicateBuilder(part, similarityNormalizer).build(); } /** @@ -413,21 +417,23 @@ private JpqlQueryBuilder.Predicate toPredicate(Part part) { * * @author Phil Webb * @author Oliver Gierke + * @author Mark Paluch */ private class PredicateBuilder { private final Part part; + private final SimilarityNormalizer normalizer; /** * Creates a new {@link PredicateBuilder} for the given {@link Part}. * * @param part must not be {@literal null}. + * @param normalizer must not be {@literal null}. */ - public PredicateBuilder(Part part) { - - Assert.notNull(part, "Part must not be null"); + public PredicateBuilder(Part part, SimilarityNormalizer normalizer) { this.part = part; + this.normalizer = normalizer; } /** @@ -537,24 +543,17 @@ public JpqlQueryBuilder.Predicate build() { JpqlQueryBuilder.Predicate lowerPredicate = null; JpqlQueryBuilder.Predicate upperPredicate = null; - if (lower.isBounded()) { - - JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder - .expression("" + lower.getValue().get().getValue()); - - where = JpqlQueryBuilder.where(distance); - lowerPredicate = lower.isInclusive() ? where.gte(distanceValue) : where.gt(distanceValue); + // Score is a distance function, you typically want less when you specify a lower boundary, + // therefore lower and upper predicates are inverted. + if (lower.isBounded()) { + JpqlQueryBuilder.Expression distanceValue = placeholder(provider.lower(within, normalizer)); + lowerPredicate = getUpperPredicate(lower.isInclusive(), distance, distanceValue); } if (upper.isBounded()) { - - JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder - .expression("" + upper.getValue().get().getValue()); - - where = JpqlQueryBuilder.where(distance); - - upperPredicate = upper.isInclusive() ? where.lte(distanceValue) : where.lt(distanceValue); + JpqlQueryBuilder.Expression distanceValue = placeholder(provider.upper(within, normalizer)); + upperPredicate = getLowerPredicate(upper.isInclusive(), distance, distanceValue); } if (lowerPredicate != null && upperPredicate != null) { @@ -570,12 +569,11 @@ public JpqlQueryBuilder.Predicate build() { if (within.getValue() instanceof Score score) { String distanceFunction = getDistanceFunction(score.getFunction()); - JpqlQueryBuilder.Expression distanceValue = placeholder(within); + JpqlQueryBuilder.Expression distanceValue = placeholder(provider.normalize(within, normalizer)); JpqlQueryBuilder.Expression distance = JpqlQueryBuilder.function(distanceFunction, pas, placeholder(vector)); - return score instanceof Similarity ? JpqlQueryBuilder.where(distance).lte(distanceValue) - : JpqlQueryBuilder.where(distance).gte(distanceValue); + return getUpperPredicate(true, distance, distanceValue); } default: @@ -583,6 +581,26 @@ public JpqlQueryBuilder.Predicate build() { } } + private JpqlQueryBuilder.Predicate getLowerPredicate(boolean inclusive, JpqlQueryBuilder.Expression lhs, + JpqlQueryBuilder.Expression distance) { + return doLower(inclusive, lhs, distance); + } + + private JpqlQueryBuilder.Predicate getUpperPredicate(boolean inclusive, JpqlQueryBuilder.Expression lhs, + JpqlQueryBuilder.Expression distance) { + return doUpper(inclusive, lhs, distance); + } + + private static JpqlQueryBuilder.Predicate doLower(boolean inclusive, JpqlQueryBuilder.Expression lhs, + JpqlQueryBuilder.Expression distance) { + return inclusive ? JpqlQueryBuilder.where(lhs).gte(distance) : JpqlQueryBuilder.where(lhs).gt(distance); + } + + private static JpqlQueryBuilder.Predicate doUpper(boolean inclusive, JpqlQueryBuilder.Expression lhs, + JpqlQueryBuilder.Expression distance) { + return inclusive ? JpqlQueryBuilder.where(lhs).lte(distance) : JpqlQueryBuilder.where(lhs).lt(distance); + } + private static String getDistanceFunction(ScoringFunction scoringFunction) { DistanceFunction distanceFunction = JpaQueryCreator.DISTANCE_FUNCTIONS.get(scoringFunction); diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java index 931c6d3ddb..b6ab61cc54 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java @@ -39,6 +39,7 @@ import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort; @@ -135,11 +136,17 @@ static class SearchResultExecution extends JpaQueryExecution { private final JpaQueryExecution delegate; private final ReturnedType returnedType; private final ScoringFunction function; + private final boolean normalizeSimilarity; + private final SimilarityNormalizer normalizer; + + SearchResultExecution(JpaQueryExecution delegate, ReturnedType returnedType, ScoringFunction function, + boolean normalizeSimilarity) { - SearchResultExecution(JpaQueryExecution delegate, ReturnedType returnedType, ScoringFunction function) { this.delegate = delegate; this.returnedType = returnedType; this.function = function; + this.normalizeSimilarity = normalizeSimilarity; + this.normalizer = normalizeSimilarity ? SimilarityNormalizer.get(function) : SimilarityNormalizer.IDENTITY; } @Override @@ -171,26 +178,31 @@ static class SearchResultExecution extends JpaQueryExecution { Object value = returnedType.needsCustomConstruction() ? t : t.get(0); try { - return new SearchResult<>(value, Score.of(t.get("distance", Number.class).doubleValue(), function)); + return new SearchResult<>(value, getScore(t.get("distance", Number.class).doubleValue())); } catch (RuntimeException e) { - return new SearchResult<>(value, Score.of(0, function)); + return new SearchResult<>(value, getScore(0)); } } if (result instanceof Object[] objects) { Object value = returnedType.needsCustomConstruction() ? objects : objects[0]; - try { - return new SearchResult<>(value, Score.of(((Number) (objects[objects.length - 1])).doubleValue(), function)); + return new SearchResult<>(value, getScore(((Number) (objects[objects.length - 1])).doubleValue())); } catch (RuntimeException e) { - return new SearchResult<>(value, Score.of(0, function)); + return new SearchResult<>(value, getScore(0)); } } return null; } + + private Score getScore(double score) { + return normalizeSimilarity ? Similarity.raw(normalizer.getSimilarity(score), function) + : Score.of(score, function); + } + } /** diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java index 90e90f14fb..ac5462175b 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java @@ -23,12 +23,12 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.function.Function; import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; import org.springframework.data.domain.Score; -import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Vector; import org.springframework.data.expression.ValueExpression; import org.springframework.data.jpa.provider.PersistenceProvider; @@ -164,10 +164,6 @@ public String toString() { */ public @Nullable Object prepare(@Nullable Object valueToBind) { - if (valueToBind instanceof Similarity similarity) { - return 1 - similarity.getValue(); - } - if (valueToBind instanceof Score score) { return score.getValue(); } @@ -328,6 +324,9 @@ public boolean isIsNullParameter() { return Collections.singleton(value); } + public String lower() { + return null; + } } /** @@ -566,6 +565,26 @@ default String getName() { default int getPosition() { throw new IllegalStateException("No position associated"); } + + /** + * Map the name of the binding to a new name using the given {@link Function} if the binding has a name. If the + * binding is not associated with a name, then the binding is returned unchanged. + * + * @param nameMapper must not be {@literal null}. + * @return the transformed {@link BindingIdentifier} if the binding has a name, otherwise the binding itself. + * @since 4.0 + */ + BindingIdentifier mapName(Function nameMapper); + + /** + * Associate a position with the binding. + * + * @param position + * @return the new binding identifier with the position. + * @since 4.0 + */ + BindingIdentifier withPosition(int position); + } private record Named(String name) implements BindingIdentifier { @@ -584,6 +603,16 @@ public String getName() { public String toString() { return name(); } + + @Override + public BindingIdentifier mapName(Function nameMapper) { + return new Named(nameMapper.apply(name())); + } + + @Override + public BindingIdentifier withPosition(int position) { + return new NamedAndIndexed(name, position); + } } private record Indexed(int position) implements BindingIdentifier { @@ -598,6 +627,16 @@ public int getPosition() { return position(); } + @Override + public BindingIdentifier mapName(Function nameMapper) { + return this; + } + + @Override + public BindingIdentifier withPosition(int position) { + return new Indexed(position); + } + @Override public String toString() { return "[" + position() + "]"; @@ -626,6 +665,16 @@ public int getPosition() { return position(); } + @Override + public BindingIdentifier mapName(Function nameMapper) { + return new NamedAndIndexed(nameMapper.apply(name), position); + } + + @Override + public BindingIdentifier withPosition(int position) { + return new NamedAndIndexed(name, position); + } + @Override public String toString() { return "[" + name() + ", " + position() + "]"; diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java index 96507ba1cb..e904216886 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java @@ -27,6 +27,8 @@ import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.Vector; import org.springframework.data.jpa.provider.PersistenceProvider; @@ -66,6 +68,7 @@ public class ParameterMetadataProvider { private final JpqlQueryTemplates templates; private final JpaParameters jpaParameters; private int position; + private int bindMarker; /** * Creates a new {@link ParameterMetadataProvider} from the given {@link CriteriaBuilder} and @@ -131,6 +134,18 @@ public List getBindings() { return bindings; } + /** + * @return the {@link SimilarityNormalizer}. + */ + SimilarityNormalizer getSimilarityNormalizer() { + + if (accessor != null && accessor.normalizeSimilarity()) { + return SimilarityNormalizer.get(accessor.getScoringFunction()); + } + + return SimilarityNormalizer.IDENTITY; + } + /** * Builds a new {@link PartTreeParameterBinding} for given {@link Part} and the next {@link Parameter}. */ @@ -179,13 +194,15 @@ private PartTreeParameterBinding next(Part part, Class type, Parameter pa Object value = bindableParameterValues == null ? PLACEHOLDER : bindableParameterValues.next(); int currentPosition = ++position; + int currentBindMarker = ++bindMarker; BindingIdentifier bindingIdentifier = parameter.getName().map(it -> BindingIdentifier.of(it, currentPosition)) .orElseGet(() -> BindingIdentifier.of(currentPosition)); /* identifier refers to bindable parameters, not _all_ parameters index */ - MethodInvocationArgument methodParameter = ParameterOrigin.ofParameter(bindingIdentifier); - PartTreeParameterBinding binding = new PartTreeParameterBinding(bindingIdentifier, methodParameter, reifiedType, + MethodInvocationArgument methodParameter = ParameterOrigin.ofParameter(BindingIdentifier.of(currentPosition)); + PartTreeParameterBinding binding = new PartTreeParameterBinding(BindingIdentifier.of(currentBindMarker), + methodParameter, reifiedType, part, value, templates, escape); // PartTreeParameterBinding is more expressive than a potential ParameterBinding for Vector. @@ -230,22 +247,6 @@ ParameterBinding getVectorBinding() { return parameterBinding; } - private void maybeAdd(ParameterBinding parameterBinding) { - - boolean found = false; - - for (ParameterBinding existing : bindings) { - - if (existing.isCompatibleWith(parameterBinding)) { - found = true; - } - } - - if (!found) { - bindings.add(parameterBinding); - } - } - EscapeCharacter getEscape() { return escape; } @@ -260,7 +261,7 @@ EscapeCharacter getEscape() { */ ParameterBinding nextSynthetic(String nameHint, Object value, Object source) { - int currentPosition = ++position; + int currentPosition = ++bindMarker; String bindingName = nameHint; if (!syntheticParameterNames.add(bindingName)) { @@ -273,4 +274,124 @@ ParameterBinding nextSynthetic(String nameHint, Object value, Object source) { ParameterOrigin.synthetic(value, source)); } + RangeParameterBinding lower(PartTreeParameterBinding within, SimilarityNormalizer normalizer) { + + int bindMarker = within.getRequiredPosition(); + + if (!bindings.remove(within)) { + bindMarker = ++this.bindMarker; + } + + BindingIdentifier identifier = within.getIdentifier(); + RangeParameterBinding rangeBinding = new RangeParameterBinding( + identifier.mapName(name -> name + "_upper").withPosition(bindMarker), within.getOrigin(), true, normalizer); + bindings.add(rangeBinding); + + return rangeBinding; + } + + RangeParameterBinding upper(PartTreeParameterBinding within, SimilarityNormalizer normalizer) { + + int bindMarker = within.getRequiredPosition(); + + if (!bindings.remove(within)) { + bindMarker = ++this.bindMarker; + } + + BindingIdentifier identifier = within.getIdentifier(); + RangeParameterBinding rangeBinding = new RangeParameterBinding( + identifier.mapName(name -> name + "_upper").withPosition(bindMarker), within.getOrigin(), false, normalizer); + bindings.add(rangeBinding); + + return rangeBinding; + } + + ScoreParameterBinding normalize(PartTreeParameterBinding within, SimilarityNormalizer normalizer) { + + bindings.remove(within); + + ScoreParameterBinding rangeBinding = new ScoreParameterBinding(within.getIdentifier(), within.getOrigin(), + normalizer); + bindings.add(rangeBinding); + + return rangeBinding; + } + + static class ScoreParameterBinding extends ParameterBinding { + + private final SimilarityNormalizer normalizer; + + /** + * Creates a new {@link ParameterBinding} for the parameter with the given identifier and origin. + * + * @param identifier of the parameter, must not be {@literal null}. + * @param origin the origin of the parameter (expression or method argument) + */ + ScoreParameterBinding(BindingIdentifier identifier, ParameterOrigin origin, SimilarityNormalizer normalizer) { + super(identifier, origin); + this.normalizer = normalizer; + } + + @Override + public @Nullable Object prepare(@Nullable Object valueToBind) { + + if (valueToBind instanceof Score score) { + return normalizer.getScore(score.getValue()); + } + + return super.prepare(valueToBind); + } + + @Override + public boolean isCompatibleWith(ParameterBinding binding) { + + if (super.isCompatibleWith(binding) && binding instanceof ScoreParameterBinding other) { + return normalizer == other.normalizer; + } + + return false; + } + } + + static class RangeParameterBinding extends ScoreParameterBinding { + + private final boolean lower; + + /** + * Creates a new {@link ParameterBinding} for the parameter with the given identifier and origin. + * + * @param identifier of the parameter, must not be {@literal null}. + * @param origin the origin of the parameter (expression or method argument) + */ + RangeParameterBinding(BindingIdentifier identifier, ParameterOrigin origin, boolean lower, + SimilarityNormalizer normalizer) { + super(identifier, origin, normalizer); + this.lower = lower; + } + + @Override + public @Nullable Object prepare(@Nullable Object valueToBind) { + + if (valueToBind instanceof Range r) { + if (lower) { + return super.prepare(r.getLowerBound().getValue().orElse(null)); + } else { + return super.prepare(r.getUpperBound().getValue().orElse(null)); + } + } + + return super.prepare(valueToBind); + } + + @Override + public boolean isCompatibleWith(ParameterBinding binding) { + + if (super.isCompatibleWith(binding) && binding instanceof RangeParameterBinding other) { + return lower == other.lower; + } + + return false; + } + } + } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java index 5990e1cb55..f04410ae51 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java @@ -296,14 +296,16 @@ protected JpqlQueryCreator createCreator(Sort sort, JpaParametersParameterAccess entityManager); } - JpqlQueryCreator creator = new CacheableJpqlQueryCreator(sort, - new JpaQueryCreator(tree, getQueryMethod().isSearchQuery(), returnedType, provider, templates, - em.getMetamodel())); - - if (accessor.getParameters().hasDynamicProjection() || getQueryMethod().isSearchQuery()) { - return creator; + JpaParameters parameters = getQueryMethod().getParameters(); + if (accessor.getParameters().hasDynamicProjection() || getQueryMethod().isSearchQuery() + || parameters.hasScoreRangeParameter() || parameters.hasScoreParameter()) { + return new JpaQueryCreator(tree, getQueryMethod().isSearchQuery(), returnedType, provider, templates, + em.getMetamodel()); } + JpqlQueryCreator creator = new CacheableJpqlQueryCreator(sort, new JpaQueryCreator(tree, + getQueryMethod().isSearchQuery(), returnedType, provider, templates, em.getMetamodel())); + cache.put(sort, accessor, creator); return creator; diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryParameterSetterFactory.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryParameterSetterFactory.java index 6d6196b8ef..b97c39da5b 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryParameterSetterFactory.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryParameterSetterFactory.java @@ -305,6 +305,10 @@ private PartTreeQueryParameterSetterFactory(JpaParameters parameters) { return super.create(binding, query); } + if (binding instanceof ParameterMetadataProvider.ScoreParameterBinding) { + return super.create(binding, query); + } + return null; } } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/SimilarityNormalizer.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/SimilarityNormalizer.java new file mode 100644 index 0000000000..ef27b43aab --- /dev/null +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/SimilarityNormalizer.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository.query; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.DoubleUnaryOperator; + +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.VectorScoringFunctions; + +/** + * Normalizes the score returned by a database to a similarity value and vice versa. + * + * @author Mark Paluch + * @since 4.0 + * @see org.springframework.data.domain.Similarity + */ +public class SimilarityNormalizer { + + /** + * Identity normalizer for {@link ScoringFunction#UNSPECIFIED} scoring function without altering the score. + */ + public static final SimilarityNormalizer IDENTITY = new SimilarityNormalizer(ScoringFunction.UNSPECIFIED, + DoubleUnaryOperator.identity(), DoubleUnaryOperator.identity()); + + /** + * Normalizer for Euclidean scores using {@code euclidean_distance(…)} as the scoring function. + */ + public static final SimilarityNormalizer EUCLIDEAN = new SimilarityNormalizer(VectorScoringFunctions.EUCLIDEAN, + it -> 1 / (1.0 + Math.pow(it, 2)), it -> it == 0 ? Float.MAX_VALUE : Math.sqrt((1 / it) - 1)); + + /** + * Normalizer for Cosine scores using {@code cosine_distance(…)} as the scoring function. + */ + public static final SimilarityNormalizer COSINE = new SimilarityNormalizer(VectorScoringFunctions.COSINE, + it -> (1.0 + (1 - it)) / 2.0, it -> 1 - ((it * 2) - 1)); + + /** + * Normalizer for Negative Inner Product (Dot) scores using {@code negative_inner_product(…)} as the scoring function. + */ + public static final SimilarityNormalizer DOT = new SimilarityNormalizer(VectorScoringFunctions.DOT, + it -> (1 - it) / 2, it -> 1 - (it * 2)); + + private static final Map NORMALIZERS = new HashMap<>(); + + static { + NORMALIZERS.put(EUCLIDEAN.scoringFunction, EUCLIDEAN); + NORMALIZERS.put(COSINE.scoringFunction, COSINE); + NORMALIZERS.put(DOT.scoringFunction, DOT); + NORMALIZERS.put(VectorScoringFunctions.INNER_PRODUCT, DOT); + } + + private final ScoringFunction scoringFunction; + private final DoubleUnaryOperator similarity; + private final DoubleUnaryOperator score; + + /** + * Constructor for {@link SimilarityNormalizer} using the given {@link DoubleUnaryOperator} for similarity and score + * computation. + * + * @param similarity compute the similarity from the underlying score returned by a database result. + * @param score compute the score value from a given {@link org.springframework.data.domain.Similarity} to compare + * against database results. + */ + SimilarityNormalizer(ScoringFunction scoringFunction, DoubleUnaryOperator similarity, DoubleUnaryOperator score) { + this.scoringFunction = scoringFunction; + this.score = score; + this.similarity = similarity; + } + + /** + * Lookup a {@link SimilarityNormalizer} for a given {@link ScoringFunction}. + * + * @param scoringFunction the scoring function to translate. + * @return the {@link SimilarityNormalizer} for the given {@link ScoringFunction}. + * @throws IllegalArgumentException if the {@link ScoringFunction} is not associated with a + * {@link SimilarityNormalizer}. + */ + public static SimilarityNormalizer get(ScoringFunction scoringFunction) { + + SimilarityNormalizer normalizer = NORMALIZERS.get(scoringFunction); + + if (normalizer == null) { + throw new IllegalArgumentException("No SimilarityNormalizer found for " + scoringFunction.getName()); + } + + return normalizer; + } + + /** + * @param score score value as returned by the database. + * @return the {@link org.springframework.data.domain.Similarity} value. + */ + public double getSimilarity(double score) { + return similarity.applyAsDouble(score); + } + + /** + * @param similarity similarity value as requested by the query mechanism. + * @return database score value. + */ + public double getScore(double similarity) { + return score.applyAsDouble(similarity); + } + + @Override + public String toString() { + return "%s Normalizer: Similarity[0 to 1] -> Score[%f to %f]".formatted(scoringFunction.getName(), getScore(0), + getScore(1)); + } + +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java index 600652261b..2b45f4d5b7 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java @@ -98,14 +98,14 @@ void setUp() { @MethodSource("scoringFunctions") void shouldApplyVectorSearchWithDistance(VectorScoringFunctions functions) { - SearchResults results = repository.searchTop2ByCountryAndEmbeddingWithin("de", VECTOR, - Similarity.of(0.1, functions)); + SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.of(0, functions)); - assertThat(results).hasSize(2).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) .containsOnly("de", "de"); assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) - .containsExactlyInAnyOrder("two", "one"); + .containsExactlyInAnyOrder("two", "one", "four"); } static Set scoringFunctions() { @@ -113,6 +113,36 @@ static Set scoringFunctions() { VectorScoringFunctions.EUCLIDEAN); } + @Test + void shouldNormalizeEuclideanSimilarity() { + + SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.of(0.99, VectorScoringFunctions.EUCLIDEAN)); + + assertThat(results).hasSize(1); + + SearchResult two = results.getContent().get(0); + + assertThat(two.getContent().getDescription()).isEqualTo("two"); + assertThat(two.getScore()).isInstanceOf(Similarity.class); + assertThat(two.getScore().getValue()).isGreaterThan(0.99); + } + + @Test + void shouldNormalizeCosineSimilarity() { + + SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.of(0.999, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(1); + + SearchResult two = results.getContent().get(0); + + assertThat(two.getContent().getDescription()).isEqualTo("two"); + assertThat(two.getScore()).isInstanceOf(Similarity.class); + assertThat(two.getScore().getValue()).isGreaterThan(0.99); + } + @Test void shouldRunStringQuery() { @@ -143,7 +173,7 @@ void shouldRunStringQueryWithDistance() { void shouldApplyVectorSearchWithRange() { SearchResults results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR, - Score.between(0, 1, VectorScoringFunctions.COSINE)); + Similarity.between(0, 1, VectorScoringFunctions.COSINE)); assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) .containsOnly("de", "de", "de"); @@ -155,18 +185,17 @@ void shouldApplyVectorSearchWithRange() { void shouldApplyVectorSearchAndReturnList() { List results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(0, VectorScoringFunctions.COSINE)); + Score.of(10, VectorScoringFunctions.COSINE)); assertThat(results).hasSize(3).extracting(WithVector::getCountry).containsOnly("de", "de", "de"); assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four"); - } @Test void shouldProjectVectorSearchAsInterface() { SearchResults results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de", - VECTOR, Score.of(0, VectorScoringFunctions.COSINE)); + VECTOR, Score.of(10, VectorScoringFunctions.COSINE)); assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) .containsSequence("two", "one", "four"); @@ -176,7 +205,7 @@ void shouldProjectVectorSearchAsInterface() { void shouldProjectVectorSearchAsDto() { SearchResults results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(0, VectorScoringFunctions.COSINE)); + Score.of(10, VectorScoringFunctions.COSINE)); assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) .containsSequence("two", "one", "four"); @@ -186,13 +215,13 @@ void shouldProjectVectorSearchAsDto() { void shouldProjectVectorSearchDynamically() { SearchResults dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(0, VectorScoringFunctions.COSINE), DescriptionDto.class); + Score.of(10, VectorScoringFunctions.COSINE), DescriptionDto.class); assertThat(dtos).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) .containsSequence("two", "one", "four"); SearchResults proxies = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(0, VectorScoringFunctions.COSINE), WithDescription.class); + Score.of(10, VectorScoringFunctions.COSINE), WithDescription.class); assertThat(proxies).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) .containsSequence("two", "one", "four"); @@ -248,6 +277,11 @@ public float[] getEmbedding() { public void setEmbedding(float[] embedding) { this.embedding = embedding; } + + @Override + public String toString() { + return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}'; + } } interface WithDescription { @@ -287,9 +321,9 @@ SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String coun Score distance); SearchResults searchAllByCountryAndEmbeddingWithin(String country, Vector embedding, - Range distance); + Range distance); - SearchResults searchTop2ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + SearchResults searchTop5ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); SearchResults searchInterfaceProjectionByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderIntegrationTests.java index 81e454c799..963d742dd1 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/ParameterMetadataProviderIntegrationTests.java @@ -26,6 +26,10 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; import org.springframework.data.jpa.domain.sample.User; import org.springframework.data.jpa.repository.support.JpqlQueryTemplates; import org.springframework.data.repository.query.Param; @@ -41,6 +45,7 @@ * * @author Oliver Gierke * @author Jens Schauder + * @author Mark Paluch * @soundtrack Elephants Crossing - We are (Irrelephant) */ @ExtendWith(SpringExtension.class) @@ -78,6 +83,52 @@ void doesNotApplyLikeExpansionOnNonStringProperties() throws Exception { assertThat(binding.prepare(1)).isEqualTo(1); } + @Test // GH- + void appliesScoreValuePreparation() throws Exception { + + ParameterMetadataProvider provider = createProvider( + Sample.class.getMethod("findByVectorWithin", Vector.class, Score.class)); + ParameterBinding.PartTreeParameterBinding vector = provider.next(new Part("VectorWithin", WithVector.class)); + ParameterBinding.PartTreeParameterBinding score = provider.next(new Part("VectorWithin", WithVector.class)); + ParameterMetadataProvider.ScoreParameterBinding binding = provider.normalize(score, SimilarityNormalizer.EUCLIDEAN); + + assertThat(binding.prepare(Score.of(1))).isEqualTo(0.0); + assertThat(binding.prepare(Score.of(0.5))).isEqualTo(1.0); + assertThat(provider.getBindings()).hasSize(2).contains(binding).doesNotContain(score); + } + + @Test // GH- + void appliesLowerRangeValuePreparation() throws Exception { + + ParameterMetadataProvider provider = createProvider( + Sample.class.getMethod("findByVectorWithin", Vector.class, Range.class)); + ParameterBinding.PartTreeParameterBinding vector = provider.next(new Part("VectorWithin", WithVector.class)); + ParameterBinding.PartTreeParameterBinding score = provider.next(new Part("VectorWithin", WithVector.class)); + ParameterMetadataProvider.ScoreParameterBinding lower = provider.lower(score, SimilarityNormalizer.EUCLIDEAN); + + Range range = Similarity.between(0.5, 1); + + assertThat(lower.prepare(range)).isEqualTo(1.0); + assertThat(provider.getBindings()).hasSize(2).contains(lower).doesNotContain(score); + } + + @Test // GH- + void appliesRangeValuePreparation() throws Exception { + + ParameterMetadataProvider provider = createProvider( + Sample.class.getMethod("findByVectorWithin", Vector.class, Range.class)); + ParameterBinding.PartTreeParameterBinding vector = provider.next(new Part("VectorWithin", WithVector.class)); + ParameterBinding.PartTreeParameterBinding score = provider.next(new Part("VectorWithin", WithVector.class)); + ParameterMetadataProvider.ScoreParameterBinding lower = provider.lower(score, SimilarityNormalizer.EUCLIDEAN); + ParameterMetadataProvider.ScoreParameterBinding upper = provider.upper(score, SimilarityNormalizer.EUCLIDEAN); + + Range range = Similarity.between(0.5, 1); + + assertThat(lower.prepare(range)).isEqualTo(1.0); + assertThat(upper.prepare(range)).isEqualTo(0.0); + assertThat(provider.getBindings()).hasSize(3).contains(lower, upper).doesNotContain(score); + } + private ParameterMetadataProvider createProvider(Method method) { JpaParameters parameters = new JpaParameters(ParametersSource.of(method)); @@ -102,5 +153,13 @@ interface Sample { User findByLastname(String lastname); User findByAgeContaining(@Param("age") Integer age); + + User findByVectorWithin(Vector vector, Score score); + + User findByVectorWithin(Vector vector, Range score); + } + + static class WithVector { + Vector vector; } } diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java new file mode 100644 index 0000000000..60485dffc8 --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository.query; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link SimilarityNormalizer}. + * + * @author Mark Paluch + */ +class SimilarityNormalizerUnitTests { + + @Test + void normalizesEuclidean() { + + assertThat(SimilarityNormalizer.EUCLIDEAN.getSimilarity(0)).isCloseTo(1.0, offset(0.01)); + assertThat(SimilarityNormalizer.EUCLIDEAN.getSimilarity(0.223606791085977)).isCloseTo(0.9523810148239136, + offset(0.01)); + assertThat(SimilarityNormalizer.EUCLIDEAN.getSimilarity(1.1618950141221271)).isCloseTo(0.42553189396858215, + offset(0.01)); + + assertThat(SimilarityNormalizer.EUCLIDEAN.getScore(1.0)).isCloseTo(0.0, offset(0.01)); + assertThat(SimilarityNormalizer.EUCLIDEAN.getScore(0.9523810148239136)).isCloseTo(0.223606791085977, offset(0.01)); + assertThat(SimilarityNormalizer.EUCLIDEAN.getScore(0.42553189396858215)).isCloseTo(1.1618950141221271, + offset(0.01)); + } + + @Test + void normalizesCosine() { + + assertThat(SimilarityNormalizer.COSINE.getSimilarity(0)).isCloseTo(1.0, offset(0.01)); + assertThat(SimilarityNormalizer.COSINE.getSimilarity(0.004470301418728173)).isCloseTo(0.9977648258209229, + offset(0.01)); + assertThat(SimilarityNormalizer.COSINE.getSimilarity(0.05568200370295473)).isCloseTo(0.9721590280532837, + offset(0.01)); + + assertThat(SimilarityNormalizer.COSINE.getScore(1.0)).isCloseTo(0.0, offset(0.01)); + assertThat(SimilarityNormalizer.COSINE.getScore(0.9977648258209229)).isCloseTo(0.004470301418728173, offset(0.01)); + assertThat(SimilarityNormalizer.COSINE.getScore(0.9721590280532837)).isCloseTo(0.05568200370295473, offset(0.01)); + } + + @Test + void normalizesNegativeInnerProduct() { + + assertThat(SimilarityNormalizer.DOT.getSimilarity(-0.8465620279312134)).isCloseTo(0.9232810139656067, offset(0.01)); + assertThat(SimilarityNormalizer.DOT.getSimilarity(-1.0626180171966553)).isCloseTo(1.0313090085983276, offset(0.01)); + assertThat(SimilarityNormalizer.DOT.getSimilarity(-2.0293400287628174)).isCloseTo(1.5146700143814087, offset(0.01)); + + assertThat(SimilarityNormalizer.DOT.getScore(0.9232810139656067)).isCloseTo(-0.8465620279312134, offset(0.01)); + assertThat(SimilarityNormalizer.DOT.getScore(1.0313090085983276)).isCloseTo(-1.0626180171966553, offset(0.01)); + assertThat(SimilarityNormalizer.DOT.getScore(1.5146700143814087)).isCloseTo(-2.0293400287628174, offset(0.01)); + } + +} From 72f47675e9f8502dfb2d313c0bef5e834913aee5 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 29 Apr 2025 17:06:27 +0200 Subject: [PATCH 5/6] Documentation. --- .../query/JpaParametersParameterAccessor.java | 2 +- .../jpa/repository/query/JpaQueryCreator.java | 11 +++---- .../repository/query/JpqlQueryBuilder.java | 2 +- .../query/ParameterMetadataProvider.java | 2 +- .../query/SimilarityNormalizer.java | 9 +++-- .../repository/PgVectorIntegrationTests.java | 26 ++++++++++++++- .../query/SimilarityNormalizerUnitTests.java | 18 ++++++---- src/main/antora/modules/ROOT/nav.adoc | 1 + .../pages/repositories/vector-search.adoc | 8 +++++ .../partials/vector-search-intro-include.adoc | 31 +++++++++++++++++ ...ector-search-method-annotated-include.adoc | 29 ++++++++++++++++ .../vector-search-method-derived-include.adoc | 16 +++++++++ .../partials/vector-search-model-include.adoc | 18 ++++++++++ .../vector-search-repository-include.adoc | 22 +++++++++++++ .../vector-search-scoring-include.adoc | 33 +++++++++++++++++++ 15 files changed, 207 insertions(+), 21 deletions(-) create mode 100644 src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java index 72b7156cc3..e77ab25c6e 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java @@ -82,7 +82,7 @@ protected Object potentiallyUnwrap(Object parameterValue) { * @return */ public ScoringFunction getScoringFunction() { - return doWithScore(Score::getFunction, Score.class::isInstance, () -> ScoringFunction.UNSPECIFIED); + return doWithScore(Score::getFunction, Score.class::isInstance, ScoringFunction::unspecified); } /** diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java index fc46b7ab7e..76f84f0ec3 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java @@ -78,10 +78,7 @@ public class JpaQueryCreator extends AbstractQueryCreator (1 - it) / 2, it -> 1 - (it * 2)); private static final Map NORMALIZERS = new HashMap<>(); @@ -60,8 +60,7 @@ public class SimilarityNormalizer { static { NORMALIZERS.put(EUCLIDEAN.scoringFunction, EUCLIDEAN); NORMALIZERS.put(COSINE.scoringFunction, COSINE); - NORMALIZERS.put(DOT.scoringFunction, DOT); - NORMALIZERS.put(VectorScoringFunctions.INNER_PRODUCT, DOT); + NORMALIZERS.put(DOT_PRODUCT.scoringFunction, DOT_PRODUCT); } private final ScoringFunction scoringFunction; diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java index 2b45f4d5b7..69e4458c0a 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java @@ -49,6 +49,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.data.domain.Range; import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Similarity; @@ -109,7 +110,7 @@ void shouldApplyVectorSearchWithDistance(VectorScoringFunctions functions) { } static Set scoringFunctions() { - return EnumSet.of(VectorScoringFunctions.COSINE, VectorScoringFunctions.INNER_PRODUCT, + return EnumSet.of(VectorScoringFunctions.COSINE, VectorScoringFunctions.DOT_PRODUCT, VectorScoringFunctions.EUCLIDEAN); } @@ -169,6 +170,21 @@ void shouldRunStringQueryWithDistance() { assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE); } + @Test + void shouldRunStringQueryWithFloatDistance() { + + SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, 2); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de", "de"); + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsSequence("two", "one", "four"); + + SearchResult result = results.getContent().get(0); + assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0); + assertThat(result.getScore().getFunction()).isEqualTo(ScoringFunction.unspecified()); + } + @Test void shouldApplyVectorSearchWithRange() { @@ -320,6 +336,14 @@ AND cosine_distance(w.embedding, :embedding) <= :distance SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + @Query(""" + SELECT w, cosine_distance(w.embedding, :embedding) as distance FROM org.springframework.data.jpa.repository.PgVectorIntegrationTests$WithVector w + WHERE w.country = ?1 + AND cosine_distance(w.embedding, :embedding) <= :distance + ORDER BY distance asc""") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + float distance); + SearchResults searchAllByCountryAndEmbeddingWithin(String country, Vector embedding, Range distance); diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java index 60485dffc8..37f08ef12b 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/SimilarityNormalizerUnitTests.java @@ -58,13 +58,19 @@ void normalizesCosine() { @Test void normalizesNegativeInnerProduct() { - assertThat(SimilarityNormalizer.DOT.getSimilarity(-0.8465620279312134)).isCloseTo(0.9232810139656067, offset(0.01)); - assertThat(SimilarityNormalizer.DOT.getSimilarity(-1.0626180171966553)).isCloseTo(1.0313090085983276, offset(0.01)); - assertThat(SimilarityNormalizer.DOT.getSimilarity(-2.0293400287628174)).isCloseTo(1.5146700143814087, offset(0.01)); + assertThat(SimilarityNormalizer.DOT_PRODUCT.getSimilarity(-0.8465620279312134)).isCloseTo(0.9232810139656067, + offset(0.01)); + assertThat(SimilarityNormalizer.DOT_PRODUCT.getSimilarity(-1.0626180171966553)).isCloseTo(1.0313090085983276, + offset(0.01)); + assertThat(SimilarityNormalizer.DOT_PRODUCT.getSimilarity(-2.0293400287628174)).isCloseTo(1.5146700143814087, + offset(0.01)); - assertThat(SimilarityNormalizer.DOT.getScore(0.9232810139656067)).isCloseTo(-0.8465620279312134, offset(0.01)); - assertThat(SimilarityNormalizer.DOT.getScore(1.0313090085983276)).isCloseTo(-1.0626180171966553, offset(0.01)); - assertThat(SimilarityNormalizer.DOT.getScore(1.5146700143814087)).isCloseTo(-2.0293400287628174, offset(0.01)); + assertThat(SimilarityNormalizer.DOT_PRODUCT.getScore(0.9232810139656067)).isCloseTo(-0.8465620279312134, + offset(0.01)); + assertThat(SimilarityNormalizer.DOT_PRODUCT.getScore(1.0313090085983276)).isCloseTo(-1.0626180171966553, + offset(0.01)); + assertThat(SimilarityNormalizer.DOT_PRODUCT.getScore(1.5146700143814087)).isCloseTo(-2.0293400287628174, + offset(0.01)); } } diff --git a/src/main/antora/modules/ROOT/nav.adoc b/src/main/antora/modules/ROOT/nav.adoc index 351c162366..126f33c4af 100644 --- a/src/main/antora/modules/ROOT/nav.adoc +++ b/src/main/antora/modules/ROOT/nav.adoc @@ -14,6 +14,7 @@ ** xref:jpa/stored-procedures.adoc[] ** xref:jpa/specifications.adoc[] ** xref:repositories/query-by-example.adoc[] +** xref:repositories/vector-search.adoc[] ** xref:jpa/transactions.adoc[] ** xref:jpa/locking.adoc[] ** xref:auditing.adoc[] diff --git a/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc new file mode 100644 index 0000000000..f33e2ad4d3 --- /dev/null +++ b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc @@ -0,0 +1,8 @@ +:vector-search-intro-include: data-jpa::partial$vector-search-intro-include.adoc +:vector-search-model-include: data-jpa::partial$vector-search-model-include.adoc +:vector-search-repository-include: data-jpa::partial$vector-search-repository-include.adoc +:vector-search-scoring-include: data-jpa::partial$vector-search-scoring-include.adoc +:vector-search-method-derived-include: data-jpa::partial$vector-search-method-derived-include.adoc +:vector-search-method-annotated-include: data-jpa::partial$vector-search-method-annotated-include.adoc + +include::{commons}@data-commons::page$repositories/vector-search.adoc[] diff --git a/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc new file mode 100644 index 0000000000..6ed3115489 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc @@ -0,0 +1,31 @@ +To use Hibernate Vector Search, you need to add the following dependencies to your project. + +The following example shows how to set up dependencies in Maven and Gradle: + +[tabs] +====== +Maven:: ++ +[source,xml,indent=0,subs="verbatim,quotes",role="primary"] +---- + + + org.hibernate.orm + hibernate-vector + ${hibernate.version} + + + +---- + +Gradle:: ++ +==== +[source,groovy,indent=0,subs="verbatim,quotes",role="secondary"] +---- +dependencies { + implementation 'org.hibernate.orm:hibernate-vector:${hibernateVersion}' +} +---- +==== +====== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc new file mode 100644 index 0000000000..e713cdbaf1 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc @@ -0,0 +1,29 @@ +Annotated search methods must define the entire JPQL query to run a Vector Search. + +.Using `@Query` Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @Query(""" + SELECT c, cosine_distance(c.embedding, :embedding) as distance FROM Comment c + WHERE c.country = ?1 + AND cosine_distance(c.embedding, :embedding) <= :distance + ORDER BY distance asc""") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + + @Query(""" + SELECT c FROM Comment c + WHERE c.country = ?1 + AND cosine_distance(c.embedding, :embedding) <= :distance + ORDER BY cosine_distance(c.embedding, :embedding) asc""") + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + +} +---- +==== + +Vector Search methods are not required to include a score or distance in their projection. +When using annotated search methods returning `SearchResults`, the execution mechanism assumes that if a second projection column is present that this one holds the score value. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc new file mode 100644 index 0000000000..3a24393f8e --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc @@ -0,0 +1,16 @@ +.Using `Near` and `Within` Keywords in Repository Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + SearchResults searchByEmbeddingNear(Vector vector, Score score); + + SearchResults searchByEmbeddingWithin(Vector vector, Range range); + + SearchResults searchByCountryAndEmbeddingWithin(String country, Vector vector, Range range); +} +---- +==== + +Derived Search Methods can define domain model attributes and Vector parameters. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc new file mode 100644 index 0000000000..a6966630c2 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc @@ -0,0 +1,18 @@ +==== +[source,java] +---- +class Comment { + + @Id String id; + String country; + String comment; + + @Column(name = "the_embedding") + @JdbcTypeCode(SqlTypes.VECTOR) + @Array(length = 5) + Vector embedding; + + // getters, setters, … +} +---- +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc new file mode 100644 index 0000000000..62e900efba --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc @@ -0,0 +1,22 @@ +.Using `SearchResult` in a Repository Search Method +==== +[source,java] +---- +interface CommentRepository extends Repository { + + SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score distance, + Limit limit); + + @Query(""" + SELECT c, cosine_distance(c.embedding, :embedding) as distance FROM Comment c + WHERE c.country = ?1 + AND cosine_distance(c.embedding, :embedding) <= :distance + ORDER BY distance asc""") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + +} + +SearchResults results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10)); +---- +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc new file mode 100644 index 0000000000..11a8fd289d --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc @@ -0,0 +1,33 @@ +Hibernate translates distance function calls to native database functions for PGvector and Oracle. +Their result is typically a distance. +When using `Similarity` instead of `Score`, Spring Data normalizes distance scores into a similarity score between 0 and 1. The higher the score, the more similar the two vectors are. +// END + +.Using `Score` and `Similarity` in a Repository Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + SearchResults searchByEmbeddingNear(Vector vector, Score score); + + SearchResults searchByEmbeddingNear(Vector vector, Similarity similarity); + + SearchResults searchByEmbeddingNear(Vector vector, Range range); +} + +repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9, ScoringFunction.cosine())); <1> + +repository.searchByEmbeddingNear(Vector.of(…), Similarity.of(0.9, ScoringFunction.cosine())); <2> + +repository.searchByEmbeddingNear(Vector.of(…), Similarity.between(0.5, 1, ScoringFunction.euclidean()));<3> +---- + +<1> Run a search and return results with a score of `0.9` or smaller using the Cosine distance. +<2> Run a search and normalize the score into a similarity value. +Return results with a similarity of `0.9` or greater using Cosine scoring. +<3> Run a search and normalize the score into a similarity value. +Return results with a similarity of between `0.5` and `1.0` or greater using Euclidean scoring. +==== + +NOTE: JPA requires a `ScoringFunction` to be provided when creating `Score` or `Similarity` instances to select a scoring function. From 8ff4c0fa13fbe8b8cb832cc2fb3163d0ef810560 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 30 Apr 2025 09:43:32 +0200 Subject: [PATCH 6/6] Add Oracle integration tests. --- pom.xml | 14 + spring-data-jpa/pom.xml | 23 ++ .../data/jpa/convert/VectorConverters.java | 61 ---- .../jpa/repository/query/JpaQueryCreator.java | 21 +- .../query/ParameterMetadataProvider.java | 25 +- .../AbstractVectorIntegrationTests.java | 342 ++++++++++++++++++ .../OracleVectorIntegrationTests.java | 95 +++++ .../repository/PgVectorIntegrationTests.java | 326 +---------------- .../MySqlStoredProcedureIntegrationTests.java | 24 +- ...stgresStoredProcedureIntegrationTests.java | 24 ++ ...ProcedureNullHandlingIntegrationTests.java | 24 +- .../scripts/oracle-vector-initialize.sql | 11 + .../test/resources/scripts/oracle-vector.sql | 16 + .../partials/vector-search-intro-include.adoc | 3 +- ...ector-search-method-annotated-include.adoc | 1 - .../vector-search-method-derived-include.adoc | 2 +- .../vector-search-repository-include.adoc | 1 - .../vector-search-scoring-include.adoc | 17 +- 18 files changed, 610 insertions(+), 420 deletions(-) delete mode 100644 spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java create mode 100644 spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java create mode 100644 spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/OracleVectorIntegrationTests.java create mode 100644 spring-data-jpa/src/test/resources/scripts/oracle-vector-initialize.sql create mode 100644 spring-data-jpa/src/test/resources/scripts/oracle-vector.sql diff --git a/pom.xml b/pom.xml index f5026ee712..b2b7e164ff 100755 --- a/pom.xml +++ b/pom.xml @@ -38,6 +38,7 @@ 5.0 9.1.0 42.7.4 + 23.7.0.25.01 4.0.0-SEARCH-RESULT-SNAPSHOT 0.10.3 @@ -120,6 +121,19 @@ + + oracle-test + test + + test + + + + **/Oracle*IntegrationTests.java + + + + diff --git a/spring-data-jpa/pom.xml b/spring-data-jpa/pom.xml index 757087ec2b..b2c0a730d4 100644 --- a/spring-data-jpa/pom.xml +++ b/spring-data-jpa/pom.xml @@ -167,6 +167,28 @@ test + + + + com.oracle.database.jdbc + ojdbc17 + ${oracle} + test + + + + com.oracle.database.jdbc + ucp17 + ${oracle} + test + + + + org.testcontainers + oracle-free + test + + io.vavr vavr @@ -331,6 +353,7 @@ **/EclipseLink* **/MySql* **/Postgres* + **/Oracle* -Xmx4G diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java deleted file mode 100644 index d6cf432340..0000000000 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/convert/VectorConverters.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.data.jpa.convert; - -import jakarta.persistence.AttributeConverter; -import jakarta.persistence.Converter; - -import org.jspecify.annotations.Nullable; - -import org.springframework.data.domain.Vector; - -/** - * JPA {@link Converter} for {@link Vector} types. - * - * @author Mark Paluch - * @since 4.0 - */ -public class VectorConverters { - - @Converter(autoApply = true) - public static class VectorAsFloatArrayConverter implements AttributeConverter<@Nullable Vector, @Nullable float[]> { - - @Override - public @Nullable float[] convertToDatabaseColumn(@Nullable Vector vector) { - return vector == null ? null : vector.toFloatArray(); - } - - @Override - public @Nullable Vector convertToEntityAttribute(@Nullable float[] floats) { - return floats == null ? null : Vector.of(floats); - } - } - - @Converter(autoApply = true) - public static class VectorAsDoubleArrayConverter implements AttributeConverter<@Nullable Vector, @Nullable double[]> { - - @Override - public @Nullable double[] convertToDatabaseColumn(@Nullable Vector vector) { - return vector == null ? null : vector.toDoubleArray(); - } - - @Override - public @Nullable Vector convertToEntityAttribute(@Nullable double[] doubles) { - return doubles == null ? null : Vector.of(doubles); - } - } - -} diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java index 76f84f0ec3..f6cda83389 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java @@ -37,13 +37,12 @@ import org.jspecify.annotations.Nullable; +import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.domain.Range; import org.springframework.data.domain.Score; import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.Sort; import org.springframework.data.domain.VectorScoringFunctions; - -import org.springframework.data.domain.Sort; import org.springframework.data.jpa.domain.JpaSort; import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.ParameterPlaceholder; import org.springframework.data.jpa.repository.query.ParameterBinding.PartTreeParameterBinding; @@ -71,7 +70,8 @@ * @author Christoph Strobl * @author Jinmyeong Kim */ -public class JpaQueryCreator extends AbstractQueryCreator implements JpqlQueryCreator { +public class JpaQueryCreator extends AbstractQueryCreator + implements JpqlQueryCreator { private static final Map DISTANCE_FUNCTIONS = Map.of(VectorScoringFunctions.COSINE, new DistanceFunction("cosine_distance", Sort.Direction.ASC), // @@ -111,13 +111,12 @@ public JpaQueryCreator(PartTree tree, ReturnedType type, ParameterMetadataProvid } public JpaQueryCreator(PartTree tree, ReturnedType type, ParameterMetadataProvider provider, - JpqlQueryTemplates templates, - Metamodel metamodel) { + JpqlQueryTemplates templates, Metamodel metamodel) { this(tree, false, type, provider, templates, metamodel); } public JpaQueryCreator(PartTree tree, boolean searchQuery, ReturnedType type, ParameterMetadataProvider provider, - JpqlQueryTemplates templates, Metamodel metamodel) { + JpqlQueryTemplates templates, Metamodel metamodel) { super(tree); @@ -488,11 +487,10 @@ public JpqlQueryBuilder.Predicate build() { PartTreeParameterBinding parameter = provider.next(part, String.class); JpqlQueryBuilder.Expression parameterExpression = potentiallyIgnoreCase(part.getProperty(), placeholder(parameter)); + // Predicate like = builder.like(propertyExpression, parameterExpression, escape.getEscapeCharacter()); String escapeChar = Character.toString(escape.getEscapeCharacter()); - return - - type.equals(NOT_LIKE) || type.equals(NOT_CONTAINING) + return type.equals(NOT_LIKE) || type.equals(NOT_CONTAINING) ? whereIgnoreCase.notLike(parameterExpression, escapeChar) : whereIgnoreCase.like(parameterExpression, escapeChar); case TRUE: @@ -519,7 +517,6 @@ public JpqlQueryBuilder.Predicate build() { where = JpqlQueryBuilder.where(entity, property); return type.equals(IS_NOT_EMPTY) ? where.isNotEmpty() : where.isEmpty(); - case WITHIN: case NEAR: PartTreeParameterBinding vector = provider.next(part); @@ -527,7 +524,7 @@ public JpqlQueryBuilder.Predicate build() { if (within.getValue() instanceof Range r) { - Range range = (Range) within.getValue(); + Range range = (Range) r; if (range.getUpperBound().isBounded() || range.getUpperBound().isBounded()) { @@ -573,6 +570,8 @@ public JpqlQueryBuilder.Predicate build() { return getUpperPredicate(true, distance, distanceValue); } + throw new InvalidDataAccessApiUsageException( + "Near/Within keywords must be used with a Score or Range type"); default: throw new IllegalArgumentException("Unsupported keyword " + type); } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java index ddfc7505fc..b1c08fc58c 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java @@ -78,8 +78,8 @@ public class ParameterMetadataProvider { * @param escape must not be {@literal null}. * @param templates must not be {@literal null}. */ - public ParameterMetadataProvider(JpaParametersParameterAccessor accessor, - EscapeCharacter escape, JpqlQueryTemplates templates) { + public ParameterMetadataProvider(JpaParametersParameterAccessor accessor, EscapeCharacter escape, + JpqlQueryTemplates templates) { this(accessor.iterator(), accessor, accessor.getParameters(), escape, templates); } @@ -91,8 +91,7 @@ public ParameterMetadataProvider(JpaParametersParameterAccessor accessor, * @param escape must not be {@literal null}. * @param templates must not be {@literal null}. */ - public ParameterMetadataProvider(JpaParameters parameters, EscapeCharacter escape, - JpqlQueryTemplates templates) { + public ParameterMetadataProvider(JpaParameters parameters, EscapeCharacter escape, JpqlQueryTemplates templates) { this(null, null, parameters, escape, templates); } @@ -106,8 +105,8 @@ public ParameterMetadataProvider(JpaParameters parameters, EscapeCharacter escap * @param templates must not be {@literal null}. */ private ParameterMetadataProvider(@Nullable Iterator bindableParameterValues, - @Nullable JpaParametersParameterAccessor accessor, JpaParameters parameters, - EscapeCharacter escape, JpqlQueryTemplates templates) { + @Nullable JpaParametersParameterAccessor accessor, JpaParameters parameters, EscapeCharacter escape, + JpqlQueryTemplates templates) { Assert.notNull(parameters, "Parameters must not be null"); Assert.notNull(escape, "EscapeCharacter must not be null"); Assert.notNull(templates, "JpqlQueryTemplates must not be null"); @@ -196,14 +195,16 @@ private PartTreeParameterBinding next(Part part, Class type, Parameter pa int currentPosition = ++position; int currentBindMarker = ++bindMarker; - BindingIdentifier bindingIdentifier = parameter.getName().map(it -> BindingIdentifier.of(it, currentPosition)) + BindingIdentifier bindingIdentifier = parameter.getName().map(it -> BindingIdentifier.of(it, currentBindMarker)) + .orElseGet(() -> BindingIdentifier.of(currentBindMarker)); + + BindingIdentifier origin = parameter.getName().map(it -> BindingIdentifier.of(it, currentPosition)) .orElseGet(() -> BindingIdentifier.of(currentPosition)); /* identifier refers to bindable parameters, not _all_ parameters index */ - MethodInvocationArgument methodParameter = ParameterOrigin.ofParameter(BindingIdentifier.of(currentPosition)); - PartTreeParameterBinding binding = new PartTreeParameterBinding(BindingIdentifier.of(currentBindMarker), - methodParameter, reifiedType, - part, value, templates, escape); + MethodInvocationArgument methodParameter = ParameterOrigin.ofParameter(origin); + PartTreeParameterBinding binding = new PartTreeParameterBinding(bindingIdentifier, + methodParameter, reifiedType, part, value, templates, escape); // PartTreeParameterBinding is more expressive than a potential ParameterBinding for Vector. bindings.add(binding); @@ -284,7 +285,7 @@ RangeParameterBinding lower(PartTreeParameterBinding within, SimilarityNormalize BindingIdentifier identifier = within.getIdentifier(); RangeParameterBinding rangeBinding = new RangeParameterBinding( - identifier.mapName(name -> name + "_upper").withPosition(bindMarker), within.getOrigin(), true, normalizer); + identifier.mapName(name -> name + "_lower").withPosition(bindMarker), within.getOrigin(), true, normalizer); bindings.add(rangeBinding); return rangeBinding; diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java new file mode 100644 index 0000000000..f4c334d39a --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java @@ -0,0 +1,342 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository; + +import static org.assertj.core.api.Assertions.*; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.Table; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; +import org.springframework.data.domain.VectorScoringFunctions; +import org.springframework.test.annotation.Rollback; +import org.springframework.transaction.annotation.Transactional; + +/** + * Testcase to verify Vector Search work with Hibernate. + * + * @author Mark Paluch + */ +@Transactional +@Rollback(value = false) +abstract class AbstractVectorIntegrationTests { + + Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f); + + @Autowired VectorSearchRepository repository; + + @BeforeEach + void setUp() { + + WithVector w1 = new WithVector("de", "one", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f }); + WithVector w2 = new WithVector("de", "two", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f }); + WithVector w3 = new WithVector("en", "three", new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f }); + WithVector w4 = new WithVector("de", "four", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f }); + + repository.deleteAllInBatch(); + repository.saveAllAndFlush(Arrays.asList(w1, w2, w3, w4)); + } + + @ParameterizedTest + @MethodSource("scoringFunctions") + void shouldApplyVectorSearchWithDistance(VectorScoringFunctions functions) { + + SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.of(0, functions)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de"); + + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsExactlyInAnyOrder("two", "one", "four"); + } + + static Set scoringFunctions() { + return EnumSet.of(VectorScoringFunctions.COSINE, VectorScoringFunctions.DOT_PRODUCT, + VectorScoringFunctions.EUCLIDEAN); + } + + @Test + void shouldNormalizeEuclideanSimilarity() { + + SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.of(0.99, VectorScoringFunctions.EUCLIDEAN)); + + assertThat(results).hasSize(1); + + SearchResult two = results.getContent().get(0); + + assertThat(two.getContent().getDescription()).isEqualTo("two"); + assertThat(two.getScore()).isInstanceOf(Similarity.class); + assertThat(two.getScore().getValue()).isGreaterThan(0.99); + } + + @Test + void shouldNormalizeCosineSimilarity() { + + SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.of(0.999, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(1); + + SearchResult two = results.getContent().get(0); + + assertThat(two.getContent().getDescription()).isEqualTo("two"); + assertThat(two.getScore()).isInstanceOf(Similarity.class); + assertThat(two.getScore().getValue()).isGreaterThan(0.99); + } + + @Test + void shouldRunStringQuery() { + + List results = repository.findAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(2, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(WithVector::getCountry).containsOnly("de", "de", "de"); + assertThat(results).extracting(WithVector::getDescription).containsSequence("two", "one", "four"); + } + + @Test + void shouldRunStringQueryWithDistance() { + + SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(2, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de", "de"); + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsSequence("two", "one", "four"); + + SearchResult result = results.getContent().get(0); + assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0); + assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE); + } + + @Test + void shouldRunStringQueryWithFloatDistance() { + + SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, 2); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de", "de"); + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsSequence("two", "one", "four"); + + SearchResult result = results.getContent().get(0); + assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0); + assertThat(result.getScore().getFunction()).isEqualTo(ScoringFunction.unspecified()); + } + + @Test + void shouldApplyVectorSearchWithRange() { + + SearchResults results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.between(0, 1, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) + .containsOnly("de", "de", "de"); + assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) + .containsSequence("two", "one", "four"); + } + + @Test + void shouldApplyVectorSearchAndReturnList() { + + List results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(10, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(WithVector::getCountry).containsOnly("de", "de", "de"); + assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four"); + } + + @Test + void shouldProjectVectorSearchAsInterface() { + + SearchResults results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de", + VECTOR, Score.of(10, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) + .containsSequence("two", "one", "four"); + } + + @Test + void shouldProjectVectorSearchAsDto() { + + SearchResults results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(10, VectorScoringFunctions.COSINE)); + + assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) + .containsSequence("two", "one", "four"); + } + + @Test + void shouldProjectVectorSearchDynamically() { + + SearchResults dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(10, VectorScoringFunctions.COSINE), DescriptionDto.class); + + assertThat(dtos).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) + .containsSequence("two", "one", "four"); + + SearchResults proxies = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, + Score.of(10, VectorScoringFunctions.COSINE), WithDescription.class); + + assertThat(proxies).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) + .containsSequence("two", "one", "four"); + } + + @Entity + @Table(name = "with_vector") + public static class WithVector { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) // + private Integer id; + + private String country; + private String description; + + @Column(name = "the_embedding") + @JdbcTypeCode(SqlTypes.VECTOR) + @Array(length = 5) private float[] embedding; + + public WithVector() {} + + public WithVector(String country, String description, float[] embedding) { + this.country = country; + this.description = description; + this.embedding = embedding; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getCountry() { + return country; + } + + public void setCountry(String country) { + this.country = country; + } + + public String getDescription() { + return description; + } + + public float[] getEmbedding() { + return embedding; + } + + public void setEmbedding(float[] embedding) { + this.embedding = embedding; + } + + @Override + public String toString() { + return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}'; + } + } + + interface WithDescription { + String getDescription(); + } + + static class DescriptionDto { + + private final String description; + + public DescriptionDto(String description) { + this.description = description; + } + + public String getDescription() { + return description; + } + } + + interface VectorSearchRepository extends JpaRepository { + + List findAllByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + + @Query(""" + SELECT w FROM org.springframework.data.jpa.repository.AbstractVectorIntegrationTests$WithVector w + WHERE w.country = ?1 + AND cosine_distance(w.embedding, :embedding) <= :distance + ORDER BY cosine_distance(w.embedding, :embedding) asc""") + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + + @Query(""" + SELECT w, cosine_distance(w.embedding, :embedding) as distance FROM org.springframework.data.jpa.repository.AbstractVectorIntegrationTests$WithVector w + WHERE w.country = ?1 + AND cosine_distance(w.embedding, :embedding) <= :distance + ORDER BY distance asc""") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + + @Query(""" + SELECT w, cosine_distance(w.embedding, :embedding) as distance FROM org.springframework.data.jpa.repository.AbstractVectorIntegrationTests$WithVector w + WHERE w.country = ?1 + AND cosine_distance(w.embedding, :embedding) <= :distance + ORDER BY distance asc""") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + float distance); + + SearchResults searchAllByCountryAndEmbeddingWithin(String country, Vector embedding, + Range distance); + + SearchResults searchTop5ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + + SearchResults searchInterfaceProjectionByCountryAndEmbeddingWithin(String country, + Vector embedding, Score distance); + + SearchResults searchDtoByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + + SearchResults searchDynamicByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, + Class projection); + + } + +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/OracleVectorIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/OracleVectorIntegrationTests.java new file mode 100644 index 0000000000..5b1d8779c6 --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/OracleVectorIntegrationTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository; + +import java.net.URL; +import java.util.List; + +import org.hibernate.dialect.OracleDialect; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan.Filter; +import org.springframework.context.annotation.FilterType; +import org.springframework.core.io.ClassPathResource; +import org.springframework.data.jpa.repository.config.EnableJpaRepositories; +import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; +import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; +import org.springframework.transaction.annotation.EnableTransactionManagement; + +import org.testcontainers.oracle.OracleContainer; +import org.testcontainers.utility.MountableFile; + +/** + * Testcase to verify Vector Search work with Oracle. + * + * @author Mark Paluch + */ +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = OracleVectorIntegrationTests.Config.class) +class OracleVectorIntegrationTests extends AbstractVectorIntegrationTests { + + @EnableJpaRepositories(considerNestedRepositories = true, + includeFilters = @Filter(type = FilterType.ASSIGNABLE_TYPE, classes = VectorSearchRepository.class)) + @EnableTransactionManagement + static class Config extends TestcontainerConfigSupport { + + public Config() { + super(OracleDialect.class, new ClassPathResource("scripts/oracle-vector.sql")); + } + + @Override + protected String getSchemaAction() { + return "none"; + } + + @Override + protected PersistenceManagedTypes getManagedTypes() { + return new PersistenceManagedTypes() { + @Override + public List getManagedClassNames() { + return List.of(WithVector.class.getName()); + } + + @Override + public List getManagedPackages() { + return List.of(); + } + + @Override + public @Nullable URL getPersistenceUnitRootUrl() { + return null; + } + + }; + } + + @SuppressWarnings("resource") + @Bean(initMethod = "start", destroyMethod = "start") + public OracleContainer container() { + + return new OracleContainer("gvenzl/oracle-free:23-slim") // + .withReuse(true) + .withCopyFileToContainer(MountableFile.forClasspathResource("/scripts/oracle-vector-initialize.sql"), + "/container-entrypoint-initdb.d/initialize.sql"); + } + + } + +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java index 69e4458c0a..2427e1f930 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/PgVectorIntegrationTests.java @@ -13,352 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.data.jpa.repository; -import static org.assertj.core.api.Assertions.*; - -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.GeneratedValue; -import jakarta.persistence.GenerationType; -import jakarta.persistence.Id; -import jakarta.persistence.Table; - import java.net.URL; -import java.util.Arrays; -import java.util.EnumSet; import java.util.List; -import java.util.Set; -import org.hibernate.annotations.Array; -import org.hibernate.annotations.JdbcTypeCode; import org.hibernate.dialect.PostgreSQLDialect; -import org.hibernate.type.SqlTypes; import org.jspecify.annotations.Nullable; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.MethodSource; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan.Filter; import org.springframework.context.annotation.FilterType; import org.springframework.core.io.ClassPathResource; -import org.springframework.data.domain.Range; -import org.springframework.data.domain.Score; -import org.springframework.data.domain.ScoringFunction; -import org.springframework.data.domain.SearchResult; -import org.springframework.data.domain.SearchResults; -import org.springframework.data.domain.Similarity; -import org.springframework.data.domain.Vector; -import org.springframework.data.domain.VectorScoringFunctions; -import org.springframework.data.jpa.convert.VectorConverters; import org.springframework.data.jpa.repository.config.EnableJpaRepositories; import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; -import org.springframework.test.annotation.Rollback; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.transaction.annotation.EnableTransactionManagement; -import org.springframework.transaction.annotation.Transactional; import org.testcontainers.containers.PostgreSQLContainer; /** - * Testcase to verify {@link org.springframework.jdbc.object.StoredProcedure}s work with Postgres. + * Testcase to verify Vector Search work with Postgres (PGvector). * * @author Mark Paluch */ -@Transactional -@Rollback(value = false) @ExtendWith(SpringExtension.class) @ContextConfiguration(classes = PgVectorIntegrationTests.Config.class) -class PgVectorIntegrationTests { - - Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f); - - @Autowired VectorSearchRepository repository; - - @BeforeEach - void setUp() { - - WithVector w1 = new WithVector("de", "one", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f }); - WithVector w2 = new WithVector("de", "two", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f }); - WithVector w3 = new WithVector("en", "three", new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f }); - WithVector w4 = new WithVector("de", "four", - new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f }); - - repository.deleteAllInBatch(); - repository.saveAllAndFlush(Arrays.asList(w1, w2, w3, w4)); - } - - @ParameterizedTest - @MethodSource("scoringFunctions") - void shouldApplyVectorSearchWithDistance(VectorScoringFunctions functions) { - - SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, - Similarity.of(0, functions)); - - assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) - .containsOnly("de", "de"); - - assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) - .containsExactlyInAnyOrder("two", "one", "four"); - } - - static Set scoringFunctions() { - return EnumSet.of(VectorScoringFunctions.COSINE, VectorScoringFunctions.DOT_PRODUCT, - VectorScoringFunctions.EUCLIDEAN); - } - - @Test - void shouldNormalizeEuclideanSimilarity() { - - SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, - Similarity.of(0.99, VectorScoringFunctions.EUCLIDEAN)); - - assertThat(results).hasSize(1); - - SearchResult two = results.getContent().get(0); - - assertThat(two.getContent().getDescription()).isEqualTo("two"); - assertThat(two.getScore()).isInstanceOf(Similarity.class); - assertThat(two.getScore().getValue()).isGreaterThan(0.99); - } - - @Test - void shouldNormalizeCosineSimilarity() { - - SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, - Similarity.of(0.999, VectorScoringFunctions.COSINE)); - - assertThat(results).hasSize(1); - - SearchResult two = results.getContent().get(0); - - assertThat(two.getContent().getDescription()).isEqualTo("two"); - assertThat(two.getScore()).isInstanceOf(Similarity.class); - assertThat(two.getScore().getValue()).isGreaterThan(0.99); - } - - @Test - void shouldRunStringQuery() { - - List results = repository.findAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(2, VectorScoringFunctions.COSINE)); - - assertThat(results).hasSize(3).extracting(WithVector::getCountry).containsOnly("de", "de", "de"); - assertThat(results).extracting(WithVector::getDescription).containsSequence("two", "one", "four"); - } - - @Test - void shouldRunStringQueryWithDistance() { - - SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(2, VectorScoringFunctions.COSINE)); - - assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) - .containsOnly("de", "de", "de"); - assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) - .containsSequence("two", "one", "four"); - - SearchResult result = results.getContent().get(0); - assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0); - assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE); - } - - @Test - void shouldRunStringQueryWithFloatDistance() { - - SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, 2); - - assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) - .containsOnly("de", "de", "de"); - assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) - .containsSequence("two", "one", "four"); - - SearchResult result = results.getContent().get(0); - assertThat(result.getScore().getValue()).isGreaterThanOrEqualTo(0); - assertThat(result.getScore().getFunction()).isEqualTo(ScoringFunction.unspecified()); - } - - @Test - void shouldApplyVectorSearchWithRange() { - - SearchResults results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR, - Similarity.between(0, 1, VectorScoringFunctions.COSINE)); - - assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithVector::getCountry) - .containsOnly("de", "de", "de"); - assertThat(results).extracting(SearchResult::getContent).extracting(WithVector::getDescription) - .containsSequence("two", "one", "four"); - } - - @Test - void shouldApplyVectorSearchAndReturnList() { - - List results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(10, VectorScoringFunctions.COSINE)); - - assertThat(results).hasSize(3).extracting(WithVector::getCountry).containsOnly("de", "de", "de"); - assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four"); - } - - @Test - void shouldProjectVectorSearchAsInterface() { - - SearchResults results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de", - VECTOR, Score.of(10, VectorScoringFunctions.COSINE)); - - assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) - .containsSequence("two", "one", "four"); - } - - @Test - void shouldProjectVectorSearchAsDto() { - - SearchResults results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(10, VectorScoringFunctions.COSINE)); - - assertThat(results).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) - .containsSequence("two", "one", "four"); - } - - @Test - void shouldProjectVectorSearchDynamically() { - - SearchResults dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(10, VectorScoringFunctions.COSINE), DescriptionDto.class); - - assertThat(dtos).hasSize(3).extracting(SearchResult::getContent).extracting(DescriptionDto::getDescription) - .containsSequence("two", "one", "four"); - - SearchResults proxies = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, - Score.of(10, VectorScoringFunctions.COSINE), WithDescription.class); - - assertThat(proxies).hasSize(3).extracting(SearchResult::getContent).extracting(WithDescription::getDescription) - .containsSequence("two", "one", "four"); - } - - @Entity - @Table(name = "with_vector") - public static class WithVector { - - @Id - @GeneratedValue(strategy = GenerationType.IDENTITY) // - private Integer id; - - private String country; - private String description; - - @Column(name = "the_embedding") - @JdbcTypeCode(SqlTypes.VECTOR) - @Array(length = 5) private float[] embedding; - - public WithVector() {} - - public WithVector(String country, String description, float[] embedding) { - this.country = country; - this.description = description; - this.embedding = embedding; - } - - public Integer getId() { - return id; - } - - public void setId(Integer id) { - this.id = id; - } - - public String getCountry() { - return country; - } - - public void setCountry(String country) { - this.country = country; - } - - public String getDescription() { - return description; - } - - public float[] getEmbedding() { - return embedding; - } - - public void setEmbedding(float[] embedding) { - this.embedding = embedding; - } - - @Override - public String toString() { - return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}'; - } - } - - interface WithDescription { - String getDescription(); - } - - static class DescriptionDto { - - private final String description; - - public DescriptionDto(String description) { - this.description = description; - } - - public String getDescription() { - return description; - } - } - - interface VectorSearchRepository extends JpaRepository { - - List findAllByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); - - @Query(""" - SELECT w FROM org.springframework.data.jpa.repository.PgVectorIntegrationTests$WithVector w - WHERE w.country = ?1 - AND cosine_distance(w.embedding, :embedding) <= :distance - ORDER BY cosine_distance(w.embedding, :embedding) asc""") - List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); - - @Query(""" - SELECT w, cosine_distance(w.embedding, :embedding) as distance FROM org.springframework.data.jpa.repository.PgVectorIntegrationTests$WithVector w - WHERE w.country = ?1 - AND cosine_distance(w.embedding, :embedding) <= :distance - ORDER BY distance asc""") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, - Score distance); - - @Query(""" - SELECT w, cosine_distance(w.embedding, :embedding) as distance FROM org.springframework.data.jpa.repository.PgVectorIntegrationTests$WithVector w - WHERE w.country = ?1 - AND cosine_distance(w.embedding, :embedding) <= :distance - ORDER BY distance asc""") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, - float distance); - - SearchResults searchAllByCountryAndEmbeddingWithin(String country, Vector embedding, - Range distance); - - SearchResults searchTop5ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); - - SearchResults searchInterfaceProjectionByCountryAndEmbeddingWithin(String country, - Vector embedding, Score distance); - - SearchResults searchDtoByCountryAndEmbeddingWithin(String country, Vector embedding, - Score distance); - - SearchResults searchDynamicByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, - Class projection); - - } +class PgVectorIntegrationTests extends AbstractVectorIntegrationTests { @EnableJpaRepositories(considerNestedRepositories = true, includeFilters = @Filter(type = FilterType.ASSIGNABLE_TYPE, classes = VectorSearchRepository.class)) @@ -379,8 +63,7 @@ protected PersistenceManagedTypes getManagedTypes() { return new PersistenceManagedTypes() { @Override public List getManagedClassNames() { - return List.of(WithVector.class.getName(), VectorConverters.VectorAsDoubleArrayConverter.class.getName(), - VectorConverters.VectorAsFloatArrayConverter.class.getName()); + return List.of(WithVector.class.getName()); } @Override @@ -393,6 +76,7 @@ public List getManagedPackages() { return null; } }; + } @SuppressWarnings("resource") @@ -402,5 +86,7 @@ public PostgreSQLContainer container() { return new PostgreSQLContainer<>("pgvector/pgvector:pg17") // .withUsername("postgres").withReuse(true); } + } + } diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java index 5981f5d456..64d52bc1d4 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/MySqlStoredProcedureIntegrationTests.java @@ -23,11 +23,12 @@ import jakarta.persistence.Id; import jakarta.persistence.NamedStoredProcedureQuery; -import java.util.Collection; +import java.net.URL; import java.util.List; import java.util.Objects; import org.hibernate.dialect.MySQLDialect; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -40,6 +41,7 @@ import org.springframework.data.jpa.repository.config.EnableJpaRepositories; import org.springframework.data.jpa.repository.query.Procedure; import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; +import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.transaction.annotation.EnableTransactionManagement; @@ -232,8 +234,24 @@ public Config() { } @Override - protected Collection getPackagesToScan() { - return List.of(getClass().getPackageName()); + protected PersistenceManagedTypes getManagedTypes() { + return new PersistenceManagedTypes() { + @Override + public List getManagedClassNames() { + return List.of(Employee.class.getName()); + } + + @Override + public List getManagedPackages() { + return List.of(); + } + + @Override + public @Nullable URL getPersistenceUnitRootUrl() { + return null; + } + }; + } @SuppressWarnings("resource") diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java index 5b9a790082..77f46a5518 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureIntegrationTests.java @@ -26,11 +26,13 @@ import jakarta.persistence.StoredProcedureParameter; import java.math.BigDecimal; +import java.net.URL; import java.util.List; import java.util.Map; import java.util.Objects; import org.hibernate.dialect.PostgreSQLDialect; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -44,6 +46,7 @@ import org.springframework.data.jpa.repository.query.Procedure; import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; import org.springframework.data.jpa.util.DisabledOnHibernate; +import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.transaction.annotation.EnableTransactionManagement; @@ -299,6 +302,27 @@ public Config() { super(PostgreSQLDialect.class, new ClassPathResource("scripts/postgres-stored-procedures.sql")); } + @Override + protected PersistenceManagedTypes getManagedTypes() { + return new PersistenceManagedTypes() { + @Override + public List getManagedClassNames() { + return List.of(Employee.class.getName()); + } + + @Override + public List getManagedPackages() { + return List.of(); + } + + @Override + public @Nullable URL getPersistenceUnitRootUrl() { + return null; + } + }; + + } + @SuppressWarnings("resource") @Bean(initMethod = "start", destroyMethod = "stop") public PostgreSQLContainer container() { diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureNullHandlingIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureNullHandlingIntegrationTests.java index ba1961062f..125a2acbc7 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureNullHandlingIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/procedures/PostgresStoredProcedureNullHandlingIntegrationTests.java @@ -20,12 +20,13 @@ import jakarta.persistence.GenerationType; import jakarta.persistence.Id; -import java.util.Collection; +import java.net.URL; import java.util.Date; import java.util.List; import java.util.UUID; import org.hibernate.dialect.PostgreSQLDialect; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -40,6 +41,7 @@ import org.springframework.data.jpa.repository.query.Procedure; import org.springframework.data.jpa.repository.support.TestcontainerConfigSupport; import org.springframework.data.jpa.util.DisabledOnHibernate; +import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.transaction.annotation.EnableTransactionManagement; @@ -138,8 +140,24 @@ public Config() { } @Override - protected Collection getPackagesToScan() { - return List.of(getClass().getPackageName()); + protected PersistenceManagedTypes getManagedTypes() { + return new PersistenceManagedTypes() { + @Override + public List getManagedClassNames() { + return List.of(TestModel.class.getName()); + } + + @Override + public List getManagedPackages() { + return List.of(); + } + + @Override + public @Nullable URL getPersistenceUnitRootUrl() { + return null; + } + }; + } @SuppressWarnings("resource") diff --git a/spring-data-jpa/src/test/resources/scripts/oracle-vector-initialize.sql b/spring-data-jpa/src/test/resources/scripts/oracle-vector-initialize.sql new file mode 100644 index 0000000000..23a69dddb7 --- /dev/null +++ b/spring-data-jpa/src/test/resources/scripts/oracle-vector-initialize.sql @@ -0,0 +1,11 @@ +-- Exit on any errors +WHENEVER SQLERROR EXIT SQL.SQLCODE + +-- Configure the size of the Vector Pool to 1 GiB. +ALTER SYSTEM SET vector_memory_size = 1G SCOPE=SPFILE; + +SHUTDOWN +ABORT; +STARTUP; + +exit; diff --git a/spring-data-jpa/src/test/resources/scripts/oracle-vector.sql b/spring-data-jpa/src/test/resources/scripts/oracle-vector.sql new file mode 100644 index 0000000000..2d0bf06de4 --- /dev/null +++ b/spring-data-jpa/src/test/resources/scripts/oracle-vector.sql @@ -0,0 +1,16 @@ +DROP TABLE IF EXISTS with_vector;; + +CREATE TABLE IF NOT EXISTS with_vector +( + id NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY, + country varchar2(10), + description varchar2(10), + the_embedding vector(5, FLOAT32) annotations(Distance 'COSINE', IndexType 'IVF') +);; + +create +vector index if not exists vector_index_1 on with_vector (the_embedding) + organization neighbor partitions + distance COSINE +with target accuracy 95 + parameters (type IVF, neighbor partitions 10);; diff --git a/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc index 6ed3115489..2c255297f4 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc @@ -15,7 +15,6 @@ Maven:: ${hibernate.version} - ---- Gradle:: @@ -29,3 +28,5 @@ dependencies { ---- ==== ====== + +NOTE: While you can use `Vector` as type for queries, you cannot use it in your domain model as Hibernate requires float or double arrays as vector types. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc index e713cdbaf1..8b27401c75 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc @@ -20,7 +20,6 @@ interface CommentRepository extends Repository { AND cosine_distance(c.embedding, :embedding) <= :distance ORDER BY cosine_distance(c.embedding, :embedding) asc""") List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); - } ---- ==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc index 3a24393f8e..9819837348 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc @@ -13,4 +13,4 @@ interface CommentRepository extends Repository { ---- ==== -Derived Search Methods can define domain model attributes and Vector parameters. +Derived search methods can declare predicates on domain model attributes and Vector parameters. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc index 62e900efba..716bf5a562 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc @@ -14,7 +14,6 @@ interface CommentRepository extends Repository { ORDER BY distance asc""") SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); - } SearchResults results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10)); diff --git a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc index 11a8fd289d..4cd793dc91 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc @@ -9,6 +9,8 @@ When using `Similarity` instead of `Score`, Spring Data normalizes distance scor ---- interface CommentRepository extends Repository { + SearchResults searchByEmbeddingNear(Vector vector, ScoringFunction function); + SearchResults searchByEmbeddingNear(Vector vector, Score score); SearchResults searchByEmbeddingNear(Vector vector, Similarity similarity); @@ -16,17 +18,20 @@ interface CommentRepository extends Repository { SearchResults searchByEmbeddingNear(Vector vector, Range range); } -repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9, ScoringFunction.cosine())); <1> +repository.searchByEmbeddingNear(Vector.of(…), ScoringFunction.cosine()); <1> + +repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9, ScoringFunction.cosine())); <2> -repository.searchByEmbeddingNear(Vector.of(…), Similarity.of(0.9, ScoringFunction.cosine())); <2> +repository.searchByEmbeddingNear(Vector.of(…), Similarity.of(0.9, ScoringFunction.cosine())); <3> -repository.searchByEmbeddingNear(Vector.of(…), Similarity.between(0.5, 1, ScoringFunction.euclidean()));<3> +repository.searchByEmbeddingNear(Vector.of(…), Similarity.between(0.5, 1, ScoringFunction.euclidean()));<4> ---- -<1> Run a search and return results with a score of `0.9` or smaller using the Cosine distance. -<2> Run a search and normalize the score into a similarity value. -Return results with a similarity of `0.9` or greater using Cosine scoring. +<1> Run a search and return results that are similar to the given `Vector` applying Cosine scoring. +<2> Run a search and return results with a score of `0.9` or smaller using the Cosine distance. <3> Run a search and normalize the score into a similarity value. +Return results with a similarity of `0.9` or greater using Cosine scoring. +<4> Run a search and normalize the score into a similarity value. Return results with a similarity of between `0.5` and `1.0` or greater using Euclidean scoring. ====