diff --git a/docs/docs/multimodal-table/global-index/vector.mdx b/docs/docs/multimodal-table/global-index/vector.mdx index 0ea650ffbafd..aba7d0c3c51b 100644 --- a/docs/docs/multimodal-table/global-index/vector.mdx +++ b/docs/docs/multimodal-table/global-index/vector.mdx @@ -69,6 +69,7 @@ Supported vector index options: |---|---|---| | `.dimension` | `128` | Vector dimension for `ARRAY` columns. Ignored for `VECTOR` columns. | | `.distance.metric` | `inner_product` | Distance metric. Supported values: `l2`, `cosine`, `inner_product`. | +| `.train.sample-ratio` | `1.0` | Ratio of vectors sampled for native index training. Must be greater than `0` and less than or equal to `1`. Lower values reduce training memory and build cost, but may reduce index quality. | | `.nlist` | `256` | Number of IVF clusters used during index build. Higher values create more partitions and can improve recall for large datasets, but may increase build cost. | | `.pq.m` | `16` | Number of PQ sub-vectors for `ivf-pq`. The vector dimension must be divisible by this value. Higher values usually improve recall with larger index files. | | `.pq.use-opq` | `false` | Whether to enable OPQ for `ivf-pq`. | @@ -101,12 +102,14 @@ CREATE TABLE my_table ( 'fields.image_embedding.dimension' = '512', -- shared by every ivf-pq column, overridden only for 'image_embedding' 'ivf-pq.nlist' = '256', - 'fields.image_embedding.nlist' = '512' + 'fields.image_embedding.nlist' = '512', + -- per-column training sample ratio + 'fields.image_embedding.train.sample-ratio' = '0.5' ); ``` With the properties above, `title_embedding` is indexed with `nlist=256` while `image_embedding` -uses `nlist=512`. +uses `nlist=512` and trains with half of the non-null vectors. ## Vector Search diff --git a/paimon-vector/pom.xml b/paimon-vector/pom.xml index 9fd20e915dd0..3ea79b3be226 100644 --- a/paimon-vector/pom.xml +++ b/paimon-vector/pom.xml @@ -32,7 +32,7 @@ under the License. Paimon : Vector Index - 0.1.0 + 0.2.0-SNAPSHOT diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java index 89aa004bee21..6ff6843f001b 100644 --- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java +++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java @@ -24,6 +24,8 @@ import org.apache.paimon.globalindex.GlobalIndexSingleColumnWriter; import org.apache.paimon.globalindex.ResultEntry; import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.index.vector.VectorIndexTrainer; +import org.apache.paimon.index.vector.VectorIndexTraining; import org.apache.paimon.index.vector.VectorIndexWriter; import org.apache.paimon.types.ArrayType; import org.apache.paimon.types.DataType; @@ -62,11 +64,15 @@ public class NativeVectorGlobalIndexWriter implements GlobalIndexSingleColumnWri private static final int IO_BUFFER_SIZE = 8 * 1024 * 1024; private static final int ADD_BATCH_SIZE = 10000; + private static final int TRAIN_BATCH_SIZE = 4096; + static final int MAX_FLOAT_ARRAY_LENGTH = Integer.MAX_VALUE - 8; + private static final long TRAIN_MEMORY_WARNING_BYTES = 4L * 1024 * 1024 * 1024; private final GlobalIndexFileWriter fileWriter; private final String identifier; private final Map nativeOptions; private final int dim; + private final double trainSampleRatio; private File tempVectorFile; private FileChannel writeChannel; @@ -84,11 +90,34 @@ public NativeVectorGlobalIndexWriter( DataType fieldType, Map options, String identifier) { + this( + fileWriter, + fieldType, + options, + identifier, + NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_SAMPLE_RATIO); + } + + public NativeVectorGlobalIndexWriter( + GlobalIndexFileWriter fileWriter, + DataType fieldType, + Map options, + String identifier, + double trainSampleRatio) { this.fileWriter = fileWriter; this.identifier = identifier; validateFieldType(fieldType); this.nativeOptions = options; this.dim = Integer.parseInt(options.get("dimension")); + if (Double.isNaN(trainSampleRatio) + || Double.isInfinite(trainSampleRatio) + || trainSampleRatio <= 0 + || trainSampleRatio > 1) { + throw new IllegalArgumentException( + "trainSampleRatio must be greater than 0 and less than or equal to 1: " + + trainSampleRatio); + } + this.trainSampleRatio = trainSampleRatio; this.count = 0; this.closed = false; this.recordSizeInBytes = checkedRecordSize(dim, IO_BUFFER_SIZE); @@ -224,12 +253,11 @@ private ResultEntry buildIndex() throws IOException { long buildStart = System.currentTimeMillis(); NativeVectorIndexLoader.loadJni(); - try (VectorIndexWriter writer = new VectorIndexWriter(nativeOptions)) { - - // Phase 1: Train - long phaseStart = System.currentTimeMillis(); - LOG.info("{} train phase started", identifier); - trainFromTempFile(writer); + // Phase 1: Train + long phaseStart = System.currentTimeMillis(); + LOG.info("{} train phase started", identifier); + try (VectorIndexTraining training = trainFromTempFile(); + VectorIndexWriter writer = new VectorIndexWriter(training)) { LOG.info( "{} train phase done in {} ms", identifier, @@ -271,31 +299,66 @@ private String fileNamePrefix() { return FILE_NAME_PREFIX + "-" + identifier; } - private void trainFromTempFile(VectorIndexWriter writer) throws IOException { - int trainCount = (int) count; - float[] trainData = new float[trainCount * dim]; + private VectorIndexTraining trainFromTempFile() throws IOException { + int trainCount = trainingVectorCount(count, trainSampleRatio); + int trainBatchSize = vectorBatchSize(TRAIN_BATCH_SIZE, dim); + float[] batchVectors = new float[trainBatchSize * dim]; + logTrainingMemoryEstimate(trainCount); - try (RandomAccessFile raf = new RandomAccessFile(tempVectorFile, "r"); + try (VectorIndexTrainer trainer = VectorIndexTrainer.create(nativeOptions); + RandomAccessFile raf = new RandomAccessFile(tempVectorFile, "r"); FileChannel channel = raf.getChannel()) { ByteBuffer readBuf = ByteBuffer.allocateDirect(IO_BUFFER_SIZE); readBuf.order(ByteOrder.nativeOrder()); readBuf.limit(0); - for (int i = 0; i < trainCount; i++) { + long selected = 0; + long nextSampleIndex = 0; + int batchCount = 0; + + for (long recordIndex = 0; + recordIndex < count && selected < trainCount; + recordIndex++) { ensureAvailable(readBuf, channel, recordSizeInBytes); readBuf.getLong(); // skip rowId - for (int d = 0; d < dim; d++) { - trainData[i * dim + d] = readBuf.getFloat(); + if (recordIndex == nextSampleIndex) { + for (int d = 0; d < dim; d++) { + batchVectors[batchCount * dim + d] = readBuf.getFloat(); + } + selected++; + batchCount++; + if (batchCount == trainBatchSize) { + trainer.addTrainingVectors(batchVectors, batchCount); + batchCount = 0; + } + if (selected < trainCount) { + nextSampleIndex = sampleIndex(selected, count, trainCount); + } + } else { + readBuf.position(readBuf.position() + dim * Float.BYTES); } } - } - writer.train(trainData, trainCount); + if (batchCount > 0) { + trainer.addTrainingVectors( + Arrays.copyOf(batchVectors, batchCount * dim), batchCount); + } + if (selected != trainCount) { + throw new IOException( + "Expected to select " + + trainCount + + " training vectors, but selected " + + selected); + } + + return trainer.finishTraining(); + } } private void addVectorsFromTempFile(VectorIndexWriter writer) throws IOException { - long[] batchIds = new long[ADD_BATCH_SIZE]; - float[] batchVectors = new float[ADD_BATCH_SIZE * dim]; + int addBatchSize = vectorBatchSize(ADD_BATCH_SIZE, dim); + long[] batchIds = new long[addBatchSize]; + float[] batchVectors = new float[addBatchSize * dim]; try (RandomAccessFile raf = new RandomAccessFile(tempVectorFile, "r"); FileChannel channel = raf.getChannel()) { @@ -307,7 +370,7 @@ private void addVectorsFromTempFile(VectorIndexWriter writer) throws IOException int lastLoggedPercent = -1; while (remaining > 0) { - int thisBatch = (int) Math.min(ADD_BATCH_SIZE, remaining); + int thisBatch = (int) Math.min(addBatchSize, remaining); for (int i = 0; i < thisBatch; i++) { ensureAvailable(readBuf, channel, recordSizeInBytes); batchIds[i] = readBuf.getLong(); @@ -315,7 +378,7 @@ private void addVectorsFromTempFile(VectorIndexWriter writer) throws IOException batchVectors[i * dim + d] = readBuf.getFloat(); } } - if (thisBatch == ADD_BATCH_SIZE) { + if (thisBatch == addBatchSize) { writer.addVectors(batchIds, batchVectors, thisBatch); } else { writer.addVectors( @@ -386,6 +449,85 @@ private static int checkedRecordSize(int dim, int bufferCapacity) { return (int) recordSize; } + static int trainingVectorCount(long vectorCount, double trainSampleRatio) { + if (vectorCount <= 0) { + return 0; + } + if (Double.isNaN(trainSampleRatio) + || Double.isInfinite(trainSampleRatio) + || trainSampleRatio <= 0 + || trainSampleRatio > 1) { + throw new IllegalArgumentException( + "trainSampleRatio must be greater than 0 and less than or equal to 1: " + + trainSampleRatio); + } + long trainCount = (long) Math.ceil(vectorCount * trainSampleRatio); + trainCount = Math.max(1L, Math.min(vectorCount, trainCount)); + if (trainCount > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Training vector count " + + trainCount + + " exceeds Java integer capacity. Reduce train.sample-ratio."); + } + return (int) trainCount; + } + + static int vectorBatchSize(int requestedBatchSize, int dim) { + if (requestedBatchSize <= 0) { + throw new IllegalArgumentException( + "requestedBatchSize must be a positive integer: " + requestedBatchSize); + } + if (dim <= 0) { + throw new IllegalArgumentException("dim must be a positive integer: " + dim); + } + int maxBatchSize = MAX_FLOAT_ARRAY_LENGTH / dim; + if (maxBatchSize <= 0) { + throw new IllegalStateException( + "Vector dimension " + dim + " exceeds Java float array capacity"); + } + return Math.min(requestedBatchSize, maxBatchSize); + } + + private void logTrainingMemoryEstimate(int trainCount) { + long rawBytes = saturatedMultiply(saturatedMultiply(trainCount, dim), Float.BYTES); + long estimatedPeakBytes = saturatedMultiply(rawBytes, 2); + if (estimatedPeakBytes >= TRAIN_MEMORY_WARNING_BYTES) { + LOG.warn( + "{} training uses {} samples out of {} vectors (dim={}). Estimated native " + + "training peak is at least {} bytes (~{} GiB) before OPQ and " + + "temporary buffers.", + identifier, + trainCount, + count, + dim, + estimatedPeakBytes, + String.format("%.2f", estimatedPeakBytes / 1024.0 / 1024.0 / 1024.0)); + } else { + LOG.info( + "{} training uses {} samples out of {} vectors (dim={})", + identifier, + trainCount, + count, + dim); + } + } + + private static long saturatedMultiply(long left, long right) { + if (left == 0 || right == 0) { + return 0; + } + if (left > Long.MAX_VALUE / right) { + return Long.MAX_VALUE; + } + return left * right; + } + + static long sampleIndex(long sampleOrdinal, long vectorCount, int trainCount) { + long quotient = vectorCount / trainCount; + long remainder = vectorCount % trainCount; + return sampleOrdinal * quotient + sampleOrdinal * remainder / trainCount; + } + @Override public void close() { if (!closed) { diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java index f45a97d34ae5..3293d1620e58 100644 --- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java +++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java @@ -39,17 +39,40 @@ public class NativeVectorGlobalIndexer implements VectorGlobalIndexer { private final DataType fieldType; private final Map options; private final String identifier; + private final double trainSampleRatio; public NativeVectorGlobalIndexer( DataType fieldType, Map options, String identifier) { + this( + fieldType, + options, + identifier, + NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_SAMPLE_RATIO); + } + + public NativeVectorGlobalIndexer( + DataType fieldType, + Map options, + String identifier, + double trainSampleRatio) { this.fieldType = fieldType; this.options = Objects.requireNonNull(options, "options must not be null"); this.identifier = Objects.requireNonNull(identifier, "identifier must not be null"); + if (Double.isNaN(trainSampleRatio) + || Double.isInfinite(trainSampleRatio) + || trainSampleRatio <= 0 + || trainSampleRatio > 1) { + throw new IllegalArgumentException( + "trainSampleRatio must be greater than 0 and less than or equal to 1: " + + trainSampleRatio); + } + this.trainSampleRatio = trainSampleRatio; } @Override public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) { - return new NativeVectorGlobalIndexWriter(fileWriter, fieldType, options, identifier); + return new NativeVectorGlobalIndexWriter( + fileWriter, fieldType, options, identifier, trainSampleRatio); } @Override diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java index 8e4daa030fad..c351697518b4 100644 --- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java +++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java @@ -32,6 +32,8 @@ public abstract class NativeVectorGlobalIndexerFactory implements GlobalIndexerFactory { private static final int DEFAULT_DIMENSION = 128; + static final String TRAIN_SAMPLE_RATIO_OPTION = "train.sample-ratio"; + static final double DEFAULT_TRAIN_SAMPLE_RATIO = 1.0; @Override public GlobalIndexer create(DataField field, Options options) { @@ -39,7 +41,8 @@ public GlobalIndexer create(DataField field, Options options) { return new NativeVectorGlobalIndexer( field.type(), nativeOptions(field.type(), options, identifier, field.name()), - identifier); + identifier, + trainSampleRatio(options, identifier, field.name())); } static Map nativeOptions( @@ -78,6 +81,62 @@ static Map nativeOptions( return nativeOptions; } + static double trainSampleRatio(Options tableOptions, String identifier, String fieldName) { + Map tableOptionsMap = tableOptions.toMap(); + String key = + resolveFieldOverriddenKey( + tableOptionsMap, identifier, fieldName, TRAIN_SAMPLE_RATIO_OPTION); + if (key == null) { + return DEFAULT_TRAIN_SAMPLE_RATIO; + } + String value = tableOptionsMap.get(key); + + try { + double parsed = Double.parseDouble(value.trim()); + if (!Double.isNaN(parsed) && !Double.isInfinite(parsed) && parsed > 0 && parsed <= 1) { + return parsed; + } + throw invalidTrainSampleRatio(key, value); + } catch (NumberFormatException e) { + throw invalidTrainSampleRatio(key, value); + } + } + + private static IllegalArgumentException invalidTrainSampleRatio(String key, String value) { + return new IllegalArgumentException( + "Invalid value for '" + + key + + "': " + + value + + ". Must be greater than 0 and less than or equal to 1."); + } + + /** + * Resolves a single option key that supports index-level ({@code .