Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vortex-array/benches/cast_primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::expr::stats::Stat;

fn main() {
divan::main();
Expand All @@ -28,6 +29,8 @@ fn cast_u16_to_u32(bencher: Bencher) {
}
}))
.into_array();
// Pre-compute min/max so values_fit_in is a cache hit during the benchmark.
arr.statistics().compute_all(&[Stat::Min, Stat::Max]).ok();
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reasonable to do to avoid a noisy benchmark on the first run. I think it's reasonable to assume these will be computed.

bencher.with_inputs(|| arr.clone()).bench_refs(|a| {
#[expect(clippy::unwrap_used)]
a.cast(DType::Primitive(PType::U32, Nullability::Nullable))
Expand Down
74 changes: 35 additions & 39 deletions vortex-array/src/arrays/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use num_traits::AsPrimitive;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_mask::AllOr;
use vortex_mask::Mask;

use crate::ArrayRef;
use crate::ExecutionCtx;
Expand Down Expand Up @@ -53,20 +51,21 @@ impl CastKernel for Primitive {
}));
}

if !values_fit_in(array, new_ptype, ctx) {
vortex_bail!(
Compute: "Cannot cast {} to {} — values exceed target range",
array.ptype(),
new_ptype,
);
}

// Same-width integers have identical bit representations due to 2's
// complement. If all values fit in the target range, reinterpret with
// no allocation.
if array.ptype().is_int()
&& new_ptype.is_int()
&& array.ptype().byte_width() == new_ptype.byte_width()
{
if !values_fit_in(array, new_ptype, ctx) {
vortex_bail!(
Compute: "Cannot cast {} to {} — values exceed target range",
array.ptype(),
new_ptype,
);
}
// SAFETY: both types are integers with the same size and alignment, and
// min/max confirm all valid values are representable in the target type.
return Ok(Some(unsafe {
Expand All @@ -79,13 +78,10 @@ impl CastKernel for Primitive {
}));
}

let mask = array.validity_mask();

// Otherwise, we need to cast the values one-by-one.
// Otherwise, cast the values element-wise.
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
match_each_native_ptype!(array.ptype(), |F| {
PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
.into_array()
PrimitiveArray::new(cast::<F, T>(array.as_slice()), new_validity).into_array()
})
})))
}
Expand All @@ -104,30 +100,11 @@ fn values_fit_in(
.is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
}

fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
let try_cast = |src: F| -> VortexResult<T> {
T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE))
};
match mask.bit_buffer() {
AllOr::None => Ok(Buffer::zeroed(array.len())),
AllOr::All => {
let mut buffer = BufferMut::with_capacity(array.len());
for &src in array {
// SAFETY: we've pre-allocated the required capacity
unsafe { buffer.push_unchecked(try_cast(src)?) }
}
Ok(buffer.freeze())
}
AllOr::Some(b) => {
let mut buffer = BufferMut::with_capacity(array.len());
for (&src, valid) in array.iter().zip(b.iter()) {
let dst = if valid { try_cast(src)? } else { T::default() };
// SAFETY: we've pre-allocated the required capacity
unsafe { buffer.push_unchecked(dst) }
}
Ok(buffer.freeze())
}
}
/// Caller must ensure all valid values are representable via `values_fit_in`.
/// Out-of-range values at invalid positions are truncated/wrapped by `as`,
/// which is fine because they are masked out by validity.
fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
}

#[cfg(test)]
Expand Down Expand Up @@ -319,6 +296,23 @@ mod test {
Ok(())
}

#[test]
fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
let arr = PrimitiveArray::new(
buffer![1000u32, 10u32, 42u32],
Validity::from_iter([false, true, true]),
);
let casted = arr
.into_array()
.cast(DType::Primitive(PType::U8, Nullability::Nullable))?
.to_primitive();
assert_arrays_eq!(
casted,
PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
);
Ok(())
}

#[rstest]
#[case(buffer![0u8, 1, 2, 3, 255].into_array())]
#[case(buffer![0u16, 100, 1000, 65535].into_array())]
Expand All @@ -329,7 +323,9 @@ mod test {
#[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
#[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
#[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
#[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
#[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
#[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
#[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
#[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
#[case(buffer![42u32].into_array())]
Expand Down
Loading