diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 907e61b09f7b..71338a6921e9 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -145,54 +145,82 @@ impl DecimalCast for i256 { } } -pub(crate) fn cast_decimal_to_decimal_error( +/// Construct closures to upscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale` +/// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`. +/// In that case, the caller should treat this as an overflow for the output scale +/// and handle it accordingly (e.g., return a cast error). +#[allow(clippy::type_complexity)] +fn make_upscaler( + input_precision: u8, + input_scale: i8, output_precision: u8, output_scale: i8, -) -> impl Fn(::Native) -> ArrowError +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - move |x: I::Native| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - O::PREFIX, - output_precision, - output_scale, - x - )) - } + let delta_scale = output_scale - input_scale; + + // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). + // Adding 1 yields exactly 10^k without computing a power at runtime. + // Using the precomputed table avoids pow(10, k) and its checked/overflow + // handling, which is faster and simpler for scaling by 10^delta_scale. + let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; + let mul = max.add_wrapping(O::Native::ONE); + let f_fallible = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); + + // if the gain in precision (digits) is greater than the multiplication due to scaling + // every number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then an increase of scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type + // needs to provide at least 8 digits precision + let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); + let f_infallible = is_infallible_cast + .then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul)); + Some((f_fallible, f_infallible)) } -pub(crate) fn convert_to_smaller_scale_decimal( - array: &PrimitiveArray, +/// Construct closures to downscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale` +/// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`. +/// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the +/// available precision). Callers should therefore produce zero values (preserving nulls) rather +/// than returning an error. +#[allow(clippy::type_complexity)] +fn make_downscaler( input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8, - cast_options: &CastOptions, -) -> Result, ArrowError> +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); let delta_scale = input_scale - output_scale; - // if the reduction of the input number through scaling (dividing) is greater - // than a possible precision loss (plus potential increase via rounding) - // every input number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then and decrease the scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). - // The rounding may add an additional digit, so the cast to be infallible, - // the output type needs to have at least 3 digits of precision. - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible - let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the // scale change divides out more digits than the input has precision and the result of the cast @@ -200,16 +228,13 @@ where // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even // smaller results, which also round to zero. In that case, just return an array of zeros. - let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize) else { - let zeros = vec![O::Native::ZERO; array.len()]; - return Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())); - }; + let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; let div = max.add_wrapping(I::Native::ONE); let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); let half_neg = half.neg_wrapping(); - let f = |x: I::Native| { + let f_fallible = move |x: I::Native| { // div is >= 10 and so this cannot overflow let d = x.div_wrapping(div); let r = x.mod_wrapping(div); @@ -223,24 +248,136 @@ where O::Native::from_decimal(adjusted) }; - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed - // to fit into the target type - array.unary(g) + // if the reduction of the input number through scaling (dividing) is greater + // than a possible precision loss (plus potential increase via rounding) + // every input number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then and decrease the scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). + // The rounding may add a digit, so the cast to be infallible, + // the output type needs to have at least 3 digits of precision. + // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: + // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible + let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); + let f_infallible = is_infallible_cast.then_some(move |x| f_fallible(x).unwrap()); + Some((f_fallible, f_infallible)) +} + +/// Apply the rescaler function to the value. +/// If the rescaler is infallible, use the infallible function. +/// Otherwise, use the fallible function and validate the precision. +fn apply_rescaler( + value: I::Native, + output_precision: u8, + f: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, +) -> Option +where + I::Native: DecimalCast, + O::Native: DecimalCast, +{ + if let Some(f_infallible) = f_infallible { + Some(f_infallible(value)) + } else { + f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + } +} + +/// Rescales a decimal value from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)` and returns the converted number when it fits +/// within the output precision. +/// +/// The function first validates that the requested precision and scale are supported for +/// both the source and destination decimal types. It then either upscales (multiplying +/// by an appropriate power of ten) or downscales (dividing with rounding) the input value. +/// When the scaling factor exceeds the precision table of the destination type, the value +/// is treated as an overflow for upscaling, or rounded to zero for downscaling (as any +/// possible result would be zero at the requested scale). +/// +/// This mirrors the column-oriented helpers of decimal casting but operates on a single value +/// (row-level) instead of an entire array. +/// +/// Returns `None` if the value cannot be represented with the requested precision. +pub fn rescale_decimal( + value: I::Native, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, +) -> Option +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + validate_decimal_precision_and_scale::(input_precision, input_scale).ok()?; + validate_decimal_precision_and_scale::(output_precision, output_scale).ok()?; + + if input_scale <= output_scale { + let (f, f_infallible) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale)?; + apply_rescaler::(value, output_precision, f, f_infallible) + } else { + let Some((f, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + return Some(O::Native::ZERO); + }; + apply_rescaler::(value, output_precision, f, f_infallible) + } +} + +fn cast_decimal_to_decimal_error( + output_precision: u8, + output_scale: i8, +) -> impl Fn(::Native) -> ArrowError +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + move |x: I::Native| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + O::PREFIX, + output_precision, + output_scale, + x + )) + } +} + +fn apply_decimal_cast( + array: &PrimitiveArray, + output_precision: u8, + output_scale: i8, + f_fallible: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array = if let Some(f_infallible) = f_infallible { + array.unary(f_infallible) } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + array.unary_opt(|x| { + f_fallible(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + }) } else { + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); array.try_unary(|x| { - f(x).ok_or_else(|| error(x)).and_then(|v| { + f_fallible(x).ok_or_else(|| error(x)).and_then(|v| { O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) }) })? - }) + }; + Ok(array) } -pub(crate) fn convert_to_bigger_or_equal_scale_decimal( +fn convert_to_smaller_scale_decimal( array: &PrimitiveArray, input_precision: u8, input_scale: i8, @@ -254,36 +391,58 @@ where I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); - let delta_scale = output_scale - input_scale; - let mul = O::Native::from_decimal(10_i128) - .unwrap() - .pow_checked(delta_scale as u32)?; + if let Some((f_fallible, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f_fallible, + f_infallible, + cast_options, + ) + } else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + let zeros = vec![O::Native::ZERO; array.len()]; + Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())) + } +} - // if the gain in precision (digits) is greater than the multiplication due to scaling - // every number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then an increase of scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type - // needs to provide at least 8 digits precision - let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); - let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); - - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - // unwrapping is safe since the result is guaranteed to fit into the target type - let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul); - array.unary(f) - } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) +fn convert_to_bigger_or_equal_scale_decimal( + array: &PrimitiveArray, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + if let Some((f, f_infallible)) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f, + f_infallible, + cast_options, + ) } else { - array.try_unary(|x| { - f(x).ok_or_else(|| error(x)).and_then(|v| { - O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) - }) - })? - }) + // Scale increase exceeds supported precision; return overflow error + Err(ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Value overflows for output scale", + O::PREFIX, + output_precision, + output_scale + ))) + } } // Only support one type of decimal cast operations @@ -763,4 +922,58 @@ mod tests { ); Ok(()) } + + #[test] + fn test_rescale_decimal_upscale_within_precision() { + let result = rescale_decimal::( + 12_345_i128, // 123.45 with scale 2 + 5, + 2, + 8, + 5, + ); + assert_eq!(result, Some(12_345_000_i128)); + } + + #[test] + fn test_rescale_decimal_downscale_rounds_half_away_from_zero() { + let positive = rescale_decimal::( + 1_050_i128, // 1.050 with scale 3 + 5, 3, 5, 1, + ); + assert_eq!(positive, Some(11_i128)); // 1.1 with scale 1 + + let negative = rescale_decimal::( + -1_050_i128, // -1.050 with scale 3 + 5, + 3, + 5, + 1, + ); + assert_eq!(negative, Some(-11_i128)); // -1.1 with scale 1 + } + + #[test] + fn test_rescale_decimal_downscale_large_delta_returns_zero() { + let result = rescale_decimal::(12_345_i32, 9, 9, 9, 4); + assert_eq!(result, Some(0_i32)); + } + + #[test] + fn test_rescale_decimal_upscale_overflow_returns_none() { + let result = rescale_decimal::(9_999_i32, 4, 0, 5, 2); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_input_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 39, 39, 38, 38); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_output_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 38, 38, 39, 39); + assert_eq!(result, None); + } } diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index fe38298b017c..4c03de7ea1eb 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -67,7 +67,7 @@ use arrow_schema::*; use arrow_select::take::take; use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive}; -pub use decimal::DecimalCast; +pub use decimal::{DecimalCast, rescale_decimal}; /// CastOptions provides a way to override the default cast behaviors #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 28087d7541e4..d15664f5af9e 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -17,8 +17,7 @@ //! Module for transforming a typed arrow `Array` to `VariantArray`. -use arrow::array::ArrowNativeTypeOp; -use arrow::compute::DecimalCast; +use arrow::compute::{DecimalCast, rescale_decimal}; use arrow::datatypes::{ self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type, DecimalType, @@ -190,90 +189,6 @@ where } } -/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale) -/// and return the scaled value if it fits the output precision. Similar to the implementation in -/// decimal.rs in arrow-cast. -pub(crate) fn rescale_decimal( - value: I::Native, - input_precision: u8, - input_scale: i8, - output_precision: u8, - output_scale: i8, -) -> Option -where - I::Native: DecimalCast, - O::Native: DecimalCast, -{ - let delta_scale = output_scale - input_scale; - - let (scaled, is_infallible_cast) = if delta_scale >= 0 { - // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). - // Adding 1 yields exactly 10^k without computing a power at runtime. - // Using the precomputed table avoids pow(10, k) and its checked/overflow - // handling, which is faster and simpler for scaling by 10^delta_scale. - let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; - let mul = max.add_wrapping(O::Native::ONE); - - // if the gain in precision (digits) is greater than the multiplication due to scaling - // every number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then an increase of scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type - // needs to provide at least 8 digits precision - let is_infallible_cast = input_precision as i8 + delta_scale <= output_precision as i8; - let value = O::Native::from_decimal(value); - let scaled = if is_infallible_cast { - Some(value.unwrap().mul_wrapping(mul)) - } else { - value.and_then(|x| x.mul_checked(mul).ok()) - }; - (scaled, is_infallible_cast) - } else { - // the abs of delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. - // If so, the scale change divides out more digits than the input has precision and the result - // of the cast is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, - // the largest possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. - // Smaller values (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) - // produce even smaller results, which also round to zero. In that case, just return zero. - let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale.unsigned_abs() as usize) else { - return Some(O::Native::ZERO); - }; - let div = max.add_wrapping(I::Native::ONE); - let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); - let half_neg = half.neg_wrapping(); - - // div is >= 10 and so this cannot overflow - let d = value.div_wrapping(div); - let r = value.mod_wrapping(div); - - // Round result - let adjusted = match value >= I::Native::ZERO { - true if r >= half => d.add_wrapping(I::Native::ONE), - false if r <= half_neg => d.sub_wrapping(I::Native::ONE), - _ => d, - }; - - // if the reduction of the input number through scaling (dividing) is greater - // than a possible precision loss (plus potential increase via rounding) - // every input number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then and decrease the scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). - // The rounding may add a digit, so for the cast to be infallible, - // the output type needs to have at least 3 digits of precision. - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible - let is_infallible_cast = input_precision as i8 + delta_scale < output_precision as i8; - (O::Native::from_decimal(adjusted), is_infallible_cast) - }; - - if is_infallible_cast { - scaled - } else { - scaled.filter(|v| O::is_valid_decimal_precision(*v, output_precision)) - } -} - /// Convert the value at a specific index in the given array into a `Variant`. macro_rules! non_generic_conversion_single_value { ($array:expr, $cast_fn:expr, $index:expr) => {{