Skip to content

Commit c00bedb

Browse files
authored
chore: Refactor Cast serde to avoid code duplication (apache#2242)
1 parent 636ce22 commit c00bedb

File tree

17 files changed

+156
-188
lines changed

17 files changed

+156
-188
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -624,14 +624,6 @@ object CometConf extends ShimCometConf {
624624
.booleanConf
625625
.createWithDefault(false)
626626

627-
val COMET_CAST_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] =
628-
conf("spark.comet.cast.allowIncompatible")
629-
.doc(
630-
"Comet is not currently fully compatible with Spark for all cast operations. " +
631-
s"Set this config to true to allow them anyway. $COMPAT_GUIDE.")
632-
.booleanConf
633-
.createWithDefault(false)
634-
635627
val COMET_REGEXP_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] =
636628
conf("spark.comet.regexp.allowIncompatible")
637629
.doc(

docs/source/contributor-guide/benchmarking_aws_ec2.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ $SPARK_HOME/bin/spark-submit \
208208
--conf spark.plugins=org.apache.spark.CometPlugin \
209209
--conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \
210210
--conf spark.comet.enabled=true \
211-
--conf spark.comet.cast.allowIncompatible=true \
211+
--conf spark.comet.expression.allowIncompatible=true \
212212
--conf spark.comet.exec.replaceSortMergeJoin=true \
213213
--conf spark.comet.exec.shuffle.enabled=true \
214214
--conf spark.comet.exec.shuffle.fallbackToColumnar=true \

docs/source/contributor-guide/benchmarking_macos.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ $SPARK_HOME/bin/spark-submit \
144144
--conf spark.comet.exec.shuffle.enableFastEncoding=true \
145145
--conf spark.comet.exec.shuffle.fallbackToColumnar=true \
146146
--conf spark.comet.exec.replaceSortMergeJoin=true \
147-
--conf spark.comet.cast.allowIncompatible=true \
147+
--conf spark.comet.expression.allowIncompatible=true \
148148
/path/to/datafusion-benchmarks/runners/datafusion-comet/tpcbench.py \
149149
--benchmark tpch \
150150
--data /path/to/tpch-data/ \

docs/source/user-guide/compatibility.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ The `native_datafusion` scan has some additional limitations:
7575

7676
### S3 Support with `native_iceberg_compat`
7777

78-
- When using the default AWS S3 endpoint (no custom endpoint configured), a valid region is required. Comet
78+
- When using the default AWS S3 endpoint (no custom endpoint configured), a valid region is required. Comet
7979
will attempt to resolve the region if it is not provided.
8080

8181
## ANSI Mode
@@ -130,7 +130,7 @@ Cast operations in Comet fall into three levels of support:
130130
- **Compatible**: The results match Apache Spark
131131
- **Incompatible**: The results may match Apache Spark for some inputs, but there are known issues where some inputs
132132
will result in incorrect results or exceptions. The query stage will fall back to Spark by default. Setting
133-
`spark.comet.cast.allowIncompatible=true` will allow all incompatible casts to run natively in Comet, but this is not
133+
`spark.comet.expression.allowIncompatible=true` will allow all incompatible casts to run natively in Comet, but this is not
134134
recommended for production use.
135135
- **Unsupported**: Comet does not provide a native version of this cast expression and the query stage will fall back to
136136
Spark.

docs/source/user-guide/configs.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ Comet provides the following configuration settings.
2828
|--------|-------------|---------------|
2929
| spark.comet.batchSize | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 |
3030
| spark.comet.caseConversion.enabled | Java uses locale-specific rules when converting strings to upper or lower case and Rust does not, so we disable upper and lower by default. | false |
31-
| spark.comet.cast.allowIncompatible | Comet is not currently fully compatible with Spark for all cast operations. Set this config to true to allow them anyway. For more information, refer to the Comet Compatibility Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html). | false |
3231
| spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous shuffle for Arrow-based shuffle. | false |
3332
| spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of threads on an executor used for Comet async columnar shuffle. This is the upper bound of total number of shuffle threads per executor. In other words, if the number of cores * the number of shuffle threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than this config. Comet will use this config as the number of shuffle threads per executor instead. | 100 |
3433
| spark.comet.columnar.shuffle.async.thread.num | Number of threads used for Comet async columnar shuffle per shuffle task. Note that more threads means more memory requirement to buffer shuffle data before flushing to disk. Also, more threads may not always improve performance, and should be set based on the number of cores available. | 3 |

docs/source/user-guide/expressions.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ The following Spark expressions are currently available. Any known compatibility
3535

3636
## Binary Arithmetic
3737

