diff --git a/java/build.gradle.kts b/java/build.gradle.kts index bb73c481776..94e9f2fec96 100644 --- a/java/build.gradle.kts +++ b/java/build.gradle.kts @@ -81,5 +81,10 @@ allprojects { } } - tasks.register("format").get().dependsOn("spotlessApply") + if (project.name == "vortex-spark_2.12") { + // vortex-spark_2.12 and vortex-spark_2.13 share a projectDir; format from the 2.13 variant only. + tasks.register("format") { enabled = false } + } else { + tasks.register("format").get().dependsOn("spotlessApply") + } } diff --git a/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java b/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java index e4d8978e112..950f75a8789 100644 --- a/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java +++ b/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java @@ -73,6 +73,10 @@ public static Expression isNull(Expression child) { return new Expression(NativeExpression.isNull(child.nativePointer())); } + public static Expression isNotNull(Expression child) { + return new Expression(NativeExpression.isNotNull(child.nativePointer())); + } + public static Expression literal(boolean value) { return new Expression(NativeExpression.literalBool(value, false)); } diff --git a/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java b/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java index fc5bedbe7b5..e35eb1d0215 100644 --- a/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java +++ b/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java @@ -27,6 +27,8 @@ private NativeExpression() {} public static native long isNull(long childPointer); + public static native long isNotNull(long childPointer); + public static native long literalBool(boolean value, boolean isNull); public static native long literalI8(byte value, boolean isNull); diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java index 1d564b8c0e0..b3d7d637504 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java @@ -20,6 +20,7 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.DataType; @@ -118,6 +119,38 @@ public StructType inferSchema(CaseInsensitiveStringMap options) { return dataSchema; } + /** + * Infers partition transforms by inspecting Hive-style {@code key=value} segments in the first listed file path. + * + *

Spark calls this before {@link #getTable(StructType, Transform[], Map)} when the caller did not provide + * explicit partitioning. Returning identity transforms here lets downstream components (notably + * {@link dev.vortex.spark.read.VortexScanBuilder}) tell which schema columns are encoded in the directory layout + * rather than stored inside the Vortex files, which matters for predicate pushdown. + */ + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + var paths = getPaths(options); + if (paths.isEmpty()) { + return new Transform[0]; + } + var formatOptions = buildDataSourceOptions(options.asCaseSensitiveMap()); + String pathToInfer = Objects.requireNonNull(Iterables.getLast(paths)); + if (!pathToInfer.endsWith(".vortex")) { + Optional firstFile = + NativeFiles.listFiles(VortexSparkSession.get(formatOptions), pathToInfer, formatOptions).stream() + .findFirst(); + if (firstFile.isEmpty()) { + return new Transform[0]; + } + pathToInfer = firstFile.get(); + } + Map partitionValues = PartitionPathUtils.parsePartitionValues(pathToInfer); + if (partitionValues.isEmpty()) { + return new Transform[0]; + } + return partitionValues.keySet().stream().map(Expressions::identity).toArray(Transform[]::new); + } + /** * Creates a Vortex table instance with the given schema and properties. * diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java index 92cc55ff211..f65f74ccf19 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java @@ -58,7 +58,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { Map opts = Maps.newHashMap(); opts.putAll(formatOptions); opts.putAll(options); - return new VortexScanBuilder(opts) + return new VortexScanBuilder(opts, partitionTransforms) .addAllPaths(paths) .addAllColumns(Arrays.asList(CatalogV2Util.structTypeToV2Columns(schema))); } diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java deleted file mode 100644 index e5cd96a3958..00000000000 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -package dev.vortex.spark.read; - -import java.util.Iterator; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.ToLongFunction; - -final class PrefetchingIterator implements Iterator, AutoCloseable { - // Global condition variable shared between the prefetcher and consumer threads, - // to coordinate wake ups for when the buffer may no longer be full. - private static final Object CONDITION = new Object(); - - private final BlockingQueue fetched = new LinkedBlockingQueue<>(); - private final Thread producerThread; - private final Iterator delegate; - private final AtomicBoolean closed = new AtomicBoolean(false); - private final AtomicLong bufferBytes = new AtomicLong(0); - private final long maxBufferSize; - private final ToLongFunction sizeFunc; - - PrefetchingIterator(Iterator delegate, long maxBufferSize, ToLongFunction sizeFunc) { - this.delegate = delegate; - this.maxBufferSize = maxBufferSize; - this.sizeFunc = sizeFunc; - this.producerThread = new Thread(this::prefetchLoop, "vortex-prefetch-thread"); - producerThread.setDaemon(true); - producerThread.start(); - } - - private void prefetchLoop() { - try { - while (!closed.get() && delegate.hasNext()) { - while (bufferBytes.get() > maxBufferSize) { - synchronized (CONDITION) { - CONDITION.wait(); - } - } - T nextElem = delegate.next(); - long elemSize = sizeFunc.applyAsLong(nextElem); - bufferBytes.addAndGet(elemSize); - fetched.put(nextElem); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Prefetching interrupted", e); - } catch (Exception e) { - throw new RuntimeException("Prefetching failed", e); - } finally { - closed.set(true); - } - } - - @Override - public boolean hasNext() { - // If the prefetcher is not finished, then we could be waiting for - // a fetched item. - while (!closed.get()) { - if (!fetched.isEmpty()) { - return true; - } - } - // If the prefetcher is finished, then we can examine fetched and immediately return a result. - return !fetched.isEmpty(); - } - - @Override - public T next() { - // We assume that this has been called after hasNext() returned true, so it is - // safe to call take() without checking if the queue maybe be empty. - try { - T nextElem = this.fetched.take(); - long elemSize = sizeFunc.applyAsLong(nextElem); - bufferBytes.addAndGet(-elemSize); - // Notify the producer that it may now be able to add more items to the queue. - synchronized (CONDITION) { - CONDITION.notify(); - } - return nextElem; - } catch (InterruptedException e) { - throw new RuntimeException("Prefetch queue take interrupted", e); - } - } - - @Override - public void close() { - closed.set(true); - producerThread.interrupt(); - } -} diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java new file mode 100644 index 00000000000..19c59d5d9b0 --- /dev/null +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +package dev.vortex.spark.read; + +import dev.vortex.api.Expression; +import dev.vortex.api.Expression.BinaryOp; +import java.util.Optional; +import java.util.Set; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.And; +import org.apache.spark.sql.connector.expressions.filter.Not; +import org.apache.spark.sql.connector.expressions.filter.Or; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.unsafe.types.UTF8String; + +/** Translates {@link Predicate Spark V2 predicates} into Vortex {@link Expression}s for predicate pushdown. */ +final class SparkPredicateToVortexExpression { + + private SparkPredicateToVortexExpression() {} + + /** + * Returns true if the given Spark predicate is structurally convertible to a Vortex expression and references only + * the supplied {@code dataColumns}. + * + *

This is the cheap check used in {@code SupportsPushDownV2Filters.pushPredicates} to decide which predicates + * Spark can drop. It does not allocate any native expressions. + */ + static boolean isPushable(Predicate predicate, Set dataColumns) { + for (NamedReference ref : predicate.references()) { + String[] parts = ref.fieldNames(); + if (parts.length != 1) { + return false; + } + if (!dataColumns.contains(parts[0])) { + return false; + } + } + return isStructurallyPushable(predicate); + } + + private static boolean isStructurallyPushable(Predicate predicate) { + if (predicate instanceof And a) { + return isStructurallyPushable(a.left()) && isStructurallyPushable(a.right()); + } + if (predicate instanceof Or o) { + return isStructurallyPushable(o.left()) && isStructurallyPushable(o.right()); + } + if (predicate instanceof Not) { + return isStructurallyPushable(((Not) predicate).child()); + } + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + switch (predicate.name()) { + case "=": + case ">": + case ">=": + case "<": + case "<=": + return children.length == 2 && isTopLevelFieldRef(children[0]) && isPushableLiteral(children[1]); + case "IS_NULL": + case "IS_NOT_NULL": + return children.length == 1 && isTopLevelFieldRef(children[0]); + case "IN": + if (children.length < 2 || !isTopLevelFieldRef(children[0])) { + return false; + } + for (int i = 1; i < children.length; i++) { + if (!isPushableLiteral(children[i])) { + return false; + } + } + return true; + default: + return false; + } + } + + /** + * Converts a Spark predicate to a Vortex expression. Returns {@link Optional#empty()} if the predicate is not + * pushable; callers should normally pre-check with {@link #isPushable}. + */ + static Optional convert(Predicate predicate) { + if (predicate instanceof And a) { + Optional left = convert(a.left()); + Optional right = convert(a.right()); + if (left.isPresent() && right.isPresent()) { + return Optional.of(Expression.and(left.get(), right.get())); + } + return Optional.empty(); + } + if (predicate instanceof Or o) { + Optional left = convert(o.left()); + Optional right = convert(o.right()); + if (left.isPresent() && right.isPresent()) { + return Optional.of(Expression.or(left.get(), right.get())); + } + return Optional.empty(); + } + if (predicate instanceof Not) { + return convert(((Not) predicate).child()).map(Expression::not); + } + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + return switch (predicate.name()) { + case "=", ">", ">=", "<", "<=" -> convertBinary(predicate.name(), children); + case "IS_NULL" -> { + if (children.length != 1) { + yield Optional.empty(); + } + yield columnNameOf(children[0]).map(name -> Expression.isNull(Expression.column(name))); + } + case "IS_NOT_NULL" -> { + if (children.length != 1) { + yield Optional.empty(); + } + yield columnNameOf(children[0]).map(name -> Expression.isNotNull(Expression.column(name))); + } + case "IN" -> convertIn(children); + default -> Optional.empty(); + }; + } + + private static Optional convertBinary( + String op, org.apache.spark.sql.connector.expressions.Expression[] children) { + if (children.length != 2) { + return Optional.empty(); + } + Optional column = columnNameOf(children[0]); + Optional literal = literalOf(children[1]); + if (column.isEmpty() || literal.isEmpty()) { + return Optional.empty(); + } + return Optional.of(Expression.binary(toBinaryOp(op), Expression.column(column.get()), literal.get())); + } + + private static Optional convertIn(org.apache.spark.sql.connector.expressions.Expression[] children) { + if (children.length < 2) { + return Optional.empty(); + } + Optional column = columnNameOf(children[0]); + if (column.isEmpty()) { + return Optional.empty(); + } + Expression columnExpr = Expression.column(column.get()); + Expression[] eqs = new Expression[children.length - 1]; + for (int i = 1; i < children.length; i++) { + Optional literal = literalOf(children[i]); + if (literal.isEmpty()) { + return Optional.empty(); + } + eqs[i - 1] = Expression.binary(BinaryOp.EQ, columnExpr, literal.get()); + } + if (eqs.length == 1) { + return Optional.of(eqs[0]); + } + return Optional.of(Expression.or(eqs)); + } + + private static BinaryOp toBinaryOp(String name) { + return switch (name) { + case "=" -> BinaryOp.EQ; + case ">" -> BinaryOp.GT; + case ">=" -> BinaryOp.GTE; + case "<" -> BinaryOp.LT; + case "<=" -> BinaryOp.LTE; + default -> throw new IllegalArgumentException("not a pushable binary operator: " + name); + }; + } + + private static boolean isTopLevelFieldRef(org.apache.spark.sql.connector.expressions.Expression expr) { + return expr instanceof NamedReference && ((NamedReference) expr).fieldNames().length == 1; + } + + private static Optional columnNameOf(org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof NamedReference)) { + return Optional.empty(); + } + String[] parts = ((NamedReference) expr).fieldNames(); + if (parts.length != 1) { + return Optional.empty(); + } + return Optional.of(parts[0]); + } + + private static boolean isPushableLiteral(org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof LiteralValue)) { + return false; + } + Object v = ((LiteralValue) expr).value(); + return v instanceof Boolean + || v instanceof Byte + || v instanceof Short + || v instanceof Integer + || v instanceof Long + || v instanceof Float + || v instanceof Double + || v instanceof UTF8String + || v instanceof CharSequence; + } + + private static Optional literalOf(org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof LiteralValue)) { + return Optional.empty(); + } + Object value = ((LiteralValue) expr).value(); + if (value == null) { + return Optional.empty(); + } + if (value instanceof Boolean) { + return Optional.of(Expression.literal((Boolean) value)); + } + if (value instanceof Byte) { + return Optional.of(Expression.literal((Byte) value)); + } + if (value instanceof Short) { + return Optional.of(Expression.literal((Short) value)); + } + if (value instanceof Integer) { + return Optional.of(Expression.literal((Integer) value)); + } + if (value instanceof Long) { + return Optional.of(Expression.literal((Long) value)); + } + if (value instanceof Float) { + return Optional.of(Expression.literal((Float) value)); + } + if (value instanceof Double) { + return Optional.of(Expression.literal((Double) value)); + } + if (value instanceof UTF8String || value instanceof CharSequence) { + return Optional.of(Expression.literal(value.toString())); + } + return Optional.empty(); + } +} diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java index e32d26f5643..8df7ce8e1db 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java @@ -16,6 +16,7 @@ import java.util.stream.Stream; import org.apache.spark.sql.connector.catalog.CatalogV2Util; import org.apache.spark.sql.connector.catalog.Column; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -26,6 +27,7 @@ public final class VortexBatchExec implements Batch { private final List paths; private final StructType readSchema; private final Map formatOptions; + private final Predicate[] pushedPredicates; private List resolvedPaths; /** @@ -33,11 +35,15 @@ public final class VortexBatchExec implements Batch { * * @param paths the list of file paths to scan * @param columns the list of columns to read from the files + * @param pushedPredicates predicates pushed down by Spark; converted to a single Vortex filter expression at read + * time */ - public VortexBatchExec(List paths, List columns, Map formatOptions) { + public VortexBatchExec( + List paths, List columns, Map formatOptions, Predicate[] pushedPredicates) { this.paths = List.copyOf(paths); this.readSchema = CatalogV2Util.v2ColumnsToStructType(columns.toArray(new Column[0])); this.formatOptions = Map.copyOf(formatOptions); + this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone(); } /** @@ -66,7 +72,7 @@ public PartitionReaderFactory createReaderFactory() { List dataColumnNames = Arrays.stream(readSchema.fieldNames()) .filter(name -> !partitionColumns.contains(name)) .collect(Collectors.toList()); - return new VortexPartitionReaderFactory(dataColumnNames, formatOptions); + return new VortexPartitionReaderFactory(dataColumnNames, formatOptions, pushedPredicates); } private List resolvePaths() { diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java index 9df44c07e6c..f9cd2363d59 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java @@ -18,6 +18,8 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Optional; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.vectorized.ColumnVector; @@ -46,7 +48,11 @@ final class VortexPartitionReader implements PartitionReader { private boolean currentBatchLoaded; private boolean exhausted; - VortexPartitionReader(VortexFilePartition spark, List dataColumnNames, Map formatOptions) { + VortexPartitionReader( + VortexFilePartition spark, + List dataColumnNames, + Map formatOptions, + Predicate[] pushedPredicates) { this.spark = spark; this.allocator = ArrowAllocation.rootAllocator(); @@ -58,9 +64,24 @@ final class VortexPartitionReader implements PartitionReader { Expression projection = Expression.select(dataColumnNames.toArray(new String[0]), Expression.root()); options.projection(projection); } + if (pushedPredicates != null && pushedPredicates.length > 0) { + buildFilterExpression(pushedPredicates).ifPresent(options::filter); + } scan = dataSource.scan(options.build()); } + private static Optional buildFilterExpression(Predicate[] predicates) { + Expression combined = null; + for (Predicate predicate : predicates) { + Optional expr = SparkPredicateToVortexExpression.convert(predicate); + if (expr.isEmpty()) { + continue; + } + combined = combined == null ? expr.get() : Expression.and(combined, expr.get()); + } + return Optional.ofNullable(combined); + } + @Override public boolean next() { if (exhausted) { diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java index 9ffbfcc3cfb..e187e4863b1 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java @@ -5,10 +5,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import dev.vortex.jni.NativeRuntime; import dev.vortex.spark.VortexFilePartition; import java.io.Serializable; import java.util.List; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -27,10 +30,13 @@ public final class VortexPartitionReaderFactory implements PartitionReaderFactor private final ImmutableList dataColumnNames; private final ImmutableMap formatOptions; + private final Predicate[] pushedPredicates; - public VortexPartitionReaderFactory(List dataColumnNames, java.util.Map formatOptions) { + public VortexPartitionReaderFactory( + List dataColumnNames, Map formatOptions, Predicate[] pushedPredicates) { this.dataColumnNames = ImmutableList.copyOf(dataColumnNames); this.formatOptions = ImmutableMap.copyOf(formatOptions); + this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone(); } @Override @@ -40,8 +46,9 @@ public PartitionReader createReader(InputPartition partition) { @Override public PartitionReader createColumnarReader(InputPartition partition) { + NativeRuntime.setWorkerThreads(Integer.parseInt(formatOptions.getOrDefault("vortex.workerThreads", "4"))); VortexFilePartition spark = (VortexFilePartition) partition; - return new VortexPartitionReader(spark, dataColumnNames, formatOptions); + return new VortexPartitionReader(spark, dataColumnNames, formatOptions, pushedPredicates); } @Override diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java index c6ec03eef80..d5949b57a4d 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java @@ -3,10 +3,12 @@ package dev.vortex.spark.read; +import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.spark.sql.connector.catalog.CatalogV2Util; import org.apache.spark.sql.connector.catalog.Column; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.types.StructType; @@ -17,6 +19,7 @@ public final class VortexScan implements Scan { private final List paths; private final List readColumns; private final Map formatOptions; + private final Predicate[] pushedPredicates; /** * Creates a new VortexScan for the specified file paths and columns. The caller is responsible for passing @@ -24,11 +27,17 @@ public final class VortexScan implements Scan { * * @param paths the list of Vortex file paths to scan * @param readColumns the list of columns to read from the files + * @param pushedPredicates predicates pushed down by Spark; {@code null} or empty means no pushdown */ - public VortexScan(List paths, List readColumns, Map formatOptions) { + public VortexScan( + List paths, + List readColumns, + Map formatOptions, + Predicate[] pushedPredicates) { this.paths = paths; this.readColumns = readColumns; this.formatOptions = formatOptions; + this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone(); } /** @@ -46,7 +55,9 @@ public StructType readSchema() { /** Logging-friendly readable description of the scan source. */ @Override public String description() { - return String.format("VortexScan{paths=%s, columns=%s}", paths, readColumns); + return String.format( + "VortexScan{paths=%s, columns=%s, pushedPredicates=%s}", + paths, readColumns, Arrays.toString(pushedPredicates)); } /** @@ -58,7 +69,7 @@ public String description() { */ @Override public Batch toBatch() { - return new VortexBatchExec(paths, readColumns, formatOptions); + return new VortexBatchExec(paths, readColumns, formatOptions, pushedPredicates); } /** diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java index a53472bc33b..849c151699e 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java @@ -6,27 +6,52 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.spark.sql.connector.catalog.Column; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; /** Spark V2 {@link ScanBuilder} for table scans over Vortex files. */ -public final class VortexScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns { +public final class VortexScanBuilder + implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters { private final ImmutableList.Builder paths; private final List columns; private final Map formatOptions; + private final Set partitionColumnNames; + private Predicate[] pushedPredicates = new Predicate[0]; /** Creates a new VortexScanBuilder with empty paths and columns. */ public VortexScanBuilder(Map formatOptions) { + this(formatOptions, new Transform[0]); + } + + /** + * Creates a new VortexScanBuilder with empty paths and columns and the supplied partition transforms. Filters that + * reference partition columns are not pushed down, since the partition columns are not stored inside the Vortex + * files. + */ + public VortexScanBuilder(Map formatOptions, Transform[] partitionTransforms) { this.paths = ImmutableList.builder(); this.columns = new ArrayList<>(); - this.formatOptions = Map.copyOf(formatOptions); + Map options = Maps.newHashMap(); + options.put("vortex.workerThreads", "4"); + options.putAll(formatOptions); + this.formatOptions = options; + this.partitionColumnNames = collectPartitionColumnNames(partitionTransforms); } /** @@ -89,7 +114,7 @@ public Scan build() { // Allow empty columns for operations like count() that don't need actual column data // If no columns are specified, we'll read the minimal schema needed - return new VortexScan(paths, List.copyOf(this.columns), this.formatOptions); + return new VortexScan(paths, List.copyOf(this.columns), this.formatOptions, pushedPredicates); } /** @@ -108,4 +133,56 @@ public void pruneColumns(StructType requiredSchema) { columns.add(Column.create(field.name(), field.dataType())); } } + + /** + * Splits the supplied predicates into pushed and not-pushed sets. + * + *

A predicate is pushed when it references only data columns (not partition columns) and uses operators and + * literal types that {@link SparkPredicateToVortexExpression} can map to Vortex expressions. Predicates that + * reference partition columns or use unsupported features are returned to Spark for post-scan evaluation. + * + * @return the predicates that Spark must still evaluate + */ + @Override + public Predicate[] pushPredicates(Predicate[] predicates) { + Set dataColumns = new HashSet<>(); + for (Column column : columns) { + if (!partitionColumnNames.contains(column.name())) { + dataColumns.add(column.name()); + } + } + List pushed = new ArrayList<>(); + List postScan = new ArrayList<>(); + for (Predicate predicate : predicates) { + if (SparkPredicateToVortexExpression.isPushable(predicate, dataColumns)) { + pushed.add(predicate); + } else { + postScan.add(predicate); + } + } + this.pushedPredicates = pushed.toArray(new Predicate[0]); + return postScan.toArray(new Predicate[0]); + } + + /** Returns the predicates this scan promises to apply. */ + @Override + public Predicate[] pushedPredicates() { + return Arrays.copyOf(pushedPredicates, pushedPredicates.length); + } + + private static Set collectPartitionColumnNames(Transform[] transforms) { + if (transforms == null || transforms.length == 0) { + return Collections.emptySet(); + } + Set names = new HashSet<>(); + for (Transform transform : transforms) { + for (NamedReference ref : transform.references()) { + String[] parts = ref.fieldNames(); + if (parts.length == 1) { + names.add(parts[0]); + } + } + } + return names; + } } diff --git a/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java b/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java new file mode 100644 index 00000000000..ef572204b1b --- /dev/null +++ b/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +package dev.vortex.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.QueryExecution; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.io.TempDir; + +/** + * Tests that Spark predicate pushdown into the Vortex datasource produces correct results. + * + *

The tests write a Vortex dataset and then read it back applying various Spark filters. The + * {@code VortexScanBuilder.pushFilters} path attempts to translate each filter to a Vortex {@code Expression}; filters + * it cannot translate (or that reference partition columns) are returned to Spark for post-scan evaluation. Either way + * the final result must match the same query against the original DataFrame. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public final class VortexFilterPushdownTest { + + private SparkSession spark; + + @TempDir + Path tempDir; + + @BeforeAll + public void setUp() { + spark = SparkSession.builder() + .appName("VortexFilterPushdownTest") + .master("local[2]") + .config("spark.driver.host", "127.0.0.1") + .config("spark.sql.shuffle.partitions", "2") + .config("spark.sql.adaptive.enabled", "false") + .config("spark.ui.enabled", "false") + .getOrCreate(); + } + + @AfterAll + public void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + @Test + @DisplayName("Equality, comparison, IS NULL, IN, AND/OR/NOT all return correct rows after pushdown") + public void testFilterPushdownCorrectness() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, "alpha", 10L, true), + RowFactory.create(2, "beta", 20L, false), + RowFactory.create(3, "gamma", 30L, true), + RowFactory.create(4, "delta", null, false), + RowFactory.create(5, null, 50L, true)), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("name", DataTypes.StringType, true), + DataTypes.createStructField("amount", DataTypes.LongType, true), + DataTypes.createStructField("flag", DataTypes.BooleanType, false)))); + + Path outputPath = tempDir.resolve("pushdown_basic"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + assertEquals( + List.of(2), idsOf(readDf.filter(readDf.col("id").equalTo(2)).orderBy("id"))); + + assertEquals( + List.of(3, 4, 5), idsOf(readDf.filter(readDf.col("id").gt(2)).orderBy("id"))); + + assertEquals(List.of(1, 2), idsOf(readDf.filter(readDf.col("id").leq(2)).orderBy("id"))); + + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("name").isin("alpha", "gamma")).orderBy("id"))); + + assertEquals( + List.of(4), idsOf(readDf.filter(readDf.col("amount").isNull()).orderBy("id"))); + + assertEquals( + List.of(1, 2, 3, 5), + idsOf(readDf.filter(readDf.col("amount").isNotNull()).orderBy("id"))); + + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("flag") + .equalTo(true) + .and(readDf.col("amount").lt(40L))) + .orderBy("id"))); + + assertEquals( + List.of(1, 4, 5), + idsOf(readDf.filter(readDf.col("id") + .equalTo(1) + .or(readDf.col("amount").isNull()) + .or(readDf.col("name").isNull())) + .orderBy("id"))); + + // NOT around an unsupported predicate (string startsWith) should still produce + // correct results — Spark applies it as a post-scan filter. + assertEquals( + List.of(2, 3, 4), + idsOf(readDf.filter(functions.not(readDf.col("name").startsWith("a"))) + .orderBy("id"))); + } + + @Test + @DisplayName("Filters on partition columns yield correct results without pushdown") + public void testFilterOnPartitionColumn() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, "alpha", "A"), + RowFactory.create(2, "beta", "B"), + RowFactory.create(3, "gamma", "A"), + RowFactory.create(4, "delta", "B")), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("name", DataTypes.StringType, true), + DataTypes.createStructField("group", DataTypes.StringType, true)))); + + Path outputPath = tempDir.resolve("pushdown_partitioned"); + df.write() + .format("vortex") + .partitionBy("group") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("group").equalTo("A")).orderBy("id"))); + + // Predicate spanning partition + data columns must still produce the right answer. + assertEquals( + List.of(3), + idsOf(readDf.filter(readDf.col("group") + .equalTo("A") + .and(readDf.col("id").gt(1))) + .orderBy("id"))); + } + + @Test + @DisplayName("Pushed filters appear in the executed scan node") + public void testPushedFiltersInPlan() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList(RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("label", DataTypes.StringType, true)))); + + Path outputPath = tempDir.resolve("pushdown_plan"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + Dataset filtered = readDf.filter(readDf.col("id").gt(1)); + QueryExecution qe = filtered.queryExecution(); + SparkPlan plan = qe.executedPlan(); + String planString = plan.toString(); + assertTrue( + planString.contains("id > 1"), + "Expected pushed predicate for id > 1 in the executed plan: " + planString); + } + + private static List idsOf(Dataset df) { + return df.collectAsList().stream().map(row -> row.getInt(0)).collect(Collectors.toList()); + } + + @AfterEach + public void cleanupTempFiles() throws IOException { + if (tempDir != null && Files.exists(tempDir)) { + try (Stream paths = Files.walk(tempDir)) { + paths.sorted(Comparator.reverseOrder()).forEach(path -> { + try { + Files.deleteIfExists(path); + } catch (IOException e) { + System.err.println("Failed to delete: " + path); + } + }); + } + } + } +} diff --git a/vortex-jni/src/expression.rs b/vortex-jni/src/expression.rs index aaed86efa75..5eea004a2f8 100644 --- a/vortex-jni/src/expression.rs +++ b/vortex-jni/src/expression.rs @@ -24,7 +24,7 @@ use jni::sys::jint; use jni::sys::jlong; use jni::sys::jshort; use vortex::dtype::FieldName; -use vortex::expr::Expression; +use vortex::expr::{is_not_null, Expression}; use vortex::expr::and_collect; use vortex::expr::get_item; use vortex::expr::is_null; @@ -201,6 +201,16 @@ pub extern "system" fn Java_dev_vortex_jni_NativeExpression_isNull( into_raw(is_null(child)) } +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_isNotNull( + _env: EnvUnowned, + _class: JClass, + child: jlong, +) -> jlong { + let child = unsafe { expr_ref(child) }.clone(); + into_raw(is_not_null(child)) +} + #[unsafe(no_mangle)] pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalBool( _env: EnvUnowned,