diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala index 1d022489b701b..5420b72a9737e 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala @@ -26,6 +26,7 @@ import scala.util.{Failure, Success} import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} import org.apache.spark.util.SparkThreadUtils.awaitResult @@ -450,4 +451,11 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession { Map("one" -> "1", "two" -> "2")) assert(df.as(StringEncoder).collect().toSet == Set("one", "two")) } + + test("Non-existent columns throw exception") { + val e = intercept[AnalysisException] { + spark.range(10).col("nonexistent") + } + assert(e.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) + } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala index ec169ba114a3d..871e6076d9c78 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect import java.util +import scala.annotation.tailrec import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils +import org.apache.spark.sql.catalyst.util.AttributeNameParser import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral, toTypedExpr} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult @@ -45,7 +47,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{struct, to_json} import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} -import org.apache.spark.sql.types.{Metadata, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, Metadata, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils @@ -451,7 +453,52 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - def col(colName: String): Column = new Column(colName, getPlanId) + def col(colName: String): Column = { + // Validate the column name against the schema. + if (!verifyColName(colName, schema)) { + this.select(colName).isLocal + } + new Column(colName, getPlanId) + } + + /** + * Verify whether the input column name can be resolved with the given schema. Note that this + * method can not 100% match the analyzer behavior, it is designed to try the best to eliminate + * unnecessary validation RPCs. + */ + private def verifyColName(name: String, schema: StructType): Boolean = { + val partsOpt = AttributeNameParser.parseAttributeName(name) + if (partsOpt == null || partsOpt.isEmpty) return false + + @tailrec + def quickVerify(parts: Seq[String], schema: DataType): Boolean = { + if (parts.isEmpty) return true + + val part = parts.head + val rest = parts.tail + + if (part == "*") { + true + } else { + val structSchema = schema match { + case s: StructType => Some(s) + case a: ArrayType if a.elementType.isInstanceOf[StructType] => + Some(a.elementType.asInstanceOf[StructType]) + case _ => None + } + + structSchema match { + case Some(s) => + s.fields.find(_.name == part) match { + case Some(field) => quickVerify(rest, field.dataType) + case None => false + } + case None => false + } + } + } + quickVerify(partsOpt, schema) + } /** @inheritdoc */ def metadataColumn(colName: String): Column = {