diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 78f73f8778b86..43d29ab27e156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -146,9 +146,13 @@ class EquivalentExpressions( // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ConditionalExpression: use its children that will always be evaluated. + // 3. HigherOrderFunction: lambda functions operate in the context of local lambdas and can't + // be called outside of that scope, only the arguments can be evaluated ahead of + // time. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) + case h: HigherOrderFunction => h.arguments case other => skipForShortcut(other).children } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2564d4eab9bd6..b87107e9a79f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -174,6 +174,41 @@ class CodegenContext extends Logging { */ var currentVars: Seq[ExprCode] = null + /** + * Holding a map of current lambda variables. + */ + var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty + + def withLambdaVars( + namedLambdas: Seq[NamedLambdaVariable], + f: Seq[ExprCode] => ExprCode): ExprCode = { + val lambdaVars = namedLambdas.map { lambda => + val id = lambda.exprId.id + if (currentLambdaVars.get(id).nonEmpty) { + throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id) + } + val isNull = if (lambda.nullable) { + JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) + } else { + FalseLiteral + } + val value = addMutableState(javaType(lambda.dataType), "lambdaValue") + val lambdaVar = ExprCode(isNull, JavaCode.global(value, lambda.dataType)) + currentLambdaVars.put(id, lambdaVar) + lambdaVar + } + + val result = f(lambdaVars) + namedLambdas.map(_.exprId.id).foreach(currentLambdaVars.remove) + result + } + + def getLambdaVar(id: Long): ExprCode = { + currentLambdaVars.getOrElse( + id, + throw QueryExecutionErrors.lambdaVariableNotDefinedError(id)) + } + /** * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a * 2-tuple: java type, variable name. @@ -411,29 +446,11 @@ class CodegenContext extends Logging { partitionInitializationStatements.mkString("\n") } - /** - * Holds expressions that are equivalent. Used to perform subexpression elimination - * during codegen. - * - * For expressions that appear more than once, generate additional code to prevent - * recomputing the value. - * - * For example, consider two expression generated from this SQL statement: - * SELECT (col1 + col2), (col1 + col2) / col3. - * - * equivalentExpressions will match the tree containing `col1 + col2` and it will only - * be evaluated once. - */ - private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. private[expressions] var subExprEliminationExprs = Map.empty[ExpressionEquals, SubExprEliminationState] - // The collection of sub-expression result resetting methods that need to be called on each row. - private val subexprFunctions = mutable.ArrayBuffer.empty[String] - val outerClassName = "OuterClass" /** @@ -1064,24 +1081,15 @@ class CodegenContext extends Logging { } } - /** - * Returns the code for subexpression elimination after splitting it if necessary. - */ - def subexprFunctionsCode: String = { - // Whole-stage codegen's subexpression elimination is handled in another code path - assert(currentVars == null || subexprFunctions.isEmpty) - splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) - } - /** * Perform a function which generates a sequence of ExprCodes with a given mapping between - * expressions and common expressions, instead of using the mapping in current context. + * expressions and common expressions. Restores previous mapping after execution. */ def withSubExprEliminationExprs( newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs = newSubExprEliminationExprs + subExprEliminationExprs = oldsubExprEliminationExprs ++ newSubExprEliminationExprs val genCodes = f @@ -1090,25 +1098,26 @@ class CodegenContext extends Logging { genCodes } + private def collectSubExprCodes(subExprStates: Seq[SubExprEliminationState]): Seq[String] = { + subExprStates.flatMap { state => + val codes = collectSubExprCodes(state.children) :+ state.eval.code.toString() + state.eval.code = EmptyBlock + codes + } + } + /** * Evaluates a sequence of `SubExprEliminationState` which represent subexpressions. After * evaluating a subexpression, this method will clean up the code block to avoid duplicate * evaluation. */ def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = { - val code = new StringBuilder() - - subExprStates.foreach { state => - val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code - code.append(currentCode + "\n") - state.eval.code = EmptyBlock - } - - code.toString() + val codes = collectSubExprCodes(subExprStates.toSeq) + splitExpressionsWithCurrentInputs(codes, "subexprFunc_split") } /** - * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen. + * Checks and sets up the state and codegen for subexpression elimination. * * This finds the common subexpressions, generates the code snippets that evaluate those * expressions and populates the mapping of common subexpressions to the generated code snippets. @@ -1141,10 +1150,12 @@ class CodegenContext extends Logging { * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression * evaluation, we can look for generated subexpressions and do replacement. */ - def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { + def subexpressionElimination( + expressions: Seq[Expression], + variablePrefix: String = ""): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val localSubExprEliminationExprsForNonSplit = + val localSubExprEliminationExprs = mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. @@ -1152,23 +1163,51 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. + // + // Filter out any expressions that are already existing subexpressions. This can happen + // when finding common subexpressions inside a lambda function, and the common expression + // does not reference the lambda variables for that function, but top level attributes or + // outer lambda variables. val commonExprs = equivalentExpressions.getCommonSubexpressions + .filter(e => !subExprEliminationExprs.contains(ExpressionEquals(e))) val nonSplitCode = { val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] commonExprs.map { expr => - withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { + withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { val eval = expr.genCode(this) + + val value = addMutableState(javaType(expr.dataType), s"${variablePrefix}subExprValue") + + val isNullLiteral = eval.isNull match { + case TrueLiteral | FalseLiteral => true + case _ => false + } + val (isNull, isNullEvalCode) = if (!isNullLiteral) { + val v = addMutableState(JAVA_BOOLEAN, s"${variablePrefix}subExprIsNull") + (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") + } else { + (eval.isNull, "") + } + + val code = code""" + |${eval.code} + |$isNullEvalCode + |$value = ${eval.value}; + """ + // Collects other subexpressions from the children. val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] expr.foreach { e => - subExprEliminationExprs.get(ExpressionEquals(e)) match { + localSubExprEliminationExprs.get(ExpressionEquals(e)) match { case Some(state) => childrenSubExprs += state case _ => } } - val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) - localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state) + val state = SubExprEliminationState( + ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + childrenSubExprs.toSeq) + localSubExprEliminationExprs.put(ExpressionEquals(expr), state) allStates += state Seq(eval) } @@ -1188,38 +1227,18 @@ class CodegenContext extends Logging { val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold val (subExprsMap, exprCodes) = if (needSplit) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { - val localSubExprEliminationExprs = - mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] commonExprs.zipWithIndex.foreach { case (expr, i) => - val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { - Seq(expr.genCode(this)) - }.head - - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - val isNullLiteral = eval.isNull match { - case TrueLiteral | FalseLiteral => true - case _ => false - } - val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") - } else { - (eval.isNull, "") - } - // Generate the code for this expression tree and wrap it in a function. val fnName = freshName("subExpr") val inputVars = inputVarsForAllFuncs(i) val argList = inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val subExprState = localSubExprEliminationExprs.remove(ExpressionEquals(expr)).get val fn = s""" |private void $fnName(${argList.mkString(", ")}) { - | ${eval.code} - | $isNullEvalCode - | $value = ${eval.value}; + | ${subExprState.eval.code} |} """.stripMargin @@ -1235,7 +1254,7 @@ class CodegenContext extends Logging { val inputVariables = inputVars.map(_.variableName).mkString(", ") val code = code"${addNewFunction(fnName, fn)}($inputVariables);" val state = SubExprEliminationState( - ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + subExprState.eval.copy(code = code), childrenSubExprs.toSeq) localSubExprEliminationExprs.put(ExpressionEquals(expr), state) } @@ -1248,67 +1267,15 @@ class CodegenContext extends Logging { throw SparkException.internalError(errMsg) } else { logInfo(errMsg) - (localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprs, Seq.empty) } } } else { - (localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprs, Seq.empty) } SubExprCodes(subExprsMap.toMap, exprCodes.flatten) } - /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the functions that evaluate those expressions and populates - * the mapping of common subexpressions to the generated functions. - */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { - // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_)) - - // Get all the expressions that appear at least twice and set up the state for subexpression - // elimination. - val commonExprs = equivalentExpressions.getCommonSubexpressions - commonExprs.foreach { expr => - val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - // Generate the code for this expression tree and wrap it in a function. - val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} - """.stripMargin - - // Add a state and a mapping of the common subexpressions that are associate with this - // state. Adding this expression to subExprEliminationExprMap means it will call `fn` - // when it is code generated. This decision should be a cost based one. - // - // The cost of doing subexpression elimination is: - // 1. Extra function call, although this is probably *good* as the JIT can decide to - // inline or not. - // The benefit doing subexpression elimination is: - // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 - // above. - // 2. Less code. - // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with - // at least two nodes) as the cost of doing it is expected to be low. - - val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - subexprFunctions += subExprCode - val state = SubExprEliminationState( - ExprCode(code"$subExprCode", - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType))) - subExprEliminationExprs += ExpressionEquals(expr) -> state - } - } - /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression * elimination will be performed. Subexpression elimination assumes that the code for each @@ -1316,12 +1283,20 @@ class CodegenContext extends Logging { */ def generateExpressions( expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { + doSubexpressionElimination: Boolean = false): (Seq[ExprCode], String) = { // We need to make sure that we do not reuse stateful expressions. This is needed for codegen // as well because some expressions may implement `CodegenFallback`. val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) - if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions) - cleanedExpressions.map(e => e.genCode(this)) + if (doSubexpressionElimination) { + val subExprs = subexpressionElimination(cleanedExpressions) + val generatedExprs = withSubExprEliminationExprs(subExprs.states) { + cleanedExpressions.map(e => e.genCode(this)) + } + val subExprCode = evaluateSubExprEliminationState(subExprs.states.values) + (generatedExprs, subExprCode) + } else { + (cleanedExpressions.map(e => e.genCode(this)), "") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 2e018de07101e..6db00654ad1f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -61,7 +61,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case (NoOp, _) => false case _ => true } - val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) + val (exprVals, evalSubexpr) = + ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { @@ -91,9 +92,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP (code, update) } - // Evaluate all the subexpressions. - val evalSubexpr = ctx.subexprFunctionsCode - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c246d07f189b4..2383ffc0839eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,8 +38,8 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { val ctx = newCodeGenContext() // Do sub-expression elimination for predicates. - val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head - val evalSubexpr = ctx.subexprFunctionsCode + val (evalExprs, evalSubexpr) = ctx.generateExpressions(Seq(predicate), useSubexprElimination) + val eval = evalExprs.head val codeBody = s""" public SpecificPredicate generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 459c1d9a8ba11..d180215783db3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { - val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + val (exprEvals, evalSubexpr) = ctx.generateExpressions(expressions, useSubexprElimination) val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) val numVarLenFields = exprSchemas.count { @@ -299,9 +299,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") - // Evaluate all the subexpression. - val evalSubexpr = ctx.subexprFunctionsCode - val writeExpressions = writeExpressionsToBuffer( ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2a5a38e93706c..beb2a3ac490a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -81,8 +82,7 @@ case class NamedLambdaVariable( exprId: ExprId = NamedExpression.newExprId, value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression - with NamedExpression - with CodegenFallback { + with NamedExpression { override def qualifier: Seq[String] = Seq.empty @@ -103,6 +103,10 @@ case class NamedLambdaVariable( override def simpleString(maxFields: Int): String = { s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.getLambdaVar(exprId.id) + } } /** @@ -114,7 +118,7 @@ case class LambdaFunction( function: Expression, arguments: Seq[NamedExpression], hidden: Boolean = false) - extends Expression with CodegenFallback { + extends Expression { override def children: Seq[Expression] = function +: arguments override def dataType: DataType = function.dataType @@ -132,6 +136,21 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val subExprCodes = ctx.subexpressionElimination(Seq(function), "lambda_") + + val functionCode = ctx.withSubExprEliminationExprs(subExprCodes.states) { + Seq(function.genCode(ctx)) + }.head + + val subExprEval = ctx.evaluateSubExprEliminationState(subExprCodes.states.values) + functionCode.copy(code = code""" + |// lambda common sub-expressions + |$subExprEval + |${functionCode.code} + """) + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): LambdaFunction = copy( @@ -239,6 +258,63 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { val canonicalizedChildren = cleaned.children.map(_.canonicalized) withNewChildren(canonicalizedChildren) } + + + protected def assignAtomic( + atomicRef: String, + value: String, + isNull: String = FalseLiteral, + nullable: Boolean = false) = { + if (nullable) { + s""" + if ($isNull) { + $atomicRef.set(null); + } else { + $atomicRef.set($value); + } + """ + } else { + s"$atomicRef.set($value);" + } + } + + protected def assignArrayElement( + ctx: CodegenContext, + arrayName: String, + elementCode: ExprCode, + elementVar: NamedLambdaVariable, + index: String): String = { + val elementType = elementVar.dataType + val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) + val extractElement = CodeGenerator.getValue(arrayName, elementType, index) + val atomicAssign = assignAtomic(elementAtomic, elementCode.value, + elementCode.isNull, elementVar.nullable) + + if (elementVar.nullable) { + s""" + ${elementCode.value} = $extractElement; + ${elementCode.isNull} = $arrayName.isNullAt($index); + $atomicAssign + """ + } else { + s""" + ${elementCode.value} = $extractElement; + $atomicAssign + """ + } + } + + protected def assignIndex( + ctx: CodegenContext, + indexCode: ExprCode, + indexVar: NamedLambdaVariable, + index: String): String = { + val indexAtomic = ctx.addReferenceObj(indexVar.name, indexVar.value) + s""" + ${indexCode.value} = $index; + ${assignAtomic(indexAtomic, indexCode.value)} + """ + } } /** @@ -284,6 +360,29 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { @@ -312,7 +411,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { case class ArrayTransform( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) @@ -354,6 +453,49 @@ case class ArrayTransform( result } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + + val initialization = CodeGenerator.createArrayData( + arrayData, dataType.elementType, numElements, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + // Some expressions return internal buffers that we have to copy + val copy = if (CodeGenerator.isPrimitiveType(function.dataType)) { + s"${functionCode.value}" + } else { + s"InternalRow.copyValue(${functionCode.value})" + } + val resultNull = if (function.nullable) Some(functionCode.isNull.toString) else None + val resultAssignment = CodeGenerator.setArrayElement(arrayData, dataType.elementType, + i, copy, isNull = resultNull) + + s""" + |final int $numElements = $arg.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "transform" override protected def withNewChildrenInternal( @@ -581,7 +723,7 @@ case class MapFilter( case class ArrayFilter( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: DataType = argument.dataType @@ -622,6 +764,72 @@ case class ArrayFilter( new GenericArrayData(buffer) } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val count = ctx.freshName("count") + val arrayTracker = ctx.freshName("arrayTracker") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val arrayType = dataType.asInstanceOf[ArrayType] + + val trackerInit = CodeGenerator.createArrayData( + arrayTracker, BooleanType, numElements, s" $prettyName failed.") + val resultInit = CodeGenerator.createArrayData( + arrayData, arrayType.elementType, count, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + val resultAssignment = CodeGenerator.setArrayElement(arrayTracker, BooleanType, + i, functionCode.value, isNull = None) + + val getTrackerValue = CodeGenerator.getValue(arrayTracker, BooleanType, i) + val copy = CodeGenerator.createArrayAssignment(arrayData, arrayType.elementType, arg, + j, i, arrayType.containsNull) + + // This takes a two passes to avoid evaluating the predicate multiple times + // The first pass evaluates each element in the array, tracks how many elements + // returned true, and tracks the result of each element in a boolean array `arrayTracker`. + // The second pass copies elements from the original array to the new array created + // based on the number of elements matching the first pass. + + s""" + |final int $numElements = $arg.numElements(); + |$trackerInit + |int $count = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + | if ((boolean)${functionCode.value}) { + | $count++; + | } + |} + | + |$resultInit + |int $j = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | if ($getTrackerValue) { + | $copy + | $j++; + | } + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "filter" override protected def withNewChildrenInternal( @@ -653,7 +861,7 @@ case class ArrayExists( argument: Expression, function: Expression, followThreeValuedLogic: Boolean) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { def this(argument: Expression, function: Expression) = { this( @@ -706,6 +914,50 @@ case class ArrayExists( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val exists = ctx.freshName("exists") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val threeWayLogic = if (followThreeValuedLogic) TrueLiteral else FalseLiteral + + val nullCheck = if (nullable) { + s""" + if ($threeWayLogic && !$exists && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $exists = false; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && !$exists) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (${functionCode.value}) { + | $exists = true; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $exists; + """.stripMargin + }) + }) + } + override def nodeName: String = "exists" override protected def withNewChildrenInternal( @@ -740,7 +992,7 @@ object ArrayExists { case class ArrayForAll( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { override def nullable: Boolean = super.nullable || function.nullable @@ -785,6 +1037,49 @@ case class ArrayForAll( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val forall = ctx.freshName("forall") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + + val nullCheck = if (nullable) { + s""" + if ($forall && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $forall = true; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && $forall) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (!${functionCode.value}) { + | $forall = false; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $forall; + """.stripMargin + }) + }) + } + override def nodeName: String = "forall" override protected def withNewChildrenInternal( @@ -816,7 +1111,7 @@ case class ArrayAggregate( zero: Expression, merge: Expression, finish: Expression) - extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] { + extends HigherOrderFunction with QuaternaryLike[Expression] { def this(argument: Expression, zero: Expression, merge: Expression) = { this(argument, zero, merge, LambdaFunction.identity) @@ -886,6 +1181,114 @@ case class ArrayAggregate( } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } + + protected def assignVar( + varCode: ExprCode, + atomicVar: String, + value: String, + isNull: String, + nullable: Boolean): String = { + val atomicAssign = assignAtomic(atomicVar, value, isNull, nullable) + if (nullable) { + s""" + ${varCode.value} = $value; + ${varCode.isNull} = $isNull; + $atomicAssign + """ + } else { + s""" + ${varCode.value} = $value; + $atomicAssign + """ + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), varCodes => { + val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val i = ctx.freshName("i") + + val zeroCode = zero.genCode(ctx) + val mergeCode = merge.genCode(ctx) + val finishCode = finish.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val mergeAtomic = ctx.addReferenceObj(accForMergeVar.name, + accForMergeVar.value) + val finishAtomic = ctx.addReferenceObj(accForFinishVar.name, + accForFinishVar.value) + + val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType) + val finishJavaType = CodeGenerator.javaType(accForFinishVar.dataType) + + // Some expressions return internal buffers that we have to copy + val mergeCopy = if (CodeGenerator.isPrimitiveType(merge.dataType)) { + s"${mergeCode.value}" + } else { + s"($mergeJavaType)InternalRow.copyValue(${mergeCode.value})" + } + + val nullCheck = if (nullable) { + s"${ev.isNull} = ${finishCode.isNull};" + } else { + "" + } + + val initialAssignment = assignVar(accForMergeCode, mergeAtomic, zeroCode.value, + zeroCode.isNull, zero.nullable) + + val mergeAssignment = assignVar(accForMergeCode, mergeAtomic, mergeCopy, + mergeCode.isNull, merge.nullable) + + val finishAssignment = assignVar(accForFinishCode, finishAtomic, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + + s""" + |final int $numElements = ${arg}.numElements(); + |${zeroCode.code} + |$initialAssignment + | + |for (int $i = 0; $i < $numElements; $i++) { + | $elementAssignment + | ${mergeCode.code} + | $mergeAssignment + |} + | + |$finishAssignment + |${finishCode.code} + |${ev.value} = ${finishCode.value}; + |$nullCheck + """.stripMargin + }) + }) + } + override def nodeName: String = "aggregate" override def first: Expression = argument diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 68d4fe6900073..87baa3b590335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -439,6 +439,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s"failed to match ${toSQLId(funcName)} at `addNewFunction`.") } + def lambdaVariableAlreadyDefinedError(id: Long): Throwable = { + new IllegalArgumentException(s"Lambda variable $id cannot be redefined") + } + + def lambdaVariableNotDefinedError(id: Long): Throwable = { + new IllegalArgumentException( + s"Lambda variable $id is not defined in the current codegen scope") + } + def cannotGenerateCodeForIncomparableTypeError( codeType: String, dataType: DataType): Throwable = { SparkException.internalError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7ce14bcedf4ba..2b61b85ad8152 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -473,42 +473,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { JavaCode.variable("dummy", BooleanType))) // raw testing of basic functionality - { - val ctx = new CodegenContext - val e = ref.genCode(ctx) - // before - ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState( - ExprCode(EmptyBlock, e.isNull, e.value)) - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) - } - - // emulate an actual codegen workload - { - val ctx = new CodegenContext - // before - ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState( + ExprCode(EmptyBlock, e.isNull, e.value)) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + // call withSubExprEliminationExprs, should now contain both + ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) { assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + Seq.empty } + // after, should only contain the original + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) } test("SPARK-23986: freshName can generate duplicated names") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index cc36cd73d6d77..3f3782733eda2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -149,15 +151,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusOne: Expression => Expression = x => x + 1 val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + val plusIndexRepeated: (Expression, Expression) => Expression = + (x, i) => plusIndex(x, i) * plusIndex(x, i) + val plusOneFallback: Expression => Expression = x => CodegenFallbackExpr(x + 1) checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) + checkEvaluation(transform(ai0, plusIndexRepeated), Seq(1, 9, 25)) checkEvaluation(transform(transform(ai0, plusIndex), plusOne), Seq(2, 4, 6)) checkEvaluation(transform(ai1, plusOne), Seq(2, null, 4)) checkEvaluation(transform(ai1, plusIndex), Seq(1, null, 5)) checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) checkEvaluation(transform(ain, plusOne), null) + checkEvaluation(transform(ai0, plusOneFallback), Seq(2, 3, 4)) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) @@ -277,15 +285,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 } + val plusIndexRepeatedEven: (Expression, Expression) => Expression = + (x, i) => ((x + i) * (x + i)) % 2 === 0 + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(filter(ai0, isEven), Seq(2)) checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) checkEvaluation(filter(ai0, indexIsEven), Seq(1, 3)) + checkEvaluation(filter(ai0, plusIndexRepeatedEven), Seq.empty) checkEvaluation(filter(ai1, isEven), Seq.empty) checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) checkEvaluation(filter(ain, isEven), null) checkEvaluation(filter(ain, isNullOrOdd), null) + checkEvaluation(filter(ai0, isEvenFallback), Seq(2)) + val as0 = Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) @@ -321,6 +335,9 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val squareRepeatedEven: Expression => Expression = + x => ((x * x) + (x * x)) % 2 === 0 + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) for (followThreeValuedLogic <- Seq(false, true)) { withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key @@ -329,6 +346,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(exists(ai0, isNullOrOdd), true) checkEvaluation(exists(ai0, alwaysFalse), false) checkEvaluation(exists(ai0, alwaysNull), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(ai0, squareRepeatedEven), true) checkEvaluation(exists(ai1, isEven), if (followThreeValuedLogic) null else false) checkEvaluation(exists(ai1, isNullOrOdd), true) checkEvaluation(exists(ai1, alwaysFalse), false) @@ -337,6 +355,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(exists(ain, isNullOrOdd), null) checkEvaluation(exists(ain, alwaysFalse), null) checkEvaluation(exists(ain, alwaysNull), null) + checkEvaluation(exists(ai0, isEvenFallback), true) } } @@ -383,11 +402,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val squareRepeatedEven: Expression => Expression = + x => ((x * x) + (x * x)) % 2 === 0 + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(forall(ai0, isEven), true) checkEvaluation(forall(ai0, isNullOrOdd), false) checkEvaluation(forall(ai0, alwaysFalse), false) checkEvaluation(forall(ai0, alwaysNull), null) + checkEvaluation(forall(ai0, squareRepeatedEven), true) checkEvaluation(forall(ai1, isEven), false) checkEvaluation(forall(ai1, isNullOrOdd), true) checkEvaluation(forall(ai1, alwaysFalse), false) @@ -401,6 +424,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(forall(ain, alwaysFalse), null) checkEvaluation(forall(ain, alwaysNull), null) + checkEvaluation(forall(ai0, isEvenFallback), true) + val as0 = Literal.create(Seq("a0", "a1", "a2", "a3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) @@ -428,6 +453,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40) checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0) checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null) + checkEvaluation(aggregate( + ai0, + 1, + (acc, elem) => (acc * elem) + (acc * elem), + acc => (acc * acc) + (acc * acc) + ), 4608) val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) @@ -886,3 +917,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ))) } } + +case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback { + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = child.resolved + override def eval(input: InternalRow): Any = child.eval(input) + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr = + copy(child = newChild) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index e9faeba2411ce..dfbbaf59c0752 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -278,7 +278,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel ExprCode(TrueLiteral, oneVar), ExprCode(TrueLiteral, twoVar)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } @@ -408,7 +408,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val exprs = Seq(add1, add1, add2, add2) val ctx = new CodegenContext() - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) val add2State = subExprs.states(ExpressionEquals(add2)) val add1State = subExprs.states(ExpressionEquals(add1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 40112979c6d46..fe01eed633633 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -210,7 +210,7 @@ trait AggregateCodegenSupport val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 24528b6f4da15..5904e0c9cacef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -629,7 +629,7 @@ case class HashAggregateExec( // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, bindReferences[Expression](groupingExpressions, child.output)) - val fastRowKeys = ctx.generateExpressions( + val (fastRowKeys, _) = ctx.generateExpressions( bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") @@ -732,7 +732,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { @@ -778,7 +778,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 70ade390c7336..cc2cfff3c73e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -69,7 +69,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) val exprs = bindReferences[Expression](projectList, child.output) val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { // subexpression elimination - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fc6d3023ed072..13b524646ea77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3790,6 +3790,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { testArrayOfPrimitiveTypeContainsNull() } + test("transform function - subexpression elimination") { + val df = Seq[Seq[Integer]]( + Seq(1, 2, 3, 4, 5) + ).toDF("i") + + var count = spark.sparkContext.longAccumulator + val func = udf((x: Integer) => { + count.add(1) + x + }) + + val result = df.select( + transform(col("i"), x => func(x) + func(x)) + ) + + // Run it once to verify the count of UDF calls + result.collect() + assert(count.value == 5) + + checkAnswer(result, Seq(Row(Seq(2, 4, 6, 8, 10)))) + } + + test("transform function - subexpression elimination inside and outside lambda") { + val df = spark.read.json(Seq( + """ + { + "outer": { + "inner": { + "a": 1, + "b": 2, + "c": 3 + } + }, + "arr": [ + 1, + 2, + 3 + ] + } + """).toDS()) + + val result = df.select( + col("outer.inner.b"), + col("outer.inner.c"), + transform(col("arr"), x => x + col("outer.inner.a") + col("outer.inner.a")) + ) + + checkAnswer(result, Seq(Row(2, 3, Seq(3, 4, 5)))) + } + test("transform function - array for non-primitive type") { val df = Seq( Seq("c", "a", "b"),