diff --git a/vortex-compute/src/cast/dvector.rs b/vortex-compute/src/cast/dvector.rs index 2cdcca5b971..9ce57f1444e 100644 --- a/vortex-compute/src/cast/dvector.rs +++ b/vortex-compute/src/cast/dvector.rs @@ -6,10 +6,12 @@ use vortex_dtype::DType; use vortex_dtype::DecimalType; use vortex_dtype::NativeDecimalType; use vortex_dtype::PrecisionScale; +use vortex_dtype::i256; use vortex_dtype::match_each_decimal_value_type; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_vector::Scalar; use vortex_vector::ScalarOps; use vortex_vector::Vector; @@ -113,7 +115,59 @@ impl Cast for DScalar { { Ok(self.clone().into()) } - // TODO(connor): cast to different precision/scale + // TODO(connor): cast to different scale + DType::Decimal(ddt, n) + if ddt.scale() == self.scale() && (n.is_nullable() || self.is_valid()) => + { + let p = ddt.precision(); + if p <= ::MAX_PRECISION { + DScalar::maybe_new( + PrecisionScale::::new(ddt.precision(), ddt.scale()), + self.value().and_then(|v| v.to_i8()), + ) + .map(|ds| ds.into()) + .ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}")) + } else if p <= ::MAX_PRECISION { + DScalar::maybe_new( + PrecisionScale::::new(ddt.precision(), ddt.scale()), + self.value().and_then(|v| v.to_i16()), + ) + .map(|ds| ds.into()) + .ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}")) + } else if p <= ::MAX_PRECISION { + DScalar::maybe_new( + PrecisionScale::::new(ddt.precision(), ddt.scale()), + self.value().and_then(|v| v.to_i32()), + ) + .map(|ds| ds.into()) + .ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}")) + } else if p <= ::MAX_PRECISION { + DScalar::maybe_new( + PrecisionScale::::new(ddt.precision(), ddt.scale()), + self.value().and_then(|v| v.to_i64()), + ) + .map(|ds| ds.into()) + .ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}")) + } else if p <= ::MAX_PRECISION { + DScalar::maybe_new( + PrecisionScale::::new(ddt.precision(), ddt.scale()), + self.value().and_then(|v| v.to_i128()), + ) + .map(|ds| ds.into()) + .ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}")) + } else if p <= ::MAX_PRECISION { + DScalar::maybe_new( + PrecisionScale::::new(ddt.precision(), ddt.scale()), + self.value().and_then(|v| v.to_i256()), + ) + .map(|ds| ds.into()) + .ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}")) + } else { + vortex_bail!( + "Target precision {p} is out of range for supported decimal values" + ) + } + } DType::Decimal(..) => { vortex_bail!( "Casting DScalar to {} with different precision/scale not yet implemented", @@ -126,3 +180,198 @@ impl Cast for DScalar { } } } + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_dtype::DType; + use vortex_dtype::DecimalDType; + use vortex_dtype::DecimalTypeDowncast; + use vortex_dtype::NativeDecimalType; + use vortex_dtype::Nullability; + use vortex_dtype::PrecisionScale; + use vortex_dtype::i256; + use vortex_error::VortexResult; + use vortex_vector::ScalarOps; + use vortex_vector::decimal::DScalar; + + use crate::cast::Cast; + + #[rstest] + #[case(2, 0, 42i8)] + #[case(2, 1, 99i8)] + #[case(2, -1, 10i8)] + fn cast_dscalar_identity( + #[case] precision: u8, + #[case] scale: i8, + #[case] value: i8, + ) -> VortexResult<()> { + let ps = PrecisionScale::::new(precision, scale); + let scalar = DScalar::maybe_new(ps, Some(value)).unwrap(); + let target = DType::Decimal( + DecimalDType::new(precision, scale), + Nullability::NonNullable, + ); + let result = scalar.cast(&target)?; + let ds = result.into_decimal().into_i8(); + assert_eq!(ds.value(), Some(value)); + assert_eq!(ds.precision(), precision); + assert_eq!(ds.scale(), scale); + Ok(()) + } + + #[test] + fn cast_dscalar_null_to_nullable() -> VortexResult<()> { + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, None).unwrap(); + let target = DType::Decimal(DecimalDType::new(2, 0), Nullability::Nullable); + let result = scalar.cast(&target)?; + assert!(!result.as_decimal().is_valid()); + Ok(()) + } + + #[test] + fn cast_dscalar_null_to_non_nullable_fails() { + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, None).unwrap(); + let target = DType::Decimal(DecimalDType::new(2, 0), Nullability::NonNullable); + assert!(scalar.cast(&target).is_err()); + } + + #[rstest] + #[case(2, 4, 42i8)] // i8 -> i16 (precision 2 -> 4) + #[case(2, 9, 99i8)] // i8 -> i32 (precision 2 -> 9) + #[case(2, 18, 10i8)] // i8 -> i64 (precision 2 -> 18) + #[case(2, 38, 55i8)] // i8 -> i128 (precision 2 -> 38) + fn cast_dscalar_upcast_precision( + #[case] src_precision: u8, + #[case] target_precision: u8, + #[case] value: i8, + ) -> VortexResult<()> { + let scale = 0i8; + let ps = PrecisionScale::::new(src_precision, scale); + let scalar = DScalar::maybe_new(ps, Some(value)).unwrap(); + let target = DType::Decimal( + DecimalDType::new(target_precision, scale), + Nullability::NonNullable, + ); + let result = scalar.cast(&target)?; + let ds = result.as_decimal(); + assert!(ds.is_valid()); + assert_eq!(ds.precision(), target_precision); + assert_eq!(ds.scale(), scale); + Ok(()) + } + + #[test] + fn cast_dscalar_i8_to_i16() -> VortexResult<()> { + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap(); + // Precision 4 requires i16 + let target = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable); + let result = scalar.cast(&target)?; + let ds = result.into_decimal().into_i16(); + assert_eq!(ds.value(), Some(42i16)); + assert_eq!(ds.precision(), 4); + Ok(()) + } + + #[test] + fn cast_dscalar_i8_to_i32() -> VortexResult<()> { + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, Some(99i8)).unwrap(); + // Precision 9 requires i32 + let target = DType::Decimal(DecimalDType::new(9, 0), Nullability::NonNullable); + let result = scalar.cast(&target)?; + let ds = result.into_decimal().into_i32(); + assert_eq!(ds.value(), Some(99i32)); + assert_eq!(ds.precision(), 9); + Ok(()) + } + + #[test] + fn cast_dscalar_i16_to_i64() -> VortexResult<()> { + let ps = PrecisionScale::::new(4, 2); + let scalar = DScalar::maybe_new(ps, Some(1234i16)).unwrap(); + // Precision 18 requires i64 + let target = DType::Decimal(DecimalDType::new(18, 2), Nullability::NonNullable); + let result = scalar.cast(&target)?; + let ds = result.into_decimal().into_i64(); + assert_eq!(ds.value(), Some(1234i64)); + assert_eq!(ds.precision(), 18); + assert_eq!(ds.scale(), 2); + Ok(()) + } + + #[test] + fn cast_dscalar_i32_to_i128() -> VortexResult<()> { + let ps = PrecisionScale::::new(9, 0); + let scalar = DScalar::maybe_new(ps, Some(123456789i32)).unwrap(); + // Precision 38 requires i128 + let target = DType::Decimal(DecimalDType::new(38, 0), Nullability::NonNullable); + let result = scalar.cast(&target)?; + let ds = result.into_decimal().into_i128(); + assert_eq!(ds.value(), Some(123456789i128)); + assert_eq!(ds.precision(), 38); + Ok(()) + } + + #[test] + fn cast_dscalar_different_scale_fails() { + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap(); + let target = DType::Decimal(DecimalDType::new(2, 1), Nullability::NonNullable); + assert!(scalar.cast(&target).is_err()); + } + + #[test] + fn cast_dscalar_to_non_decimal_fails() { + use vortex_dtype::PType; + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap(); + let target = DType::Primitive(PType::I32, Nullability::NonNullable); + assert!(scalar.cast(&target).is_err()); + } + + #[test] + fn cast_dscalar_downcast_precision_within_same_type() -> VortexResult<()> { + // Downcast within the same native type (i8 precision 2 -> precision 1) + // should work as long as the value fits + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, Some(9i8)).unwrap(); // value 9 fits in precision 1 + let target = DType::Decimal(DecimalDType::new(1, 0), Nullability::NonNullable); + let result = scalar.cast(&target)?; + let ds = result.into_decimal().into_i8(); + assert_eq!(ds.value(), Some(9i8)); + assert_eq!(ds.precision(), 1); + Ok(()) + } + + #[test] + fn cast_dscalar_downcast_value_too_large_fails() { + // Value 42 doesn't fit in precision 1 (max 9) + let ps = PrecisionScale::::new(2, 0); + let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap(); + let target = DType::Decimal(DecimalDType::new(1, 0), Nullability::NonNullable); + assert!(scalar.cast(&target).is_err()); + } + + #[rstest] + #[case(::MAX_PRECISION)] + #[case(::MAX_PRECISION)] + #[case(::MAX_PRECISION)] + #[case(::MAX_PRECISION)] + #[case(::MAX_PRECISION)] + #[case(::MAX_PRECISION)] + fn cast_dscalar_to_max_precision_boundary(#[case] target_precision: u8) -> VortexResult<()> { + let ps = PrecisionScale::::new(1, 0); + let scalar = DScalar::maybe_new(ps, Some(1i8)).unwrap(); + let target = DType::Decimal( + DecimalDType::new(target_precision, 0), + Nullability::NonNullable, + ); + let result = scalar.cast(&target)?; + assert_eq!(result.as_decimal().precision(), target_precision); + Ok(()) + } +}