Skip to content

Commit 8aeb23f

Browse files
committed
Add subexpression elimination to higher order functions
1 parent 9b94470 commit 8aeb23f

File tree

5 files changed

+109
-24
lines changed

5 files changed

+109
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,13 +1083,13 @@ class CodegenContext extends Logging {
10831083

10841084
/**
10851085
* Perform a function which generates a sequence of ExprCodes with a given mapping between
1086-
* expressions and common expressions, instead of using the mapping in current context.
1086+
* expressions and common expressions. Restores previous mapping after execution.
10871087
*/
10881088
def withSubExprEliminationExprs(
10891089
newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])(
10901090
f: => Seq[ExprCode]): Seq[ExprCode] = {
10911091
val oldsubExprEliminationExprs = subExprEliminationExprs
1092-
subExprEliminationExprs = newSubExprEliminationExprs
1092+
subExprEliminationExprs = oldsubExprEliminationExprs ++ newSubExprEliminationExprs
10931093

10941094
val genCodes = f
10951095

@@ -1150,7 +1150,9 @@ class CodegenContext extends Logging {
11501150
* (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression
11511151
* evaluation, we can look for generated subexpressions and do replacement.
11521152
*/
1153-
def subexpressionElimination(expressions: Seq[Expression]): SubExprCodes = {
1153+
def subexpressionElimination(
1154+
expressions: Seq[Expression],
1155+
variablePrefix: String = ""): SubExprCodes = {
11541156
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
11551157
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
11561158
val localSubExprEliminationExprs =
@@ -1161,22 +1163,28 @@ class CodegenContext extends Logging {
11611163

11621164
// Get all the expressions that appear at least twice and set up the state for subexpression
11631165
// elimination.
1166+
//
1167+
// Filter out any expressions that are already existing subexpressions. This can happen
1168+
// when finding common subexpressions inside a lambda function, and the common expression
1169+
// does not reference the lambda variables for that function, but top level attributes or
1170+
// outer lambda variables.
11641171
val commonExprs = equivalentExpressions.getCommonSubexpressions
1172+
.filter(e => !subExprEliminationExprs.contains(ExpressionEquals(e)))
11651173

11661174
val nonSplitCode = {
11671175
val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
11681176
commonExprs.map { expr =>
11691177
withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
11701178
val eval = expr.genCode(this)
11711179

1172-
val value = addMutableState(javaType(expr.dataType), "subExprValue")
1180+
val value = addMutableState(javaType(expr.dataType), s"${variablePrefix}subExprValue")
11731181

11741182
val isNullLiteral = eval.isNull match {
11751183
case TrueLiteral | FalseLiteral => true
11761184
case _ => false
11771185
}
11781186
val (isNull, isNullEvalCode) = if (!isNullLiteral) {
1179-
val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
1187+
val v = addMutableState(JAVA_BOOLEAN, s"${variablePrefix}subExprIsNull")
11801188
(JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};")
11811189
} else {
11821190
(eval.isNull, "")
@@ -1191,7 +1199,7 @@ class CodegenContext extends Logging {
11911199
// Collects other subexpressions from the children.
11921200
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
11931201
expr.foreach { e =>
1194-
subExprEliminationExprs.get(ExpressionEquals(e)) match {
1202+
localSubExprEliminationExprs.get(ExpressionEquals(e)) match {
11951203
case Some(state) => childrenSubExprs += state
11961204
case _ =>
11971205
}
@@ -1282,7 +1290,7 @@ class CodegenContext extends Logging {
12821290
if (doSubexpressionElimination) {
12831291
val subExprs = subexpressionElimination(cleanedExpressions)
12841292
val generatedExprs = withSubExprEliminationExprs(subExprs.states) {
1285-
cleanedExpressions.map(e => e.genCode(this))
1293+
cleanedExpressions.map(e => e.genCode(this))
12861294
}
12871295
val subExprCode = evaluateSubExprEliminationState(subExprs.states.values)
12881296
(generatedExprs, subExprCode)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,18 @@ case class LambdaFunction(
137137
override def eval(input: InternalRow): Any = function.eval(input)
138138

139139
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
140-
function.genCode(ctx)
140+
val subExprCodes = ctx.subexpressionElimination(Seq(function), "lambda_")
141+
142+
val functionCode = ctx.withSubExprEliminationExprs(subExprCodes.states) {
143+
Seq(function.genCode(ctx))
144+
}.head
145+
146+
val subExprEval = ctx.evaluateSubExprEliminationState(subExprCodes.states.values)
147+
functionCode.copy(code = code"""
148+
|// lambda common sub-expressions
149+
|$subExprEval
150+
|${functionCode.code}
151+
""")
141152
}
142153

143154
override protected def withNewChildrenInternal(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -473,24 +473,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
473473
JavaCode.variable("dummy", BooleanType)))
474474

475475
// raw testing of basic functionality
476-
{
477-
val ctx = new CodegenContext
478-
val e = ref.genCode(ctx)
479-
// before
480-
ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState(
481-
ExprCode(EmptyBlock, e.isNull, e.value))
482-
assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
483-
// call withSubExprEliminationExprs
484-
ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) {
485-
assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
486-
assert(!ctx.subExprEliminationExprs.contains(wrap(ref)))
487-
Seq.empty
488-
}
489-
// after
490-
assert(ctx.subExprEliminationExprs.nonEmpty)
476+
val ctx = new CodegenContext
477+
val e = ref.genCode(ctx)
478+
// before
479+
ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState(
480+
ExprCode(EmptyBlock, e.isNull, e.value))
481+
assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
482+
// call withSubExprEliminationExprs, should now contain both
483+
ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) {
484+
assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
491485
assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
492-
assert(!ctx.subExprEliminationExprs.contains(wrap(add1)))
486+
Seq.empty
493487
}
488+
// after, should only contain the original
489+
assert(ctx.subExprEliminationExprs.nonEmpty)
490+
assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
491+
assert(!ctx.subExprEliminationExprs.contains(wrap(add1)))
494492
}
495493

496494
test("SPARK-23986: freshName can generate duplicated names") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
151151

152152
val plusOne: Expression => Expression = x => x + 1
153153
val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i
154+
val plusIndexRepeated: (Expression, Expression) => Expression =
155+
(x, i) => plusIndex(x, i) * plusIndex(x, i)
154156
val plusOneFallback: Expression => Expression = x => CodegenFallbackExpr(x + 1)
155157

156158
checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4))
157159
checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5))
160+
checkEvaluation(transform(ai0, plusIndexRepeated), Seq(1, 9, 25))
158161
checkEvaluation(transform(transform(ai0, plusIndex), plusOne), Seq(2, 4, 6))
159162
checkEvaluation(transform(ai1, plusOne), Seq(2, null, 4))
160163
checkEvaluation(transform(ai1, plusIndex), Seq(1, null, 5))
@@ -282,11 +285,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
282285
val isEven: Expression => Expression = x => x % 2 === 0
283286
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
284287
val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 }
288+
val plusIndexRepeatedEven: (Expression, Expression) => Expression =
289+
(x, i) => ((x + i) * (x + i)) % 2 === 0
285290
val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0)
286291

287292
checkEvaluation(filter(ai0, isEven), Seq(2))
288293
checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3))
289294
checkEvaluation(filter(ai0, indexIsEven), Seq(1, 3))
295+
checkEvaluation(filter(ai0, plusIndexRepeatedEven), Seq.empty)
290296
checkEvaluation(filter(ai1, isEven), Seq.empty)
291297
checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3))
292298
checkEvaluation(filter(ain, isEven), null)
@@ -329,6 +335,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
329335
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
330336
val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral
331337
val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType)
338+
val squareRepeatedEven: Expression => Expression =
339+
x => ((x * x) + (x * x)) % 2 === 0
332340
val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0)
333341

