Skip to content

Commit a6741e8

Browse files
authored
feat: CometNativeScan per-partition plan serde (#3511)
1 parent d4537bc commit a6741e8

File tree

6 files changed

+238
-74
lines changed

6 files changed

+238
-74
lines changed

native/core/src/execution/planner.rs

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -965,20 +965,31 @@ impl PhysicalPlanner {
965965
))
966966
}
967967
OpStruct::NativeScan(scan) => {
968-
let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
968+
// Extract common data and single partition's file list
969+
// Per-partition injection happens in Scala before sending to native
970+
let common = scan
971+
.common
972+
.as_ref()
973+
.ok_or_else(|| GeneralError("NativeScan missing common data".into()))?;
974+
975+
let data_schema =
976+
convert_spark_types_to_arrow_schema(common.data_schema.as_slice());
969977
let required_schema: SchemaRef =
970-
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
978+
convert_spark_types_to_arrow_schema(common.required_schema.as_slice());
971979
let partition_schema: SchemaRef =
972-
convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice());
973-
let projection_vector: Vec<usize> = scan
980+
convert_spark_types_to_arrow_schema(common.partition_schema.as_slice());
981+
let projection_vector: Vec<usize> = common
974982
.projection_vector
975983
.iter()
976984
.map(|offset| *offset as usize)
977985
.collect();
978986

979-
// Check if this partition has any files (bucketed scan with bucket pruning may have empty partitions)
980-
let partition_files = &scan.file_partitions[self.partition as usize];
987+
let partition_files = scan
988+
.file_partition
989+
.as_ref()
990+
.ok_or_else(|| GeneralError("NativeScan missing file_partition".into()))?;
981991

992+
// Check if this partition has any files (bucketed scan with bucket pruning may have empty partitions)
982993
if partition_files.partitioned_file.is_empty() {
983994
let empty_exec = Arc::new(EmptyExec::new(required_schema));
984995
return Ok((
@@ -988,19 +999,19 @@ impl PhysicalPlanner {
988999
}
9891000

9901001
// Convert the Spark expressions to Physical expressions
991-
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
1002+
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = common
9921003
.data_filters
9931004
.iter()
9941005
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema)))
9951006
.collect();
9961007

