Skip to content

Commit

Permalink
Retrieve supportedDepth from JNI
Browse files Browse the repository at this point in the history
Signed-off-by: Yan Feng <[email protected]>
  • Loading branch information
ustcfy committed Dec 12, 2024
1 parent 7d7d57c commit 4d7cbfd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -8480,9 +8480,9 @@ are limited.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>Nested levels exceeding 8 layers are not supported;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types DECIMAL, BINARY, CALENDAR, MAP, UDT, DAYTIME, YEARMONTH</em></td>
<td><em>PS<br/>The nesting depth has a certain limit;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types DECIMAL, BINARY, CALENDAR, MAP, UDT, DAYTIME, YEARMONTH</em></td>
<td><b>NS</b></td>
<td><em>PS<br/>Nested levels exceeding 8 layers are not supported;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types DECIMAL, BINARY, CALENDAR, MAP, UDT, DAYTIME, YEARMONTH</em></td>
<td><em>PS<br/>The nesting depth has a certain limit;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types DECIMAL, BINARY, CALENDAR, MAP, UDT, DAYTIME, YEARMONTH</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.util.control.NonFatal
import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF}
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.jni.Hash
import com.nvidia.spark.rapids.lore.GpuLore
import com.nvidia.spark.rapids.shims._
import com.nvidia.spark.rapids.window.{GpuDenseRank, GpuLag, GpuLead, GpuPercentRank, GpuRank, GpuRowNumber, GpuSpecialFrameBoundary, GpuWindowExecMeta, GpuWindowSpecDefinitionMeta}
Expand Down Expand Up @@ -3321,23 +3322,24 @@ object GpuOverrides extends Logging {
ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT,
repeatingParamCheck = Some(RepeatingParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested() +
TypeSig.psNote(TypeEnum.ARRAY, "Nested levels exceeding 8 layers are not supported") +
TypeSig.psNote(TypeEnum.STRUCT, "Nested levels exceeding 8 layers are not supported"),
TypeSig.psNote(TypeEnum.ARRAY, "The nesting depth has a certain limit") +
TypeSig.psNote(TypeEnum.STRUCT, "The nesting depth has a certain limit"),
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[HiveHash](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
def getMaxNestedDepth(inputType: DataType): Int = {
def getMaxStackDepth(inputType: DataType): Int = {
inputType match {
case at: ArrayType => 1 + getMaxNestedDepth(at.elementType)
case at: ArrayType => 1 + getMaxStackDepth(at.elementType)
case st: StructType =>
1 + st.map(f => getMaxNestedDepth(f.dataType)).max
1 + st.map(f => getMaxStackDepth(f.dataType)).max
case _ => 0 // primitive types
}
}
val maxDepth = a.children.map(c => getMaxNestedDepth(c.dataType)).max
if (maxDepth > 8) {
willNotWorkOnGpu(s"GPU HiveHash supports 8 levels at most for " +
s"nested types, but got $maxDepth")
val maxDepth = a.children.map(c => getMaxStackDepth(c.dataType)).max
val supportedDepth = Hash.MAX_STACK_DEPTH
if (maxDepth > supportedDepth) {
willNotWorkOnGpu(s"the data type requires a stack size of $maxDepth, " +
s"which exceeds the GPU limit of $supportedDepth")
}
}

Expand Down

0 comments on commit 4d7cbfd

Please sign in to comment.