diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/pom.xml b/pinot-plugins/pinot-input-format/pinot-arrow/pom.xml index 22cac4ce43b0..71ce5a6db8f4 100644 --- a/pinot-plugins/pinot-input-format/pinot-arrow/pom.xml +++ b/pinot-plugins/pinot-input-format/pinot-arrow/pom.xml @@ -56,6 +56,11 @@ org.apache.arrow arrow-vector + + org.apache.pinot + pinot-segment-local + test + diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowAccumulators.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowAccumulators.java new file mode 100644 index 000000000000..4667ed988af1 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowAccumulators.java @@ -0,0 +1,184 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.util.VectorAppender; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.ColumnReader; + + +/** + * Package-private helper shared by {@link ArrowColumnReaderFactory} and + * {@link InMemoryArrowColumnReaderFactory}: walks every record batch in an {@link ArrowReader}, + * concatenates each wanted column's values into a per-column accumulator {@link FieldVector} via + * Arrow's {@link VectorAppender}, and produces one {@link ArrowColumnReader} per accumulator. + * + *

Accumulator vectors are allocated against the caller-supplied {@link BufferAllocator}. The + * caller (factory) owns and closes them via {@link Result#getAccumulators()}; this helper does + * not retain references. + */ +final class ArrowAccumulators { + + private ArrowAccumulators() { + } + + static Result populate(ArrowReader reader, BufferAllocator allocator, Schema targetSchema, + @Nullable Set colsToRead) + throws IOException { + Set wantedColumns = computeWantedColumns(targetSchema, colsToRead); + + VectorSchemaRoot perBatchRoot = reader.getVectorSchemaRoot(); + Set availableColumns = Collections.unmodifiableSet(collectAvailableNames(perBatchRoot)); + + Map accumulators = new LinkedHashMap<>(); + Map appenders = new LinkedHashMap<>(); + for (FieldVector source : perBatchRoot.getFieldVectors()) { + String name = source.getField().getName(); + if (!wantedColumns.isEmpty() && !wantedColumns.contains(name)) { + continue; + } + // Dictionary-encoded columns surface as their index type (e.g. Int32) rather than the + // decoded logical type. Reject loudly so we don't silently produce a wrong segment. The + // row-major ArrowRecordExtractor decodes via DictionaryEncoder.decode; adding the same + // here is left as a follow-up once a real use case appears. + Preconditions.checkArgument(source.getField().getDictionary() == null, + "Dictionary-encoded Arrow column '%s' is not supported by Arrow column-major build. " + + "Use ArrowRecordReader (row-major) for files containing dictionary-encoded columns.", + name); + FieldVector accumulator = source.getField().createVector(allocator); + // Pre-allocate buffers so VectorAppender can read offsets / validity from them on the + // first append (otherwise BaseVariableWidthVector visits IOOBE on an empty offset buffer). + accumulator.allocateNew(); + accumulators.put(name, accumulator); + appenders.put(name, new VectorAppender(accumulator)); + } + + // Walk every record batch and bulk-append each wanted column into its accumulator via + // Arrow's VectorAppender (Visitor-based; grows offset / data buffers once per batch and + // bulk-copies, rather than per-row copyValueSafe). Single-batch and multi-batch inputs go + // through the same code path — VectorAppender handles either correctly. + boolean anyBatchSeen = false; + while (reader.loadNextBatch()) { + if (perBatchRoot.getRowCount() == 0) { + continue; + } + anyBatchSeen = true; + for (FieldVector source : perBatchRoot.getFieldVectors()) { + VectorAppender appender = appenders.get(source.getField().getName()); + if (appender != null) { + source.accept(appender, null); + } + } + } + if (!anyBatchSeen) { + throw new IOException("Arrow source contains no non-empty record batches"); + } + + Map readers = new LinkedHashMap<>(); + for (Map.Entry entry : accumulators.entrySet()) { + readers.put(entry.getKey(), new ArrowColumnReader(entry.getKey(), entry.getValue())); + } + + return new Result(accumulators, readers, availableColumns); + } + + private static Set computeWantedColumns(Schema targetSchema, @Nullable Set colsToRead) { + if (colsToRead != null && !colsToRead.isEmpty()) { + return new HashSet<>(colsToRead); + } + Set targetColumns = new HashSet<>(); + for (FieldSpec fieldSpec : targetSchema.getAllFieldSpecs()) { + if (!fieldSpec.isVirtualColumn()) { + targetColumns.add(fieldSpec.getName()); + } + } + return targetColumns; + } + + private static Set collectAvailableNames(VectorSchemaRoot root) { + Set names = new HashSet<>(); + for (FieldVector vector : root.getFieldVectors()) { + names.add(vector.getField().getName()); + } + return names; + } + + /** + * Close each accumulator vector, accumulating the first failure as an {@link IOException}. + * Used by both factory {@code close()} paths so the per-vector close loop lives once. + * + * @param accumulators per-column accumulator vectors to close; may be {@code null} + * @return the first close failure encountered, or {@code null} if all closes succeeded + */ + @Nullable + static IOException closeAll(@Nullable Map accumulators) { + if (accumulators == null) { + return null; + } + IOException firstException = null; + for (FieldVector vector : accumulators.values()) { + try { + vector.close(); + } catch (Exception e) { + if (firstException == null) { + firstException = new IOException("Failed to close Arrow accumulator vector", e); + } + } + } + return firstException; + } + + static final class Result { + private final Map _accumulators; + private final Map _readers; + private final Set _availableColumns; + + Result(Map accumulators, Map readers, + Set availableColumns) { + _accumulators = accumulators; + _readers = readers; + _availableColumns = availableColumns; + } + + Map getAccumulators() { + return _accumulators; + } + + Map getReaders() { + return _readers; + } + + Set getAvailableColumns() { + return _availableColumns; + } + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReader.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReader.java new file mode 100644 index 000000000000..db76220b3dc4 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReader.java @@ -0,0 +1,595 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.util.BitSet; +import javax.annotation.Nullable; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.pinot.spi.data.readers.ColumnReader; +import org.apache.pinot.spi.data.readers.MultiValueResult; + + +/** + * Column reader for Apache Arrow {@link FieldVector}. + * + *

Wraps a single Arrow {@link FieldVector} and exposes sequential and random-access + * read patterns conforming to the {@link ColumnReader} contract. The vector is owned by + * the enclosing {@link ArrowColumnReaderFactory} which is responsible for its lifecycle; + * closing this reader is a no-op for the underlying vector. + * + *

Supported Arrow types map to Pinot's stored types as follows: + *

+ * + *

The list above applies to the typed primitive accessors only ({@link #getInt}, + * {@link #getString}, {@link #getIntMV}, ...). Complex types (Map, Struct, Union, ...) + * and temporal types are still readable via the generic {@link #getValue(int)} / + * {@link #next()} accessors, which delegate to {@link ArrowToPinotTypeConverter} and + * return the same canonical JDK types as the row-major path ({@link + * ArrowRecordExtractor}) — e.g. {@code Map} for Struct / Map, + * {@code Object[]} for List variants, {@code LocalDate} / {@code LocalTime} / + * {@code Timestamp} for temporal types. + * + *

This class is not thread-safe. + */ +public class ArrowColumnReader implements ColumnReader { + + private final String _columnName; + private final FieldVector _vector; + private final int _totalDocs; + private final boolean _isSingleValue; + + private int _nextDocId; + + /** + * Construct an ArrowColumnReader for the given vector. + * + * @param columnName Pinot column name + * @param vector Arrow field vector backing this column + */ + public ArrowColumnReader(String columnName, FieldVector vector) { + _columnName = columnName; + _vector = vector; + _totalDocs = vector.getValueCount(); + _isSingleValue = !(vector instanceof ListVector); + _nextDocId = 0; + } + + @Override + public boolean hasNext() { + return _nextDocId < _totalDocs; + } + + @Override + @Nullable + public Object next() + throws IOException { + Object value = getValue(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public boolean isNextNull() + throws IOException { + return _vector.isNull(_nextDocId); + } + + @Override + public void skipNext() + throws IOException { + _nextDocId++; + } + + @Override + public boolean isSingleValue() { + return _isSingleValue; + } + + @Override + public boolean isInt() { + return _isSingleValue && (_vector instanceof IntVector || _vector instanceof BitVector); + } + + @Override + public boolean isLong() { + return _isSingleValue && _vector instanceof BigIntVector; + } + + @Override + public boolean isFloat() { + return _isSingleValue && _vector instanceof Float4Vector; + } + + @Override + public boolean isDouble() { + return _isSingleValue && _vector instanceof Float8Vector; + } + + @Override + public boolean isBigDecimal() { + return _isSingleValue && _vector instanceof DecimalVector; + } + + @Override + public boolean isString() { + return _isSingleValue && _vector instanceof VarCharVector; + } + + @Override + public boolean isBytes() { + return _isSingleValue && _vector instanceof VarBinaryVector; + } + + @Override + public int nextInt() + throws IOException { + int value = getInt(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public long nextLong() + throws IOException { + long value = getLong(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public float nextFloat() + throws IOException { + float value = getFloat(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public double nextDouble() + throws IOException { + double value = getDouble(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public BigDecimal nextBigDecimal() + throws IOException { + BigDecimal value = getBigDecimal(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public String nextString() + throws IOException { + String value = getString(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public byte[] nextBytes() + throws IOException { + byte[] value = getBytes(_nextDocId); + _nextDocId++; + return value; + } + + @Override + public MultiValueResult nextIntMV() + throws IOException { + MultiValueResult result = getIntMV(_nextDocId); + _nextDocId++; + return result; + } + + @Override + public MultiValueResult nextLongMV() + throws IOException { + MultiValueResult result = getLongMV(_nextDocId); + _nextDocId++; + return result; + } + + @Override + public MultiValueResult nextFloatMV() + throws IOException { + MultiValueResult result = getFloatMV(_nextDocId); + _nextDocId++; + return result; + } + + @Override + public MultiValueResult nextDoubleMV() + throws IOException { + MultiValueResult result = getDoubleMV(_nextDocId); + _nextDocId++; + return result; + } + + @Override + public BigDecimal[] nextBigDecimalMV() + throws IOException { + BigDecimal[] result = getBigDecimalMV(_nextDocId); + _nextDocId++; + return result; + } + + @Override + public String[] nextStringMV() + throws IOException { + String[] result = getStringMV(_nextDocId); + _nextDocId++; + return result; + } + + @Override + public byte[][] nextBytesMV() + throws IOException { + byte[][] result = getBytesMV(_nextDocId); + _nextDocId++; + return result; + } + + @Override + public void rewind() + throws IOException { + _nextDocId = 0; + } + + @Override + public String getColumnName() { + return _columnName; + } + + @Override + public int getTotalDocs() { + return _totalDocs; + } + + @Override + public boolean isNull(int docId) { + checkBounds(docId); + return _vector.isNull(docId); + } + + @Override + public int getInt(int docId) + throws IOException { + checkBounds(docId); + if (_vector instanceof IntVector) { + return ((IntVector) _vector).get(docId); + } + if (_vector instanceof BitVector) { + return ((BitVector) _vector).get(docId); + } + throw typeMismatch("INT"); + } + + @Override + public long getLong(int docId) + throws IOException { + checkBounds(docId); + if (_vector instanceof BigIntVector) { + return ((BigIntVector) _vector).get(docId); + } + throw typeMismatch("LONG"); + } + + @Override + public float getFloat(int docId) + throws IOException { + checkBounds(docId); + if (_vector instanceof Float4Vector) { + return ((Float4Vector) _vector).get(docId); + } + throw typeMismatch("FLOAT"); + } + + @Override + public double getDouble(int docId) + throws IOException { + checkBounds(docId); + if (_vector instanceof Float8Vector) { + return ((Float8Vector) _vector).get(docId); + } + throw typeMismatch("DOUBLE"); + } + + @Override + public BigDecimal getBigDecimal(int docId) + throws IOException { + checkBounds(docId); + if (_vector instanceof DecimalVector) { + return ((DecimalVector) _vector).getObject(docId); + } + throw typeMismatch("BIG_DECIMAL"); + } + + @Override + public String getString(int docId) + throws IOException { + checkBounds(docId); + if (_vector instanceof VarCharVector) { + byte[] bytes = ((VarCharVector) _vector).get(docId); + return bytes == null ? null : new String(bytes, StandardCharsets.UTF_8); + } + throw typeMismatch("STRING"); + } + + @Override + public byte[] getBytes(int docId) + throws IOException { + checkBounds(docId); + if (_vector instanceof VarBinaryVector) { + return ((VarBinaryVector) _vector).get(docId); + } + throw typeMismatch("BYTES"); + } + + @Override + public Object getValue(int docId) + throws IOException { + checkBounds(docId); + Object value = _vector.getObject(docId); + if (value == null) { + return null; + } + // Delegate Arrow → Pinot type conversion to the shared utility extracted from + // ArrowRecordExtractor. Returns canonical JDK types: String for Utf8 / LargeUtf8 (unwrapped + // from Arrow's Text), Object[] for List variants (with recursive element conversion), + // LocalDate / LocalTime / Timestamp for temporal types, etc. + return ArrowToPinotTypeConverter.toPinotValue(_vector.getField(), value, false); + } + + @Override + public MultiValueResult getIntMV(int docId) + throws IOException { + return readPrimitiveMV(docId, int[].class); + } + + @Override + public MultiValueResult getLongMV(int docId) + throws IOException { + return readPrimitiveMV(docId, long[].class); + } + + @Override + public MultiValueResult getFloatMV(int docId) + throws IOException { + return readPrimitiveMV(docId, float[].class); + } + + @Override + public MultiValueResult getDoubleMV(int docId) + throws IOException { + return readPrimitiveMV(docId, double[].class); + } + + @Override + public BigDecimal[] getBigDecimalMV(int docId) + throws IOException { + checkBounds(docId); + requireListVector(); + ListVector list = (ListVector) _vector; + int start = list.getElementStartIndex(docId); + int end = list.getElementEndIndex(docId); + int length = end - start; + BigDecimal[] out = new BigDecimal[length]; + DecimalVector elements = (DecimalVector) list.getDataVector(); + for (int i = 0; i < length; i++) { + out[i] = elements.isNull(start + i) ? null : elements.getObject(start + i); + } + return out; + } + + @Override + public String[] getStringMV(int docId) + throws IOException { + checkBounds(docId); + requireListVector(); + ListVector list = (ListVector) _vector; + int start = list.getElementStartIndex(docId); + int end = list.getElementEndIndex(docId); + int length = end - start; + String[] out = new String[length]; + VarCharVector elements = (VarCharVector) list.getDataVector(); + for (int i = 0; i < length; i++) { + if (elements.isNull(start + i)) { + out[i] = null; + } else { + byte[] bytes = elements.get(start + i); + out[i] = new String(bytes, StandardCharsets.UTF_8); + } + } + return out; + } + + @Override + public byte[][] getBytesMV(int docId) + throws IOException { + checkBounds(docId); + requireListVector(); + ListVector list = (ListVector) _vector; + int start = list.getElementStartIndex(docId); + int end = list.getElementEndIndex(docId); + int length = end - start; + byte[][] out = new byte[length][]; + VarBinaryVector elements = (VarBinaryVector) list.getDataVector(); + for (int i = 0; i < length; i++) { + out[i] = elements.isNull(start + i) ? null : elements.get(start + i); + } + return out; + } + + /** + * Read a primitive multi-value from a {@link ListVector}, populating a fresh array and a + * nulls BitSet for element-level validity. + */ + @SuppressWarnings("unchecked") + private MultiValueResult readPrimitiveMV(int docId, Class arrayClass) + throws IOException { + checkBounds(docId); + requireListVector(); + ListVector list = (ListVector) _vector; + int start = list.getElementStartIndex(docId); + int end = list.getElementEndIndex(docId); + int length = end - start; + FieldVector elements = list.getDataVector(); + BitSet nulls = null; + + Object array; + if (arrayClass == int[].class) { + int[] values = new int[length]; + if (elements instanceof IntVector) { + IntVector iv = (IntVector) elements; + for (int i = 0; i < length; i++) { + if (iv.isNull(start + i)) { + if (nulls == null) { + nulls = new BitSet(length); + } + nulls.set(i); + } else { + values[i] = iv.get(start + i); + } + } + } else if (elements instanceof BitVector) { + BitVector bv = (BitVector) elements; + for (int i = 0; i < length; i++) { + if (bv.isNull(start + i)) { + if (nulls == null) { + nulls = new BitSet(length); + } + nulls.set(i); + } else { + values[i] = bv.get(start + i); + } + } + } else { + throw typeMismatch("INT_MV"); + } + array = values; + } else if (arrayClass == long[].class) { + long[] values = new long[length]; + if (!(elements instanceof BigIntVector)) { + throw typeMismatch("LONG_MV"); + } + BigIntVector lv = (BigIntVector) elements; + for (int i = 0; i < length; i++) { + if (lv.isNull(start + i)) { + if (nulls == null) { + nulls = new BitSet(length); + } + nulls.set(i); + } else { + values[i] = lv.get(start + i); + } + } + array = values; + } else if (arrayClass == float[].class) { + float[] values = new float[length]; + if (!(elements instanceof Float4Vector)) { + throw typeMismatch("FLOAT_MV"); + } + Float4Vector fv = (Float4Vector) elements; + for (int i = 0; i < length; i++) { + if (fv.isNull(start + i)) { + if (nulls == null) { + nulls = new BitSet(length); + } + nulls.set(i); + } else { + values[i] = fv.get(start + i); + } + } + array = values; + } else if (arrayClass == double[].class) { + double[] values = new double[length]; + if (!(elements instanceof Float8Vector)) { + throw typeMismatch("DOUBLE_MV"); + } + Float8Vector dv = (Float8Vector) elements; + for (int i = 0; i < length; i++) { + if (dv.isNull(start + i)) { + if (nulls == null) { + nulls = new BitSet(length); + } + nulls.set(i); + } else { + values[i] = dv.get(start + i); + } + } + array = values; + } else { + throw new IOException("Unsupported primitive MV array type: " + arrayClass.getName()); + } + return MultiValueResult.of((T) array, nulls); + } + + private void requireListVector() + throws IOException { + if (!(_vector instanceof ListVector)) { + throw new IOException( + "Column " + _columnName + " is not a ListVector; cannot read multi-value"); + } + } + + private void checkBounds(int docId) { + if (docId < 0 || docId >= _totalDocs) { + throw new IndexOutOfBoundsException( + "docId " + docId + " is out of range [0, " + _totalDocs + ") for column " + _columnName); + } + } + + private IOException typeMismatch(String expectedType) { + return new IOException("Column " + _columnName + " (Arrow type " + _vector.getField().getType() + + ") cannot be read as " + expectedType); + } + + /** + * The underlying vector is owned by {@link ArrowColumnReaderFactory}; closing this reader does + * not release the vector's memory. + */ + @Override + public void close() { + // No-op: factory owns the vector lifecycle. + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReaderFactory.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReaderFactory.java new file mode 100644 index 000000000000..29f659686490 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReaderFactory.java @@ -0,0 +1,143 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.ColumnReader; +import org.apache.pinot.spi.data.readers.ColumnReaderFactory; + + +/** + * {@link ColumnReaderFactory} backed by a caller-managed {@link ArrowReader}. + * + *

The caller supplies an already-open {@link ArrowReader} (any subclass — {@code + * ArrowFileReader}, {@code ArrowStreamReader}, or a custom subclass that yields in-process record + * batches) plus the {@link BufferAllocator} the reader was opened against. Per-column accumulator + * vectors are allocated against that allocator and concatenated from the reader's batches via the + * shared {@link ArrowAccumulators} helper, then exposed as one {@link ArrowColumnReader} per + * accumulator. + * + *

Ownership: the supplied {@link ArrowReader} and {@link BufferAllocator} are owned by + * the caller and are NOT closed by this factory. Per-column accumulator vectors created during + * {@link #init} ARE owned by the factory and are released on {@link #close}. For the file-backed + * convenience that opens and owns its own reader and allocator, see {@link + * ArrowFileColumnReaderFactory}. + * + *

This class is not thread-safe. + */ +public class ArrowColumnReaderFactory implements ColumnReaderFactory { + + private final ArrowReader _reader; + private final BufferAllocator _allocator; + + private transient Map _accumulatorVectors; + private transient Map _columnReaders; + private transient Set _availableColumnNames; + private transient boolean _initialized; + + /** + * Construct a factory reading from the given Arrow reader and allocator. + * + * @param reader Caller-managed Arrow reader (e.g. {@code ArrowStreamReader} over a byte buffer, + * or a custom reader yielding pre-populated record batches). Not closed by this + * factory. + * @param allocator Caller-managed Arrow allocator used for per-column accumulator allocations. + * Not closed by this factory. + */ + public ArrowColumnReaderFactory(ArrowReader reader, BufferAllocator allocator) { + _reader = reader; + _allocator = allocator; + } + + @Override + public void init(Schema targetSchema) + throws IOException { + init(targetSchema, null, Collections.emptyMap()); + } + + @Override + public void init(Schema targetSchema, Set colsToRead) + throws IOException { + init(targetSchema, colsToRead, Collections.emptyMap()); + } + + /** + * Initialise the factory. {@code colsToRead == null} or an empty set both mean "read all + * non-virtual columns from {@code targetSchema} that the Arrow source actually contains"; pass a + * non-empty set to restrict to a subset. The {@code configs} map is ignored — allocator sizing + * is the caller's responsibility for this factory. + */ + @Override + public void init(Schema targetSchema, @Nullable Set colsToRead, Map configs) + throws IOException { + ArrowAccumulators.Result built = + ArrowAccumulators.populate(_reader, _allocator, targetSchema, colsToRead); + _accumulatorVectors = built.getAccumulators(); + _columnReaders = built.getReaders(); + _availableColumnNames = built.getAvailableColumns(); + _initialized = true; + } + + @Override + public Set getAvailableColumns() { + requireInitialized(); + return _availableColumnNames; + } + + @Override + @Nullable + public ColumnReader getColumnReader(String columnName) { + requireInitialized(); + return _columnReaders.get(columnName); + } + + @Override + public Map getAllColumnReaders() { + requireInitialized(); + return Collections.unmodifiableMap(_columnReaders); + } + + private void requireInitialized() { + if (!_initialized) { + throw new IllegalStateException("ArrowColumnReaderFactory must be initialized before use"); + } + } + + @Override + public void close() + throws IOException { + // _reader and _allocator are caller-owned; only release the accumulator vectors we created. + IOException accumulatorCloseFailure = ArrowAccumulators.closeAll(_accumulatorVectors); + _accumulatorVectors = null; + _columnReaders = null; + _availableColumnNames = null; + _initialized = false; + if (accumulatorCloseFailure != null) { + throw accumulatorCloseFailure; + } + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowFileColumnReaderFactory.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowFileColumnReaderFactory.java new file mode 100644 index 000000000000..ca0a3552c6c4 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowFileColumnReaderFactory.java @@ -0,0 +1,214 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ipc.ArrowFileReader; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.ColumnReader; +import org.apache.pinot.spi.data.readers.ColumnReaderFactory; + + +/** + * {@link ColumnReaderFactory} backed by an Apache Arrow IPC file on disk. + * + *

File-specialised convenience over {@link ArrowColumnReaderFactory}: this class opens a private + * {@link RootAllocator} sized by {@link #CONFIG_ALLOCATOR_LIMIT}, opens the file via + * {@link ArrowFileReader}, concatenates all record batches into per-column accumulators via the + * shared {@link ArrowAccumulators} helper, and closes the file, reader, and allocator on + * {@link #close}. Callers that already manage an {@link org.apache.arrow.vector.ipc.ArrowReader} + * and {@link org.apache.arrow.memory.BufferAllocator} themselves should use + * {@link ArrowColumnReaderFactory} directly. + * + *

Columns in the target schema that are absent from the Arrow file are reported via + * {@link #getAvailableColumns()}, and {@link #getColumnReader(String)} returns {@code null} for + * them; schema-evolution defaults are the columnar build driver's responsibility. + * + *

This class is not thread-safe. For very large inputs the per-column accumulators materialise + * the full column set in the Arrow allocator; partition upstream so each factory instance handles + * one shard. + */ +public class ArrowFileColumnReaderFactory implements ColumnReaderFactory { + + /// Default allocator limit when no `configs` override is supplied. Matches + /// {@link ArrowRecordReaderConfig#DEFAULT_ALLOCATOR_LIMIT} so users get the same + /// memory ceiling whether they pick the row-major or column-major reader path. + public static final String CONFIG_ALLOCATOR_LIMIT = "arrowAllocatorLimit"; + public static final long DEFAULT_ALLOCATOR_LIMIT = ArrowRecordReaderConfig.DEFAULT_ALLOCATOR_LIMIT; + + private final File _dataFile; + + private transient RootAllocator _allocator; + private transient FileInputStream _fileInputStream; + private transient ArrowFileReader _arrowFileReader; + // Per-column accumulator vectors holding values concatenated across all input batches. + // Owned by this factory; released in close() before the allocator. + private transient Map _accumulatorVectors; + // Cached ColumnReader instances keyed by Pinot column name. + private transient Map _columnReaders; + private transient Set _availableColumnNames; + private transient boolean _initialized; + + /** + * Construct a factory reading from the given Arrow IPC file. + * + * @param dataFile Path to the Arrow IPC file to read + */ + public ArrowFileColumnReaderFactory(File dataFile) { + _dataFile = dataFile; + } + + @Override + public void init(Schema targetSchema) + throws IOException { + init(targetSchema, null, Collections.emptyMap()); + } + + @Override + public void init(Schema targetSchema, Set colsToRead) + throws IOException { + init(targetSchema, colsToRead, Collections.emptyMap()); + } + + /** + * Initialise the factory. {@code colsToRead == null} or an empty set both mean "read all + * non-virtual columns from {@code targetSchema} that the Arrow file actually contains"; pass a + * non-empty set to restrict to a subset. + */ + @Override + public void init(Schema targetSchema, @Nullable Set colsToRead, Map configs) + throws IOException { + long allocatorLimit = parseAllocatorLimit(configs); + _allocator = new RootAllocator(allocatorLimit); + _fileInputStream = new FileInputStream(_dataFile); + _arrowFileReader = new ArrowFileReader(_fileInputStream.getChannel(), _allocator); + + ArrowAccumulators.Result built = + ArrowAccumulators.populate(_arrowFileReader, _allocator, targetSchema, colsToRead); + _accumulatorVectors = built.getAccumulators(); + _columnReaders = built.getReaders(); + _availableColumnNames = built.getAvailableColumns(); + + _initialized = true; + } + + private long parseAllocatorLimit(Map configs) { + if (configs == null) { + return DEFAULT_ALLOCATOR_LIMIT; + } + String raw = configs.get(CONFIG_ALLOCATOR_LIMIT); + if (raw == null) { + return DEFAULT_ALLOCATOR_LIMIT; + } + try { + return Long.parseLong(raw); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid value '" + raw + "' for config '" + CONFIG_ALLOCATOR_LIMIT + "': expected a long", + e); + } + } + + @Override + public Set getAvailableColumns() { + requireInitialized(); + return _availableColumnNames; + } + + @Override + @Nullable + public ColumnReader getColumnReader(String columnName) { + requireInitialized(); + return _columnReaders.get(columnName); + } + + @Override + public Map getAllColumnReaders() { + requireInitialized(); + return Collections.unmodifiableMap(_columnReaders); + } + + private void requireInitialized() { + if (!_initialized) { + throw new IllegalStateException("ArrowFileColumnReaderFactory must be initialized before use"); + } + } + + @Override + public void close() + throws IOException { + // Close ordering: accumulator vectors first, then the file resources we opened (reader, + // stream, allocator). Each step's first failure is preserved and re-thrown at the end so + // later steps still run. + IOException firstException = ArrowAccumulators.closeAll(_accumulatorVectors); + _accumulatorVectors = null; + + if (_arrowFileReader != null) { + try { + _arrowFileReader.close(); + } catch (IOException e) { + if (firstException == null) { + firstException = e; + } + } finally { + _arrowFileReader = null; + } + } + + if (_fileInputStream != null) { + try { + _fileInputStream.close(); + } catch (IOException e) { + if (firstException == null) { + firstException = e; + } + } finally { + _fileInputStream = null; + } + } + + if (_allocator != null) { + try { + _allocator.close(); + } catch (Exception e) { + if (firstException == null) { + firstException = new IOException("Failed to close Arrow allocator", e); + } + } finally { + _allocator = null; + } + } + + _columnReaders = null; + _availableColumnNames = null; + _initialized = false; + + if (firstException != null) { + throw firstException; + } + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractor.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractor.java index 4bdfadeb25b1..648c1e39e34c 100644 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractor.java +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractor.java @@ -18,14 +18,7 @@ */ package org.apache.pinot.plugin.inputformat.arrow; -import com.google.common.collect.Maps; import java.io.IOException; -import java.sql.Timestamp; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.time.ZoneOffset; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -33,17 +26,13 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.ipc.ArrowReader; -import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; -import org.apache.arrow.vector.types.pojo.Field; import org.apache.pinot.spi.data.readers.BaseRecordExtractor; import org.apache.pinot.spi.data.readers.GenericRow; import org.apache.pinot.spi.data.readers.RecordExtractorConfig; -import org.apache.pinot.spi.utils.TimestampUtils; /// Extracts a single Arrow row into a [GenericRow]. Reader-scoped state ([VectorSchemaRoot] + @@ -208,247 +197,10 @@ public GenericRow extract(Record from, GenericRow to) { ValueVector activeVector = activeVectors[i]; Object rawValue = activeVector.getObject(from._rowId); to.putValue(fieldVectors[i].getField().getName(), - rawValue != null ? convert(activeVector.getField(), rawValue) : null); + rawValue != null + ? ArrowToPinotTypeConverter.toPinotValue(activeVector.getField(), rawValue, _extractRawTimeValues) + : null); } return to; } - - /// Schema-driven dispatch — one branch per [ArrowType.ArrowTypeID]; complex types recurse with - /// their child [Field]s, scalars normalize per the contract. - @Nullable - private Object convert(Field field, Object value) { - ArrowType type = field.getType(); - switch (type.getTypeID()) { - // Pass-through — Arrow boxes these directly into the contract output type. - case Bool: // Boolean - case FloatingPoint: // Float / Double - case Decimal: // BigDecimal - case Binary: // byte[] - case LargeBinary: // byte[] - case FixedSizeBinary: // byte[] - return value; - // toString — `Utf8` / `LargeUtf8` produce `String`; `Interval` / `Duration` produce ISO-8601 - // (`java.time.Period` / `java.time.Duration` / `PeriodDuration` all have meaningful toString). - case Utf8: - case LargeUtf8: - case Interval: - case Duration: - return value.toString(); - // Integer — `Byte` widens to `Integer` per contract (sign-extended for signed `TinyIntVector`, - // zero-extended via `& 0xFF` for unsigned `UInt1Vector`); `Short` (signed `SmallIntVector`) - // sign-extends; `Character` (unsigned 16, from `UInt2Vector`) widens to its `int` code point; - // `Integer` / `Long` pass through. - case Int: - if (value instanceof Byte) { - int v = (Byte) value; - return ((ArrowType.Int) type).getIsSigned() ? v : v & 0xFF; - } - if (value instanceof Short) { - return ((Short) value).intValue(); - } - if (value instanceof Character) { - return (int) (Character) value; - } - return value; - // Null — NullVector.getObject always returns null; extractValue short-circuits on null, so - // this branch is unreachable in practice. Defensive return. - case Null: - return null; - // Logical temporal — schema's `TimeUnit` drives the conversion. - case Timestamp: - return convertTimestamp((ArrowType.Timestamp) type, value); - case Date: - return convertDate((ArrowType.Date) type, value); - case Time: - return convertTime((ArrowType.Time) type, value); - // Multi-value — `List` (and primitive-array lists) → `Object[]`. - case List: - case LargeList: - case FixedSizeList: - return convertList(field.getChildren().get(0), (List) value); - // Map / nested complex types. - case Map: - // The Map field's children are [entriesStruct]; the entries struct's children are - // [keyField, valueField] (named per MapVector.KEY_NAME / VALUE_NAME). - Field entriesField = field.getChildren().get(0); - return convertMap(entriesField.getChildren().get(0), entriesField.getChildren().get(1), (List) value); - case Struct: - return convertStruct(field.getChildren(), (Map) value); - case Union: - // The chosen branch isn't visible from the resolved value alone — dispatch by the value's - // runtime Java type. Nested complex sub-branches fall back to `value.toString()`. - return convertByRuntimeType(value); - default: - // `NONE` is a placeholder; any other ID is a future Arrow addition. - throw new IllegalStateException("Unsupported Arrow type: " + type + " for field: " + field.getName()); - } - } - - /// Constructs a [Timestamp] from an Arrow `Timestamp` value. No-TZ vectors surface as - /// `LocalDateTime` (interpreted as UTC); with-TZ vectors surface as `Long` epoch counted in the - /// schema's `TimeUnit`. Sub-millisecond precision is preserved via [TimestampUtils]. - /// With [#_extractRawTimeValues] the raw `long` epoch in the schema's `TimeUnit` is returned. - private Object convertTimestamp(ArrowType.Timestamp type, Object value) { - if (_extractRawTimeValues) { - if (value instanceof LocalDateTime) { - // No-TZ vector — convert the LocalDateTime back to an epoch `long` in the declared unit. - Instant instant = ((LocalDateTime) value).toInstant(ZoneOffset.UTC); - return toEpochInUnit(instant, type.getUnit()); - } - // With-TZ vector — already raw `long` in the declared unit. - return value; - } - if (value instanceof LocalDateTime) { - return Timestamp.from(((LocalDateTime) value).toInstant(ZoneOffset.UTC)); - } - long raw = ((Number) value).longValue(); - switch (type.getUnit()) { - case SECOND: - return new Timestamp(raw * 1000L); - case MILLISECOND: - return new Timestamp(raw); - case MICROSECOND: - return TimestampUtils.fromMicrosSinceEpoch(raw); - case NANOSECOND: - return TimestampUtils.fromNanosSinceEpoch(raw); - default: - throw new IllegalStateException("Unsupported Timestamp unit: " + type.getUnit()); - } - } - - private static long toEpochInUnit(Instant instant, org.apache.arrow.vector.types.TimeUnit unit) { - switch (unit) { - case SECOND: - return instant.getEpochSecond(); - case MILLISECOND: - return instant.toEpochMilli(); - case MICROSECOND: - return Math.addExact(Math.multiplyExact(instant.getEpochSecond(), 1_000_000L), instant.getNano() / 1_000L); - case NANOSECOND: - return Math.addExact(Math.multiplyExact(instant.getEpochSecond(), 1_000_000_000L), instant.getNano()); - default: - throw new IllegalStateException("Unsupported Timestamp unit: " + unit); - } - } - - /// Reduces an Arrow `Date` value to its contract Java type ([LocalDate]), or to `int` - /// days-since-epoch when [#_extractRawTimeValues] is set. `DateDayVector` surfaces as `Integer` - /// raw days; `DateMilliVector` surfaces as `LocalDateTime` at UTC midnight. - private Object convertDate(ArrowType.Date type, Object value) { - int days; - switch (type.getUnit()) { - case DAY: - days = (Integer) value; - break; - case MILLISECOND: - days = (int) ((LocalDateTime) value).toLocalDate().toEpochDay(); - break; - default: - throw new IllegalStateException("Unsupported Date unit: " + type.getUnit()); - } - return _extractRawTimeValues ? days : LocalDate.ofEpochDay(days); - } - - /// Constructs a [LocalTime] from an Arrow `Time` value, dispatched by the schema's `TimeUnit`: - /// `TimeMilliVector` surfaces as `LocalDateTime`; `TimeSecVector` as `Integer`; - /// `TimeMicroVector` / `TimeNanoVector` as `Long`. All collapse onto nanoseconds-since-midnight. - /// With [#_extractRawTimeValues] the raw count in the schema's `TimeUnit` is returned instead. - private Object convertTime(ArrowType.Time type, Object value) { - if (_extractRawTimeValues) { - if (value instanceof LocalDateTime) { - // `TimeMilliVector` surfaces as `LocalDateTime`; raw is `int` ms since midnight. - return (int) (((LocalDateTime) value).toLocalTime().toNanoOfDay() / 1_000_000L); - } - // `TimeSecVector` (Integer) / `TimeMicroVector` / `TimeNanoVector` (Long) — already raw. - return value; - } - if (value instanceof LocalDateTime) { - return ((LocalDateTime) value).toLocalTime(); - } - long raw = ((Number) value).longValue(); - switch (type.getUnit()) { - case SECOND: - return LocalTime.ofSecondOfDay(raw); - case MILLISECOND: - return LocalTime.ofNanoOfDay(raw * 1_000_000L); - case MICROSECOND: - return LocalTime.ofNanoOfDay(raw * 1_000L); - case NANOSECOND: - return LocalTime.ofNanoOfDay(raw); - default: - throw new IllegalStateException("Unsupported Time unit: " + type.getUnit()); - } - } - - private Object[] convertList(Field elementField, List list) { - int size = list.size(); - Object[] result = new Object[size]; - int i = 0; - for (Object element : list) { - result[i++] = element != null ? convert(elementField, element) : null; - } - return result; - } - - /// Flattens an Arrow `Map` column's entry list (`List>`) into a - /// `Map`, recursing into each value via [#convert] and stringifying each key via - /// [BaseRecordExtractor#stringifyMapKey] per the contract. Entries with a `null` key (input or - /// post-conversion) are dropped. - private Map convertMap(Field keyField, Field valueField, List entries) { - Map result = Maps.newLinkedHashMapWithExpectedSize(entries.size()); - for (Object entry : entries) { - if (entry == null) { - continue; - } - Map entryMap = (Map) entry; - Object rawKey = entryMap.get(MapVector.KEY_NAME); - if (rawKey == null) { - continue; - } - Object convertedKey = convert(keyField, rawKey); - if (convertedKey == null) { - continue; - } - Object rawValue = entryMap.get(MapVector.VALUE_NAME); - result.put(stringifyMapKey(convertedKey), rawValue != null ? convert(valueField, rawValue) : null); - } - return result; - } - - private Map convertStruct(List childFields, Map value) { - Map result = Maps.newHashMapWithExpectedSize(childFields.size()); - for (Field childField : childFields) { - String name = childField.getName(); - Object rawValue = value.get(name); - result.put(name, rawValue != null ? convert(childField, rawValue) : null); - } - return result; - } - - /// Runtime-type dispatch used by the `Union` case (where the chosen branch isn't accessible - /// from the resolved value). Mirrors the scalar handling of [#convert] for the common Arrow - /// boxed types; nested complex types fall back to `value.toString()` because their child - /// [Field]s aren't reachable from here. - private static Object convertByRuntimeType(Object value) { - if (value instanceof Number) { - if (value instanceof Byte || value instanceof Short) { - return ((Number) value).intValue(); - } - return value; - } - if (value instanceof Boolean || value instanceof byte[]) { - return value; - } - if (value instanceof Character) { - // `UInt2Vector` surfaces as `Character`; widen to `int` per the Int(16) contract. - return (int) (Character) value; - } - if (value instanceof LocalDateTime) { - // Ambiguous between Timestamp / Date / Time — best-effort: treat as Timestamp UTC. - return Timestamp.from(((LocalDateTime) value).toInstant(ZoneOffset.UTC)); - } - // `Text` (Utf8 / LargeUtf8), `Period` / `Duration` / `PeriodDuration` (Interval / Duration), and - // anything unrecognized fall through to `toString()`. - return value.toString(); - } } diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowToPinotTypeConverter.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowToPinotTypeConverter.java new file mode 100644 index 000000000000..3bebd57f4ec6 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowToPinotTypeConverter.java @@ -0,0 +1,313 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import com.google.common.collect.Maps; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.pinot.spi.data.readers.BaseRecordExtractor; +import org.apache.pinot.spi.utils.TimestampUtils; + + +/** + * Stateless schema-driven Arrow → Pinot value converter shared by the row-major and + * column-major Arrow ingestion paths. + * + *

The conversion mirrors the contract established by {@link ArrowRecordExtractor} prior to + * this extraction (see apache/pinot#18434 for the original row-major refactor): one branch per + * {@link ArrowType.ArrowTypeID}, with complex types recursing through their child {@link Field}s + * and scalars normalising per Pinot's expected JDK types. + * + *

Reused by: + *

+ * + *

All conversion methods are static; the only per-extraction state is the + * {@code extractRawTimeValues} flag, passed through as a method parameter. + */ +public final class ArrowToPinotTypeConverter { + + private ArrowToPinotTypeConverter() { + } + + /** + * Convert an Arrow vector value to its Pinot-canonical JDK representation. + * + * @param field the Arrow {@link Field} describing the value's type + * @param value the raw value emitted by {@code FieldVector.getObject(docId)} + * @param extractRawTimeValues when {@code true}, {@code Date} / {@code Time} / {@code Timestamp} + * surface as raw {@code int} / {@code long} in the schema's + * {@link org.apache.arrow.vector.types.TimeUnit} instead of the + * corresponding {@code java.time} / {@link Timestamp} contract type + * @return the Pinot-canonical value, or {@code null} for {@link ArrowType.Null} + */ + @Nullable + public static Object toPinotValue(Field field, Object value, boolean extractRawTimeValues) { + ArrowType type = field.getType(); + switch (type.getTypeID()) { + // Pass-through — Arrow boxes these directly into the contract output type. + case Bool: // Boolean + case FloatingPoint: // Float / Double + case Decimal: // BigDecimal + case Binary: // byte[] + case LargeBinary: // byte[] + case FixedSizeBinary: // byte[] + return value; + // toString — `Utf8` / `LargeUtf8` produce `String`; `Interval` / `Duration` produce ISO-8601 + // (`java.time.Period` / `java.time.Duration` / `PeriodDuration` all have meaningful toString). + case Utf8: + case LargeUtf8: + case Interval: + case Duration: + return value.toString(); + // Integer — `Byte` widens to `Integer` per contract (sign-extended for signed `TinyIntVector`, + // zero-extended via `& 0xFF` for unsigned `UInt1Vector`); `Short` (signed `SmallIntVector`) + // sign-extends; `Character` (unsigned 16, from `UInt2Vector`) widens to its `int` code point; + // `Integer` / `Long` pass through. + case Int: + if (value instanceof Byte) { + int v = (Byte) value; + return ((ArrowType.Int) type).getIsSigned() ? v : v & 0xFF; + } + if (value instanceof Short) { + return ((Short) value).intValue(); + } + if (value instanceof Character) { + return (int) (Character) value; + } + return value; + // Null — NullVector.getObject always returns null; callers should short-circuit on null, + // so this branch is unreachable in practice. Defensive return. + case Null: + return null; + // Logical temporal — schema's `TimeUnit` drives the conversion. + case Timestamp: + return convertTimestamp((ArrowType.Timestamp) type, value, extractRawTimeValues); + case Date: + return convertDate((ArrowType.Date) type, value, extractRawTimeValues); + case Time: + return convertTime((ArrowType.Time) type, value, extractRawTimeValues); + // Multi-value — `List` (and primitive-array lists) → `Object[]`. + case List: + case LargeList: + case FixedSizeList: + return convertList(field.getChildren().get(0), (List) value, extractRawTimeValues); + // Map / nested complex types. + case Map: + // The Map field's children are [entriesStruct]; the entries struct's children are + // [keyField, valueField] (named per MapVector.KEY_NAME / VALUE_NAME). + Field entriesField = field.getChildren().get(0); + return convertMap(entriesField.getChildren().get(0), entriesField.getChildren().get(1), + (List) value, extractRawTimeValues); + case Struct: + return convertStruct(field.getChildren(), (Map) value, extractRawTimeValues); + case Union: + // The chosen branch isn't visible from the resolved value alone — dispatch by the value's + // runtime Java type. Nested complex sub-branches fall back to `value.toString()`. + return convertByRuntimeType(value); + default: + // `NONE` is a placeholder; any other ID is a future Arrow addition. + throw new IllegalStateException("Unsupported Arrow type: " + type + " for field: " + field.getName()); + } + } + + /// Constructs a [Timestamp] from an Arrow `Timestamp` value. No-TZ vectors surface as + /// `LocalDateTime` (interpreted as UTC); with-TZ vectors surface as `Long` epoch counted in the + /// schema's `TimeUnit`. Sub-millisecond precision is preserved via [TimestampUtils]. + /// With `extractRawTimeValues` the raw `long` epoch in the schema's `TimeUnit` is returned. + private static Object convertTimestamp(ArrowType.Timestamp type, Object value, boolean extractRawTimeValues) { + if (extractRawTimeValues) { + if (value instanceof LocalDateTime) { + // No-TZ vector — convert the LocalDateTime back to an epoch `long` in the declared unit. + Instant instant = ((LocalDateTime) value).toInstant(ZoneOffset.UTC); + return toEpochInUnit(instant, type.getUnit()); + } + // With-TZ vector — already raw `long` in the declared unit. + return value; + } + if (value instanceof LocalDateTime) { + return Timestamp.from(((LocalDateTime) value).toInstant(ZoneOffset.UTC)); + } + long raw = ((Number) value).longValue(); + switch (type.getUnit()) { + case SECOND: + return new Timestamp(raw * 1000L); + case MILLISECOND: + return new Timestamp(raw); + case MICROSECOND: + return TimestampUtils.fromMicrosSinceEpoch(raw); + case NANOSECOND: + return TimestampUtils.fromNanosSinceEpoch(raw); + default: + throw new IllegalStateException("Unsupported Timestamp unit: " + type.getUnit()); + } + } + + private static long toEpochInUnit(Instant instant, org.apache.arrow.vector.types.TimeUnit unit) { + switch (unit) { + case SECOND: + return instant.getEpochSecond(); + case MILLISECOND: + return instant.toEpochMilli(); + case MICROSECOND: + return Math.addExact(Math.multiplyExact(instant.getEpochSecond(), 1_000_000L), instant.getNano() / 1_000L); + case NANOSECOND: + return Math.addExact(Math.multiplyExact(instant.getEpochSecond(), 1_000_000_000L), instant.getNano()); + default: + throw new IllegalStateException("Unsupported Timestamp unit: " + unit); + } + } + + /// Reduces an Arrow `Date` value to its contract Java type ([LocalDate]), or to `int` + /// days-since-epoch when `extractRawTimeValues` is set. `DateDayVector` surfaces as `Integer` + /// raw days; `DateMilliVector` surfaces as `LocalDateTime` at UTC midnight. + private static Object convertDate(ArrowType.Date type, Object value, boolean extractRawTimeValues) { + int days; + switch (type.getUnit()) { + case DAY: + days = (Integer) value; + break; + case MILLISECOND: + days = (int) ((LocalDateTime) value).toLocalDate().toEpochDay(); + break; + default: + throw new IllegalStateException("Unsupported Date unit: " + type.getUnit()); + } + return extractRawTimeValues ? days : LocalDate.ofEpochDay(days); + } + + /// Constructs a [LocalTime] from an Arrow `Time` value, dispatched by the schema's `TimeUnit`: + /// `TimeMilliVector` surfaces as `LocalDateTime`; `TimeSecVector` as `Integer`; + /// `TimeMicroVector` / `TimeNanoVector` as `Long`. All collapse onto nanoseconds-since-midnight. + /// With `extractRawTimeValues` the raw count in the schema's `TimeUnit` is returned instead. + private static Object convertTime(ArrowType.Time type, Object value, boolean extractRawTimeValues) { + if (extractRawTimeValues) { + if (value instanceof LocalDateTime) { + // `TimeMilliVector` surfaces as `LocalDateTime`; raw is `int` ms since midnight. + return (int) (((LocalDateTime) value).toLocalTime().toNanoOfDay() / 1_000_000L); + } + // `TimeSecVector` (Integer) / `TimeMicroVector` / `TimeNanoVector` (Long) — already raw. + return value; + } + if (value instanceof LocalDateTime) { + return ((LocalDateTime) value).toLocalTime(); + } + long raw = ((Number) value).longValue(); + switch (type.getUnit()) { + case SECOND: + return LocalTime.ofSecondOfDay(raw); + case MILLISECOND: + return LocalTime.ofNanoOfDay(raw * 1_000_000L); + case MICROSECOND: + return LocalTime.ofNanoOfDay(raw * 1_000L); + case NANOSECOND: + return LocalTime.ofNanoOfDay(raw); + default: + throw new IllegalStateException("Unsupported Time unit: " + type.getUnit()); + } + } + + private static Object[] convertList(Field elementField, List list, boolean extractRawTimeValues) { + int size = list.size(); + Object[] result = new Object[size]; + int i = 0; + for (Object element : list) { + result[i++] = element != null ? toPinotValue(elementField, element, extractRawTimeValues) : null; + } + return result; + } + + /// Flattens an Arrow `Map` column's entry list (`List>`) into a + /// `Map`, recursing into each value via [#toPinotValue] and stringifying each + /// key via [BaseRecordExtractor#stringifyMapKey] per the contract. Entries with a `null` key + /// (input or post-conversion) are dropped. + private static Map convertMap(Field keyField, Field valueField, List entries, + boolean extractRawTimeValues) { + Map result = Maps.newLinkedHashMapWithExpectedSize(entries.size()); + for (Object entry : entries) { + if (entry == null) { + continue; + } + Map entryMap = (Map) entry; + Object rawKey = entryMap.get(MapVector.KEY_NAME); + if (rawKey == null) { + continue; + } + Object convertedKey = toPinotValue(keyField, rawKey, extractRawTimeValues); + if (convertedKey == null) { + continue; + } + Object rawValue = entryMap.get(MapVector.VALUE_NAME); + result.put(BaseRecordExtractor.stringifyMapKey(convertedKey), + rawValue != null ? toPinotValue(valueField, rawValue, extractRawTimeValues) : null); + } + return result; + } + + private static Map convertStruct(List childFields, Map value, + boolean extractRawTimeValues) { + Map result = Maps.newHashMapWithExpectedSize(childFields.size()); + for (Field childField : childFields) { + String name = childField.getName(); + Object rawValue = value.get(name); + result.put(name, rawValue != null ? toPinotValue(childField, rawValue, extractRawTimeValues) : null); + } + return result; + } + + /// Runtime-type dispatch used by the `Union` case (where the chosen branch isn't accessible + /// from the resolved value). Mirrors the scalar handling of [#toPinotValue] for the common + /// Arrow boxed types; nested complex types fall back to `value.toString()` because their child + /// [Field]s aren't reachable from here. + private static Object convertByRuntimeType(Object value) { + if (value instanceof Number) { + if (value instanceof Byte || value instanceof Short) { + return ((Number) value).intValue(); + } + return value; + } + if (value instanceof Boolean || value instanceof byte[]) { + return value; + } + if (value instanceof Character) { + // `UInt2Vector` surfaces as `Character`; widen to `int` per the Int(16) contract. + return (int) (Character) value; + } + if (value instanceof LocalDateTime) { + // Ambiguous between Timestamp / Date / Time — best-effort: treat as Timestamp UTC. + return Timestamp.from(((LocalDateTime) value).toInstant(ZoneOffset.UTC)); + } + // `Text` (Utf8 / LargeUtf8), `Period` / `Duration` / `PeriodDuration` (Interval / Duration), and + // anything unrecognized fall through to `toString()`. + return value.toString(); + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReaderFactoryTest.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReaderFactoryTest.java new file mode 100644 index 000000000000..e1281a94427e --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnReaderFactoryTest.java @@ -0,0 +1,189 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.Arrays; +import java.util.Collections; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.pinot.spi.data.DimensionFieldSpec; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.apache.pinot.spi.data.readers.ColumnReader; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + + +/** + * Verifies {@link ArrowColumnReaderFactory} can consume a caller-managed {@link ArrowStreamReader} + * backed by a {@link ByteArrayInputStream}, without disk I/O, and that the caller-owned + * {@link RootAllocator} remains usable after the factory is closed. + */ +public class ArrowColumnReaderFactoryTest { + + private static final int ROW_COUNT = 32; + + @Test + public void testReadsValuesFromInMemoryReader() + throws Exception { + // Caller owns this allocator. Verify the factory does not close it. + try (RootAllocator callerAllocator = new RootAllocator(Long.MAX_VALUE)) { + byte[] ipcBytes = writeArrowStreamFixture(callerAllocator); + + try (ArrowStreamReader streamReader = new ArrowStreamReader( + Channels.newChannel(new ByteArrayInputStream(ipcBytes)), callerAllocator); + ArrowColumnReaderFactory factory = + new ArrowColumnReaderFactory(streamReader, callerAllocator)) { + + factory.init(newSchema()); + + try (ColumnReader intCol = factory.getColumnReader("intCol"); + ColumnReader longCol = factory.getColumnReader("longCol"); + ColumnReader stringCol = factory.getColumnReader("stringCol")) { + + assertNotNull(intCol); + assertNotNull(longCol); + assertNotNull(stringCol); + assertEquals(intCol.getTotalDocs(), ROW_COUNT); + assertEquals(longCol.getTotalDocs(), ROW_COUNT); + assertEquals(stringCol.getTotalDocs(), ROW_COUNT); + + for (int i = 0; i < ROW_COUNT; i++) { + assertEquals(intCol.getInt(i), i); + assertEquals(longCol.getLong(i), (long) i * 1000); + assertEquals(stringCol.getString(i), "row_" + i); + } + } + } + + // Caller's allocator should still be usable after the factory closes — proves the factory + // did not close it. Allocate-and-release a small buffer to confirm. + try (IntVector probe = new IntVector("probe", callerAllocator)) { + probe.allocateNew(1); + probe.set(0, 42); + probe.setValueCount(1); + assertEquals(probe.get(0), 42); + } + } + } + + @Test + public void testGetAllColumnReadersReflectsSourceSchema() + throws Exception { + try (RootAllocator callerAllocator = new RootAllocator(Long.MAX_VALUE)) { + byte[] ipcBytes = writeArrowStreamFixture(callerAllocator); + + try (ArrowStreamReader streamReader = new ArrowStreamReader( + Channels.newChannel(new ByteArrayInputStream(ipcBytes)), callerAllocator); + ArrowColumnReaderFactory factory = + new ArrowColumnReaderFactory(streamReader, callerAllocator)) { + + factory.init(newSchema()); + + assertEquals(factory.getAvailableColumns().size(), 3); + assertTrue(factory.getAvailableColumns().contains("intCol")); + assertTrue(factory.getAvailableColumns().contains("longCol")); + assertTrue(factory.getAvailableColumns().contains("stringCol")); + assertEquals(factory.getAllColumnReaders().size(), 3); + } + } + } + + @Test + public void testColumnSubsetFiltering() + throws Exception { + try (RootAllocator callerAllocator = new RootAllocator(Long.MAX_VALUE)) { + byte[] ipcBytes = writeArrowStreamFixture(callerAllocator); + + try (ArrowStreamReader streamReader = new ArrowStreamReader( + Channels.newChannel(new ByteArrayInputStream(ipcBytes)), callerAllocator); + ArrowColumnReaderFactory factory = + new ArrowColumnReaderFactory(streamReader, callerAllocator)) { + + factory.init(newSchema(), Collections.singleton("intCol")); + + // Only the requested column has a reader; the others fall back to null per the SPI. + assertNotNull(factory.getColumnReader("intCol")); + assertEquals(factory.getAllColumnReaders().size(), 1); + assertFalse(factory.getAllColumnReaders().containsKey("longCol")); + assertFalse(factory.getAllColumnReaders().containsKey("stringCol")); + } + } + } + + private byte[] writeArrowStreamFixture(RootAllocator allocator) + throws IOException { + Field intField = new Field("intCol", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field longField = new Field("longCol", FieldType.nullable(new ArrowType.Int(64, true)), null); + Field stringField = new Field("stringCol", FieldType.nullable(new ArrowType.Utf8()), null); + Schema arrowSchema = new Schema(Arrays.asList(intField, longField, stringField)); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))) { + IntVector intVec = (IntVector) root.getVector("intCol"); + BigIntVector longVec = (BigIntVector) root.getVector("longCol"); + VarCharVector stringVec = (VarCharVector) root.getVector("stringCol"); + + intVec.allocateNew(ROW_COUNT); + longVec.allocateNew(ROW_COUNT); + stringVec.allocateNew(ROW_COUNT * 8, ROW_COUNT); + + for (int i = 0; i < ROW_COUNT; i++) { + intVec.set(i, i); + longVec.set(i, (long) i * 1000); + stringVec.set(i, ("row_" + i).getBytes()); + } + intVec.setValueCount(ROW_COUNT); + longVec.setValueCount(ROW_COUNT); + stringVec.setValueCount(ROW_COUNT); + root.setRowCount(ROW_COUNT); + + writer.start(); + writer.writeBatch(); + writer.end(); + } + return out.toByteArray(); + } + + private org.apache.pinot.spi.data.Schema newSchema() { + org.apache.pinot.spi.data.Schema schema = new org.apache.pinot.spi.data.Schema(); + schema.setSchemaName("inMemoryArrowTest"); + schema.addField(new DimensionFieldSpec("intCol", DataType.INT, true)); + schema.addField(new DimensionFieldSpec("longCol", DataType.LONG, true)); + schema.addField(new DimensionFieldSpec("stringCol", DataType.STRING, true)); + return schema; + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnarBuildIntegrationTest.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnarBuildIntegrationTest.java new file mode 100644 index 000000000000..0c27f201b679 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowColumnarBuildIntegrationTest.java @@ -0,0 +1,289 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader; +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl; +import org.apache.pinot.segment.spi.ColumnMetadata; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.SegmentMetadata; +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.DimensionFieldSpec; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.apache.pinot.spi.utils.ReadMode; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + + +/** + * End-to-end tests that build a Pinot segment from the same Arrow data via row-major and + * column-major paths and assert the produced segments carry equivalent metadata. + * + *

Paths exercised: + *

+ * + *

The test asserts identical per-column cardinality, min, max, totalDocs, data type, and + * segment-level doc count across each compared pair. + */ +public class ArrowColumnarBuildIntegrationTest { + + private static final String TABLE_NAME = "arrowColumnarBuildTest"; + private static final int ROW_COUNT = 64; + + private Path _tempDir; + + @BeforeMethod + public void setUp() + throws IOException { + _tempDir = Files.createTempDirectory("arrow-columnar-integration"); + } + + @AfterMethod + public void tearDown() + throws IOException { + if (_tempDir != null) { + try (var stream = Files.walk(_tempDir)) { + stream.sorted((a, b) -> b.compareTo(a)).forEach(p -> { + try { + Files.deleteIfExists(p); + } catch (IOException ignored) { + // best effort + } + }); + } + } + } + + @Test + public void testSegmentMetadataEquivalenceRowMajorVsFileColumnar() + throws Exception { + File arrowFile = writeArrowFileFixture("equivalence-file.arrow"); + + File rowMajorSegmentDir = buildSegmentRowMajor(arrowFile, "rowMajor"); + File columnarSegmentDir = buildSegmentFileColumnar(arrowFile, "fileColumnar"); + + assertSegmentMetadataEquivalence(rowMajorSegmentDir, columnarSegmentDir); + } + + @Test + public void testSegmentMetadataEquivalenceRowMajorVsInMemoryColumnar() + throws Exception { + // Same data is materialised two ways. The row-major path needs a file because + // ArrowRecordReader.init only accepts File; the in-memory column-major path consumes the + // same data as Arrow IPC stream bytes via ArrowStreamReader — no file touched on that path. + File arrowFile = writeArrowFileFixture("equivalence-inmem-source.arrow"); + byte[] streamBytes = writeArrowStreamBytes(); + + File rowMajorSegmentDir = buildSegmentRowMajor(arrowFile, "rowMajor"); + File columnarSegmentDir = buildSegmentInMemoryColumnar(streamBytes, "inMemColumnar"); + + assertSegmentMetadataEquivalence(rowMajorSegmentDir, columnarSegmentDir); + } + + private void assertSegmentMetadataEquivalence(File rowMajorSegmentDir, File columnarSegmentDir) + throws Exception { + IndexSegment rowMajorSegment = ImmutableSegmentLoader.load(rowMajorSegmentDir, ReadMode.heap); + IndexSegment columnarSegment = ImmutableSegmentLoader.load(columnarSegmentDir, ReadMode.heap); + try { + SegmentMetadata rowMajorMeta = rowMajorSegment.getSegmentMetadata(); + SegmentMetadata columnarMeta = columnarSegment.getSegmentMetadata(); + + assertEquals(columnarMeta.getTotalDocs(), rowMajorMeta.getTotalDocs(), + "Segment-level doc count must match between row-major and columnar builds"); + + // Compare only the user-defined columns — virtual columns ($segmentName, $docId, ...) + // are populated post-build and their min/max naturally differ across segments. + for (String column : rowMajorMeta.getAllColumns()) { + if (column.startsWith("$")) { + continue; + } + ColumnMetadata rmCol = rowMajorMeta.getColumnMetadataFor(column); + ColumnMetadata colCol = columnarMeta.getColumnMetadataFor(column); + assertNotNull(rmCol, "Row-major missing column metadata for " + column); + assertNotNull(colCol, "Columnar missing column metadata for " + column); + assertEquals(colCol.getCardinality(), rmCol.getCardinality(), + "Cardinality mismatch on column " + column); + assertEquals(colCol.getMinValue(), rmCol.getMinValue(), + "Min value mismatch on column " + column); + assertEquals(colCol.getMaxValue(), rmCol.getMaxValue(), + "Max value mismatch on column " + column); + assertEquals(colCol.getTotalDocs(), rmCol.getTotalDocs(), + "Per-column doc count mismatch on column " + column); + assertEquals(colCol.getDataType(), rmCol.getDataType(), + "Data type mismatch on column " + column); + } + } finally { + rowMajorSegment.destroy(); + columnarSegment.destroy(); + } + } + + private File buildSegmentRowMajor(File arrowFile, String segmentName) + throws Exception { + File outputDir = _tempDir.resolve("rm-out-" + segmentName).toFile(); + SegmentGeneratorConfig config = newSegmentConfig(outputDir, segmentName); + try (ArrowRecordReader recordReader = new ArrowRecordReader()) { + recordReader.init(arrowFile, null, null); + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + driver.init(config, recordReader); + driver.build(); + } + return new File(outputDir, segmentName); + } + + private File buildSegmentFileColumnar(File arrowFile, String segmentName) + throws Exception { + File outputDir = _tempDir.resolve("col-file-out-" + segmentName).toFile(); + SegmentGeneratorConfig config = newSegmentConfig(outputDir, segmentName); + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + driver.init(config, factory); + driver.build(); + } + return new File(outputDir, segmentName); + } + + private File buildSegmentInMemoryColumnar(byte[] streamBytes, String segmentName) + throws Exception { + File outputDir = _tempDir.resolve("col-inmem-out-" + segmentName).toFile(); + SegmentGeneratorConfig config = newSegmentConfig(outputDir, segmentName); + // Caller manages the allocator and the ArrowStreamReader; the factory borrows them. + try (RootAllocator callerAllocator = new RootAllocator(Long.MAX_VALUE); + ArrowStreamReader streamReader = new ArrowStreamReader( + Channels.newChannel(new ByteArrayInputStream(streamBytes)), callerAllocator); + ArrowColumnReaderFactory factory = + new ArrowColumnReaderFactory(streamReader, callerAllocator)) { + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + driver.init(config, factory); + driver.build(); + } + return new File(outputDir, segmentName); + } + + private SegmentGeneratorConfig newSegmentConfig(File outputDir, String segmentName) { + org.apache.pinot.spi.data.Schema schema = new org.apache.pinot.spi.data.Schema(); + schema.setSchemaName(TABLE_NAME); + schema.addField(new DimensionFieldSpec("intCol", DataType.INT, true)); + schema.addField(new DimensionFieldSpec("longCol", DataType.LONG, true)); + schema.addField(new DimensionFieldSpec("stringCol", DataType.STRING, true)); + + TableConfig tableConfig = + new TableConfigBuilder(TableType.OFFLINE).setTableName(TABLE_NAME).build(); + + SegmentGeneratorConfig config = new SegmentGeneratorConfig(tableConfig, schema); + config.setOutDir(outputDir.getAbsolutePath()); + config.setSegmentName(segmentName); + return config; + } + + private File writeArrowFileFixture(String fileName) + throws IOException { + File out = _tempDir.resolve(fileName).toFile(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create(fixtureArrowSchema(), allocator); + FileOutputStream fos = new FileOutputStream(out); + FileChannel channel = fos.getChannel(); + ArrowFileWriter writer = new ArrowFileWriter(root, null, channel)) { + populateFixtureVectors(root); + writer.start(); + writer.writeBatch(); + writer.end(); + } + return out; + } + + private byte[] writeArrowStreamBytes() + throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create(fixtureArrowSchema(), allocator); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))) { + populateFixtureVectors(root); + writer.start(); + writer.writeBatch(); + writer.end(); + } + return out.toByteArray(); + } + + private Schema fixtureArrowSchema() { + Field intField = new Field("intCol", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field longField = new Field("longCol", FieldType.nullable(new ArrowType.Int(64, true)), null); + Field stringField = new Field("stringCol", FieldType.nullable(new ArrowType.Utf8()), null); + return new Schema(Arrays.asList(intField, longField, stringField)); + } + + private void populateFixtureVectors(VectorSchemaRoot root) { + IntVector intVec = (IntVector) root.getVector("intCol"); + BigIntVector longVec = (BigIntVector) root.getVector("longCol"); + VarCharVector stringVec = (VarCharVector) root.getVector("stringCol"); + + intVec.allocateNew(ROW_COUNT); + longVec.allocateNew(ROW_COUNT); + stringVec.allocateNew(ROW_COUNT * 8, ROW_COUNT); + + for (int i = 0; i < ROW_COUNT; i++) { + intVec.set(i, i); + longVec.set(i, (long) i * 1000); + stringVec.set(i, ("row_" + i).getBytes()); + } + intVec.setValueCount(ROW_COUNT); + longVec.setValueCount(ROW_COUNT); + stringVec.setValueCount(ROW_COUNT); + root.setRowCount(ROW_COUNT); + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowFileColumnReaderFactoryTest.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowFileColumnReaderFactoryTest.java new file mode 100644 index 000000000000..2858890553bb --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowFileColumnReaderFactoryTest.java @@ -0,0 +1,456 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.pinot.spi.data.DimensionFieldSpec; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.readers.ColumnReader; +import org.apache.pinot.spi.data.readers.MultiValueResult; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + + +/** + * Unit tests for {@link ArrowColumnReaderFactory} and {@link ArrowColumnReader}. + * + *

Each test materialises a small Arrow IPC file on disk with a known fixture and verifies + * that the factory can read it back column-by-column using all three documented patterns: + * generic sequential {@code next()}, typed sequential {@code nextInt() / nextLong() / ...}, + * and random-access {@code getInt(docId) / getLong(docId) / ...}. + */ +public class ArrowFileColumnReaderFactoryTest { + + private static final int ROWS = 8; + private Path _tempDir; + + @BeforeMethod + public void setUp() + throws IOException { + _tempDir = Files.createTempDirectory("arrow-column-reader-test"); + } + + @AfterMethod + public void tearDown() + throws IOException { + if (_tempDir != null) { + try (var stream = Files.walk(_tempDir)) { + stream.sorted((a, b) -> b.compareTo(a)).forEach(p -> { + try { + Files.deleteIfExists(p); + } catch (IOException ignored) { + // best effort + } + }); + } + } + } + + @Test + public void testSequentialReadAllPrimitiveTypes() + throws IOException { + File arrowFile = writePrimitiveFixture("primitive.arrow"); + org.apache.pinot.spi.data.Schema schema = primitiveSchema(); + + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + factory.init(schema); + + Map readers = factory.getAllColumnReaders(); + assertEquals(readers.size(), 6); + for (ColumnReader r : readers.values()) { + assertEquals(r.getTotalDocs(), ROWS); + assertTrue(r.isSingleValue()); + } + + ColumnReader intReader = readers.get("intCol"); + ColumnReader longReader = readers.get("longCol"); + ColumnReader floatReader = readers.get("floatCol"); + ColumnReader doubleReader = readers.get("doubleCol"); + ColumnReader stringReader = readers.get("stringCol"); + ColumnReader bytesReader = readers.get("bytesCol"); + + assertTrue(intReader.isInt()); + assertTrue(longReader.isLong()); + assertTrue(floatReader.isFloat()); + assertTrue(doubleReader.isDouble()); + assertTrue(stringReader.isString()); + assertTrue(bytesReader.isBytes()); + + // Pattern 2: typed sequential read with null handling + for (int i = 0; i < ROWS; i++) { + if (i == 3) { + assertTrue(intReader.isNextNull(), "Row 3 INT should be null"); + intReader.skipNext(); + assertTrue(longReader.isNextNull(), "Row 3 LONG should be null"); + longReader.skipNext(); + assertTrue(stringReader.isNextNull(), "Row 3 STRING should be null"); + stringReader.skipNext(); + floatReader.nextFloat(); + doubleReader.nextDouble(); + bytesReader.nextBytes(); + } else { + assertFalse(intReader.isNextNull()); + assertEquals(intReader.nextInt(), i * 10); + assertEquals(longReader.nextLong(), (long) i * 100); + assertEquals(floatReader.nextFloat(), i + 0.5f, 0.0001f); + assertEquals(doubleReader.nextDouble(), i + 0.25, 0.0001); + assertEquals(stringReader.nextString(), "s_" + i); + assertEquals(new String(bytesReader.nextBytes()), "b_" + i); + } + } + + // Verify rewind enables a second pass. + intReader.rewind(); + stringReader.rewind(); + assertTrue(intReader.hasNext()); + assertEquals(intReader.nextInt(), 0); + assertEquals(stringReader.nextString(), "s_0"); + } + } + + @Test + public void testRandomAccessByDocId() + throws IOException { + File arrowFile = writePrimitiveFixture("random.arrow"); + org.apache.pinot.spi.data.Schema schema = primitiveSchema(); + + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + factory.init(schema); + ColumnReader intReader = factory.getColumnReader("intCol"); + ColumnReader longReader = factory.getColumnReader("longCol"); + + // Pattern 3: read out of order. + assertEquals(intReader.getInt(5), 50); + assertEquals(intReader.getInt(0), 0); + assertEquals(intReader.getInt(7), 70); + assertTrue(intReader.isNull(3)); + assertEquals(longReader.getLong(2), 200L); + assertEquals(longReader.getValue(3), null); + + // Sequential read is unaffected by prior random-access reads on the same reader. + assertTrue(intReader.hasNext()); + assertEquals(intReader.nextInt(), 0); + } + } + + @Test + public void testGenericNextHandlesNulls() + throws IOException { + File arrowFile = writePrimitiveFixture("generic.arrow"); + org.apache.pinot.spi.data.Schema schema = primitiveSchema(); + + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + factory.init(schema); + ColumnReader stringReader = factory.getColumnReader("stringCol"); + + int nullCount = 0; + int nonNullCount = 0; + while (stringReader.hasNext()) { + Object value = stringReader.next(); + if (value == null) { + nullCount++; + } else { + nonNullCount++; + } + } + assertEquals(nullCount, 1); + assertEquals(nonNullCount, ROWS - 1); + } + } + + @Test + public void testMultiValueIntColumn() + throws IOException { + File arrowFile = writeMultiValueIntFixture("mv-int.arrow"); + org.apache.pinot.spi.data.Schema schema = new org.apache.pinot.spi.data.Schema.SchemaBuilder() + .setSchemaName("mv") + .addMultiValueDimension("intArr", FieldSpec.DataType.INT) + .build(); + + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + factory.init(schema); + ColumnReader reader = factory.getColumnReader("intArr"); + assertNotNull(reader); + assertFalse(reader.isSingleValue()); + assertEquals(reader.getTotalDocs(), 3); + + MultiValueResult doc0 = reader.getIntMV(0); + assertEquals(doc0.getValues(), new int[]{1, 2, 3}); + assertFalse(doc0.hasNulls()); + + MultiValueResult doc1 = reader.getIntMV(1); + assertEquals(doc1.getValues().length, 2); + assertTrue(doc1.hasNulls()); + assertTrue(doc1.isNull(1)); + assertFalse(doc1.isNull(0)); + assertEquals(doc1.getValues()[0], 4); + + MultiValueResult doc2 = reader.getIntMV(2); + assertEquals(doc2.getValues().length, 0); + } + } + + @Test + public void testInitSubsetOfColumns() + throws IOException { + File arrowFile = writePrimitiveFixture("subset.arrow"); + org.apache.pinot.spi.data.Schema schema = primitiveSchema(); + + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + factory.init(schema, java.util.Set.of("intCol", "stringCol"), Collections.emptyMap()); + Map readers = factory.getAllColumnReaders(); + assertEquals(readers.size(), 2); + assertTrue(readers.containsKey("intCol")); + assertTrue(readers.containsKey("stringCol")); + assertNull(factory.getColumnReader("longCol")); + } + } + + @Test + public void testMultiBatchConcatenation() + throws IOException { + File arrowFile = writeMultiBatchFixture("multi-batch.arrow", 3, 4); + org.apache.pinot.spi.data.Schema schema = new org.apache.pinot.spi.data.Schema.SchemaBuilder() + .setSchemaName("multiBatch") + .addSingleValueDimension("seqCol", FieldSpec.DataType.INT) + .build(); + + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + factory.init(schema); + ColumnReader reader = factory.getColumnReader("seqCol"); + assertNotNull(reader); + assertEquals(reader.getTotalDocs(), 12, "Expected 3 batches of 4 rows = 12 docs total"); + + // Sequential traversal yields the global doc order. + for (int i = 0; i < 12; i++) { + assertEquals(reader.nextInt(), i); + } + assertFalse(reader.hasNext()); + + // Random access across batch boundaries. + reader.rewind(); + assertEquals(reader.getInt(0), 0); + assertEquals(reader.getInt(3), 3); // last row of batch 0 + assertEquals(reader.getInt(4), 4); // first row of batch 1 + assertEquals(reader.getInt(7), 7); // last row of batch 1 + assertEquals(reader.getInt(8), 8); // first row of batch 2 + assertEquals(reader.getInt(11), 11); // last row overall + } + } + + @Test + public void testGetAvailableColumnsIncludesAllSourceColumns() + throws IOException { + File arrowFile = writePrimitiveFixture("available.arrow"); + org.apache.pinot.spi.data.Schema schema = primitiveSchema(); + + try (ArrowFileColumnReaderFactory factory = new ArrowFileColumnReaderFactory(arrowFile)) { + factory.init(schema, java.util.Set.of("intCol"), Collections.emptyMap()); + // getAvailableColumns reflects the source schema, not the requested subset. + assertTrue(factory.getAvailableColumns().containsAll( + Arrays.asList("intCol", "longCol", "floatCol", "doubleCol", "stringCol", "bytesCol"))); + } + } + + // ===== Fixture builders ===== + + private org.apache.pinot.spi.data.Schema primitiveSchema() { + org.apache.pinot.spi.data.Schema schema = new org.apache.pinot.spi.data.Schema(); + schema.setSchemaName("primitive"); + schema.addField(new DimensionFieldSpec("intCol", FieldSpec.DataType.INT, true)); + schema.addField(new DimensionFieldSpec("longCol", FieldSpec.DataType.LONG, true)); + schema.addField(new DimensionFieldSpec("floatCol", FieldSpec.DataType.FLOAT, true)); + schema.addField(new DimensionFieldSpec("doubleCol", FieldSpec.DataType.DOUBLE, true)); + schema.addField(new DimensionFieldSpec("stringCol", FieldSpec.DataType.STRING, true)); + schema.addField(new DimensionFieldSpec("bytesCol", FieldSpec.DataType.BYTES, true)); + return schema; + } + + private File writePrimitiveFixture(String fileName) + throws IOException { + File out = _tempDir.resolve(fileName).toFile(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field intField = new Field("intCol", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field longField = new Field("longCol", FieldType.nullable(new ArrowType.Int(64, true)), null); + Field floatField = + new Field("floatCol", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null); + Field doubleField = + new Field("doubleCol", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null); + Field stringField = new Field("stringCol", FieldType.nullable(new ArrowType.Utf8()), null); + Field bytesField = new Field("bytesCol", FieldType.nullable(new ArrowType.Binary()), null); + Schema schema = + new Schema(Arrays.asList(intField, longField, floatField, doubleField, stringField, bytesField)); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + IntVector intVec = (IntVector) root.getVector("intCol"); + BigIntVector longVec = (BigIntVector) root.getVector("longCol"); + Float4Vector floatVec = (Float4Vector) root.getVector("floatCol"); + Float8Vector doubleVec = (Float8Vector) root.getVector("doubleCol"); + VarCharVector stringVec = (VarCharVector) root.getVector("stringCol"); + VarBinaryVector bytesVec = (VarBinaryVector) root.getVector("bytesCol"); + + intVec.allocateNew(ROWS); + longVec.allocateNew(ROWS); + floatVec.allocateNew(ROWS); + doubleVec.allocateNew(ROWS); + stringVec.allocateNew(ROWS * 8, ROWS); + bytesVec.allocateNew(ROWS * 8, ROWS); + + for (int i = 0; i < ROWS; i++) { + if (i == 3) { + intVec.setNull(i); + longVec.setNull(i); + stringVec.setNull(i); + } else { + intVec.set(i, i * 10); + longVec.set(i, (long) i * 100); + stringVec.set(i, ("s_" + i).getBytes()); + } + floatVec.set(i, i + 0.5f); + doubleVec.set(i, i + 0.25); + bytesVec.set(i, ("b_" + i).getBytes()); + } + intVec.setValueCount(ROWS); + longVec.setValueCount(ROWS); + floatVec.setValueCount(ROWS); + doubleVec.setValueCount(ROWS); + stringVec.setValueCount(ROWS); + bytesVec.setValueCount(ROWS); + root.setRowCount(ROWS); + + writeIpc(out, root); + } + } + return out; + } + + private File writeMultiBatchFixture(String fileName, int batchCount, int rowsPerBatch) + throws IOException { + File out = _tempDir.resolve(fileName).toFile(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field seqField = new Field("seqCol", FieldType.nullable(new ArrowType.Int(32, true)), null); + Schema schema = new Schema(Collections.singletonList(seqField)); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + FileOutputStream fos = new FileOutputStream(out); + FileChannel channel = fos.getChannel(); + ArrowFileWriter writer = new ArrowFileWriter(root, null, channel)) { + writer.start(); + IntVector seqVec = (IntVector) root.getVector("seqCol"); + + int globalSeq = 0; + for (int b = 0; b < batchCount; b++) { + seqVec.allocateNew(rowsPerBatch); + for (int i = 0; i < rowsPerBatch; i++) { + seqVec.set(i, globalSeq++); + } + seqVec.setValueCount(rowsPerBatch); + root.setRowCount(rowsPerBatch); + writer.writeBatch(); + } + writer.end(); + } + } + return out; + } + + private File writeMultiValueIntFixture(String fileName) + throws IOException { + File out = _tempDir.resolve(fileName).toFile(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field elementField = + new Field("element", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field listField = new Field("intArr", FieldType.nullable(ArrowType.List.INSTANCE), + Collections.singletonList(elementField)); + Schema schema = new Schema(Collections.singletonList(listField)); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + ListVector listVec = (ListVector) root.getVector("intArr"); + listVec.allocateNew(); + UnionListWriter writer = listVec.getWriter(); + + writer.startList(); + writer.setPosition(0); + writer.startList(); + writer.integer().writeInt(1); + writer.integer().writeInt(2); + writer.integer().writeInt(3); + writer.endList(); + + writer.setPosition(1); + writer.startList(); + writer.integer().writeInt(4); + writer.writeNull(); + writer.endList(); + + writer.setPosition(2); + writer.startList(); + writer.endList(); + + listVec.setValueCount(3); + root.setRowCount(3); + + writeIpc(out, root); + } + } + return out; + } + + private void writeIpc(File out, VectorSchemaRoot root) + throws IOException { + try (FileOutputStream fos = new FileOutputStream(out); + FileChannel channel = fos.getChannel(); + ArrowFileWriter writer = new ArrowFileWriter(root, null, channel)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + } +}