Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
linliu-code committed Dec 19, 2024
1 parent 3c11e55 commit aad4dd7
Showing 1 changed file with 48 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.common.config.{HoodieReaderConfig, HoodieStorageConfig}
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.testutils.SparkClientFunctionalTestHarness
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource

class TestDecimalTypeDataWorkflow extends SparkClientFunctionalTestHarness{
val sparkOpts: Map[String, String] = Map(
Expand All @@ -37,58 +39,86 @@ class TestDecimalTypeDataWorkflow extends SparkClientFunctionalTestHarness{
HoodieReaderConfig.MERGE_USE_RECORD_POSITIONS.key -> "true")
val opts = sparkOpts ++ fgReaderOpts

@Test
def testDecimalInsertUpdateDeleteRead(): Unit = {
@ParameterizedTest
@CsvSource(value = Array("10,2", "20,10", "38,18", "5,0"))
def testDecimalInsertUpdateDeleteRead(precision: String, scale: String): Unit = {
// Create schema
val schema = StructType(Seq(
StructField("id", IntegerType, nullable = true),
StructField(
"decimal_col",
DecimalType(Integer.valueOf(precision), Integer.valueOf(scale)),
nullable = true)))
// Build data conforming to the schema.
val tablePath = basePath
val data: Seq[(Int, Decimal)] = Seq(
(1, Decimal("123.45")),
(2, Decimal("987.65")),
(3, Decimal("-10.23")),
(4, Decimal("0.01")),
(5, Decimal("1000.00")))
val insertDf: DataFrame = spark.createDataFrame(data)
.toDF("id", "decimal_value").sort("id")
val rows = data.map{
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)}
val rddData = spark.sparkContext.parallelize(rows)

// Insert.
val insertDf: DataFrame = spark.sqlContext.createDataFrame(rddData, schema)
.toDF("id", "decimal_col").sort("id")
insertDf.write.format("hudi")
.option(RECORDKEY_FIELD.key(), "id")
.option(PRECOMBINE_FIELD.key(), "decimal_value")
.option(PRECOMBINE_FIELD.key(), "decimal_col")
.option(TABLE_TYPE.key, "MERGE_ON_READ")
.option(TABLE_NAME.key, "test_table")
.options(opts)
.mode(SaveMode.Overwrite)
.save(tablePath)

val update: Seq[(Int, BigDecimal)] = Seq(
(1, BigDecimal("543.21")),
(2, BigDecimal("111.11")),
(6, BigDecimal("1001.00")))
val updateDf: DataFrame = spark.createDataFrame(update)
.toDF("id", "decimal_value").sort("id")
// Update.
val update: Seq[(Int, Decimal)] = Seq(
(1, Decimal("543.21")),
(2, Decimal("111.11")),
(6, Decimal("1001.00")))
val updateRows = update.map {
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
}
val rddUpdate = spark.sparkContext.parallelize(updateRows)
val updateDf: DataFrame = spark.createDataFrame(rddUpdate, schema)
.toDF("id", "decimal_col").sort("id")
updateDf.write.format("hudi")
.option(OPERATION.key(), "upsert")
.options(opts)
.mode(SaveMode.Append)
.save(tablePath)

// Delete.
val delete: Seq[(Int, Decimal)] = Seq(
(3, Decimal("543.21")),
(4, Decimal("111.11")))
val deleteDf: DataFrame = spark.createDataFrame(delete)
.toDF("id", "decimal_value").sort("id")
val deleteRows = delete.map {
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
}
val rddDelete = spark.sparkContext.parallelize(deleteRows)
val deleteDf: DataFrame = spark.createDataFrame(rddDelete, schema)
.toDF("id", "decimal_col").sort("id")
deleteDf.write.format("hudi")
.option(OPERATION.key(), "delete")
.options(opts)
.mode(SaveMode.Append)
.save(tablePath)

// Asserts
val actual = spark.read.format("hudi").load(tablePath).select("id", "decimal_value")
val actual = spark.read.format("hudi").load(tablePath).select("id", "decimal_col")
val expected: Seq[(Int, Decimal)] = Seq(
(1, Decimal("543.21")),
(2, Decimal("987.65")),
(5, Decimal("1000.00")),
(6, Decimal("1001.00")))
val expectedDf: DataFrame = spark.createDataFrame(expected)
.toDF("id", "decimal_value").sort("id")
val expectedRows = expected.map {
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
}
val rddExpected = spark.sparkContext.parallelize(expectedRows)
val expectedDf: DataFrame = spark.createDataFrame(rddExpected, schema)
.toDF("id", "decimal_col").sort("id")
val expectedMinusActual = expectedDf.except(actual)
val actualMinusExpected = actual.except(expectedDf)
expectedDf.show(false)
Expand Down

0 comments on commit aad4dd7

Please sign in to comment.