Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
00d4da8
staging changes for testing iceberg
May 13, 2025
93b7c92
timeboxing test changes
May 13, 2025
2d4ed6f
bootstrapping spark test, 1/3 working on FormatTest
May 14, 2025
878f64a
got most tests working except droppartitions and the new dateint test
May 15, 2025
4739e93
cleaning some local changes
May 15, 2025
65c81a1
reverting some local changes
May 15, 2025
cab2c6a
formatting
May 15, 2025
f4bfb70
more silly local hacks
May 15, 2025
ac505d3
Merge branch 'main' into iceberg_unit_tests
abbywh May 15, 2025
ae2088f
fixed the constant derby flake
May 16, 2025
dee01ab
refactored to match deltalake
May 16, 2025
69cd50d
added Iceberg Kryo Serializer
May 16, 2025
5526abe
scalafmt
May 16, 2025
2b9c246
iceberg circleci integration
May 16, 2025
b423f4b
fixing typo
May 16, 2025
9e5eaae
giving circleci a dependency
May 16, 2025
558adc3
removing env file
May 16, 2025
c1eb8a2
moving integration test to spark_embedded
May 16, 2025
510266c
figured out why delta lake was on 2.13, need it for spark 3.2
May 16, 2025
58a58ff
typo
May 16, 2025
6ea4121
scalafmt
May 16, 2025
cb012a6
skipping the flink parts since it doesn't compile to 2.13.6
May 16, 2025
117c4f7
including TableUtilsTest as well in CI
May 16, 2025
7d3272f
sperating table utils and format for seperate jvms
May 16, 2025
f009f97
typo
May 16, 2025
02c1463
corrected behavior for long partitions
May 16, 2025
b05ef1a
Merge branch 'main' into iceberg_unit_tests
abbywh May 16, 2025
059e16e
eventeventlongds test, more kryo registration
May 17, 2025
907d2f8
iceberg drop partitions
May 17, 2025
0065406
long partition testing
May 17, 2025
a956c7f
Merge branch 'main' into iceberg_unit_tests
abbywh May 17, 2025
e980fa1
unskipping fixed tests
May 19, 2025
4054892
changing test schema
May 19, 2025
2ffd32d
updating drop partitions to be schemaless
May 19, 2025
50445f0
found bug during CI testing
May 21, 2025
589eba9
Apply suggestions from code review
abbywh Jun 7, 2025
8cfd661
formatting
Jun 7, 2025
c72b74f
propping name refactor
Jun 7, 2025
d13dc5e
fixing some typos
Jun 7, 2025
c1c8b3d
initial commit for AQE support
Jun 8, 2025
38591b7
Merge branch 'main' into support_aqe
abbywh Jun 8, 2025
6c25117
i really need to add a scalafmt githook
Jun 8, 2025
0c6a03f
updated kryo serializer
Jun 8, 2025
a5a6716
Merge branch 'main' into support_aqe
abbywh Jun 10, 2025
3091ef5
Merge branch 'main' into support_aqe
abbywh Jul 19, 2025
26196b7
simplifying the empty check
Jul 21, 2025
ce7a16f
small refactor
Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ class ChrononKryoRegistrator extends KryoRegistrator {
"org.apache.spark.sql.catalyst.InternalRow$$anonfun$getAccessor$8",
"org.apache.spark.sql.catalyst.InternalRow$$anonfun$getAccessor$5",
"scala.collection.immutable.ArraySeq$ofRef",
"org.apache.spark.sql.catalyst.expressions.GenericInternalRow"
"org.apache.spark.sql.catalyst.expressions.GenericInternalRow",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an inline comment that these registrations are needed for AQE (that's just my guess on why they are here:))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually more generic than that, AQE uses it, but some of the SQL expressions will output this too. I think this matches the current file organization.

A separate PR/github issues could be cloning this and adding more explanations to the class
https://github.com/apache/spark/blob/dc687d4c83b877e90c8dc03fb88f13440d4ae911/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala#L575

"scala.math.Ordering$Reverse",
"java.lang.invoke.SerializedLambda",
"org.apache.spark.sql.catalyst.InternalRow$",
"org.apache.spark.util.collection.BitSet"
)
names.foreach(name => doRegister(name, kryo))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ object SparkSessionBuilder {
.config("spark.default.parallelism", "2")
.config("spark.testing", "true")
.config("spark.sql.adaptive.enabled", true)
.config("spark.sql.adaptive.coalescePartitions.enabled", true)
.config("spark.sql.adaptive.skewJoin.enabled", true)
.config("spark.local.dir", s"/tmp/$userName/$name")
.config("spark.sql.warehouse.dir", s"$warehouseDir/data")
.config("spark.driver.bindAddress", "127.0.0.1")
Expand Down
169 changes: 98 additions & 71 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ case class TableUtils(sparkSession: SparkSession) {

def sql(query: String): DataFrame = {
val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000)
val autoCoalesceEnabled = sparkSession.conf.get("spark.sql.adaptive.coalescePartitions.enabled", "true").toBoolean

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we default this to false if it is not set to match to the existing default behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea is to change the default behavior since it's much more performant in most cases

val sw = new StringWriter()
val pw = new PrintWriter(sw)
new Throwable().printStackTrace(pw)
Expand All @@ -603,13 +604,24 @@ case class TableUtils(sparkSession: SparkSession) {
.filter(_.contains("chronon"))
.map(_.replace("at ai.chronon.spark.", ""))
.mkString("\n")

logger.info(
s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------")
if (!autoCoalesceEnabled) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] swap the if else branches to do the if enabled case first; i find code easier to read without explicit negations:)

Same with the val finalDf if-else below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a common style in this code base, ex: if (!tableExists(tableName)) return Seq.empty[String] which I chose to follow (I agree it's a bit wonky)

logger.info(
s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------")
} else {
logger.info(
s"\n----[Running query with AQE auto coalesce enabled]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------")
}
try {
// Run the query
val df = sparkSession.sql(query).coalesce(partitionCount)
df
val df = sparkSession.sql(query)
// if aqe auto coalesce is disabled, apply manual coalesce
val finalDf = if (!autoCoalesceEnabled) {
df.coalesce(partitionCount)
} else {
logger.info(s"AQE auto coalesce is enabled, skipping manual coalesce")
df
}
finalDf
} catch {
case e: AnalysisException if e.getMessage.contains(" already exists") =>
logger.warn(s"Non-Fatal: ${e.getMessage}. Query may result in redefinition.")
Expand Down Expand Up @@ -702,76 +714,91 @@ case class TableUtils(sparkSession: SparkSession) {
stats: Option[DfStats],
sortByCols: Seq[String] = Seq.empty,
partitionCols: Seq[String] = Seq.empty): Unit = {
// get row count and table partition count statistics

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to restructure this code to minimize diff and thus reviewer mental load to keep track of things? Can we do something like :

if (useAqeRoute) {
   write_into_df_the_simple_way()
   return
}

The_rest_of_the_non_aqe_code_as_is()

So the The_rest_of_the_non_aqe_code_as_is() code still remains on the same indentation level as is and wouldn't show up as diff in the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We eventually branch into the same code below, not sure how to make this cleaner while keeping that. I definitely emphasize that this diff rendered weirdly

// to determine shuffle parallelism, count only top-level/first partition column
// assumed to be the date partition. If not given, use default
val partitionCol = partitionCols.headOption.getOrElse(partitionColumn)
val (rowCount: Long, tablePartitionCount: Int) =
if (df.schema.fieldNames.contains(partitionCol)) {
if (stats.isDefined && stats.get.partitionRange.wellDefined) {
stats.get.count -> stats.get.partitionRange.partitions.length
} else {
val result = df.select(count(lit(1)), approx_count_distinct(col(partitionCol))).head()
(result.getAs[Long](0), result.getAs[Long](1).toInt)
}
} else {
(df.count(), 1)
}
val useAqeRoute = sparkSession.conf.getOption("spark.sql.adaptive.enabled").contains("true") &&
sparkSession.conf.getOption("spark.sql.adaptive.coalescePartitions.enabled").contains("true")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional "is AQE on" logic here check for both config options while the logic in def sql() above only checks for spark.sql.adaptive.coalescePartitions.enabled?

It also might be a good idea to have a common helper function to abstract away the "is AQE on" logic.


// set to one if tablePartitionCount=0 to avoid division by zero
val nonZeroTablePartitionCount = if (tablePartitionCount == 0) 1 else tablePartitionCount

logger.info(s"$rowCount rows requested to be written into table $tableName")
if (rowCount > 0) {
val columnSizeEstimate = columnSizeEstimator(df.schema)

// check if spark is running in local mode or cluster mode
val isLocal = sparkSession.conf.get("spark.master").startsWith("local")

// roughly 1 partition count per 1m rows x 100 columns
val rowCountPerPartition = df.sparkSession.conf
.getOption(SparkConstants.ChrononRowCountPerPartition)
.map(_.toDouble)
.flatMap(value => if (value > 0) Some(value) else None)
.getOrElse(1e8)

val totalFileCountEstimate = math.ceil(rowCount * columnSizeEstimate / rowCountPerPartition).toInt
val dailyFileCountUpperBound = 2000
val dailyFileCountLowerBound = if (isLocal) 1 else 10
val dailyFileCountEstimate = totalFileCountEstimate / nonZeroTablePartitionCount + 1
val dailyFileCountBounded =
math.max(math.min(dailyFileCountEstimate, dailyFileCountUpperBound), dailyFileCountLowerBound)

val outputParallelism = df.sparkSession.conf
.getOption(SparkConstants.ChrononOutputParallelismOverride)
.map(_.toInt)
.flatMap(value => if (value > 0) Some(value) else None)

if (outputParallelism.isDefined) {
logger.info(s"Using custom outputParallelism ${outputParallelism.get}")
if (useAqeRoute) {
if (df.isEmpty) {
logger.info(s"Input DataFrame for table $tableName is empty. Nothing to write.")
return
}
val dailyFileCount = outputParallelism.getOrElse(dailyFileCountBounded)

// finalized shuffle parallelism
val shuffleParallelism = Math.max(dailyFileCount * nonZeroTablePartitionCount, minWriteShuffleParallelism)
val saltCol = "random_partition_salt"
val saltedDf = df.withColumn(saltCol, round(rand() * (dailyFileCount + 1)))
val sortedDf = df.sortWithinPartitions(sortByCols.map(col).toSeq: _*)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we ignoring the partitionCols that users may pass in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, this is kind of a weird one, but iirc, and it's been awhile, including them actually breaks the table schema assumptions and fails the test.

We are essentially relying on AQE as well as the writer plugin to do our repartitioning into the correct partitions. The catalyst engine rewrites this entire section to do the repartition before the sort iirc, making us know the partition columns are all equal to each other, and therefore we don't sort them here. Including it changes the plan in such a way that it does not match the previous chronon plan. Maybe this is a catalyst bug or something, but I'm pretty sure I ran into the same issue as stripe where the results would sometimes be out of order until I removed that.

val writer = sortedDf.write.mode(saveMode)
writer.insertInto(tableName)

logger.info(
s"repartitioning data for table $tableName by $shuffleParallelism spark tasks into $tablePartitionCount table partitions and $dailyFileCount files per partition")
val (repartitionCols: immutable.Seq[String], partitionSortCols: immutable.Seq[String]) =
logger.info(s"Successfully finished writing to $tableName using AQE.")
} else {
// get row count and table partition count statistics
// to determine shuffle parallelism, count only top-level/first partition column
// assumed to be the date partition. If not given, use default
val partitionCol = partitionCols.headOption.getOrElse(partitionColumn)
val (rowCount: Long, tablePartitionCount: Int) =
if (df.schema.fieldNames.contains(partitionCol)) {
(Seq(partitionCol, saltCol), Seq(partitionCol) ++ sortByCols)
} else { (Seq(saltCol), sortByCols) }
logger.info(s"Sorting within partitions with cols: $partitionSortCols")
saltedDf
.repartition(shuffleParallelism, repartitionCols.map(saltedDf.col): _*)
.drop(saltCol)
.sortWithinPartitions(partitionSortCols.map(col): _*)
.write
.mode(saveMode)
.insertInto(tableName)
logger.info(s"Finished writing to $tableName")
if (stats.isDefined && stats.get.partitionRange.wellDefined) {
stats.get.count -> stats.get.partitionRange.partitions.length
} else {
val result = df.select(count(lit(1)), approx_count_distinct(col(partitionCol))).head()
(result.getAs[Long](0), result.getAs[Long](1).toInt)
}
} else {
(df.count(), 1)
}

// set to one if tablePartitionCount=0 to avoid division by zero
val nonZeroTablePartitionCount = if (tablePartitionCount == 0) 1 else tablePartitionCount

logger.info(s"$rowCount rows requested to be written into table $tableName")
if (rowCount > 0) {
val columnSizeEstimate = columnSizeEstimator(df.schema)

// check if spark is running in local mode or cluster mode
val isLocal = sparkSession.conf.get("spark.master").startsWith("local")

// roughly 1 partition count per 1m rows x 100 columns
val rowCountPerPartition = df.sparkSession.conf
.getOption(SparkConstants.ChrononRowCountPerPartition)
.map(_.toDouble)
.flatMap(value => if (value > 0) Some(value) else None)
.getOrElse(1e8)

val totalFileCountEstimate = math.ceil(rowCount * columnSizeEstimate / rowCountPerPartition).toInt
val dailyFileCountUpperBound = 2000
val dailyFileCountLowerBound = if (isLocal) 1 else 10
val dailyFileCountEstimate = totalFileCountEstimate / nonZeroTablePartitionCount + 1
val dailyFileCountBounded =
math.max(math.min(dailyFileCountEstimate, dailyFileCountUpperBound), dailyFileCountLowerBound)

val outputParallelism = df.sparkSession.conf
.getOption(SparkConstants.ChrononOutputParallelismOverride)
.map(_.toInt)
.flatMap(value => if (value > 0) Some(value) else None)

if (outputParallelism.isDefined) {
logger.info(s"Using custom outputParallelism ${outputParallelism.get}")
}
val dailyFileCount = outputParallelism.getOrElse(dailyFileCountBounded)

// finalized shuffle parallelism
val shuffleParallelism = Math.max(dailyFileCount * nonZeroTablePartitionCount, minWriteShuffleParallelism)
val saltCol = "random_partition_salt"
val saltedDf = df.withColumn(saltCol, round(rand() * (dailyFileCount + 1)))

logger.info(
s"repartitioning data for table $tableName by $shuffleParallelism spark tasks into $tablePartitionCount table partitions and $dailyFileCount files per partition")
val (repartitionCols: immutable.Seq[String], partitionSortCols: immutable.Seq[String]) =
if (df.schema.fieldNames.contains(partitionCol)) {
(Seq(partitionCol, saltCol), Seq(partitionCol) ++ sortByCols)
} else { (Seq(saltCol), sortByCols) }
logger.info(s"Sorting within partitions with cols: $partitionSortCols")
saltedDf
.repartition(shuffleParallelism, repartitionCols.map(saltedDf.col): _*)
.drop(saltCol)
.sortWithinPartitions(partitionSortCols.map(col): _*)
.write
.mode(saveMode)
.insertInto(tableName)
logger.info(s"Finished writing to $tableName")
}
}
}

Expand Down