Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package io.github.jbellis.jvector.example.util;

import io.github.jbellis.jvector.graph.SearchResult;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.Set;

/**
* Computes accuracy metrics, such as recall and mean average precision.
Expand All @@ -41,43 +39,54 @@ public static double recallFromSearchResults(List<? extends List<Integer>> gt, L
if (gt.size() != retrieved.size()) {
throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements");
}
Long correctCount = IntStream.range(0, gt.size())
.mapToObj(i -> topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved))
.reduce(0L, Long::sum);

long correctCount = 0;
for (int i = 0; i < gt.size(); i++) {
correctCount += topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved);
}

return (double) correctCount / (kGT * gt.size());
}

private static long topKCorrect(List<Integer> gt, List<Integer> retrieved, int kGT, int kRetrieved) {
private static long topKCorrect(List<Integer> gt, SearchResult retrieved, int kGT, int kRetrieved) {
// Exception validation
var nodes = retrieved.getNodes();
if (kGT > kRetrieved) {
throw new IllegalArgumentException("kGT: " + kGT + " > kRetrieved: " + kRetrieved);
}
if (kGT > gt.size()) {
throw new IllegalArgumentException("kGT: " + kGT + " > Gt size: " + gt.size());
}
if (kRetrieved > retrieved.size()) {
throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + retrieved.size());
if (kRetrieved > nodes.length) {
throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + nodes.length);
}

var gtView = crop(gt, kGT);
var retrievedView = crop(retrieved, kRetrieved);

if (gtView.size() > retrieved.size()) {
return gtView.stream().filter(retrievedView::contains).count();
} else {
return retrievedView.stream().filter(gtView::contains).count();
// Build HashSet with explicit capacity to avoid rehashing.
// Load factor is 0.75, so sized to kGT / 0.75.
Set<Integer> gtSet = new HashSet<>((int) (kGT / 0.75f) + 1);
for (int i = 0; i < kGT; i++) {
Integer ord = gt.get(i);
if (ord == null) {
throw new IllegalArgumentException("Null ground truth ordinal in top-" + kGT + " at index " + i);
}
if (!gtSet.add(ord)) {
throw new IllegalArgumentException("Duplicate ground truth ordinal in top-" + kGT + ": " + ord);
}
}
}

private static long topKCorrect(List<Integer> gt, SearchResult retrieved, int kGT, int kRetrieved) {
var temp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node)
.boxed()
.collect(Collectors.toList());
return topKCorrect(gt, temp, kGT, kRetrieved);
}
Set<Integer> seenRetrieved = new HashSet<>((int) (kRetrieved / 0.75f) + 1);
int hits = 0;
for (int i = 0; i < kRetrieved; i++) {
int p = nodes[i].node;
if (!seenRetrieved.add(p)) {
throw new IllegalArgumentException("Duplicate retrieved ordinal in top-" + kRetrieved + ": " + p);
}
if (gtSet.contains(p)) {
hits++;
}
}

private static List<Integer> crop(List<Integer> list, int k) {
int count = Math.min(list.size(), k);
return list.subList(0, count);
return hits;
}

/**
Expand All @@ -89,33 +98,41 @@ private static List<Integer> crop(List<Integer> list, int k) {
* @return the average precision
*/
public static double averagePrecisionAtK(List<Integer> gt, SearchResult retrieved, int k) {
var retrievedTemp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node)
.boxed()
.collect(Collectors.toList());

var nodes = retrieved.getNodes();
if (k > gt.size()) {
throw new IllegalArgumentException("k: " + k + " > Gt size: " + gt.size());
}
if (k > retrievedTemp.size()) {
throw new IllegalArgumentException("k: " + k + " > retrieved size: " + retrievedTemp.size());
if (k > nodes.length) {
throw new IllegalArgumentException("k: " + k + " > retrieved size: " + nodes.length);
}

var gtView = crop(gt, k);
var retrievedView = crop(retrievedTemp, k);
// Sized hashset used for performance.
Set<Integer> gtSet = new HashSet<>((int) (k / 0.75f) + 1);
for (int i = 0; i < k; i++) {
Integer ord = gt.get(i);
if (ord == null) {
throw new IllegalArgumentException("Null ground truth ordinal in top-" + k + " at index " + i);
}
if (!gtSet.add(ord)) {
throw new IllegalArgumentException("Duplicate ground truth ordinal in top-" + k + ": " + ord);
}
}

Set<Integer> seenRetrieved = new HashSet<>((int) (k / 0.75f) + 1);
double score = 0.;
int num_hits = 0;
int i = 0;

for (var p : retrievedView) {
if (gtView.contains(p) && !retrievedView.subList(0, i).contains(p)) {
num_hits += 1;
score += num_hits / (i + 1.0);
int hits = 0;
for (int i = 0; i < k; i++) {
int p = nodes[i].node;
if (!seenRetrieved.add(p)) {
throw new IllegalArgumentException("Duplicate retrieved ordinal in top-" + k + ": " + p);
}
if (gtSet.contains(p)) {
hits++;
score += (double) hits / (i + 1);
}
i++;
}

return score / gtView.size();
return score / k;
}

/**
Expand All @@ -130,10 +147,12 @@ public static double meanAveragePrecisionAtK(List<? extends List<Integer>> gt, L
if (gt.size() != retrieved.size()) {
throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements");
}
Double apk = IntStream.range(0, gt.size())
.mapToObj(i -> averagePrecisionAtK(gt.get(i), retrieved.get(i), k))
.reduce(0., Double::sum);
return apk / gt.size();
}

double totalAp = 0;
for (int i = 0; i < gt.size(); i++) {
totalAp += averagePrecisionAtK(gt.get(i), retrieved.get(i), k);
}

return totalAp / gt.size();
}
}
Loading
Loading