Skip to content

Commit d88298a

Browse files
bersprocketsMaxGekk
authored andcommitted
[SPARK-52738][SQL] Support aggregating the TIME type with a UDAF when the underlying buffer is an UnsafeRow
### What changes were proposed in this pull request? - Change `BufferSetterGetterUtils` to use `InternalRow.setLong` for setting TIME values rather then `InternalRow.update`. - Change `BufferSetterGetterUtils` to use `InternalRow.getLong` for getting TIME values. - Update the test "udaf with all data types" in `AggregationQuerySuite` so that it checks aggregation with both an unsafe and safe aggregation buffer. Since SPARK-41359, that test has been testing with only a safe aggregation buffer. ### Why are the changes needed? When a query uses a UDAF to aggregate a TIME column , and all other columns are "mutable" (as determined by `UnsafeRow#isMutable`), the aggregator creates an `UnsafeRow` for the low-level aggregation buffer. However, the wrapper of that buffer (`MutableAggregationBufferImpl`) fails to properly set up a field setter function for the TIME column, so it attempts to call `UnsafeRow.update` on the underlying buffer. The `UnsafeRow` instance throws `org.apache.spark.SparkUnsupportedOperationException`: ``` Exception in task 0.0 in stage 0.0 (TID 0) org.apache.spark.SparkUnsupportedOperationException: [UNSUPPORTED_CALL.WITHOUT_SUGGESTION] Cannot call the method "update" of the class "org.apache.spark.sql.catalyst.expressions.UnsafeRow". SQLSTATE: 0A000 ``` See SPARK-52738 for a reproduction example. ### Does this PR introduce _any_ user-facing change? No. The TIME type is not released yet. ### How was this patch tested? Updated a unit test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51430 from bersprockets/time_udaf. Authored-by: Bruce Robbins <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 35baa97 commit d88298a

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ sealed trait BufferSetterGetterUtils {
8484
(row: InternalRow, ordinal: Int) =>
8585
if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
8686

87-
case TimestampType | TimestampNTZType =>
87+
case TimestampType | TimestampNTZType | _: TimeType =>
8888
(row: InternalRow, ordinal: Int) =>
8989
if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
9090

@@ -188,7 +188,7 @@ sealed trait BufferSetterGetterUtils {
188188
row.setNullAt(ordinal)
189189
}
190190

191-
case TimestampType | TimestampNTZType =>
191+
case TimestampType | TimestampNTZType | _: TimeType =>
192192
(row: InternalRow, ordinal: Int, value: Any) =>
193193
if (value != null) {
194194
row.setLong(ordinal, value.asInstanceOf[Long])

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import test.org.apache.spark.sql.MyDoubleAvg
2323
import test.org.apache.spark.sql.MyDoubleSum
2424

2525
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, RandomDataGenerator, Row}
26-
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
26+
import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, UnsafeRow}
2727
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
2828
import org.apache.spark.sql.classic.Dataset
2929
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
@@ -899,11 +899,15 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
899899
ArrayType(IntegerType), MapType(StringType, LongType), struct,
900900
new TestUDT.MyDenseVectorUDT()) ++ dayTimeIntervalTypes ++ unsafeRowMutableFieldTypes ++
901901
timeTypes
902-
// Right now, we will use SortAggregate to handle UDAFs.
903-
// UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortAggregate to use
904-
// UnsafeRow as the aggregation buffer. While, dataTypes will trigger
905-
// SortAggregate to use a safe row as the aggregation buffer.
906-
Seq(dataTypes).foreach { dataTypes =>
902+
// A schema that contains only data types where UnsafeRow.isMutable is true
903+
// will trigger the aggregator to use unsafe row as the aggregation buffer.
904+
// Other dataTypes will trigger the aggregator to use a safe row as the
905+
// aggregation buffer.
906+
//
907+
// Below we want to test with *both* UnsafeRow and safe row as the underlying
908+
// buffer.
909+
val mutableDataTypes = dataTypes.filter(UnsafeRow.isMutable)
910+
Seq(dataTypes, mutableDataTypes).foreach { dataTypes =>
907911
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
908912
StructField(s"col$index", dataType, nullable = true)
909913
}

0 commit comments

Comments
 (0)