From 63f75e3642c3e26dc597e55aa2f0fe424d6b428f Mon Sep 17 00:00:00 2001 From: Alfonso Subiotto Marques Date: Sun, 12 Apr 2026 22:53:10 +0200 Subject: [PATCH] perf[vortex-array]: use from_trusted_len_iter in primitive casts Some of our scan profiles show 10% of scan cpu time is spent in integer widening casts (nullable dictionary codes). This commit simplifies primitive casts by hoisting a lot of hot loop branching logic. Specifically, this commit relies on values_fit_in to verify representability so that we can avoid a potential validity and error check in the hot loop. Additionally from_trusted_len_iter lets the destination BufferMut optimize the actual cast instead of using push_unchecked for each element. Signed-off-by: Alfonso Subiotto Marques --- vortex-array/benches/cast_primitive.rs | 3 + .../src/arrays/primitive/compute/cast.rs | 74 +++++++++---------- 2 files changed, 38 insertions(+), 39 deletions(-) diff --git a/vortex-array/benches/cast_primitive.rs b/vortex-array/benches/cast_primitive.rs index 3e1b0831736..be200ef137d 100644 --- a/vortex-array/benches/cast_primitive.rs +++ b/vortex-array/benches/cast_primitive.rs @@ -9,6 +9,7 @@ use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; +use vortex_array::expr::stats::Stat; fn main() { divan::main(); @@ -28,6 +29,8 @@ fn cast_u16_to_u32(bencher: Bencher) { } })) .into_array(); + // Pre-compute min/max so values_fit_in is a cache hit during the benchmark. + arr.statistics().compute_all(&[Stat::Min, Stat::Max]).ok(); bencher.with_inputs(|| arr.clone()).bench_refs(|a| { #[expect(clippy::unwrap_used)] a.cast(DType::Primitive(PType::U32, Nullability::Nullable)) diff --git a/vortex-array/src/arrays/primitive/compute/cast.rs b/vortex-array/src/arrays/primitive/compute/cast.rs index dc98495ef11..13cba60bfa6 100644 --- a/vortex-array/src/arrays/primitive/compute/cast.rs +++ b/vortex-array/src/arrays/primitive/compute/cast.rs @@ -1,13 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::AsPrimitive; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_mask::AllOr; -use vortex_mask::Mask; use crate::ArrayRef; use crate::ExecutionCtx; @@ -53,6 +51,14 @@ impl CastKernel for Primitive { })); } + if !values_fit_in(array, new_ptype, ctx) { + vortex_bail!( + Compute: "Cannot cast {} to {} — values exceed target range", + array.ptype(), + new_ptype, + ); + } + // Same-width integers have identical bit representations due to 2's // complement. If all values fit in the target range, reinterpret with // no allocation. @@ -60,13 +66,6 @@ impl CastKernel for Primitive { && new_ptype.is_int() && array.ptype().byte_width() == new_ptype.byte_width() { - if !values_fit_in(array, new_ptype, ctx) { - vortex_bail!( - Compute: "Cannot cast {} to {} — values exceed target range", - array.ptype(), - new_ptype, - ); - } // SAFETY: both types are integers with the same size and alignment, and // min/max confirm all valid values are representable in the target type. return Ok(Some(unsafe { @@ -79,13 +78,10 @@ impl CastKernel for Primitive { })); } - let mask = array.validity_mask(); - - // Otherwise, we need to cast the values one-by-one. + // Otherwise, cast the values element-wise. Ok(Some(match_each_native_ptype!(new_ptype, |T| { match_each_native_ptype!(array.ptype(), |F| { - PrimitiveArray::new(cast::(array.as_slice(), mask)?, new_validity) - .into_array() + PrimitiveArray::new(cast::(array.as_slice()), new_validity).into_array() }) }))) } @@ -104,30 +100,11 @@ fn values_fit_in( .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok()) } -fn cast(array: &[F], mask: Mask) -> VortexResult> { - let try_cast = |src: F| -> VortexResult { - T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE)) - }; - match mask.bit_buffer() { - AllOr::None => Ok(Buffer::zeroed(array.len())), - AllOr::All => { - let mut buffer = BufferMut::with_capacity(array.len()); - for &src in array { - // SAFETY: we've pre-allocated the required capacity - unsafe { buffer.push_unchecked(try_cast(src)?) } - } - Ok(buffer.freeze()) - } - AllOr::Some(b) => { - let mut buffer = BufferMut::with_capacity(array.len()); - for (&src, valid) in array.iter().zip(b.iter()) { - let dst = if valid { try_cast(src)? } else { T::default() }; - // SAFETY: we've pre-allocated the required capacity - unsafe { buffer.push_unchecked(dst) } - } - Ok(buffer.freeze()) - } - } +/// Caller must ensure all valid values are representable via `values_fit_in`. +/// Out-of-range values at invalid positions are truncated/wrapped by `as`, +/// which is fine because they are masked out by validity. +fn cast, T: NativePType>(array: &[F]) -> Buffer { + BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze() } #[cfg(test)] @@ -319,6 +296,23 @@ mod test { Ok(()) } + #[test] + fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> { + let arr = PrimitiveArray::new( + buffer![1000u32, 10u32, 42u32], + Validity::from_iter([false, true, true]), + ); + let casted = arr + .into_array() + .cast(DType::Primitive(PType::U8, Nullability::Nullable))? + .to_primitive(); + assert_arrays_eq!( + casted, + PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)]) + ); + Ok(()) + } + #[rstest] #[case(buffer![0u8, 1, 2, 3, 255].into_array())] #[case(buffer![0u16, 100, 1000, 65535].into_array())] @@ -329,7 +323,9 @@ mod test { #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())] #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())] #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())] + #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())] #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())] + #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())] #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())] #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())] #[case(buffer![42u32].into_array())]