997-
let default_values: Option<HashMap<usize, ScalarValue>> = if !scan
1008+
let default_values: Option<HashMap<usize, ScalarValue>> = if !common
9981009
.default_values
9991010
.is_empty()
10001011
{
10011012
// We have default values. Extract the two lists (same length) of values and
10021013
// indexes in the schema, and then create a HashMap to use in the SchemaMapper.
1003-
let default_values: Result<Vec<ScalarValue>, DataFusionError> = scan
1014+
let default_values: Result<Vec<ScalarValue>, DataFusionError> = common
10041015
.default_values
10051016
.iter()
10061017
.map(|expr| {
@@ -1015,7 +1026,7 @@ impl PhysicalPlanner {
10151026
})
10161027
.collect();
10171028
let default_values = default_values?;
1018-
let default_values_indexes: Vec<usize> = scan
1029+
let default_values_indexes: Vec<usize> = common
10191030
.default_values_indexes
10201031
.iter()
10211032
.map(|offset| *offset as usize)
@@ -1037,7 +1048,7 @@ impl PhysicalPlanner {
10371048
.map(|f| f.file_path.clone())
10381049
.expect("partition should have files after empty check");
10391050

1040-
let object_store_options: HashMap<String, String> = scan
1051+
let object_store_options: HashMap<String, String> = common
10411052
.object_store_options
10421053
.iter()
10431054
.map(|(k, v)| (k.clone(), v.clone()))
@@ -1048,10 +1059,8 @@ impl PhysicalPlanner {
10481059
&object_store_options,
10491060
)?;
10501061

1051-
// Comet serializes all partitions' PartitionedFiles, but we only want to read this
1052-
// Spark partition's PartitionedFiles
1053-
let files =
1054-
self.get_partitioned_files(&scan.file_partitions[self.partition as usize])?;
1062+
// Get files for this partition
1063+
let files = self.get_partitioned_files(partition_files)?;
10551064
let file_groups: Vec<Vec<PartitionedFile>> = vec![files];
10561065
let partition_fields: Vec<Field> = partition_schema
10571066
.fields()
@@ -1070,10 +1079,10 @@ impl PhysicalPlanner {
10701079
Some(projection_vector),
10711080
Some(data_filters?),
10721081
default_values,
1073-
scan.session_timezone.as_str(),
1074-
scan.case_sensitive,
1082+
common.session_timezone.as_str(),
1083+
common.case_sensitive,
10751084
self.session_ctx(),
1076-
scan.encryption_enabled,
1085+
common.encryption_enabled,
10771086
)?;
10781087
Ok((
10791088
vec![],

native/proto/src/proto/operator.proto

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,30 +85,29 @@ message Scan {
8585
bool arrow_ffi_safe = 3;
8686
}
8787

88+
// Common data shared by all partitions in split mode (sent once at planning)
89+
message NativeScanCommon {
90+
repeated SparkStructField required_schema = 1;
91+
repeated SparkStructField data_schema = 2;
92+
repeated SparkStructField partition_schema = 3;
93+
repeated spark.spark_expression.Expr data_filters = 4;
94+
repeated int64 projection_vector = 5;
95+
string session_timezone = 6;
96+
repeated spark.spark_expression.Expr default_values = 7;
97+
repeated int64 default_values_indexes = 8;
98+
bool case_sensitive = 9;
99+
map<string, string> object_store_options = 10;
100+
bool encryption_enabled = 11;
101+
string source = 12;
102+
repeated spark.spark_expression.DataType fields = 13;
103+
}
104+
88105
message NativeScan {
89-
repeated spark.spark_expression.DataType fields = 1;
90-
// The source of the scan (e.g. file scan, broadcast exchange, shuffle, etc). This
91-
// is purely for informational purposes when viewing native query plans in
92-
// debug mode.
93-
string source = 2;
94-
repeated SparkStructField required_schema = 3;
95-
repeated SparkStructField data_schema = 4;
96-
repeated SparkStructField partition_schema = 5;
97-
repeated spark.spark_expression.Expr data_filters = 6;
98-
repeated SparkFilePartition file_partitions = 7;
99-
repeated int64 projection_vector = 8;
100-
string session_timezone = 9;
101-
repeated spark.spark_expression.Expr default_values = 10;
102-
repeated int64 default_values_indexes = 11;
103-
bool case_sensitive = 12;
104-
// Options for configuring object stores such as AWS S3, GCS, etc. The key-value pairs are taken
105-
// from Hadoop configuration for compatibility with Hadoop FileSystem implementations of object
106-
// stores.
107-
// The configuration values have hadoop. or spark.hadoop. prefix trimmed. For instance, the
108-
// configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in
109-
// the map.
110-
map<string, string> object_store_options = 13;
111-
bool encryption_enabled = 14;
106+
// Common data shared across partitions (schemas, filters, projections, config)
107+
NativeScanCommon common = 1;
108+
109+
// Single partition's file list (injected at execution time)
110+
SparkFilePartition file_partition = 2;
112111
}
113112

114113
message CsvScan {

spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,17 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {
9797
builder: Operator.Builder,
9898
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
9999
val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
100-
nativeScanBuilder.setSource(scan.simpleStringWithNodeId())
100+
val commonBuilder = OperatorOuterClass.NativeScanCommon.newBuilder()
101+
102+
// Set source in common (used as part of injection key)
103+
commonBuilder.setSource(scan.simpleStringWithNodeId())
101104

102105
val scanTypes = scan.output.flatten { attr =>
103106
serializeDataType(attr.dataType)
104107
}
105108

106109
if (scanTypes.length == scan.output.length) {
107-
nativeScanBuilder.addAllFields(scanTypes.asJava)
110+
commonBuilder.addAllFields(scanTypes.asJava)
108111

109112
// Sink operators don't have children
110113
builder.clearChildren()
@@ -120,7 +123,7 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {
120123
logWarning(s"Unsupported data filter $filter")
121124
}
122125
}
123-
nativeScanBuilder.addAllDataFilters(dataFilters.asJava)
126+
commonBuilder.addAllDataFilters(dataFilters.asJava)
124127
}
125128

126129
val possibleDefaultValues = getExistenceDefaultValues(scan.requiredSchema)
@@ -136,20 +139,15 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {
136139
(Literal(expr), index.toLong.asInstanceOf[java.lang.Long])
137140
}
138141
.unzip
139-
nativeScanBuilder.addAllDefaultValues(
142+
commonBuilder.addAllDefaultValues(
140143
defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava)
141-
nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava)
144+
commonBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava)
142145
}
143146

147+
// Extract object store options from first file (S3 configs apply to all files in scan)
144148
var firstPartition: Option[PartitionedFile] = None
145149
val filePartitions = scan.getFilePartitions()
146-
val filePartitionsProto = filePartitions.map { partition =>
147-
if (firstPartition.isEmpty) {
148-
firstPartition = partition.files.headOption
149-
}
150-
partition2Proto(partition, scan.relation.partitionSchema)
151-
}
152-
nativeScanBuilder.addAllFilePartitions(filePartitionsProto.asJava)
150+
firstPartition = filePartitions.flatMap(_.files.headOption).headOption
153151

154152
val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
155153
val requiredSchema = schema2Proto(scan.requiredSchema.fields)
@@ -166,31 +164,34 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {
166164
val projectionVector = (dataSchemaIndexes ++ partitionSchemaIndexes).map(idx =>
167165
idx.toLong.asInstanceOf[java.lang.Long])
168166

169-
nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava)
167+
commonBuilder.addAllProjectionVector(projectionVector.toIterable.asJava)
170168

171169
// In `CometScanRule`, we ensure partitionSchema is supported.
172170
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)
173171

174-
nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
175-
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
176-
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
177-
nativeScanBuilder.setSessionTimezone(scan.conf.getConfString("spark.sql.session.timeZone"))
178-
nativeScanBuilder.setCaseSensitive(scan.conf.getConf[Boolean](SQLConf.CASE_SENSITIVE))
172+
commonBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
173+
commonBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
174+
commonBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
175+
commonBuilder.setSessionTimezone(scan.conf.getConfString("spark.sql.session.timeZone"))
176+
commonBuilder.setCaseSensitive(scan.conf.getConf[Boolean](SQLConf.CASE_SENSITIVE))
179177

180178
// Collect S3/cloud storage configurations
181179
val hadoopConf = scan.relation.sparkSession.sessionState
182180
.newHadoopConfWithOptions(scan.relation.options)
183181

184-
nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf))
182+
commonBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf))
185183

186184
firstPartition.foreach { partitionFile =>
187185
val objectStoreOptions =
188186
NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri)
189187
objectStoreOptions.foreach { case (key, value) =>
190-
nativeScanBuilder.putObjectStoreOptions(key, value)
188+
commonBuilder.putObjectStoreOptions(key, value)
191189
}
192190
}
193191

192+
// Set common data in NativeScan (file_partition will be populated at execution time)
193+
nativeScanBuilder.setCommon(commonBuilder.build())
194+
194195
Some(builder.setNativeScan(nativeScanBuilder).build())
195196

196197
} else {
@@ -204,6 +205,6 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {
204205
}
205206

206207
override def createExec(nativeOp: Operator, op: CometScanExec): CometNativeExec = {
207-
CometNativeScanExec(nativeOp, op.wrapped, op.session)
208+
CometNativeScanExec(nativeOp, op.wrapped, op.session, op)
208209
}
209210
}

0 commit comments

Comments
 (0)