diff --git a/java/build.gradle.kts b/java/build.gradle.kts
index bb73c481776..94e9f2fec96 100644
--- a/java/build.gradle.kts
+++ b/java/build.gradle.kts
@@ -81,5 +81,10 @@ allprojects {
}
}
- tasks.register("format").get().dependsOn("spotlessApply")
+ if (project.name == "vortex-spark_2.12") {
+ // vortex-spark_2.12 and vortex-spark_2.13 share a projectDir; format from the 2.13 variant only.
+ tasks.register("format") { enabled = false }
+ } else {
+ tasks.register("format").get().dependsOn("spotlessApply")
+ }
}
diff --git a/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java b/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java
index e4d8978e112..950f75a8789 100644
--- a/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java
+++ b/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java
@@ -73,6 +73,10 @@ public static Expression isNull(Expression child) {
return new Expression(NativeExpression.isNull(child.nativePointer()));
}
+ public static Expression isNotNull(Expression child) {
+ return new Expression(NativeExpression.isNotNull(child.nativePointer()));
+ }
+
public static Expression literal(boolean value) {
return new Expression(NativeExpression.literalBool(value, false));
}
diff --git a/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java b/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java
index fc5bedbe7b5..e35eb1d0215 100644
--- a/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java
+++ b/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java
@@ -27,6 +27,8 @@ private NativeExpression() {}
public static native long isNull(long childPointer);
+ public static native long isNotNull(long childPointer);
+
public static native long literalBool(boolean value, boolean isNull);
public static native long literalI8(byte value, boolean isNull);
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java
index 1d564b8c0e0..b3d7d637504 100644
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java
@@ -20,6 +20,7 @@
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableProvider;
+import org.apache.spark.sql.connector.expressions.Expressions;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.sources.DataSourceRegister;
import org.apache.spark.sql.types.DataType;
@@ -118,6 +119,38 @@ public StructType inferSchema(CaseInsensitiveStringMap options) {
return dataSchema;
}
+ /**
+ * Infers partition transforms by inspecting Hive-style {@code key=value} segments in the first listed file path.
+ *
+ *
Spark calls this before {@link #getTable(StructType, Transform[], Map)} when the caller did not provide
+ * explicit partitioning. Returning identity transforms here lets downstream components (notably
+ * {@link dev.vortex.spark.read.VortexScanBuilder}) tell which schema columns are encoded in the directory layout
+ * rather than stored inside the Vortex files, which matters for predicate pushdown.
+ */
+ @Override
+ public Transform[] inferPartitioning(CaseInsensitiveStringMap options) {
+ var paths = getPaths(options);
+ if (paths.isEmpty()) {
+ return new Transform[0];
+ }
+ var formatOptions = buildDataSourceOptions(options.asCaseSensitiveMap());
+ String pathToInfer = Objects.requireNonNull(Iterables.getLast(paths));
+ if (!pathToInfer.endsWith(".vortex")) {
+ Optional firstFile =
+ NativeFiles.listFiles(VortexSparkSession.get(formatOptions), pathToInfer, formatOptions).stream()
+ .findFirst();
+ if (firstFile.isEmpty()) {
+ return new Transform[0];
+ }
+ pathToInfer = firstFile.get();
+ }
+ Map partitionValues = PartitionPathUtils.parsePartitionValues(pathToInfer);
+ if (partitionValues.isEmpty()) {
+ return new Transform[0];
+ }
+ return partitionValues.keySet().stream().map(Expressions::identity).toArray(Transform[]::new);
+ }
+
/**
* Creates a Vortex table instance with the given schema and properties.
*
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java
index 92cc55ff211..f65f74ccf19 100644
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java
@@ -58,7 +58,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
Map opts = Maps.newHashMap();
opts.putAll(formatOptions);
opts.putAll(options);
- return new VortexScanBuilder(opts)
+ return new VortexScanBuilder(opts, partitionTransforms)
.addAllPaths(paths)
.addAllColumns(Arrays.asList(CatalogV2Util.structTypeToV2Columns(schema)));
}
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java
deleted file mode 100644
index e5cd96a3958..00000000000
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java
+++ /dev/null
@@ -1,94 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// SPDX-FileCopyrightText: Copyright the Vortex contributors
-
-package dev.vortex.spark.read;
-
-import java.util.Iterator;
-import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicLong;
-import java.util.function.ToLongFunction;
-
-final class PrefetchingIterator implements Iterator, AutoCloseable {
- // Global condition variable shared between the prefetcher and consumer threads,
- // to coordinate wake ups for when the buffer may no longer be full.
- private static final Object CONDITION = new Object();
-
- private final BlockingQueue fetched = new LinkedBlockingQueue<>();
- private final Thread producerThread;
- private final Iterator delegate;
- private final AtomicBoolean closed = new AtomicBoolean(false);
- private final AtomicLong bufferBytes = new AtomicLong(0);
- private final long maxBufferSize;
- private final ToLongFunction sizeFunc;
-
- PrefetchingIterator(Iterator delegate, long maxBufferSize, ToLongFunction sizeFunc) {
- this.delegate = delegate;
- this.maxBufferSize = maxBufferSize;
- this.sizeFunc = sizeFunc;
- this.producerThread = new Thread(this::prefetchLoop, "vortex-prefetch-thread");
- producerThread.setDaemon(true);
- producerThread.start();
- }
-
- private void prefetchLoop() {
- try {
- while (!closed.get() && delegate.hasNext()) {
- while (bufferBytes.get() > maxBufferSize) {
- synchronized (CONDITION) {
- CONDITION.wait();
- }
- }
- T nextElem = delegate.next();
- long elemSize = sizeFunc.applyAsLong(nextElem);
- bufferBytes.addAndGet(elemSize);
- fetched.put(nextElem);
- }
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new RuntimeException("Prefetching interrupted", e);
- } catch (Exception e) {
- throw new RuntimeException("Prefetching failed", e);
- } finally {
- closed.set(true);
- }
- }
-
- @Override
- public boolean hasNext() {
- // If the prefetcher is not finished, then we could be waiting for
- // a fetched item.
- while (!closed.get()) {
- if (!fetched.isEmpty()) {
- return true;
- }
- }
- // If the prefetcher is finished, then we can examine fetched and immediately return a result.
- return !fetched.isEmpty();
- }
-
- @Override
- public T next() {
- // We assume that this has been called after hasNext() returned true, so it is
- // safe to call take() without checking if the queue maybe be empty.
- try {
- T nextElem = this.fetched.take();
- long elemSize = sizeFunc.applyAsLong(nextElem);
- bufferBytes.addAndGet(-elemSize);
- // Notify the producer that it may now be able to add more items to the queue.
- synchronized (CONDITION) {
- CONDITION.notify();
- }
- return nextElem;
- } catch (InterruptedException e) {
- throw new RuntimeException("Prefetch queue take interrupted", e);
- }
- }
-
- @Override
- public void close() {
- closed.set(true);
- producerThread.interrupt();
- }
-}
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java
new file mode 100644
index 00000000000..19c59d5d9b0
--- /dev/null
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java
@@ -0,0 +1,235 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright the Vortex contributors
+
+package dev.vortex.spark.read;
+
+import dev.vortex.api.Expression;
+import dev.vortex.api.Expression.BinaryOp;
+import java.util.Optional;
+import java.util.Set;
+import org.apache.spark.sql.connector.expressions.LiteralValue;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.filter.And;
+import org.apache.spark.sql.connector.expressions.filter.Not;
+import org.apache.spark.sql.connector.expressions.filter.Or;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/** Translates {@link Predicate Spark V2 predicates} into Vortex {@link Expression}s for predicate pushdown. */
+final class SparkPredicateToVortexExpression {
+
+ private SparkPredicateToVortexExpression() {}
+
+ /**
+ * Returns true if the given Spark predicate is structurally convertible to a Vortex expression and references only
+ * the supplied {@code dataColumns}.
+ *
+ * This is the cheap check used in {@code SupportsPushDownV2Filters.pushPredicates} to decide which predicates
+ * Spark can drop. It does not allocate any native expressions.
+ */
+ static boolean isPushable(Predicate predicate, Set dataColumns) {
+ for (NamedReference ref : predicate.references()) {
+ String[] parts = ref.fieldNames();
+ if (parts.length != 1) {
+ return false;
+ }
+ if (!dataColumns.contains(parts[0])) {
+ return false;
+ }
+ }
+ return isStructurallyPushable(predicate);
+ }
+
+ private static boolean isStructurallyPushable(Predicate predicate) {
+ if (predicate instanceof And a) {
+ return isStructurallyPushable(a.left()) && isStructurallyPushable(a.right());
+ }
+ if (predicate instanceof Or o) {
+ return isStructurallyPushable(o.left()) && isStructurallyPushable(o.right());
+ }
+ if (predicate instanceof Not) {
+ return isStructurallyPushable(((Not) predicate).child());
+ }
+ org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children();
+ switch (predicate.name()) {
+ case "=":
+ case ">":
+ case ">=":
+ case "<":
+ case "<=":
+ return children.length == 2 && isTopLevelFieldRef(children[0]) && isPushableLiteral(children[1]);
+ case "IS_NULL":
+ case "IS_NOT_NULL":
+ return children.length == 1 && isTopLevelFieldRef(children[0]);
+ case "IN":
+ if (children.length < 2 || !isTopLevelFieldRef(children[0])) {
+ return false;
+ }
+ for (int i = 1; i < children.length; i++) {
+ if (!isPushableLiteral(children[i])) {
+ return false;
+ }
+ }
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ /**
+ * Converts a Spark predicate to a Vortex expression. Returns {@link Optional#empty()} if the predicate is not
+ * pushable; callers should normally pre-check with {@link #isPushable}.
+ */
+ static Optional convert(Predicate predicate) {
+ if (predicate instanceof And a) {
+ Optional left = convert(a.left());
+ Optional right = convert(a.right());
+ if (left.isPresent() && right.isPresent()) {
+ return Optional.of(Expression.and(left.get(), right.get()));
+ }
+ return Optional.empty();
+ }
+ if (predicate instanceof Or o) {
+ Optional left = convert(o.left());
+ Optional right = convert(o.right());
+ if (left.isPresent() && right.isPresent()) {
+ return Optional.of(Expression.or(left.get(), right.get()));
+ }
+ return Optional.empty();
+ }
+ if (predicate instanceof Not) {
+ return convert(((Not) predicate).child()).map(Expression::not);
+ }
+ org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children();
+ return switch (predicate.name()) {
+ case "=", ">", ">=", "<", "<=" -> convertBinary(predicate.name(), children);
+ case "IS_NULL" -> {
+ if (children.length != 1) {
+ yield Optional.empty();
+ }
+ yield columnNameOf(children[0]).map(name -> Expression.isNull(Expression.column(name)));
+ }
+ case "IS_NOT_NULL" -> {
+ if (children.length != 1) {
+ yield Optional.empty();
+ }
+ yield columnNameOf(children[0]).map(name -> Expression.isNotNull(Expression.column(name)));
+ }
+ case "IN" -> convertIn(children);
+ default -> Optional.empty();
+ };
+ }
+
+ private static Optional convertBinary(
+ String op, org.apache.spark.sql.connector.expressions.Expression[] children) {
+ if (children.length != 2) {
+ return Optional.empty();
+ }
+ Optional column = columnNameOf(children[0]);
+ Optional literal = literalOf(children[1]);
+ if (column.isEmpty() || literal.isEmpty()) {
+ return Optional.empty();
+ }
+ return Optional.of(Expression.binary(toBinaryOp(op), Expression.column(column.get()), literal.get()));
+ }
+
+ private static Optional convertIn(org.apache.spark.sql.connector.expressions.Expression[] children) {
+ if (children.length < 2) {
+ return Optional.empty();
+ }
+ Optional column = columnNameOf(children[0]);
+ if (column.isEmpty()) {
+ return Optional.empty();
+ }
+ Expression columnExpr = Expression.column(column.get());
+ Expression[] eqs = new Expression[children.length - 1];
+ for (int i = 1; i < children.length; i++) {
+ Optional literal = literalOf(children[i]);
+ if (literal.isEmpty()) {
+ return Optional.empty();
+ }
+ eqs[i - 1] = Expression.binary(BinaryOp.EQ, columnExpr, literal.get());
+ }
+ if (eqs.length == 1) {
+ return Optional.of(eqs[0]);
+ }
+ return Optional.of(Expression.or(eqs));
+ }
+
+ private static BinaryOp toBinaryOp(String name) {
+ return switch (name) {
+ case "=" -> BinaryOp.EQ;
+ case ">" -> BinaryOp.GT;
+ case ">=" -> BinaryOp.GTE;
+ case "<" -> BinaryOp.LT;
+ case "<=" -> BinaryOp.LTE;
+ default -> throw new IllegalArgumentException("not a pushable binary operator: " + name);
+ };
+ }
+
+ private static boolean isTopLevelFieldRef(org.apache.spark.sql.connector.expressions.Expression expr) {
+ return expr instanceof NamedReference && ((NamedReference) expr).fieldNames().length == 1;
+ }
+
+ private static Optional columnNameOf(org.apache.spark.sql.connector.expressions.Expression expr) {
+ if (!(expr instanceof NamedReference)) {
+ return Optional.empty();
+ }
+ String[] parts = ((NamedReference) expr).fieldNames();
+ if (parts.length != 1) {
+ return Optional.empty();
+ }
+ return Optional.of(parts[0]);
+ }
+
+ private static boolean isPushableLiteral(org.apache.spark.sql.connector.expressions.Expression expr) {
+ if (!(expr instanceof LiteralValue)) {
+ return false;
+ }
+ Object v = ((LiteralValue>) expr).value();
+ return v instanceof Boolean
+ || v instanceof Byte
+ || v instanceof Short
+ || v instanceof Integer
+ || v instanceof Long
+ || v instanceof Float
+ || v instanceof Double
+ || v instanceof UTF8String
+ || v instanceof CharSequence;
+ }
+
+ private static Optional literalOf(org.apache.spark.sql.connector.expressions.Expression expr) {
+ if (!(expr instanceof LiteralValue)) {
+ return Optional.empty();
+ }
+ Object value = ((LiteralValue>) expr).value();
+ if (value == null) {
+ return Optional.empty();
+ }
+ if (value instanceof Boolean) {
+ return Optional.of(Expression.literal((Boolean) value));
+ }
+ if (value instanceof Byte) {
+ return Optional.of(Expression.literal((Byte) value));
+ }
+ if (value instanceof Short) {
+ return Optional.of(Expression.literal((Short) value));
+ }
+ if (value instanceof Integer) {
+ return Optional.of(Expression.literal((Integer) value));
+ }
+ if (value instanceof Long) {
+ return Optional.of(Expression.literal((Long) value));
+ }
+ if (value instanceof Float) {
+ return Optional.of(Expression.literal((Float) value));
+ }
+ if (value instanceof Double) {
+ return Optional.of(Expression.literal((Double) value));
+ }
+ if (value instanceof UTF8String || value instanceof CharSequence) {
+ return Optional.of(Expression.literal(value.toString()));
+ }
+ return Optional.empty();
+ }
+}
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java
index e32d26f5643..8df7ce8e1db 100644
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java
@@ -16,6 +16,7 @@
import java.util.stream.Stream;
import org.apache.spark.sql.connector.catalog.CatalogV2Util;
import org.apache.spark.sql.connector.catalog.Column;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
@@ -26,6 +27,7 @@ public final class VortexBatchExec implements Batch {
private final List paths;
private final StructType readSchema;
private final Map formatOptions;
+ private final Predicate[] pushedPredicates;
private List resolvedPaths;
/**
@@ -33,11 +35,15 @@ public final class VortexBatchExec implements Batch {
*
* @param paths the list of file paths to scan
* @param columns the list of columns to read from the files
+ * @param pushedPredicates predicates pushed down by Spark; converted to a single Vortex filter expression at read
+ * time
*/
- public VortexBatchExec(List paths, List columns, Map formatOptions) {
+ public VortexBatchExec(
+ List paths, List columns, Map formatOptions, Predicate[] pushedPredicates) {
this.paths = List.copyOf(paths);
this.readSchema = CatalogV2Util.v2ColumnsToStructType(columns.toArray(new Column[0]));
this.formatOptions = Map.copyOf(formatOptions);
+ this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone();
}
/**
@@ -66,7 +72,7 @@ public PartitionReaderFactory createReaderFactory() {
List dataColumnNames = Arrays.stream(readSchema.fieldNames())
.filter(name -> !partitionColumns.contains(name))
.collect(Collectors.toList());
- return new VortexPartitionReaderFactory(dataColumnNames, formatOptions);
+ return new VortexPartitionReaderFactory(dataColumnNames, formatOptions, pushedPredicates);
}
private List resolvePaths() {
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java
index 9df44c07e6c..f9cd2363d59 100644
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java
@@ -18,6 +18,8 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.vectorized.ColumnVector;
@@ -46,7 +48,11 @@ final class VortexPartitionReader implements PartitionReader {
private boolean currentBatchLoaded;
private boolean exhausted;
- VortexPartitionReader(VortexFilePartition spark, List dataColumnNames, Map formatOptions) {
+ VortexPartitionReader(
+ VortexFilePartition spark,
+ List dataColumnNames,
+ Map formatOptions,
+ Predicate[] pushedPredicates) {
this.spark = spark;
this.allocator = ArrowAllocation.rootAllocator();
@@ -58,9 +64,24 @@ final class VortexPartitionReader implements PartitionReader {
Expression projection = Expression.select(dataColumnNames.toArray(new String[0]), Expression.root());
options.projection(projection);
}
+ if (pushedPredicates != null && pushedPredicates.length > 0) {
+ buildFilterExpression(pushedPredicates).ifPresent(options::filter);
+ }
scan = dataSource.scan(options.build());
}
+ private static Optional buildFilterExpression(Predicate[] predicates) {
+ Expression combined = null;
+ for (Predicate predicate : predicates) {
+ Optional expr = SparkPredicateToVortexExpression.convert(predicate);
+ if (expr.isEmpty()) {
+ continue;
+ }
+ combined = combined == null ? expr.get() : Expression.and(combined, expr.get());
+ }
+ return Optional.ofNullable(combined);
+ }
+
@Override
public boolean next() {
if (exhausted) {
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java
index 9ffbfcc3cfb..e187e4863b1 100644
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java
@@ -5,10 +5,13 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import dev.vortex.jni.NativeRuntime;
import dev.vortex.spark.VortexFilePartition;
import java.io.Serializable;
import java.util.List;
+import java.util.Map;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
@@ -27,10 +30,13 @@ public final class VortexPartitionReaderFactory implements PartitionReaderFactor
private final ImmutableList dataColumnNames;
private final ImmutableMap formatOptions;
+ private final Predicate[] pushedPredicates;
- public VortexPartitionReaderFactory(List dataColumnNames, java.util.Map formatOptions) {
+ public VortexPartitionReaderFactory(
+ List dataColumnNames, Map formatOptions, Predicate[] pushedPredicates) {
this.dataColumnNames = ImmutableList.copyOf(dataColumnNames);
this.formatOptions = ImmutableMap.copyOf(formatOptions);
+ this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone();
}
@Override
@@ -40,8 +46,9 @@ public PartitionReader createReader(InputPartition partition) {
@Override
public PartitionReader createColumnarReader(InputPartition partition) {
+ NativeRuntime.setWorkerThreads(Integer.parseInt(formatOptions.getOrDefault("vortex.workerThreads", "4")));
VortexFilePartition spark = (VortexFilePartition) partition;
- return new VortexPartitionReader(spark, dataColumnNames, formatOptions);
+ return new VortexPartitionReader(spark, dataColumnNames, formatOptions, pushedPredicates);
}
@Override
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java
index c6ec03eef80..d5949b57a4d 100644
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java
@@ -3,10 +3,12 @@
package dev.vortex.spark.read;
+import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.spark.sql.connector.catalog.CatalogV2Util;
import org.apache.spark.sql.connector.catalog.Column;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.types.StructType;
@@ -17,6 +19,7 @@ public final class VortexScan implements Scan {
private final List paths;
private final List readColumns;
private final Map formatOptions;
+ private final Predicate[] pushedPredicates;
/**
* Creates a new VortexScan for the specified file paths and columns. The caller is responsible for passing
@@ -24,11 +27,17 @@ public final class VortexScan implements Scan {
*
* @param paths the list of Vortex file paths to scan
* @param readColumns the list of columns to read from the files
+ * @param pushedPredicates predicates pushed down by Spark; {@code null} or empty means no pushdown
*/
- public VortexScan(List paths, List readColumns, Map formatOptions) {
+ public VortexScan(
+ List paths,
+ List readColumns,
+ Map formatOptions,
+ Predicate[] pushedPredicates) {
this.paths = paths;
this.readColumns = readColumns;
this.formatOptions = formatOptions;
+ this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone();
}
/**
@@ -46,7 +55,9 @@ public StructType readSchema() {
/** Logging-friendly readable description of the scan source. */
@Override
public String description() {
- return String.format("VortexScan{paths=%s, columns=%s}", paths, readColumns);
+ return String.format(
+ "VortexScan{paths=%s, columns=%s, pushedPredicates=%s}",
+ paths, readColumns, Arrays.toString(pushedPredicates));
}
/**
@@ -58,7 +69,7 @@ public String description() {
*/
@Override
public Batch toBatch() {
- return new VortexBatchExec(paths, readColumns, formatOptions);
+ return new VortexBatchExec(paths, readColumns, formatOptions, pushedPredicates);
}
/**
diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java
index a53472bc33b..849c151699e 100644
--- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java
+++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java
@@ -6,27 +6,52 @@
import static com.google.common.base.Preconditions.checkState;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import org.apache.spark.sql.connector.catalog.Column;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Transform;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns;
+import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
/** Spark V2 {@link ScanBuilder} for table scans over Vortex files. */
-public final class VortexScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns {
+public final class VortexScanBuilder
+ implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters {
private final ImmutableList.Builder paths;
private final List columns;
private final Map formatOptions;
+ private final Set partitionColumnNames;
+ private Predicate[] pushedPredicates = new Predicate[0];
/** Creates a new VortexScanBuilder with empty paths and columns. */
public VortexScanBuilder(Map formatOptions) {
+ this(formatOptions, new Transform[0]);
+ }
+
+ /**
+ * Creates a new VortexScanBuilder with empty paths and columns and the supplied partition transforms. Filters that
+ * reference partition columns are not pushed down, since the partition columns are not stored inside the Vortex
+ * files.
+ */
+ public VortexScanBuilder(Map formatOptions, Transform[] partitionTransforms) {
this.paths = ImmutableList.builder();
this.columns = new ArrayList<>();
- this.formatOptions = Map.copyOf(formatOptions);
+ Map options = Maps.newHashMap();
+ options.put("vortex.workerThreads", "4");
+ options.putAll(formatOptions);
+ this.formatOptions = options;
+ this.partitionColumnNames = collectPartitionColumnNames(partitionTransforms);
}
/**
@@ -89,7 +114,7 @@ public Scan build() {
// Allow empty columns for operations like count() that don't need actual column data
// If no columns are specified, we'll read the minimal schema needed
- return new VortexScan(paths, List.copyOf(this.columns), this.formatOptions);
+ return new VortexScan(paths, List.copyOf(this.columns), this.formatOptions, pushedPredicates);
}
/**
@@ -108,4 +133,56 @@ public void pruneColumns(StructType requiredSchema) {
columns.add(Column.create(field.name(), field.dataType()));
}
}
+
+ /**
+ * Splits the supplied predicates into pushed and not-pushed sets.
+ *
+ * A predicate is pushed when it references only data columns (not partition columns) and uses operators and
+ * literal types that {@link SparkPredicateToVortexExpression} can map to Vortex expressions. Predicates that
+ * reference partition columns or use unsupported features are returned to Spark for post-scan evaluation.
+ *
+ * @return the predicates that Spark must still evaluate
+ */
+ @Override
+ public Predicate[] pushPredicates(Predicate[] predicates) {
+ Set dataColumns = new HashSet<>();
+ for (Column column : columns) {
+ if (!partitionColumnNames.contains(column.name())) {
+ dataColumns.add(column.name());
+ }
+ }
+ List pushed = new ArrayList<>();
+ List postScan = new ArrayList<>();
+ for (Predicate predicate : predicates) {
+ if (SparkPredicateToVortexExpression.isPushable(predicate, dataColumns)) {
+ pushed.add(predicate);
+ } else {
+ postScan.add(predicate);
+ }
+ }
+ this.pushedPredicates = pushed.toArray(new Predicate[0]);
+ return postScan.toArray(new Predicate[0]);
+ }
+
+ /** Returns the predicates this scan promises to apply. */
+ @Override
+ public Predicate[] pushedPredicates() {
+ return Arrays.copyOf(pushedPredicates, pushedPredicates.length);
+ }
+
+ private static Set collectPartitionColumnNames(Transform[] transforms) {
+ if (transforms == null || transforms.length == 0) {
+ return Collections.emptySet();
+ }
+ Set names = new HashSet<>();
+ for (Transform transform : transforms) {
+ for (NamedReference ref : transform.references()) {
+ String[] parts = ref.fieldNames();
+ if (parts.length == 1) {
+ names.add(parts[0]);
+ }
+ }
+ }
+ return names;
+ }
}
diff --git a/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java b/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java
new file mode 100644
index 00000000000..ef572204b1b
--- /dev/null
+++ b/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java
@@ -0,0 +1,227 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright the Vortex contributors
+
+package dev.vortex.spark;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SaveMode;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.QueryExecution;
+import org.apache.spark.sql.execution.SparkPlan;
+import org.apache.spark.sql.functions;
+import org.apache.spark.sql.types.DataTypes;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+import org.junit.jupiter.api.io.TempDir;
+
+/**
+ * Tests that Spark predicate pushdown into the Vortex datasource produces correct results.
+ *
+ * The tests write a Vortex dataset and then read it back applying various Spark filters. The
+ * {@code VortexScanBuilder.pushFilters} path attempts to translate each filter to a Vortex {@code Expression}; filters
+ * it cannot translate (or that reference partition columns) are returned to Spark for post-scan evaluation. Either way
+ * the final result must match the same query against the original DataFrame.
+ */
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
+public final class VortexFilterPushdownTest {
+
+ private SparkSession spark;
+
+ @TempDir
+ Path tempDir;
+
+ @BeforeAll
+ public void setUp() {
+ spark = SparkSession.builder()
+ .appName("VortexFilterPushdownTest")
+ .master("local[2]")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.sql.shuffle.partitions", "2")
+ .config("spark.sql.adaptive.enabled", "false")
+ .config("spark.ui.enabled", "false")
+ .getOrCreate();
+ }
+
+ @AfterAll
+ public void tearDown() {
+ if (spark != null) {
+ spark.stop();
+ }
+ }
+
+ @Test
+ @DisplayName("Equality, comparison, IS NULL, IN, AND/OR/NOT all return correct rows after pushdown")
+ public void testFilterPushdownCorrectness() throws IOException {
+ Dataset df = spark.createDataFrame(
+ Arrays.asList(
+ RowFactory.create(1, "alpha", 10L, true),
+ RowFactory.create(2, "beta", 20L, false),
+ RowFactory.create(3, "gamma", 30L, true),
+ RowFactory.create(4, "delta", null, false),
+ RowFactory.create(5, null, 50L, true)),
+ DataTypes.createStructType(Arrays.asList(
+ DataTypes.createStructField("id", DataTypes.IntegerType, false),
+ DataTypes.createStructField("name", DataTypes.StringType, true),
+ DataTypes.createStructField("amount", DataTypes.LongType, true),
+ DataTypes.createStructField("flag", DataTypes.BooleanType, false))));
+
+ Path outputPath = tempDir.resolve("pushdown_basic");
+ df.write()
+ .format("vortex")
+ .option("path", outputPath.toUri().toString())
+ .mode(SaveMode.Overwrite)
+ .save();
+
+ Dataset readDf = spark.read()
+ .format("vortex")
+ .option("path", outputPath.toUri().toString())
+ .load();
+
+ assertEquals(
+ List.of(2), idsOf(readDf.filter(readDf.col("id").equalTo(2)).orderBy("id")));
+
+ assertEquals(
+ List.of(3, 4, 5), idsOf(readDf.filter(readDf.col("id").gt(2)).orderBy("id")));
+
+ assertEquals(List.of(1, 2), idsOf(readDf.filter(readDf.col("id").leq(2)).orderBy("id")));
+
+ assertEquals(
+ List.of(1, 3),
+ idsOf(readDf.filter(readDf.col("name").isin("alpha", "gamma")).orderBy("id")));
+
+ assertEquals(
+ List.of(4), idsOf(readDf.filter(readDf.col("amount").isNull()).orderBy("id")));
+
+ assertEquals(
+ List.of(1, 2, 3, 5),
+ idsOf(readDf.filter(readDf.col("amount").isNotNull()).orderBy("id")));
+
+ assertEquals(
+ List.of(1, 3),
+ idsOf(readDf.filter(readDf.col("flag")
+ .equalTo(true)
+ .and(readDf.col("amount").lt(40L)))
+ .orderBy("id")));
+
+ assertEquals(
+ List.of(1, 4, 5),
+ idsOf(readDf.filter(readDf.col("id")
+ .equalTo(1)
+ .or(readDf.col("amount").isNull())
+ .or(readDf.col("name").isNull()))
+ .orderBy("id")));
+
+ // NOT around an unsupported predicate (string startsWith) should still produce
+ // correct results — Spark applies it as a post-scan filter.
+ assertEquals(
+ List.of(2, 3, 4),
+ idsOf(readDf.filter(functions.not(readDf.col("name").startsWith("a")))
+ .orderBy("id")));
+ }
+
+ @Test
+ @DisplayName("Filters on partition columns yield correct results without pushdown")
+ public void testFilterOnPartitionColumn() throws IOException {
+ Dataset df = spark.createDataFrame(
+ Arrays.asList(
+ RowFactory.create(1, "alpha", "A"),
+ RowFactory.create(2, "beta", "B"),
+ RowFactory.create(3, "gamma", "A"),
+ RowFactory.create(4, "delta", "B")),
+ DataTypes.createStructType(Arrays.asList(
+ DataTypes.createStructField("id", DataTypes.IntegerType, false),
+ DataTypes.createStructField("name", DataTypes.StringType, true),
+ DataTypes.createStructField("group", DataTypes.StringType, true))));
+
+ Path outputPath = tempDir.resolve("pushdown_partitioned");
+ df.write()
+ .format("vortex")
+ .partitionBy("group")
+ .option("path", outputPath.toUri().toString())
+ .mode(SaveMode.Overwrite)
+ .save();
+
+ Dataset readDf = spark.read()
+ .format("vortex")
+ .option("path", outputPath.toUri().toString())
+ .load();
+
+ assertEquals(
+ List.of(1, 3),
+ idsOf(readDf.filter(readDf.col("group").equalTo("A")).orderBy("id")));
+
+ // Predicate spanning partition + data columns must still produce the right answer.
+ assertEquals(
+ List.of(3),
+ idsOf(readDf.filter(readDf.col("group")
+ .equalTo("A")
+ .and(readDf.col("id").gt(1)))
+ .orderBy("id")));
+ }
+
+ @Test
+ @DisplayName("Pushed filters appear in the executed scan node")
+ public void testPushedFiltersInPlan() throws IOException {
+ Dataset df = spark.createDataFrame(
+ Arrays.asList(RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")),
+ DataTypes.createStructType(Arrays.asList(
+ DataTypes.createStructField("id", DataTypes.IntegerType, false),
+ DataTypes.createStructField("label", DataTypes.StringType, true))));
+
+ Path outputPath = tempDir.resolve("pushdown_plan");
+ df.write()
+ .format("vortex")
+ .option("path", outputPath.toUri().toString())
+ .mode(SaveMode.Overwrite)
+ .save();
+
+ Dataset readDf = spark.read()
+ .format("vortex")
+ .option("path", outputPath.toUri().toString())
+ .load();
+
+ Dataset filtered = readDf.filter(readDf.col("id").gt(1));
+ QueryExecution qe = filtered.queryExecution();
+ SparkPlan plan = qe.executedPlan();
+ String planString = plan.toString();
+ assertTrue(
+ planString.contains("id > 1"),
+ "Expected pushed predicate for id > 1 in the executed plan: " + planString);
+ }
+
+ private static List idsOf(Dataset df) {
+ return df.collectAsList().stream().map(row -> row.getInt(0)).collect(Collectors.toList());
+ }
+
+ @AfterEach
+ public void cleanupTempFiles() throws IOException {
+ if (tempDir != null && Files.exists(tempDir)) {
+ try (Stream paths = Files.walk(tempDir)) {
+ paths.sorted(Comparator.reverseOrder()).forEach(path -> {
+ try {
+ Files.deleteIfExists(path);
+ } catch (IOException e) {
+ System.err.println("Failed to delete: " + path);
+ }
+ });
+ }
+ }
+ }
+}
diff --git a/vortex-jni/src/expression.rs b/vortex-jni/src/expression.rs
index aaed86efa75..5eea004a2f8 100644
--- a/vortex-jni/src/expression.rs
+++ b/vortex-jni/src/expression.rs
@@ -24,7 +24,7 @@ use jni::sys::jint;
use jni::sys::jlong;
use jni::sys::jshort;
use vortex::dtype::FieldName;
-use vortex::expr::Expression;
+use vortex::expr::{is_not_null, Expression};
use vortex::expr::and_collect;
use vortex::expr::get_item;
use vortex::expr::is_null;
@@ -201,6 +201,16 @@ pub extern "system" fn Java_dev_vortex_jni_NativeExpression_isNull(
into_raw(is_null(child))
}
+#[unsafe(no_mangle)]
+pub extern "system" fn Java_dev_vortex_jni_NativeExpression_isNotNull(
+ _env: EnvUnowned,
+ _class: JClass,
+ child: jlong,
+) -> jlong {
+ let child = unsafe { expr_ref(child) }.clone();
+ into_raw(is_not_null(child))
+}
+
#[unsafe(no_mangle)]
pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalBool(
_env: EnvUnowned,