diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationMutableVectorBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationMutableVectorBenchmark.java new file mode 100644 index 000000000..f3def2cab --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationMutableVectorBenchmark.java @@ -0,0 +1,144 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.quantization.MutablePQVectors; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * Benchmark that compares the distance calculation of mutable Product Quantized vectors vs full precision vectors. + */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"}) +@Warmup(iterations = 2) +@Measurement(iterations = 3) +@Threads(1) +public class PQDistanceCalculationMutableVectorBenchmark { + private static final Logger log = LoggerFactory.getLogger(PQDistanceCalculationMutableVectorBenchmark.class); + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private List> vectors; + private PQVectors pqVectors; + private List> queryVectors; + private ProductQuantization pq; + private BuildScoreProvider buildScoreProvider; + + @Param({"1536"}) + private int dimension; + + @Param({"10000"}) + private int vectorCount; + + @Param({"100"}) + private int queryCount; + + @Param({ "16","32", "64","96", "192"}) + private int M; // Number of subspaces for PQ + + @Param + private VectorSimilarityFunction vsf; + + @Setup + public void setup() throws IOException { + log.info("Creating dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount); + + // Create random vectors + vectors = new ArrayList<>(vectorCount); + for (int i = 0; i < vectorCount; i++) { + vectors.add(createRandomVector(dimension)); + } + + // Create query vectors + queryVectors = new ArrayList<>(queryCount); + for (int i = 0; i < queryCount; i++) { + queryVectors.add(createRandomVector(dimension)); + } + + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); + // Create Mutable PQ vectors + pq = ProductQuantization.compute(ravv, M, 256, true); + pqVectors = new MutablePQVectors(pq); + // build the index vector-at-a-time (on disk) + for (int ordinal = 0; ordinal < vectors.size(); ordinal++) + { + VectorFloat v = vectors.get(ordinal); + // compress the new vector and add it to the PQVectors + ((MutablePQVectors)pqVectors).encodeAndSet(ordinal, v); + } + buildScoreProvider = BuildScoreProvider.pqBuildScoreProvider(vsf, pqVectors); + log.info("Created dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount); + } + + @Benchmark + public void scoreCalculation(Blackhole blackhole) { + float totalSimilarity = 0; + + for (VectorFloat query : queryVectors) { + + ScoreFunction.ApproximateScoreFunction asf = pqVectors.scoreFunctionFor(query, vsf); + for (int i = 0; i < vectorCount; i++) { + float similarity = asf.similarityTo(i); + totalSimilarity += similarity; + } + } + + blackhole.consume(totalSimilarity); + } + + @Benchmark + public void diversityCalculation(Blackhole blackhole) { + float totalSimilarity = 0; + + for (int q = 0; q < queryCount; q++) { + for (int i = 0; i < vectorCount; i++) { + final ScoreFunction sf = buildScoreProvider.diversityProviderFor(i).scoreFunction(); + float similarity = sf.similarityTo(q); + totalSimilarity += similarity; + } + } + + blackhole.consume(totalSimilarity); + } + + private VectorFloat createRandomVector(int dimension) { + VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); + for (int i = 0; i < dimension; i++) { + vector.set(i, (float) Math.random()); + } + return vector; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index 73e59b20f..9fbf7b2a1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -211,33 +211,16 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, var encodedChunk = getChunk(node2); var encodedOffset = getOffsetInChunk(node2); // compute the dot product of the query and the codebook centroids corresponding to the encoded points - float dp = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; - dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); - } + float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); // scale to [0, 1] return (1 + dp) / 2; }; case COSINE: - float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery); return (node2) -> { var encodedChunk = getChunk(node2); var encodedOffset = getOffsetInChunk(node2); - // compute the dot product of the query and the codebook centroids corresponding to the encoded points - float sum = 0; - float norm2 = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; - var codebookOffset = centroidIndex * centroidLength; - sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength); - norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength); - } - float cosine = sum / (float) Math.sqrt(norm1 * norm2); + // compute the cosine of the query and the codebook centroids corresponding to the encoded points + float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); // scale to [0, 1] return (1 + cosine) / 2; }; @@ -246,13 +229,7 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, var encodedChunk = getChunk(node2); var encodedOffset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points - float sum = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; - sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); - } + float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); // scale to [0, 1] return 1 / (1 + sum); }; @@ -273,40 +250,16 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve var node2Chunk = getChunk(node2); var node2Offset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points - float dp = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); - int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength); - } + float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); // scale to [0, 1] return (1 + dp) / 2; }; case COSINE: - float norm1 = 0.0f; - for (int m1 = 0; m1 < subspaceCount; m1++) { - int centroidIndex = Byte.toUnsignedInt(node1Chunk.get(m1 + node1Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m1][0]; - var codebookOffset = centroidIndex * centroidLength; - norm1 += VectorUtil.dotProduct(pq.codebooks[m1], codebookOffset, pq.codebooks[m1], codebookOffset, centroidLength); - } - final float norm1final = norm1; return (node2) -> { var node2Chunk = getChunk(node2); var node2Offset = getOffsetInChunk(node2); // compute the dot product of the query and the codebook centroids corresponding to the encoded points - float sum = 0; - float norm2 = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); - int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int codebookOffset = centroidIndex2 * centroidLength; - sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], centroidIndex1 * centroidLength, centroidLength); - norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength); - } - float cosine = sum / (float) Math.sqrt(norm1final * norm2); + float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); // scale to [0, 1] return (1 + cosine) / 2; }; @@ -315,13 +268,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve var node2Chunk = getChunk(node2); var node2Offset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points - float sum = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); - int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength); - } + float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); // scale to [0, 1] return 1 / (1 + sum); }; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index 867e1c85d..5e55e40da 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -584,4 +584,90 @@ public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValu return squaredSum; } + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float dp = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + dp += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + } + return dp; +} + + + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + sum += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + aMagnitude += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex1 * centroidLength, centroidLength); + bMagnitude += dotProduct(codebooks[m], centroidIndex2 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + } + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + + sum += squareDistance(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + } + return sum; + + } + + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float dp = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int centroidOffset = subvectorSizesAndOffsets[m][1]; + dp += dotProduct(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); + } + return dp; + } + + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0; + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int centroidOffset = subvectorSizesAndOffsets[m][1]; + var codebookOffset = centroidIndex * centroidLength; + sum += dotProduct(codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength); + aMagnitude += dotProduct(codebooks[m], codebookOffset, codebooks[m], codebookOffset, centroidLength); + bMagnitude += dotProduct(centeredQuery, centroidOffset, centeredQuery, centroidOffset, centroidLength); + } + float cosine = sum / (float) Math.sqrt(aMagnitude * bMagnitude); + return cosine; + } + + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int centroidOffset = subvectorSizesAndOffsets[m][1]; + sum += squareDistance(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); + } + return sum; + } + } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index e7e8b068f..1e88517f0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -254,4 +254,29 @@ public static float nvqLoss(VectorFloat vector, float growthRate, float midpo public static float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { return impl.nvqUniformLoss(vector, minValue, maxValue, nBits); } + + public static float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + return impl.pqScoreDotProduct(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + public static float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + return impl.pqScoreCosine(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + public static float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + return impl.pqScoreEuclidean(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + public static float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + return impl.pqScoreDotProduct(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + public static float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + return impl.pqScoreCosine(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + public static float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + return impl.pqScoreEuclidean(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index cc1f74f1b..88daf17b8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -338,4 +338,78 @@ default float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffs */ float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits); + /** + * Calculates the dotproduct for an array of codebooks, uses diversityFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param node1Chunk centroid vector for node1's subvectors + * @param node1Offset offset into ByteSequence of node1 + * @param node2Chunk centroid vector for node2's subvectors + * @param node2Offset offset into ByteSequence of node2 + * @param subspaceCount the number of PQ subspaces + * @return the dot product + */ + float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount); + + /** + * Calculates cosine for an array of codebooks, uses diversityFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param node1Chunk centroid vector for node1's subvectors + * @param node1Offset offset into ByteSequence of node1 + * @param node2Chunk centroid vector for node2's subvectors + * @param node2Offset offset into ByteSequence of node2 + * @param subspaceCount the number of PQ subspaces + * @return the cosine value + */ + float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount); + + /** + * Calculates the Euclidean distance for an array of codebooks, uses diversityFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param node1Chunk centroid vector for node1's subvectors + * @param node1Offset offset into ByteSequence of node1 + * @param node2Chunk centroid vector for node2's subvectors + * @param node2Offset offset into ByteSequence of node2 + * @param subspaceCount the number of PQ subspaces + * @return the Euclidean distance + */ + float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount); + + /** + * Overloaded function to calculate the dotproduct for an array of codebooks, uses scoreFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param encodedChunk centroid vector for encoded point + * @param encodedOffset offset into ByteSequence of encoded vector + * @param centeredQuery query + * @param subspaceCount the number of PQ subspaces + * @return the dotproduct + */ + float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount); + + /** + * Overloaded function to calculate cosine for an array of codebooks, uses scoreFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param encodedChunk centroid vector for encoded point + * @param encodedOffset offset into ByteSequence of encoded vector + * @param centeredQuery query + * @param subspaceCount the number of PQ subspaces + * @return the cosine value + */ + float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery,int subspaceCount); + + /** + * Overloaded function to calculate the Euclidean distance for an array of codebooks, uses scoreFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param encodedChunk centroid vector for encoded point + * @param encodedOffset offset into ByteSequence of encoded vector + * @param centeredQuery query + * @param subspaceCount the number of PQ subspaces + * @return the Euclidean distance + */ + float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount); } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index eacb10866..35fd1e5b3 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -1576,5 +1576,1027 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); } + + /*-------------------- Score functions--------------------*/ + //adding SPECIES_64 & SPECIES_128 for completeness, will it get there? + float pqScoreEuclidean_512(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean_256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean_128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean_64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreEuclidean_512(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreEuclidean_256( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreEuclidean_128( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + return pqScoreEuclidean_64( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + float pqScoreDotProduct_512(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct_256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct_128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct_64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreDotProduct_512(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreDotProduct_256(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreDotProduct_128(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + return pqScoreDotProduct_64(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + float pqScoreCosine_512(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine_256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine_128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine_64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreCosine_512(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreCosine_256(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreCosine_128(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + return pqScoreCosine_64(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + float pqScoreEuclidean_512(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean_256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean_128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean_64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff,sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreEuclidean_512(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreEuclidean_256(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreEuclidean_128(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + return pqScoreEuclidean_64(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + float pqScoreDotProduct_512(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct_256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct_128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct_64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreDotProduct_512(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreDotProduct_256(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreDotProduct_128(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + return pqScoreDotProduct_64(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + float pqScoreCosine_512(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine_256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine_128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine_64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreCosine_512(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreCosine_256(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreCosine_128(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + return pqScoreCosine_64(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + }