|
15 | 15 | */
|
16 | 16 | package org.springframework.data.jpa.repository.query;
|
17 | 17 |
|
| 18 | +import java.util.function.Function; |
| 19 | +import java.util.function.Predicate; |
| 20 | +import java.util.function.Supplier; |
| 21 | + |
18 | 22 | import org.jspecify.annotations.Nullable;
|
19 | 23 |
|
| 24 | +import org.springframework.data.domain.Range; |
| 25 | +import org.springframework.data.domain.Score; |
| 26 | +import org.springframework.data.domain.ScoringFunction; |
| 27 | +import org.springframework.data.domain.Similarity; |
20 | 28 | import org.springframework.data.jpa.repository.query.JpaParameters.JpaParameter;
|
21 | 29 | import org.springframework.data.repository.query.Parameter;
|
22 | 30 | import org.springframework.data.repository.query.Parameters;
|
@@ -68,4 +76,54 @@ protected Object potentiallyUnwrap(Object parameterValue) {
|
68 | 76 | return parameterValue;
|
69 | 77 | }
|
70 | 78 |
|
| 79 | + /** |
| 80 | + * Returns the {@link ScoringFunction}. |
| 81 | + * |
| 82 | + * @return |
| 83 | + */ |
| 84 | + public ScoringFunction getScoringFunction() { |
| 85 | + return doWithScore(Score::getFunction, Score.class::isInstance, ScoringFunction::unspecified); |
| 86 | + } |
| 87 | + |
| 88 | + /** |
| 89 | + * Returns whether to normalize similarities (i.e. translate the database-specific score into {@link Similarity}). |
| 90 | + * |
| 91 | + * @return |
| 92 | + */ |
| 93 | + public boolean normalizeSimilarity() { |
| 94 | + return doWithScore(it -> true, Similarity.class::isInstance, () -> false); |
| 95 | + } |
| 96 | + |
| 97 | + /** |
| 98 | + * Returns the {@link ScoringFunction}. |
| 99 | + * |
| 100 | + * @return |
| 101 | + */ |
| 102 | + public <T> T doWithScore(Function<Score, T> function, Predicate<Score> scoreFilter, Supplier<T> defaultValue) { |
| 103 | + |
| 104 | + Score score = getScore(); |
| 105 | + if (score != null && scoreFilter.test(score)) { |
| 106 | + return function.apply(score); |
| 107 | + } |
| 108 | + |
| 109 | + JpaParameters parameters = getParameters(); |
| 110 | + if (parameters.hasScoreRangeParameter()) { |
| 111 | + |
| 112 | + Range<Score> range = getScoreRange(); |
| 113 | + |
| 114 | + if (range != null && range.getLowerBound().isBounded() |
| 115 | + && scoreFilter.test(range.getLowerBound().getValue().get())) { |
| 116 | + return function.apply(range.getUpperBound().getValue().get()); |
| 117 | + } |
| 118 | + |
| 119 | + if (range != null && range.getUpperBound().isBounded() |
| 120 | + && scoreFilter.test(range.getUpperBound().getValue().get())) { |
| 121 | + return function.apply(range.getUpperBound().getValue().get()); |
| 122 | + } |
| 123 | + |
| 124 | + } |
| 125 | + |
| 126 | + return defaultValue.get(); |
| 127 | + } |
| 128 | + |
71 | 129 | }
|
0 commit comments