-
Notifications
You must be signed in to change notification settings - Fork 81
Support aqe #1001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support aqe #1001
Changes from all commits
00d4da8
93b7c92
2d4ed6f
878f64a
4739e93
65c81a1
cab2c6a
f4bfb70
ac505d3
ae2088f
dee01ab
69cd50d
5526abe
2b9c246
b423f4b
9e5eaae
558adc3
c1eb8a2
510266c
58a58ff
6ea4121
cb012a6
117c4f7
7d3272f
f009f97
02c1463
b05ef1a
059e16e
907d2f8
0065406
a956c7f
e980fa1
4054892
2ffd32d
50445f0
589eba9
8cfd661
c72b74f
d13dc5e
c1c8b3d
38591b7
6c25117
0c6a03f
a5a6716
3091ef5
26196b7
ce7a16f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -594,6 +594,7 @@ case class TableUtils(sparkSession: SparkSession) { | |
|
||
def sql(query: String): DataFrame = { | ||
val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000) | ||
val autoCoalesceEnabled = sparkSession.conf.get("spark.sql.adaptive.coalescePartitions.enabled", "true").toBoolean | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we default this to false if it is not set to match to the existing default behavior? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the idea is to change the default behavior since it's much more performant in most cases |
||
val sw = new StringWriter() | ||
val pw = new PrintWriter(sw) | ||
new Throwable().printStackTrace(pw) | ||
|
@@ -603,13 +604,24 @@ case class TableUtils(sparkSession: SparkSession) { | |
.filter(_.contains("chronon")) | ||
.map(_.replace("at ai.chronon.spark.", "")) | ||
.mkString("\n") | ||
|
||
logger.info( | ||
s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------") | ||
if (!autoCoalesceEnabled) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] swap the if else branches to do the if enabled case first; i find code easier to read without explicit negations:) Same with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a common style in this code base, ex: |
||
logger.info( | ||
s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------") | ||
} else { | ||
logger.info( | ||
s"\n----[Running query with AQE auto coalesce enabled]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------") | ||
} | ||
try { | ||
// Run the query | ||
val df = sparkSession.sql(query).coalesce(partitionCount) | ||
df | ||
val df = sparkSession.sql(query) | ||
// if aqe auto coalesce is disabled, apply manual coalesce | ||
val finalDf = if (!autoCoalesceEnabled) { | ||
df.coalesce(partitionCount) | ||
} else { | ||
logger.info(s"AQE auto coalesce is enabled, skipping manual coalesce") | ||
df | ||
} | ||
finalDf | ||
} catch { | ||
case e: AnalysisException if e.getMessage.contains(" already exists") => | ||
logger.warn(s"Non-Fatal: ${e.getMessage}. Query may result in redefinition.") | ||
|
@@ -702,76 +714,91 @@ case class TableUtils(sparkSession: SparkSession) { | |
stats: Option[DfStats], | ||
sortByCols: Seq[String] = Seq.empty, | ||
partitionCols: Seq[String] = Seq.empty): Unit = { | ||
// get row count and table partition count statistics | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it possible to restructure this code to minimize diff and thus reviewer mental load to keep track of things? Can we do something like :
So the The_rest_of_the_non_aqe_code_as_is() code still remains on the same indentation level as is and wouldn't show up as diff in the PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We eventually branch into the same code below, not sure how to make this cleaner while keeping that. I definitely emphasize that this diff rendered weirdly |
||
// to determine shuffle parallelism, count only top-level/first partition column | ||
// assumed to be the date partition. If not given, use default | ||
val partitionCol = partitionCols.headOption.getOrElse(partitionColumn) | ||
val (rowCount: Long, tablePartitionCount: Int) = | ||
if (df.schema.fieldNames.contains(partitionCol)) { | ||
if (stats.isDefined && stats.get.partitionRange.wellDefined) { | ||
stats.get.count -> stats.get.partitionRange.partitions.length | ||
} else { | ||
val result = df.select(count(lit(1)), approx_count_distinct(col(partitionCol))).head() | ||
(result.getAs[Long](0), result.getAs[Long](1).toInt) | ||
} | ||
} else { | ||
(df.count(), 1) | ||
} | ||
val useAqeRoute = sparkSession.conf.getOption("spark.sql.adaptive.enabled").contains("true") && | ||
sparkSession.conf.getOption("spark.sql.adaptive.coalescePartitions.enabled").contains("true") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it intentional "is AQE on" logic here check for both config options while the logic in It also might be a good idea to have a common helper function to abstract away the "is AQE on" logic. |
||
|
||
// set to one if tablePartitionCount=0 to avoid division by zero | ||
val nonZeroTablePartitionCount = if (tablePartitionCount == 0) 1 else tablePartitionCount | ||
|
||
logger.info(s"$rowCount rows requested to be written into table $tableName") | ||
if (rowCount > 0) { | ||
val columnSizeEstimate = columnSizeEstimator(df.schema) | ||
|
||
// check if spark is running in local mode or cluster mode | ||
val isLocal = sparkSession.conf.get("spark.master").startsWith("local") | ||
|
||
// roughly 1 partition count per 1m rows x 100 columns | ||
val rowCountPerPartition = df.sparkSession.conf | ||
.getOption(SparkConstants.ChrononRowCountPerPartition) | ||
.map(_.toDouble) | ||
.flatMap(value => if (value > 0) Some(value) else None) | ||
.getOrElse(1e8) | ||
|
||
val totalFileCountEstimate = math.ceil(rowCount * columnSizeEstimate / rowCountPerPartition).toInt | ||
val dailyFileCountUpperBound = 2000 | ||
val dailyFileCountLowerBound = if (isLocal) 1 else 10 | ||
val dailyFileCountEstimate = totalFileCountEstimate / nonZeroTablePartitionCount + 1 | ||
val dailyFileCountBounded = | ||
math.max(math.min(dailyFileCountEstimate, dailyFileCountUpperBound), dailyFileCountLowerBound) | ||
|
||
val outputParallelism = df.sparkSession.conf | ||
.getOption(SparkConstants.ChrononOutputParallelismOverride) | ||
.map(_.toInt) | ||
.flatMap(value => if (value > 0) Some(value) else None) | ||
|
||
if (outputParallelism.isDefined) { | ||
logger.info(s"Using custom outputParallelism ${outputParallelism.get}") | ||
if (useAqeRoute) { | ||
if (df.isEmpty) { | ||
logger.info(s"Input DataFrame for table $tableName is empty. Nothing to write.") | ||
return | ||
} | ||
val dailyFileCount = outputParallelism.getOrElse(dailyFileCountBounded) | ||
|
||
// finalized shuffle parallelism | ||
val shuffleParallelism = Math.max(dailyFileCount * nonZeroTablePartitionCount, minWriteShuffleParallelism) | ||
val saltCol = "random_partition_salt" | ||
val saltedDf = df.withColumn(saltCol, round(rand() * (dailyFileCount + 1))) | ||
val sortedDf = df.sortWithinPartitions(sortByCols.map(col).toSeq: _*) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we ignoring the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, this is kind of a weird one, but iirc, and it's been awhile, including them actually breaks the table schema assumptions and fails the test. We are essentially relying on AQE as well as the writer plugin to do our repartitioning into the correct partitions. The catalyst engine rewrites this entire section to do the repartition before the sort iirc, making us know the partition columns are all equal to each other, and therefore we don't sort them here. Including it changes the plan in such a way that it does not match the previous chronon plan. Maybe this is a catalyst bug or something, but I'm pretty sure I ran into the same issue as stripe where the results would sometimes be out of order until I removed that. |
||
val writer = sortedDf.write.mode(saveMode) | ||
writer.insertInto(tableName) | ||
|
||
logger.info( | ||
s"repartitioning data for table $tableName by $shuffleParallelism spark tasks into $tablePartitionCount table partitions and $dailyFileCount files per partition") | ||
val (repartitionCols: immutable.Seq[String], partitionSortCols: immutable.Seq[String]) = | ||
logger.info(s"Successfully finished writing to $tableName using AQE.") | ||
} else { | ||
// get row count and table partition count statistics | ||
// to determine shuffle parallelism, count only top-level/first partition column | ||
// assumed to be the date partition. If not given, use default | ||
val partitionCol = partitionCols.headOption.getOrElse(partitionColumn) | ||
val (rowCount: Long, tablePartitionCount: Int) = | ||
if (df.schema.fieldNames.contains(partitionCol)) { | ||
(Seq(partitionCol, saltCol), Seq(partitionCol) ++ sortByCols) | ||
} else { (Seq(saltCol), sortByCols) } | ||
logger.info(s"Sorting within partitions with cols: $partitionSortCols") | ||
saltedDf | ||
.repartition(shuffleParallelism, repartitionCols.map(saltedDf.col): _*) | ||
.drop(saltCol) | ||
.sortWithinPartitions(partitionSortCols.map(col): _*) | ||
.write | ||
.mode(saveMode) | ||
.insertInto(tableName) | ||
logger.info(s"Finished writing to $tableName") | ||
if (stats.isDefined && stats.get.partitionRange.wellDefined) { | ||
stats.get.count -> stats.get.partitionRange.partitions.length | ||
} else { | ||
val result = df.select(count(lit(1)), approx_count_distinct(col(partitionCol))).head() | ||
(result.getAs[Long](0), result.getAs[Long](1).toInt) | ||
} | ||
} else { | ||
(df.count(), 1) | ||
} | ||
|
||
// set to one if tablePartitionCount=0 to avoid division by zero | ||
val nonZeroTablePartitionCount = if (tablePartitionCount == 0) 1 else tablePartitionCount | ||
|
||
logger.info(s"$rowCount rows requested to be written into table $tableName") | ||
if (rowCount > 0) { | ||
val columnSizeEstimate = columnSizeEstimator(df.schema) | ||
|
||
// check if spark is running in local mode or cluster mode | ||
val isLocal = sparkSession.conf.get("spark.master").startsWith("local") | ||
|
||
// roughly 1 partition count per 1m rows x 100 columns | ||
val rowCountPerPartition = df.sparkSession.conf | ||
.getOption(SparkConstants.ChrononRowCountPerPartition) | ||
.map(_.toDouble) | ||
.flatMap(value => if (value > 0) Some(value) else None) | ||
.getOrElse(1e8) | ||
|
||
val totalFileCountEstimate = math.ceil(rowCount * columnSizeEstimate / rowCountPerPartition).toInt | ||
val dailyFileCountUpperBound = 2000 | ||
val dailyFileCountLowerBound = if (isLocal) 1 else 10 | ||
val dailyFileCountEstimate = totalFileCountEstimate / nonZeroTablePartitionCount + 1 | ||
val dailyFileCountBounded = | ||
math.max(math.min(dailyFileCountEstimate, dailyFileCountUpperBound), dailyFileCountLowerBound) | ||
|
||
val outputParallelism = df.sparkSession.conf | ||
.getOption(SparkConstants.ChrononOutputParallelismOverride) | ||
.map(_.toInt) | ||
.flatMap(value => if (value > 0) Some(value) else None) | ||
|
||
if (outputParallelism.isDefined) { | ||
logger.info(s"Using custom outputParallelism ${outputParallelism.get}") | ||
} | ||
val dailyFileCount = outputParallelism.getOrElse(dailyFileCountBounded) | ||
|
||
// finalized shuffle parallelism | ||
val shuffleParallelism = Math.max(dailyFileCount * nonZeroTablePartitionCount, minWriteShuffleParallelism) | ||
val saltCol = "random_partition_salt" | ||
val saltedDf = df.withColumn(saltCol, round(rand() * (dailyFileCount + 1))) | ||
|
||
logger.info( | ||
s"repartitioning data for table $tableName by $shuffleParallelism spark tasks into $tablePartitionCount table partitions and $dailyFileCount files per partition") | ||
val (repartitionCols: immutable.Seq[String], partitionSortCols: immutable.Seq[String]) = | ||
if (df.schema.fieldNames.contains(partitionCol)) { | ||
(Seq(partitionCol, saltCol), Seq(partitionCol) ++ sortByCols) | ||
} else { (Seq(saltCol), sortByCols) } | ||
logger.info(s"Sorting within partitions with cols: $partitionSortCols") | ||
saltedDf | ||
.repartition(shuffleParallelism, repartitionCols.map(saltedDf.col): _*) | ||
.drop(saltCol) | ||
.sortWithinPartitions(partitionSortCols.map(col): _*) | ||
.write | ||
.mode(saveMode) | ||
.insertInto(tableName) | ||
logger.info(s"Finished writing to $tableName") | ||
} | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add an inline comment that these registrations are needed for AQE (that's just my guess on why they are here:))?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually more generic than that, AQE uses it, but some of the SQL expressions will output this too. I think this matches the current file organization.
A separate PR/github issues could be cloning this and adding more explanations to the class
https://github.com/apache/spark/blob/dc687d4c83b877e90c8dc03fb88f13440d4ae911/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala#L575