diff --git a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala index f5365d499..d06380082 100644 --- a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala +++ b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala @@ -147,7 +147,11 @@ class ChrononKryoRegistrator extends KryoRegistrator { "org.apache.spark.sql.catalyst.InternalRow$$anonfun$getAccessor$8", "org.apache.spark.sql.catalyst.InternalRow$$anonfun$getAccessor$5", "scala.collection.immutable.ArraySeq$ofRef", - "org.apache.spark.sql.catalyst.expressions.GenericInternalRow" + "org.apache.spark.sql.catalyst.expressions.GenericInternalRow", + "scala.math.Ordering$Reverse", + "java.lang.invoke.SerializedLambda", + "org.apache.spark.sql.catalyst.InternalRow$", + "org.apache.spark.util.collection.BitSet" ) names.foreach(name => doRegister(name, kryo)) diff --git a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala index d1c8eacda..9eb00df0c 100644 --- a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala +++ b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala @@ -123,6 +123,8 @@ object SparkSessionBuilder { .config("spark.default.parallelism", "2") .config("spark.testing", "true") .config("spark.sql.adaptive.enabled", true) + .config("spark.sql.adaptive.coalescePartitions.enabled", true) + .config("spark.sql.adaptive.skewJoin.enabled", true) .config("spark.local.dir", s"/tmp/$userName/$name") .config("spark.sql.warehouse.dir", s"$warehouseDir/data") .config("spark.driver.bindAddress", "127.0.0.1") diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index a83943407..5c1f91563 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -594,6 +594,7 @@ case class TableUtils(sparkSession: SparkSession) { def sql(query: String): DataFrame = { val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000) + val autoCoalesceEnabled = sparkSession.conf.get("spark.sql.adaptive.coalescePartitions.enabled", "true").toBoolean val sw = new StringWriter() val pw = new PrintWriter(sw) new Throwable().printStackTrace(pw) @@ -603,13 +604,24 @@ case class TableUtils(sparkSession: SparkSession) { .filter(_.contains("chronon")) .map(_.replace("at ai.chronon.spark.", "")) .mkString("\n") - - logger.info( - s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------") + if (!autoCoalesceEnabled) { + logger.info( + s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------") + } else { + logger.info( + s"\n----[Running query with AQE auto coalesce enabled]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------") + } try { // Run the query - val df = sparkSession.sql(query).coalesce(partitionCount) - df + val df = sparkSession.sql(query) + // if aqe auto coalesce is disabled, apply manual coalesce + val finalDf = if (!autoCoalesceEnabled) { + df.coalesce(partitionCount) + } else { + logger.info(s"AQE auto coalesce is enabled, skipping manual coalesce") + df + } + finalDf } catch { case e: AnalysisException if e.getMessage.contains(" already exists") => logger.warn(s"Non-Fatal: ${e.getMessage}. Query may result in redefinition.") @@ -702,76 +714,91 @@ case class TableUtils(sparkSession: SparkSession) { stats: Option[DfStats], sortByCols: Seq[String] = Seq.empty, partitionCols: Seq[String] = Seq.empty): Unit = { - // get row count and table partition count statistics - // to determine shuffle parallelism, count only top-level/first partition column - // assumed to be the date partition. If not given, use default - val partitionCol = partitionCols.headOption.getOrElse(partitionColumn) - val (rowCount: Long, tablePartitionCount: Int) = - if (df.schema.fieldNames.contains(partitionCol)) { - if (stats.isDefined && stats.get.partitionRange.wellDefined) { - stats.get.count -> stats.get.partitionRange.partitions.length - } else { - val result = df.select(count(lit(1)), approx_count_distinct(col(partitionCol))).head() - (result.getAs[Long](0), result.getAs[Long](1).toInt) - } - } else { - (df.count(), 1) - } + val useAqeRoute = sparkSession.conf.getOption("spark.sql.adaptive.enabled").contains("true") && + sparkSession.conf.getOption("spark.sql.adaptive.coalescePartitions.enabled").contains("true") - // set to one if tablePartitionCount=0 to avoid division by zero - val nonZeroTablePartitionCount = if (tablePartitionCount == 0) 1 else tablePartitionCount - - logger.info(s"$rowCount rows requested to be written into table $tableName") - if (rowCount > 0) { - val columnSizeEstimate = columnSizeEstimator(df.schema) - - // check if spark is running in local mode or cluster mode - val isLocal = sparkSession.conf.get("spark.master").startsWith("local") - - // roughly 1 partition count per 1m rows x 100 columns - val rowCountPerPartition = df.sparkSession.conf - .getOption(SparkConstants.ChrononRowCountPerPartition) - .map(_.toDouble) - .flatMap(value => if (value > 0) Some(value) else None) - .getOrElse(1e8) - - val totalFileCountEstimate = math.ceil(rowCount * columnSizeEstimate / rowCountPerPartition).toInt - val dailyFileCountUpperBound = 2000 - val dailyFileCountLowerBound = if (isLocal) 1 else 10 - val dailyFileCountEstimate = totalFileCountEstimate / nonZeroTablePartitionCount + 1 - val dailyFileCountBounded = - math.max(math.min(dailyFileCountEstimate, dailyFileCountUpperBound), dailyFileCountLowerBound) - - val outputParallelism = df.sparkSession.conf - .getOption(SparkConstants.ChrononOutputParallelismOverride) - .map(_.toInt) - .flatMap(value => if (value > 0) Some(value) else None) - - if (outputParallelism.isDefined) { - logger.info(s"Using custom outputParallelism ${outputParallelism.get}") + if (useAqeRoute) { + if (df.isEmpty) { + logger.info(s"Input DataFrame for table $tableName is empty. Nothing to write.") + return } - val dailyFileCount = outputParallelism.getOrElse(dailyFileCountBounded) - - // finalized shuffle parallelism - val shuffleParallelism = Math.max(dailyFileCount * nonZeroTablePartitionCount, minWriteShuffleParallelism) - val saltCol = "random_partition_salt" - val saltedDf = df.withColumn(saltCol, round(rand() * (dailyFileCount + 1))) + val sortedDf = df.sortWithinPartitions(sortByCols.map(col).toSeq: _*) + val writer = sortedDf.write.mode(saveMode) + writer.insertInto(tableName) - logger.info( - s"repartitioning data for table $tableName by $shuffleParallelism spark tasks into $tablePartitionCount table partitions and $dailyFileCount files per partition") - val (repartitionCols: immutable.Seq[String], partitionSortCols: immutable.Seq[String]) = + logger.info(s"Successfully finished writing to $tableName using AQE.") + } else { + // get row count and table partition count statistics + // to determine shuffle parallelism, count only top-level/first partition column + // assumed to be the date partition. If not given, use default + val partitionCol = partitionCols.headOption.getOrElse(partitionColumn) + val (rowCount: Long, tablePartitionCount: Int) = if (df.schema.fieldNames.contains(partitionCol)) { - (Seq(partitionCol, saltCol), Seq(partitionCol) ++ sortByCols) - } else { (Seq(saltCol), sortByCols) } - logger.info(s"Sorting within partitions with cols: $partitionSortCols") - saltedDf - .repartition(shuffleParallelism, repartitionCols.map(saltedDf.col): _*) - .drop(saltCol) - .sortWithinPartitions(partitionSortCols.map(col): _*) - .write - .mode(saveMode) - .insertInto(tableName) - logger.info(s"Finished writing to $tableName") + if (stats.isDefined && stats.get.partitionRange.wellDefined) { + stats.get.count -> stats.get.partitionRange.partitions.length + } else { + val result = df.select(count(lit(1)), approx_count_distinct(col(partitionCol))).head() + (result.getAs[Long](0), result.getAs[Long](1).toInt) + } + } else { + (df.count(), 1) + } + + // set to one if tablePartitionCount=0 to avoid division by zero + val nonZeroTablePartitionCount = if (tablePartitionCount == 0) 1 else tablePartitionCount + + logger.info(s"$rowCount rows requested to be written into table $tableName") + if (rowCount > 0) { + val columnSizeEstimate = columnSizeEstimator(df.schema) + + // check if spark is running in local mode or cluster mode + val isLocal = sparkSession.conf.get("spark.master").startsWith("local") + + // roughly 1 partition count per 1m rows x 100 columns + val rowCountPerPartition = df.sparkSession.conf + .getOption(SparkConstants.ChrononRowCountPerPartition) + .map(_.toDouble) + .flatMap(value => if (value > 0) Some(value) else None) + .getOrElse(1e8) + + val totalFileCountEstimate = math.ceil(rowCount * columnSizeEstimate / rowCountPerPartition).toInt + val dailyFileCountUpperBound = 2000 + val dailyFileCountLowerBound = if (isLocal) 1 else 10 + val dailyFileCountEstimate = totalFileCountEstimate / nonZeroTablePartitionCount + 1 + val dailyFileCountBounded = + math.max(math.min(dailyFileCountEstimate, dailyFileCountUpperBound), dailyFileCountLowerBound) + + val outputParallelism = df.sparkSession.conf + .getOption(SparkConstants.ChrononOutputParallelismOverride) + .map(_.toInt) + .flatMap(value => if (value > 0) Some(value) else None) + + if (outputParallelism.isDefined) { + logger.info(s"Using custom outputParallelism ${outputParallelism.get}") + } + val dailyFileCount = outputParallelism.getOrElse(dailyFileCountBounded) + + // finalized shuffle parallelism + val shuffleParallelism = Math.max(dailyFileCount * nonZeroTablePartitionCount, minWriteShuffleParallelism) + val saltCol = "random_partition_salt" + val saltedDf = df.withColumn(saltCol, round(rand() * (dailyFileCount + 1))) + + logger.info( + s"repartitioning data for table $tableName by $shuffleParallelism spark tasks into $tablePartitionCount table partitions and $dailyFileCount files per partition") + val (repartitionCols: immutable.Seq[String], partitionSortCols: immutable.Seq[String]) = + if (df.schema.fieldNames.contains(partitionCol)) { + (Seq(partitionCol, saltCol), Seq(partitionCol) ++ sortByCols) + } else { (Seq(saltCol), sortByCols) } + logger.info(s"Sorting within partitions with cols: $partitionSortCols") + saltedDf + .repartition(shuffleParallelism, repartitionCols.map(saltedDf.col): _*) + .drop(saltCol) + .sortWithinPartitions(partitionSortCols.map(col): _*) + .write + .mode(saveMode) + .insertInto(tableName) + logger.info(s"Finished writing to $tableName") + } } }