Skip to content

Commit

Permalink
Merge pull request #11339 from NVIDIA/merge-branch-24.08-to-main
Browse files Browse the repository at this point in the history
Merge branch-24.08 into main
  • Loading branch information
nvauto authored Aug 16, 2024
2 parents fd331a5 + d60008b commit 52b43d6
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 34 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Change log
Generated on 2024-08-12
Generated on 2024-08-16

## Release 24.08

Expand Down Expand Up @@ -88,6 +88,8 @@ Generated on 2024-08-12
### PRs
|||
|:---|:---|
|[#11335](https://github.com/NVIDIA/spark-rapids/pull/11335)|Fix Delta Lake truncation of min/max string values|
|[#11304](https://github.com/NVIDIA/spark-rapids/pull/11304)|Update changelog for v24.08.0 release [skip ci]|
|[#11303](https://github.com/NVIDIA/spark-rapids/pull/11303)|Update rapids JNI and private dependency to 24.08.0|
|[#11296](https://github.com/NVIDIA/spark-rapids/pull/11296)|[DOC] update doc for 2408 release [skip CI]|
|[#11309](https://github.com/NVIDIA/spark-rapids/pull/11309)|[Doc ]Update lore doc about the range [skip ci]|
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* This file was derived from StatisticsCollection.scala
* in the Delta Lake project at https://github.com/delta-io/delta.
Expand Down Expand Up @@ -31,7 +31,7 @@ import com.nvidia.spark.rapids.delta.shims.{ShimDeltaColumnMapping, ShimDeltaUDF
import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.functions.{count, lit, max, min, struct, substring, sum, when}
import org.apache.spark.sql.functions.{count, lit, max, min, struct, sum, when}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -87,7 +87,9 @@ trait GpuStatisticsCollection extends ShimUsesMetadataFields {
collectStats(MIN, statCollectionSchema) {
// Truncate string min values as necessary
case (c, GpuSkippingEligibleDataType(StringType), true) =>
substring(min(c), 0, stringPrefixLength)
val udfTruncateMin = ShimDeltaUDF.stringStringUdf(
GpuStatisticsCollection.truncateMinStringAgg(prefixLength)_)
udfTruncateMin(min(c))

// Collect all numeric min values
case (c, GpuSkippingEligibleDataType(_), true) =>
Expand Down Expand Up @@ -203,25 +205,76 @@ trait GpuStatisticsCollection extends ShimUsesMetadataFields {
}

object GpuStatisticsCollection {
val ASCII_MAX_CHARACTER = '\u007F'

val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT))

def truncateMinStringAgg(prefixLen: Int)(input: String): String = {
if (input == null || input.length <= prefixLen) {
return input
}
if (prefixLen <= 0) {
return null
}
if (Character.isHighSurrogate(input.charAt(prefixLen - 1)) &&
Character.isLowSurrogate(input.charAt(prefixLen))) {
// If the character at prefixLen - 1 is a high surrogate and the next character is a low
// surrogate, we need to include the next character in the prefix to ensure that we don't
// truncate the string in the middle of a surrogate pair.
input.take(prefixLen + 1)
} else {
input.take(prefixLen)
}
}

/**
* Helper method to truncate the input string `x` to the given `prefixLen` length, while also
* appending the unicode max character to the end of the truncated string. This ensures that any
* value in this column is less than or equal to the max.
* Helper method to truncate the input string `input` to the given `prefixLen` length, while also
* ensuring the any value in this column is less than or equal to the truncated max in UTF-8
* encoding.
*/
def truncateMaxStringAgg(prefixLen: Int)(x: String): String = {
if (x == null || x.length <= prefixLen) {
x
} else {
// Grab the prefix. We want to append `\ufffd` as a tie-breaker, but that is only safe
// if the character we truncated was smaller. Keep extending the prefix until that
// condition holds, or we run off the end of the string.
// scalastyle:off nonascii
val tieBreaker = '\ufffd'
x.take(prefixLen) + x.substring(prefixLen).takeWhile(_ >= tieBreaker) + tieBreaker
// scalastyle:off nonascii
def truncateMaxStringAgg(prefixLen: Int)(originalMax: String): String = {
// scalastyle:off nonascii
if (originalMax == null || originalMax.length <= prefixLen) {
return originalMax
}
if (prefixLen <= 0) {
return null
}

// Grab the prefix. We want to append max Unicode code point `\uDBFF\uDFFF` as a tie-breaker,
// but that is only safe if the character we truncated was smaller in UTF-8 encoded binary
// comparison. Keep extending the prefix until that condition holds, or we run off the end of
// the string.
// We also try to use the ASCII max character `\u007F` as a tie-breaker if possible.
val maxLen = getExpansionLimit(prefixLen)
// Start with a valid prefix
var currLen = truncateMinStringAgg(prefixLen)(originalMax).length
while (currLen <= maxLen) {
if (currLen >= originalMax.length) {
// Return originalMax if we have reached the end of the string
return originalMax
} else if (currLen + 1 < originalMax.length &&
originalMax.substring(currLen, currLen + 2) == UTF8_MAX_CHARACTER) {
// Skip the UTF-8 max character. It occupies two characters in a Scala string.
currLen += 2
} else if (originalMax.charAt(currLen) < ASCII_MAX_CHARACTER) {
return originalMax.take(currLen) + ASCII_MAX_CHARACTER
} else {
return originalMax.take(currLen) + UTF8_MAX_CHARACTER
}
}

// Return null when the input string is too long to truncate.
null
// scalastyle:on nonascii
}

/**
* Calculates the upper character limit when constructing a maximum is not possible with only
* prefixLen chars.
*/
private def getExpansionLimit(prefixLen: Int): Int = 2 * prefixLen

def batchStatsToRow(
schema: StructType,
explodedDataSchema: Map[Seq[String], Int],
Expand Down
4 changes: 4 additions & 0 deletions integration_tests/src/main/python/delta_lake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def json_to_sort_key(j):
jsons.sort(key=json_to_sort_key)
return jsons

def read_delta_logs(spark, path):
log_data = spark.sparkContext.wholeTextFiles(path).collect()
return dict([(os.path.basename(x), _decode_jsons(y)) for x, y in log_data])

def assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path):
cpu_log_data = spark.sparkContext.wholeTextFiles(data_path + "/CPU/_delta_log/*").collect()
gpu_log_data = spark.sparkContext.wholeTextFiles(data_path + "/GPU/_delta_log/*").collect()
Expand Down
50 changes: 34 additions & 16 deletions integration_tests/src/main/python/delta_lake_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import pyspark.sql.functions as f
import pytest
import sys

from asserts import *
from data_gen import *
Expand Down Expand Up @@ -628,27 +628,45 @@ def gen_bad_data(spark):
@allow_non_gpu(*delta_meta_allow)
@delta_lake
@ignore_order
@pytest.mark.parametrize("num_cols", [-1, 0, 1, 2, 3 ], ids=idfn)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
def test_delta_write_stat_column_limits(num_cols, spark_tmp_path):
def test_delta_write_stat_column_limits(spark_tmp_path):
data_path = spark_tmp_path + "/DELTA_DATA"
confs = copy_and_update(delta_writes_enabled_conf, {"spark.databricks.io.skipping.stringPrefixLength": 8})
strgen = StringGen() \
.with_special_case((chr(sys.maxunicode) * 7) + "abc") \
.with_special_case((chr(sys.maxunicode) * 8) + "abc") \
.with_special_case((chr(sys.maxunicode) * 16) + "abc") \
.with_special_case(('\U0000FFFD' * 7) + "abc") \
.with_special_case(('\U0000FFFD' * 8) + "abc") \
.with_special_case(('\U0000FFFD' * 16) + "abc")
gens = [("a", StructGen([("x", strgen), ("y", StructGen([("z", strgen)]))])),
("b", binary_gen),
("c", strgen)]
# maximum unicode codepoint and maximum ascii character
umax, amax = chr(1114111), chr(0x7f)
expected_min = {"a": "abcdefgh", "b": "abcdefg�", "c": "abcdefgh",
"d": "abcdefgh", "e": umax * 4, "f": umax * 4}
# no max expected for column f since it cannot be truncated to 8 characters and remain
# larger than the original value
expected_max = {"a": "bcdefghi", "b": "bcdefgh�", "c": "bcdefghi" + amax,
"d": "bcdefghi" + umax, "e": umax * 8}
def write_table(spark, path):
df = spark.createDataFrame([
("bcdefghi", "abcdefg�", "bcdefghijk", "abcdefgh�", umax * 4, umax * 9),
("abcdefgh", "bcdefgh�", "abcdefghij", "bcdefghi�", umax * 8, umax * 9)],
"a string, b string, c string, d string, e string, f string")
df.repartition(1).write.format("delta").save(path)
def verify_stat_limits(spark):
log_data = read_delta_logs(spark, data_path + "/GPU/_delta_log/*.json")
assert len(log_data) == 1, "GPU should generate exactly one Delta log"
json_objs = list(log_data.values())[0]
json_adds = [x["add"] for x in json_objs if "add" in x]
assert len(json_adds) == 1, "GPU should only generate a single add in Delta log"
stats = json.loads(json_adds[0]["stats"])
actual_min = stats["minValues"]
assert expected_min == actual_min, \
f"minValues mismatch, expected: {expected_min} actual: {actual_min}"
actual_max = stats["maxValues"]
assert expected_max == actual_max, \
f"maxValues stats mismatch, expected: {expected_max} actual: {actual_max}"
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gens).coalesce(1).write.format("delta").save(path),
lambda spark, path: spark.read.format("delta").load(path),
write_table,
read_delta_path,
data_path,
conf=confs)
with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path))
# Many Delta Lake versions are missing the fix from https://github.com/delta-io/delta/pull/3430
# so instead of a full delta log compare with the CPU, focus on the reported statistics on GPU.
with_cpu_session(verify_stat_limits)

@allow_non_gpu("CreateTableExec", *delta_meta_allow)
@delta_lake
Expand Down

0 comments on commit 52b43d6

Please sign in to comment.