Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2706,12 +2706,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val newChild = rewrite(p.child)
val projectList = ArrayBuffer.empty[NamedExpression]
val newPList = p.projectList.map(rewriteSQLFunctions(_, projectList))
if (newPList != newChild.output) {
val newProj = if (newPList != newChild.output) {
p.copy(newPList, Project(newChild.output ++ projectList, newChild))
} else {
assert(projectList.isEmpty)
p.copy(child = newChild)
}
newProj.copyTagsFrom(p)
newProj

case f: Filter if f.resolved && hasSQLFunctionExpression(f.expressions) =>
val newChild = rewrite(f.child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ import org.apache.spark.sql.functions._
class DataFrameTableValuedFunctionsSuite extends QueryTest with RemoteSparkSession {
import testImplicits._

test("preserve plan ID in ResolveSQLFunctions with UDF") {
// Create a simple SQL function / UDF for testing purposes.
spark.sql("""CREATE OR REPLACE FUNCTION funct(x INT) RETURNS STRING RETURN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's wrap the test with withUserDefinedFunction

CASE WHEN x >= 0 THEN 'foo' ELSE 'bar' END""")
val df = spark
.sql("SELECT * FROM VALUES (0, 1)")
.select(expr("funct(col1)").alias("col2"))
// Now use the UDF in a query that will initiate plan rewrite by ResolveSQLFunctions.
df.groupBy("col2").agg(count("*")).explain(true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use checkAnswer?

}

test("explode") {
val actual1 = spark.tvf.explode(array(lit(1), lit(2)))
val expected1 = spark.sql("SELECT * FROM explode(array(1, 2))")
Expand Down