diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index a43a31ed7..b99b71fe4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -325,7 +325,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.simdExecutor = simdExecutor; this.parallelExecutor = parallelExecutor; - this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha)); + this.graph = new OnHeapGraphIndex(maxDegrees, dimension, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha)); this.searchers = ExplicitThreadLocal.withInitial(() -> { var gs = new GraphSearcher(graph); @@ -1001,7 +1001,7 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, var diversityProvider = new VamanaDiversityProvider(buildScoreProvider, alpha); - try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, overflowRatio, diversityProvider);) { + try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, newVectors.dimension(), overflowRatio, diversityProvider);) { GraphIndexBuilder builder = new GraphIndexBuilder( buildScoreProvider, diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java index 088f9a1af..0fc6d27f8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java @@ -80,6 +80,11 @@ default int size() { List maxDegrees(); + /** + * @return the dimension of the vectors in the graph + */ + int getDimension(); + /** * @return the first ordinal greater than all node ids in the graph. Equal to size() in simple cases; * May be different from size() if nodes are being added concurrently, or if nodes have been diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 29999bfde..711304f79 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -73,15 +73,17 @@ public class OnHeapGraphIndex implements MutableGraphIndex { // Maximum number of neighbors (edges) per node per layer final List maxDegrees; + private final int dimension; // The ratio by which we can overflow the neighborhood of a node during construction. Since it is a multiplicative // ratio, i.e., the maximum allowable degree if maxDegree * overflowRatio, it should be higher than 1. private final double overflowRatio; private volatile boolean allMutationsCompleted = false; - OnHeapGraphIndex(List maxDegrees, double overflowRatio, DiversityProvider diversityProvider) { + OnHeapGraphIndex(List maxDegrees, int dimension, double overflowRatio, DiversityProvider diversityProvider) { this.overflowRatio = overflowRatio; this.maxDegrees = new IntArrayList(); + this.dimension = dimension; setDegrees(maxDegrees); entryPoint = new AtomicReference<>(); this.completions = new CompletionTracker(1024); @@ -369,6 +371,11 @@ public void setDegrees(List layerDegrees) { maxDegrees.addAll(layerDegrees); } + @Override + public int getDimension() { + return dimension; + } + @Override public void setAllMutationsCompleted() { allMutationsCompleted = true; @@ -541,7 +548,7 @@ public void save(DataOutput out) throws IOException { */ @Experimental @Deprecated - public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException { + public static OnHeapGraphIndex load(RandomAccessReader in, int dimension, double overflowRatio, DiversityProvider diversityProvider) throws IOException { int magic = in.readInt(); // the magic number if (magic != OnHeapGraphIndex.MAGIC) { throw new IOException("Unsupported magic number: " + magic); @@ -561,7 +568,7 @@ public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, int entryNode = in.readInt(); - var graph = new OnHeapGraphIndex(layerDegrees, overflowRatio, diversityProvider); + var graph = new OnHeapGraphIndex(layerDegrees, dimension, overflowRatio, diversityProvider); Map nodeLevelMap = new HashMap<>(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 647334e91..8f18ffcf4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -227,6 +227,11 @@ public Set getFeatureSet() { return features.keySet(); } + @Override + public int getDimension() { + return dimension; + } + @Override public int size(int level) { return layerInfo.get(level).size; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 21de0fede..05a3f7195 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -264,6 +264,11 @@ public List maxDegrees() { throw new NotImplementedException(); } + @Override + public int getDimension() { + throw new NotImplementedException(); + } + @Override public int getIdUpperBound() { return ImmutableGraphIndex.super.getIdUpperBound(); @@ -424,6 +429,11 @@ public List maxDegrees() { throw new NotImplementedException(); } + @Override + public int getDimension() { + throw new NotImplementedException(); + } + @Override public int getIdUpperBound() { return ImmutableGraphIndex.super.getIdUpperBound(); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java index c2cf9fc90..65d14f91b 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -151,7 +151,7 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { log.info("Reading on-heap graph from {}", heapGraphOutputPath); MutableGraphIndex reconstructedOnHeapGraphIndex; try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { - reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(baseBuildScoreProvider, ALPHA)); + reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), baseVectorsRavv.dimension(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(baseBuildScoreProvider, ALPHA)); } try (var readerSupplier = new SimpleMappedReader.Supplier(graphOutputPath.toAbsolutePath());