Skip to content

[SPARK-52082][PYTHON] Improve ExtractPythonUDF docs #50867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,64 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
e.exists(PythonUDF.isScalarPythonUDF)
}

/**
* Return true if we should extract the current expression, including all of its current
* children (including UDF expression, and all others), to a logical node.
* The children of the expression can be udf expressions, this would be `chaining`.
* If child udf expressions were already extracted before, then this will just extract
* the current udf expression, so they will end up in separate logical nodes. The child
* expressions will have been transformed to Attribute expressions referencing the child plan
* node's output.
*
* Return false if there is no single continuous chain of UDFs that can be extracted:
* - if there are other expression in-between, e.g., foo(1 + bar(baz())), return false. The
* caller will have to extract bar(baz()) separately first.
* - if the eval types of the udf expressions in the chain differ, return false.
* - if a UDF has more than one child, e.g. foo(bar(), baz()), return false
* If we return false here, the expectation is that the recursive calls of
* collectEvaluableUDFsFromExpressions will then visit the children and extract them first to
* separate nodes.
*/
@scala.annotation.tailrec
private def canEvaluateInPython(e: PythonUDF): Boolean = {
private def shouldExtractUDFExpressionTree(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => correctEvalType(e) == correctEvalType(u) && canEvaluateInPython(u)
case Seq(child: PythonUDF) => correctEvalType(e) == correctEvalType(child) &&
shouldExtractUDFExpressionTree(child)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasScalarPythonUDF)
}
}

/**
* We use the following terminology:
* - fusing is the act of combining multiple UDFs into a single logical node. This can be
* accomplished in different cases:
* - if the UDFs are siblings, e.g., foo(x), bar(x) - we refer to this as parallel fusing,
* where multiple independent UDFs are evaluated together over the same input.
* - if the UDFs are nested, e.g., foo(bar(...)) - we refer to this as chained fusing
* or chaining, where the output of one UDF feeds into the next in a sequential pipeline.
*
* collectEvaluableUDFsFromExpressions returns a list of UDF expressions that can be planned
* together into one plan node. collectEvaluableUDFsFromExpressions will be called multiple times
* by recursive calls of extract(plan), until no more evaluable UDFs are found.
*
* As an example, consider the following expression tree:
* udf1(udf2(udf3(x)), udf4(x))), where all UDFs are PythonUDFs of the same evaltype.
* We can only fuse UDFs of the same eval type, and never UDFs of SQL_SCALAR_PANDAS_ITER_UDF.
* The following udf expressions will be returned:
* - First, we will return Seq(udf3, udf4), as these two UDFs must be evaluated first.
* We return both in one Seq, as it is possible to do parallel fusing for udf3 an udf4.
* - As we can only chain UDFs with exactly one child, we will not fuse udf2 with its children.
* But we can chain udf1 and udf2, so a later call to collectEvaluableUDFsFromExpressions will
* return Seq(udf1, udf2).
*/
private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
// If first UDF is SQL_SCALAR_PANDAS_ITER_UDF, then only return this UDF,
// otherwise check if subsequent UDFs are of the same type as the first UDF. (since we can only
// extract UDFs of the same eval type)
// otherwise check if subsequent UDFs are of the same type as the first UDF.

var firstVisitedScalarUDFEvalType: Option[Int] = None

def canChainUDF(evalType: Int): Boolean = {
def canFuseWithParallelUDFs(evalType: Int): Boolean = {
if (evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF) {
false
} else {
Expand All @@ -195,12 +235,14 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
}

def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf)
&& shouldExtractUDFExpressionTree(udf)
&& firstVisitedScalarUDFEvalType.isEmpty =>
firstVisitedScalarUDFEvalType = Some(correctEvalType(udf))
Seq(udf)
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& canChainUDF(correctEvalType(udf)) =>
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf)
&& shouldExtractUDFExpressionTree(udf)
&& canFuseWithParallelUDFs(correctEvalType(udf)) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs)
}
Expand Down