Skip to content

Commit

Permalink
Test decimal workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
linliu-code committed Dec 19, 2024
1 parent cb447c9 commit 3c11e55
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.hudi.common.util.collection.{CachingIterator, ClosableIterator
import org.apache.hudi.io.storage.{HoodieSparkFileReaderFactory, HoodieSparkParquetReader}
import org.apache.hudi.storage.{HoodieStorage, StorageConfiguration, StoragePath}
import org.apache.hudi.util.CloseableInternalRowIterator

import org.apache.avro.Schema
import org.apache.avro.Schema.Type
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
Expand All @@ -42,10 +41,11 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, SparkParquetReader}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{LongType, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.{DecimalType, LongType, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.mutable
Expand Down Expand Up @@ -263,20 +263,23 @@ class SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
}

override def castValue(value: Comparable[_], newType: Schema.Type): Comparable[_] = {
value match {
val valueToCast = if (value == null) 0 else value
valueToCast match {
case v: Integer => newType match {
case Type.INT => v
case Type.LONG => v.longValue()
case Type.FLOAT => v.floatValue()
case Type.DOUBLE => v.doubleValue()
case Type.STRING => UTF8String.fromString(v.toString)
case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Integer to $x is not supported")
}
case v: java.lang.Long => newType match {
case Type.LONG => v
case Type.FLOAT => v.floatValue()
case Type.DOUBLE => v.doubleValue()
case Type.STRING => UTF8String.fromString(v.toString)
case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Long to $x is not supported")
}
case v: java.lang.Float => newType match {
Expand All @@ -288,6 +291,7 @@ class SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
case v: java.lang.Double => newType match {
case Type.DOUBLE => v
case Type.STRING => UTF8String.fromString(v.toString)
case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Double to $x is not supported")
}
case v: String => newType match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.hudi

import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.common.config.{HoodieReaderConfig, HoodieStorageConfig}
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.testutils.SparkClientFunctionalTestHarness
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.{DataFrame, SaveMode}
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test

class TestDecimalTypeDataWorkflow extends SparkClientFunctionalTestHarness{
val sparkOpts: Map[String, String] = Map(
HoodieStorageConfig.LOGFILE_DATA_BLOCK_FORMAT.key -> "parquet",
HoodieWriteConfig.RECORD_MERGE_IMPL_CLASSES.key -> classOf[DefaultSparkRecordMerger].getName)
val fgReaderOpts: Map[String, String] = Map(
HoodieReaderConfig.FILE_GROUP_READER_ENABLED.key -> "true",
HoodieReaderConfig.MERGE_USE_RECORD_POSITIONS.key -> "true")
val opts = sparkOpts ++ fgReaderOpts

@Test
def testDecimalInsertUpdateDeleteRead(): Unit = {
val tablePath = basePath
val data: Seq[(Int, Decimal)] = Seq(
(1, Decimal("123.45")),
(2, Decimal("987.65")),
(3, Decimal("-10.23")),
(4, Decimal("0.01")),
(5, Decimal("1000.00")))
val insertDf: DataFrame = spark.createDataFrame(data)
.toDF("id", "decimal_value").sort("id")
insertDf.write.format("hudi")
.option(RECORDKEY_FIELD.key(), "id")
.option(PRECOMBINE_FIELD.key(), "decimal_value")
.option(TABLE_TYPE.key, "MERGE_ON_READ")
.option(TABLE_NAME.key, "test_table")
.options(opts)
.mode(SaveMode.Overwrite)
.save(tablePath)

val update: Seq[(Int, BigDecimal)] = Seq(
(1, BigDecimal("543.21")),
(2, BigDecimal("111.11")),
(6, BigDecimal("1001.00")))
val updateDf: DataFrame = spark.createDataFrame(update)
.toDF("id", "decimal_value").sort("id")
updateDf.write.format("hudi")
.option(OPERATION.key(), "upsert")
.options(opts)
.mode(SaveMode.Append)
.save(tablePath)

val delete: Seq[(Int, Decimal)] = Seq(
(3, Decimal("543.21")),
(4, Decimal("111.11")))
val deleteDf: DataFrame = spark.createDataFrame(delete)
.toDF("id", "decimal_value").sort("id")
deleteDf.write.format("hudi")
.option(OPERATION.key(), "delete")
.options(opts)
.mode(SaveMode.Append)
.save(tablePath)

// Asserts
val actual = spark.read.format("hudi").load(tablePath).select("id", "decimal_value")
val expected: Seq[(Int, Decimal)] = Seq(
(1, Decimal("543.21")),
(2, Decimal("987.65")),
(5, Decimal("1000.00")),
(6, Decimal("1001.00")))
val expectedDf: DataFrame = spark.createDataFrame(expected)
.toDF("id", "decimal_value").sort("id")
val expectedMinusActual = expectedDf.except(actual)
val actualMinusExpected = actual.except(expectedDf)
expectedDf.show(false)
actual.show(false)
expectedMinusActual.show(false)
actualMinusExpected.show(false)
assertTrue(expectedMinusActual.isEmpty && actualMinusExpected.isEmpty)
}
}

0 comments on commit 3c11e55

Please sign in to comment.