334342
for (followThreeValuedLogic <- Seq(false, true)) {
@@ -338,6 +346,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
338346
checkEvaluation(exists(ai0, isNullOrOdd), true)
339347
checkEvaluation(exists(ai0, alwaysFalse), false)
340348
checkEvaluation(exists(ai0, alwaysNull), if (followThreeValuedLogic) null else false)
349+
checkEvaluation(exists(ai0, squareRepeatedEven), true)
341350
checkEvaluation(exists(ai1, isEven), if (followThreeValuedLogic) null else false)
342351
checkEvaluation(exists(ai1, isNullOrOdd), true)
343352
checkEvaluation(exists(ai1, alwaysFalse), false)
@@ -393,12 +402,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
393402
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
394403
val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral
395404
val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType)
405+
val squareRepeatedEven: Expression => Expression =
406+
x => ((x * x) + (x * x)) % 2 === 0
396407
val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0)
397408

398409
checkEvaluation(forall(ai0, isEven), true)
399410
checkEvaluation(forall(ai0, isNullOrOdd), false)
400411
checkEvaluation(forall(ai0, alwaysFalse), false)
401412
checkEvaluation(forall(ai0, alwaysNull), null)
413+
checkEvaluation(forall(ai0, squareRepeatedEven), true)
402414
checkEvaluation(forall(ai1, isEven), false)
403415
checkEvaluation(forall(ai1, isNullOrOdd), true)
404416
checkEvaluation(forall(ai1, alwaysFalse), false)
@@ -441,6 +453,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
441453
checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40)
442454
checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0)
443455
checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null)
456+
checkEvaluation(aggregate(
457+
ai0,
458+
1,
459+
(acc, elem) => (acc * elem) + (acc * elem),
460+
acc => (acc * acc) + (acc * acc)
461+
), 4608)
444462

445463
val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false))
446464
val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true))

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3790,6 +3790,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
37903790
testArrayOfPrimitiveTypeContainsNull()
37913791
}
37923792

3793+
test("transform function - subexpression elimination") {
3794+
val df = Seq[Seq[Integer]](
3795+
Seq(1, 2, 3, 4, 5)
3796+
).toDF("i")
3797+
3798+
var count = spark.sparkContext.longAccumulator
3799+
val func = udf((x: Integer) => {
3800+
count.add(1)
3801+
x
3802+
})
3803+
3804+
val result = df.select(
3805+
transform(col("i"), x => func(x) + func(x))
3806+
)
3807+
3808+
// Run it once to verify the count of UDF calls
3809+
result.collect()
3810+
assert(count.value == 5)
3811+
3812+
checkAnswer(result, Seq(Row(Seq(2, 4, 6, 8, 10))))
3813+
}
3814+
3815+
test("transform function - subexpression elimination inside and outside lambda") {
3816+
val df = spark.read.json(Seq(
3817+
"""
3818+
{
3819+
"outer": {
3820+
"inner": {
3821+
"a": 1,
3822+
"b": 2,
3823+
"c": 3
3824+
}
3825+
},
3826+
"arr": [
3827+
1,
3828+
2,
3829+
3
3830+
]
3831+
}
3832+
""").toDS())
3833+
3834+
val result = df.select(
3835+
col("outer.inner.b"),
3836+
col("outer.inner.c"),
3837+
transform(col("arr"), x => x + col("outer.inner.a") + col("outer.inner.a"))
3838+
)
3839+
3840+
checkAnswer(result, Seq(Row(2, 3, Seq(3, 4, 5))))
3841+
}
3842+
37933843
test("transform function - array for non-primitive type") {
37943844
val df = Seq(
37953845
Seq("c", "a", "b"),

0 commit comments

Comments
 (0)