From aaf4136363fbf1ae8abc5fc3c34db89465d79771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Ko=C5=82aczkowski?= Date: Mon, 8 Sep 2025 14:56:48 +0200 Subject: [PATCH] Fix recall when all vectors score the same It is possible to insert non-equal vectors which score the same by the provided similarity measure. E.g. vectors (0.1, 0.1), (0.2, 0.2), (0.3, 0.3) all are really the same point under cosine metric and any pair of those would score 1.0 similarity. This edge case caused some serious issues with graph connectivity and the queries returned at most 33 nodes, even if for large graphs. This PR fixes it by improving fairness of node selection when nodes score the same. This is achieved by a small modification to how node ids are encoded in the NodeQueue. When nodes score the same, they were compared by node ids, which always preferred the nodes added earlier. That created a huge bias. If we reverse the bits of node ids, now this shuffles their order and breaks the systematic bias towards the older nodes. Another part of the fix is making sure nodes with the same score don't block backlinks to be formed. By placing new neighbours before the neighbours with the same score, we're giving them a chance to be linked to. It will drop some other neighbor from the list, but considering it was present in the graph for some time already, it is much more likely to be already well connected. --- .../jbellis/jvector/graph/NodeArray.java | 7 ++-- .../jbellis/jvector/graph/NodeQueue.java | 13 +++--- .../jbellis/jvector/graph/TestNodeArray.java | 22 +++++----- .../jvector/graph/TestVectorGraph.java | 41 +++++++++++++++++++ 4 files changed, 63 insertions(+), 20 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index fd0d4f792..318af8122 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -173,7 +173,8 @@ public int insertSorted(int newNode, float newScore) { if (size == nodes.length) { growArrays(); } - int insertionPoint = descSortFindRightMostInsertionPoint(newScore); + + int insertionPoint = descSortFindLeftMostInsertionPoint(newScore); if (duplicateExistsNear(insertionPoint, newNode, newScore)) { return -1; } @@ -282,12 +283,12 @@ public String toString() { return sb.toString(); } - protected final int descSortFindRightMostInsertionPoint(float newScore) { + protected final int descSortFindLeftMostInsertionPoint(float newScore) { int start = 0; int end = size - 1; while (start <= end) { int mid = (start + end) / 2; - if (scores[mid] < newScore) end = mid - 1; + if (scores[mid] <= newScore) end = mid - 1; else start = mid + 1; } return start; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java index 4db8e32b8..7b228eb35 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java @@ -100,22 +100,21 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) { } /** - * Encodes the node ID and its similarity score as long. If two scores are equals, - * the smaller node ID wins. + * Encodes the node ID and its similarity score as long. * *

The most significant 32 bits represent the float score, encoded as a sortable int. * *

The less significant 32 bits represent the node ID. * - *

The bits representing the node ID are complemented to guarantee the win for the smaller node - * ID. + *

The bits representing the node ID are reversed to ensure no bias towards smaller or greater IDs + * when scores are equal. * *

The AND with 0xFFFFFFFFL (a long with first 32 bit as 1) is necessary to obtain a long that * has * *

The most significant 32 bits to 0 * - *

The less significant 32 bits represent the node ID. + *

The less significant 32 bits represent the encoded node ID. * * @param node the node ID * @param score the node score @@ -124,7 +123,7 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) { private long encode(int node, float score) { assert node >= 0 : node; return order.apply( - (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); + (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & Integer.reverse(node))); } private float decodeScore(long heapValue) { @@ -132,7 +131,7 @@ private float decodeScore(long heapValue) { } private int decodeNodeId(long heapValue) { - return (int) ~(order.apply(heapValue)); + return Integer.reverse((int) order.apply(heapValue)); } /** Removes the top element and returns its node id. */ diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java index 0f3586200..f4652d33a 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java @@ -60,48 +60,50 @@ public void testScoresDescOrder() { neighbors.insertSorted(4, 1f); assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {4, 0, 3, 1}, neighbors); neighbors.insertSorted(5, 1.1f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {5, 4, 0, 3, 1}, neighbors); neighbors.insertSorted(6, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {5, 4, 0, 3, 6, 1}, neighbors); neighbors.insertSorted(7, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 4, 0, 3, 7, 6, 1}, neighbors); neighbors.removeIndex(2); assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 4, 3, 7, 6, 1}, neighbors); neighbors.removeIndex(0); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {4, 3, 7, 6, 1}, neighbors); neighbors.removeIndex(4); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {4, 3, 7, 6}, neighbors); neighbors.removeLast(); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1}, neighbors); + assertNodesEqual(new int[] {4, 3, 7}, neighbors); neighbors.insertSorted(8, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors); + assertNodesEqual(new int[] {4, 8, 3, 7}, neighbors); } private void assertScoresEqual(float[] scores, NodeArray neighbors) { + assertEquals(scores.length, neighbors.size(), "Number of scores differs"); for (int i = 0; i < scores.length; i++) { assertEquals(scores[i], neighbors.getScore(i), 0.01f); } } private void assertNodesEqual(int[] nodes, NodeArray neighbors) { + assertEquals(nodes.length, neighbors.size(), "Number of nodes differs"); for (int i = 0; i < nodes.length; i++) { assertEquals(nodes[i], neighbors.getNode(i)); } @@ -181,7 +183,7 @@ public void testNoDuplicatesSameScores() { cna.insertSorted(3, 10.0f); cna.insertSorted(1, 10.0f); // This is a duplicate and should be ignored cna.insertSorted(3, 10.0f); // This is also a duplicate - assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes()); + assertArrayEquals(new int[] {3, 2, 1}, cna.copyDenseNodes()); assertArrayEquals(new float[] {10.0f, 10.0f, 10.0f}, cna.copyDenseScores(), 0.01f); validateSortedByScore(cna); } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java index c11b04ebd..53381019d 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java @@ -700,6 +700,47 @@ public void testZeroCentroid(boolean addHierarchy) { } } + @Test + public void testSameScoreWithCosineSimilarity() + { + testSameScoreWithCosineSimilarity(10); + testSameScoreWithCosineSimilarity(20); + testSameScoreWithCosineSimilarity(50); + testSameScoreWithCosineSimilarity(100); + testSameScoreWithCosineSimilarity(200); + testSameScoreWithCosineSimilarity(500); + testSameScoreWithCosineSimilarity(1000); + } + + private void testSameScoreWithCosineSimilarity(final int N) { + // Create N vectors which differ in their magnitude but have the same direction, so they would + // all have the exactly same cosine similarity to the query vector. + Random rand = getRandom(); + VectorFloat[] vectors = new VectorFloat[N]; + for (int i = 0; i < N; i++) { + float x = 0.01f + rand.nextFloat(); + vectors[i] = vectorTypeSupport.createFloatVector(new float[]{x, x}); + } + MockVectorValues vectorValues = MockVectorValues.fromValues(vectors); + + similarityFunction = VectorSimilarityFunction.COSINE; + GraphIndexBuilder builder = new GraphIndexBuilder(vectorValues, similarityFunction, 10, 20, 1.0f, 1.0f, false); + OnHeapGraphIndex graph = builder.build(vectorValues); + + VectorFloat query = vectorTypeSupport.createFloatVector(new float[]{0.5f, 0.5f}); + SearchResult result = GraphSearcher.search(query, N, vectorValues, similarityFunction, graph, Bits.ALL); + + // In perfect world, we should return all N vectors, but this is hard to guarantee considering + // the graph is built with a semi-randomized algorithm. And this is an edge case already, so + // we don't want to make the graph building algorithm more complex or less performant in order to satisfy + // this test. In a typical scenario we'll have many more vectors in the graph than the query limit, + // so missing some results is fine. We'd fall back to brute force search anyway if limit + // is the same order of magnitude as the graph size. + int minExpected = (int) (N * 0.5); + assertTrue("Should return almost all vectors, expected at least: " + minExpected + ", got: " + result.getNodes().length, + result.getNodes().length >= minExpected); + } + /** * Returns vectors evenly distributed around the upper unit semicircle. */