Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUDI-8783] Add tests for decimal data type #12519

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.hudi.common.util.collection.{CachingIterator, ClosableIterator
import org.apache.hudi.io.storage.{HoodieSparkFileReaderFactory, HoodieSparkParquetReader}
import org.apache.hudi.storage.{HoodieStorage, StorageConfiguration, StoragePath}
import org.apache.hudi.util.CloseableInternalRowIterator

import org.apache.avro.Schema
import org.apache.avro.Schema.Type
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
Expand All @@ -42,10 +41,11 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, SparkParquetReader}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{LongType, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.{DecimalType, LongType, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.mutable
Expand Down Expand Up @@ -263,20 +263,23 @@ class SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
}

override def castValue(value: Comparable[_], newType: Schema.Type): Comparable[_] = {
value match {
val valueToCast = if (value == null) 0 else value
valueToCast match {
case v: Integer => newType match {
case Type.INT => v
case Type.LONG => v.longValue()
case Type.FLOAT => v.floatValue()
case Type.DOUBLE => v.doubleValue()
case Type.STRING => UTF8String.fromString(v.toString)
case Type.FIXED => BigDecimal(v)
Copy link
Contributor

@danny0405 danny0405 Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a special handling for Spark decimal data type for precision < 18 in the UnsafeRow decimal getter:

public Decimal getDecimal(int ordinal, int precision, int scale) {
    if (isNullAt(ordinal)) {
      return null;
    }
    if (precision <= Decimal.MAX_LONG_DIGITS()) {
      return Decimal.createUnsafe(getLong(ordinal), precision, scale);
    } else {
      byte[] bytes = getBinary(ordinal);
      BigInteger bigInteger = new BigInteger(bytes);
      BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
      return Decimal.apply(javaDecimal, precision, scale);
    }
  }

Should we do it here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. let me add it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danny0405 , I tried to cast it, but I think we so far we have lost the precision and scale information from FIXED type here. So we can not do the special treatment.

case x => throw new UnsupportedOperationException(s"Cast from Integer to $x is not supported")
}
case v: java.lang.Long => newType match {
case Type.LONG => v
case Type.FLOAT => v.floatValue()
case Type.DOUBLE => v.doubleValue()
case Type.STRING => UTF8String.fromString(v.toString)
case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Long to $x is not supported")
}
case v: java.lang.Float => newType match {
Expand All @@ -288,6 +291,7 @@ class SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
case v: java.lang.Double => newType match {
case Type.DOUBLE => v
case Type.STRING => UTF8String.fromString(v.toString)
case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Double to $x is not supported")
}
case v: String => newType match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.hudi

import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.common.config.{HoodieReaderConfig, HoodieStorageConfig}
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.testutils.SparkClientFunctionalTestHarness
import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource

class TestDecimalTypeDataWorkflow extends SparkClientFunctionalTestHarness{
val sparkOpts: Map[String, String] = Map(
HoodieStorageConfig.LOGFILE_DATA_BLOCK_FORMAT.key -> "parquet",
HoodieWriteConfig.RECORD_MERGE_IMPL_CLASSES.key -> classOf[DefaultSparkRecordMerger].getName)
val fgReaderOpts: Map[String, String] = Map(
HoodieReaderConfig.FILE_GROUP_READER_ENABLED.key -> "true",
HoodieReaderConfig.MERGE_USE_RECORD_POSITIONS.key -> "true")
val opts = sparkOpts ++ fgReaderOpts

@ParameterizedTest
@CsvSource(value = Array("10,2", "15,5", "20,10", "38,18", "5,0"))
def testDecimalInsertUpdateDeleteRead(precision: String, scale: String): Unit = {
// Create schema
val schema = StructType(Seq(
StructField("id", IntegerType, nullable = true),
StructField(
"decimal_col",
DecimalType(Integer.valueOf(precision), Integer.valueOf(scale)),
nullable = true)))
// Build data conforming to the schema.
val tablePath = basePath
val data: Seq[(Int, Decimal)] = Seq(
(1, Decimal("123.45")),
(2, Decimal("987.65")),
(3, Decimal("-10.23")),
(4, Decimal("0.01")),
(5, Decimal("1000.00")))
val rows = data.map{
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)}
val rddData = spark.sparkContext.parallelize(rows)

// Insert.
val insertDf: DataFrame = spark.sqlContext.createDataFrame(rddData, schema)
.toDF("id", "decimal_col").sort("id")
insertDf.write.format("hudi")
.option(RECORDKEY_FIELD.key(), "id")
.option(PRECOMBINE_FIELD.key(), "decimal_col")
.option(TABLE_TYPE.key, "MERGE_ON_READ")
.option(TABLE_NAME.key, "test_table")
.options(opts)
.mode(SaveMode.Overwrite)
.save(tablePath)

// Update.
val update: Seq[(Int, Decimal)] = Seq(
(1, Decimal("543.21")),
(2, Decimal("111.11")),
(6, Decimal("1001.00")))
val updateRows = update.map {
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
}
val rddUpdate = spark.sparkContext.parallelize(updateRows)
val updateDf: DataFrame = spark.createDataFrame(rddUpdate, schema)
.toDF("id", "decimal_col").sort("id")
updateDf.write.format("hudi")
.option(OPERATION.key(), "upsert")
.options(opts)
.mode(SaveMode.Append)
.save(tablePath)

// Delete.
val delete: Seq[(Int, Decimal)] = Seq(
(3, Decimal("543.21")),
(4, Decimal("111.11")))
val deleteRows = delete.map {
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
}
val rddDelete = spark.sparkContext.parallelize(deleteRows)
val deleteDf: DataFrame = spark.createDataFrame(rddDelete, schema)
.toDF("id", "decimal_col").sort("id")
deleteDf.write.format("hudi")
.option(OPERATION.key(), "delete")
.options(opts)
.mode(SaveMode.Append)
.save(tablePath)

// Asserts
val actual = spark.read.format("hudi").load(tablePath).select("id", "decimal_col")
val expected: Seq[(Int, Decimal)] = Seq(
(1, Decimal("543.21")),
(2, Decimal("987.65")),
(5, Decimal("1000.00")),
(6, Decimal("1001.00")))
val expectedRows = expected.map {
case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
}
val rddExpected = spark.sparkContext.parallelize(expectedRows)
val expectedDf: DataFrame = spark.createDataFrame(rddExpected, schema)
.toDF("id", "decimal_col").sort("id")
val expectedMinusActual = expectedDf.except(actual)
val actualMinusExpected = actual.except(expectedDf)
expectedDf.show(false)
actual.show(false)
expectedMinusActual.show(false)
actualMinusExpected.show(false)
assertTrue(expectedMinusActual.isEmpty && actualMinusExpected.isEmpty)
}
}
Loading