38-
| Expression | Notes |
39-
|------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
40-
| Add (`+`) | |
41-
| Subtract (`-`) | |
42-
| Multiply (`*`) | |
43-
| Divide (`/`) | |
44-
| IntegralDivide (`div`) | All operands are cast to DecimalType (in case the input type is not already decima type) with precision 19 and scale 0. Please set `spark.comet.cast.allowIncompatible` to `true` to enable DataFusion’s cast operation for LongType inputs. |
45-
| Remainder (`%`) | |
38+
| Expression | Notes |
39+
|------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
40+
| Add (`+`) | |
41+
| Subtract (`-`) | |
42+
| Multiply (`*`) | |
43+
| Divide (`/`) | |
44+
| IntegralDivide (`div`) | All operands are cast to DecimalType (in case the input type is not already decima type) with precision 19 and scale 0. Please set `spark.comet.expression.allowIncompatible` to `true` to enable DataFusion’s cast operation for LongType inputs. |
45+
| Remainder (`%`) | |
4646

4747
## Binary Try Arithmetic
4848

docs/source/user-guide/kubernetes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ spec:
7676
"spark.plugins": "org.apache.spark.CometPlugin"
7777
"spark.comet.enabled": "true"
7878
"spark.comet.exec.enabled": "true"
79-
"spark.comet.cast.allowIncompatible": "true"
79+
"spark.comet.expression.allowIncompatible": "true"
8080
"spark.comet.exec.shuffle.enabled": "true"
8181
"spark.comet.exec.shuffle.mode": "auto"
8282
"spark.shuffle.manager": "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager"

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,17 @@
1919

2020
package org.apache.comet.expressions
2121

22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression}
2223
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
2324

24-
import org.apache.comet.serde.{Compatible, Incompatible, SupportLevel, Unsupported}
25+
import org.apache.comet.CometConf
26+
import org.apache.comet.CometSparkSessionExtensions.withInfo
27+
import org.apache.comet.serde.{CometExpressionSerde, Compatible, ExprOuterClass, Incompatible, SupportLevel, Unsupported}
28+
import org.apache.comet.serde.ExprOuterClass.Expr
29+
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProtoInternal, serializeDataType}
30+
import org.apache.comet.shims.CometExprShim
2531

26-
object CometCast {
32+
object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
2733

2834
def supportedTypes: Seq[DataType] =
2935
Seq(
@@ -42,6 +48,51 @@ object CometCast {
4248
// TODO add DataTypes.TimestampNTZType for Spark 3.4 and later
4349
// https://github.com/apache/datafusion-comet/issues/378
4450

51+
override def getSupportLevel(cast: Cast): SupportLevel = {
52+
isSupported(cast.child.dataType, cast.dataType, cast.timeZoneId, evalMode(cast))
53+
}
54+
55+
override def convert(
56+
cast: Cast,
57+
inputs: Seq[Attribute],
58+
binding: Boolean): Option[ExprOuterClass.Expr] = {
59+
val childExpr = exprToProtoInternal(cast.child, inputs, binding)
60+
if (childExpr.isDefined) {
61+
castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, evalMode(cast))
62+
} else {
63+
withInfo(cast, cast.child)
64+
None
65+
}
66+
}
67+
68+
/**
69+
* Wrap an already serialized expression in a cast.
70+
*/
71+
def castToProto(
72+
expr: Expression,
73+
timeZoneId: Option[String],
74+
dt: DataType,
75+
childExpr: Expr,
76+
evalMode: CometEvalMode.Value): Option[Expr] = {
77+
serializeDataType(dt) match {
78+
case Some(dataType) =>
79+
val castBuilder = ExprOuterClass.Cast.newBuilder()
80+
castBuilder.setChild(childExpr)
81+
castBuilder.setDatatype(dataType)
82+
castBuilder.setEvalMode(evalModeToProto(evalMode))
83+
castBuilder.setAllowIncompat(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get())
84+
castBuilder.setTimezone(timeZoneId.getOrElse("UTC"))
85+
Some(
86+
ExprOuterClass.Expr
87+
.newBuilder()
88+
.setCast(castBuilder)
89+
.build())
90+
case _ =>
91+
withInfo(expr, s"Unsupported datatype in castToProto: $dt")
92+
None
93+
}
94+
}
95+
4596
def isSupported(
4697
fromType: DataType,
4798
toType: DataType,
@@ -62,7 +113,7 @@ object CometCast {
62113
case DataTypes.TimestampType | DataTypes.DateType | DataTypes.StringType =>
63114
Incompatible()
64115
case _ =>
65-
Unsupported
116+
unsupported(fromType, toType)
66117
}
67118
case (_: DecimalType, _: DecimalType) =>
68119
Compatible()
@@ -98,7 +149,7 @@ object CometCast {
98149
}
99150
}
100151
Compatible()
101-
case _ => Unsupported
152+
case _ => unsupported(fromType, toType)
102153
}
103154
}
104155

