Skip to content

Commit aa82116

Browse files
authored
Refactor serde for Alias and AttributeReference (apache#2290)
1 parent ab8a7b2 commit aa82116

File tree

2 files changed

+92
-52
lines changed

2 files changed

+92
-52
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
7575
* Mapping of Spark expression class to Comet expression handler.
7676
*/
7777
private val exprSerdeMap: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
78+
classOf[AttributeReference] -> CometAttributeReference,
79+
classOf[Alias] -> CometAlias,
7880
classOf[Add] -> CometAdd,
7981
classOf[Subtract] -> CometSubtract,
8082
classOf[Multiply] -> CometMultiply,
@@ -626,13 +628,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
626628
}
627629

628630
versionSpecificExprToProtoInternal(expr, inputs, binding).orElse(expr match {
629-
case a @ Alias(_, _) =>
630-
val r = exprToProtoInternal(a.child, inputs, binding)
631-
if (r.isEmpty) {
632-
withInfo(expr, a.child)
633-
}
634-
r
635-
636631
case cast @ Cast(_: Literal, dataType, _, _) =>
637632
// This can happen after promoting decimal precisions
638633
val value = cast.eval()
@@ -878,51 +873,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
878873
None
879874
}
880875

881-
case attr: AttributeReference =>
882-
val dataType = serializeDataType(attr.dataType)
883-
884-
if (dataType.isDefined) {
885-
if (binding) {
886-
// Spark may produce unresolvable attributes in some cases,
887-
// for example https://github.com/apache/datafusion-comet/issues/925.
888-
// So, we allow the binding to fail.
889-
val boundRef: Any = BindReferences
890-
.bindReference(attr, inputs, allowFailures = true)
891-
892-
if (boundRef.isInstanceOf[AttributeReference]) {
893-
withInfo(attr, s"cannot resolve $attr among ${inputs.mkString(", ")}")
894-
return None
895-
}
896-
897-
val boundExpr = ExprOuterClass.BoundReference
898-
.newBuilder()
899-
.setIndex(boundRef.asInstanceOf[BoundReference].ordinal)
900-
.setDatatype(dataType.get)
901-
.build()
902-
903-
Some(
904-
ExprOuterClass.Expr
905-
.newBuilder()
906-
.setBound(boundExpr)
907-
.build())
908-
} else {
909-
val unboundRef = ExprOuterClass.UnboundReference
910-
.newBuilder()
911-
.setName(attr.name)
912-
.setDatatype(dataType.get)
913-
.build()
914-
915-
Some(
916-
ExprOuterClass.Expr
917-
.newBuilder()
918-
.setUnbound(unboundRef)
919-
.build())
920-
}
921-
} else {
922-
withInfo(attr, s"unsupported datatype: ${attr.dataType}")
923-
None
924-
}
925-
926876
// abs implementation is not correct
927877
// https://github.com/apache/datafusion-comet/issues/666
928878
// case Abs(child, failOnErr) =>
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BindReferences, BoundReference}
23+
24+
import org.apache.comet.CometSparkSessionExtensions.withInfo
25+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType}
26+
27+
object CometAlias extends CometExpressionSerde[Alias] {
28+
override def convert(
29+
a: Alias,
30+
inputs: Seq[Attribute],
31+
binding: Boolean): Option[ExprOuterClass.Expr] = {
32+
val r = exprToProtoInternal(a.child, inputs, binding)
33+
if (r.isEmpty) {
34+
withInfo(a, a.child)
35+
}
36+
r
37+
}
38+
}
39+
40+
object CometAttributeReference extends CometExpressionSerde[AttributeReference] {
41+
override def convert(
42+
attr: AttributeReference,
43+
inputs: Seq[Attribute],
44+
binding: Boolean): Option[ExprOuterClass.Expr] = {
45+
val dataType = serializeDataType(attr.dataType)
46+
47+
if (dataType.isDefined) {
48+
if (binding) {
49+
// Spark may produce unresolvable attributes in some cases,
50+
// for example https://github.com/apache/datafusion-comet/issues/925.
51+
// So, we allow the binding to fail.
52+
val boundRef: Any = BindReferences
53+
.bindReference(attr, inputs, allowFailures = true)
54+
55+
if (boundRef.isInstanceOf[AttributeReference]) {
56+
withInfo(attr, s"cannot resolve $attr among ${inputs.mkString(", ")}")
57+
return None
58+
}
59+
60+
val boundExpr = ExprOuterClass.BoundReference
61+
.newBuilder()
62+
.setIndex(boundRef.asInstanceOf[BoundReference].ordinal)
63+
.setDatatype(dataType.get)
64+
.build()
65+
66+
Some(
67+
ExprOuterClass.Expr
68+
.newBuilder()
69+
.setBound(boundExpr)
70+
.build())
71+
} else {
72+
val unboundRef = ExprOuterClass.UnboundReference
73+
.newBuilder()
74+
.setName(attr.name)
75+
.setDatatype(dataType.get)
76+
.build()
77+
78+
Some(
79+
ExprOuterClass.Expr
80+
.newBuilder()
81+
.setUnbound(unboundRef)
82+
.build())
83+
}
84+
} else {
85+
withInfo(attr, s"unsupported datatype: ${attr.dataType}")
86+
None
87+
}
88+
89+
}
90+
}

0 commit comments

Comments
 (0)