@@ -136,7 +187,7 @@ object CometCast {
136187
// https://github.com/apache/datafusion-comet/issues/328
137188
Incompatible(Some("Not all valid formats are supported"))
138189
case _ =>
139-
Unsupported
190+
unsupported(DataTypes.StringType, toType)
140191
}
141192
}
142193

@@ -171,13 +222,13 @@ object CometCast {
171222
isSupported(field.dataType, DataTypes.StringType, timeZoneId, evalMode) match {
172223
case s: Incompatible =>
173224
return s
174-
case Unsupported =>
175-
return Unsupported
225+
case u: Unsupported =>
226+
return u
176227
case _ =>
177228
}
178229
}
179230
Compatible()
180-
case _ => Unsupported
231+
case _ => unsupported(fromType, DataTypes.StringType)
181232
}
182233
}
183234

@@ -187,21 +238,21 @@ object CometCast {
187238
DataTypes.IntegerType =>
188239
// https://github.com/apache/datafusion-comet/issues/352
189240
// this seems like an edge case that isn't important for us to support
190-
Unsupported
241+
unsupported(DataTypes.TimestampType, toType)
191242
case DataTypes.LongType =>
192243
// https://github.com/apache/datafusion-comet/issues/352
193244
Compatible()
194245
case DataTypes.StringType => Compatible()
195246
case DataTypes.DateType => Compatible()
196-
case _ => Unsupported
247+
case _ => unsupported(DataTypes.TimestampType, toType)
197248
}
198249
}
199250

200251
private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
201252
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
202253
DataTypes.FloatType | DataTypes.DoubleType =>
203254
Compatible()
204-
case _ => Unsupported
255+
case _ => unsupported(DataTypes.BooleanType, toType)
205256
}
206257

207258
private def canCastFromByte(toType: DataType): SupportLevel = toType match {
@@ -212,7 +263,7 @@ object CometCast {
212263
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
213264
Compatible()
214265
case _ =>
215-
Unsupported
266+
unsupported(DataTypes.ByteType, toType)
216267
}
217268

218269
private def canCastFromShort(toType: DataType): SupportLevel = toType match {
@@ -223,7 +274,7 @@ object CometCast {
223274
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
224275
Compatible()
225276
case _ =>
226-
Unsupported
277+
unsupported(DataTypes.ShortType, toType)
227278
}
228279

229280
private def canCastFromInt(toType: DataType): SupportLevel = toType match {
@@ -236,7 +287,7 @@ object CometCast {
236287
case _: DecimalType =>
237288
Incompatible(Some("No overflow check"))
238289
case _ =>
239-
Unsupported
290+
unsupported(DataTypes.IntegerType, toType)
240291
}
241292

242293
private def canCastFromLong(toType: DataType): SupportLevel = toType match {
@@ -249,7 +300,7 @@ object CometCast {
249300
case _: DecimalType =>
250301
Incompatible(Some("No overflow check"))
251302
case _ =>
252-
Unsupported
303+
unsupported(DataTypes.LongType, toType)
253304
}
254305

255306
private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
@@ -259,7 +310,8 @@ object CometCast {
259310
case _: DecimalType =>
260311
// https://github.com/apache/datafusion-comet/issues/1371
261312
Incompatible(Some("There can be rounding differences"))
262-
case _ => Unsupported
313+
case _ =>
314+
unsupported(DataTypes.FloatType, toType)
263315
}
264316

265317
private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
@@ -269,14 +321,17 @@ object CometCast {
269321
case _: DecimalType =>
270322
// https://github.com/apache/datafusion-comet/issues/1371
271323
Incompatible(Some("There can be rounding differences"))
272-
case _ => Unsupported
324+
case _ => unsupported(DataTypes.DoubleType, toType)
273325
}
274326

275327
private def canCastFromDecimal(toType: DataType): SupportLevel = toType match {
276328
case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
277329
DataTypes.IntegerType | DataTypes.LongType =>
278330
Compatible()
279-
case _ => Unsupported
331+
case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not supported"))
280332
}
281333

334+
private def unsupported(fromType: DataType, toType: DataType): Unsupported = {
335+
Unsupported(Some(s"Cast from $fromType to $toType is not supported"))
336+
}
282337
}

0 commit comments

Comments
 (0)