diff --git a/vortex-array/src/arrays/primitive/compute/take/avx2.rs b/vortex-array/src/arrays/primitive/compute/take/avx2.rs index fc63fd48d66..239e668ca99 100644 --- a/vortex-array/src/arrays/primitive/compute/take/avx2.rs +++ b/vortex-array/src/arrays/primitive/compute/take/avx2.rs @@ -44,7 +44,7 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::primitive::PrimitiveArray; use crate::arrays::primitive::compute::take::TakeImpl; -use crate::arrays::primitive::compute::take::take_primitive_scalar; +use crate::arrays::primitive::compute::take::scalar::take_primitive_scalar; use crate::validity::Validity; #[allow(unused)] diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index b740d6cb03d..ea4a4e65024 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -6,21 +6,16 @@ mod avx2; #[cfg(vortex_nightly)] mod portable; +mod scalar; use std::sync::LazyLock; -use vortex_buffer::Buffer; use vortex_dtype::DType; -use vortex_dtype::IntegerPType; -use vortex_dtype::NativePType; -use vortex_dtype::match_each_integer_ptype; -use vortex_dtype::match_each_native_ptype; use vortex_error::VortexResult; use vortex_error::vortex_bail; use crate::Array; use crate::ArrayRef; -use crate::IntoArray; use crate::ToCanonical; use crate::arrays::PrimitiveVTable; use crate::arrays::primitive::PrimitiveArray; @@ -44,11 +39,11 @@ static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| if is_x86_feature_detected!("avx2") { &avx2::TakeKernelAVX2 } else { - &TakeKernelScalar + &scalar::TakeKernelScalar } } else { // stable all other platforms: scalar kernel - &TakeKernelScalar + &scalar::TakeKernelScalar } } }); @@ -62,25 +57,6 @@ trait TakeImpl: Send + Sync { ) -> VortexResult; } -#[allow(unused)] -struct TakeKernelScalar; - -impl TakeImpl for TakeKernelScalar { - fn take( - &self, - array: &PrimitiveArray, - indices: &PrimitiveArray, - validity: Validity, - ) -> VortexResult { - match_each_native_ptype!(array.ptype(), |T| { - match_each_integer_ptype!(indices.ptype(), |I| { - let values = take_primitive_scalar(array.as_slice::(), indices.as_slice::()); - Ok(PrimitiveArray::new(values, validity).into_array()) - }) - }) - } -} - impl TakeKernel for PrimitiveVTable { fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult { let DType::Primitive(ptype, null) = indices.dtype() else { @@ -102,13 +78,6 @@ impl TakeKernel for PrimitiveVTable { register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift()); -// Compiler may see this as unused based on enabled features -#[allow(unused)] -#[inline(always)] -fn take_primitive_scalar(array: &[T], indices: &[I]) -> Buffer { - indices.iter().map(|idx| array[idx.as_()]).collect() -} - #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[cfg(test)] mod test { diff --git a/vortex-array/src/arrays/primitive/compute/take/scalar.rs b/vortex-array/src/arrays/primitive/compute/take/scalar.rs new file mode 100644 index 00000000000..c84ffe73981 --- /dev/null +++ b/vortex-array/src/arrays/primitive/compute/take/scalar.rs @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::Buffer; +use vortex_dtype::IntegerPType; +use vortex_dtype::NativePType; +use vortex_dtype::match_each_integer_ptype; +use vortex_dtype::match_each_native_ptype; +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::PrimitiveArray; +use crate::arrays::primitive::compute::take::TakeImpl; +use crate::validity::Validity; +use crate::vtable::ValidityHelper; + +#[allow(unused)] +pub(super) struct TakeKernelScalar; + +impl TakeImpl for TakeKernelScalar { + #[allow(clippy::cognitive_complexity)] + fn take( + &self, + array: &PrimitiveArray, + indices: &PrimitiveArray, + validity: Validity, + ) -> VortexResult { + match_each_native_ptype!(array.ptype(), |T| { + match_each_integer_ptype!(indices.ptype(), |I| { + let indices_slice = indices.as_slice::(); + let indices_validity = indices.validity(); + let values = if indices_validity.all_valid(indices_slice.len()) { + // Fast path: indices have no nulls, safe to index directly + take_primitive_scalar(array.as_slice::(), indices_slice) + } else { + // Slow path: indices may have nulls with garbage values + take_primitive_scalar_with_nulls( + array.as_slice::(), + indices_slice, + indices_validity, + ) + }; + Ok(PrimitiveArray::new(values, validity).into_array()) + }) + }) + } +} + +// Compiler may see this as unused based on enabled features +#[allow(unused)] +#[inline(always)] +pub(super) fn take_primitive_scalar( + array: &[T], + indices: &[I], +) -> Buffer { + indices.iter().map(|idx| array[idx.as_()]).collect() +} + +/// Slow path for take when indices may contain nulls with garbage values. +/// Uses 0 as a safe index for null positions (the value will be masked out by validity). +#[allow(unused)] +#[inline(always)] +fn take_primitive_scalar_with_nulls( + array: &[T], + indices: &[I], + validity: &Validity, +) -> Buffer { + indices + .iter() + .enumerate() + .map(|(i, idx)| { + if validity.is_valid(i) { + array[idx.as_()] + } else { + T::zero() + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::IntoArray; + use crate::ToCanonical; + use crate::arrays::PrimitiveArray; + use crate::arrays::primitive::compute::take::TakeImpl; + use crate::arrays::primitive::compute::take::scalar::TakeKernelScalar; + use crate::validity::Validity; + + #[test] + fn test_scalar_basic() { + let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive(); + let indices = buffer![0, 1, 1, 2, 2, 3, 4].into_array().to_primitive(); + + let result = TakeKernelScalar + .take(&values, &indices, Validity::NonNullable) + .unwrap() + .to_primitive(); + assert_eq!(result.as_slice::(), &[1, 2, 2, 3, 3, 4, 5]); + } + + #[test] + fn test_scalar_with_nulls() { + let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive(); + let validity = Validity::from_iter([true, false, true, true, true]); + let indices = PrimitiveArray::new(buffer![0, 100, 2, 3, 4], validity.clone()); + + let result = TakeKernelScalar + .take(&values, &indices, validity.clone()) + .unwrap() + .to_primitive(); + + assert_eq!(result.as_slice::(), &[1, 0, 3, 4, 5]); + assert_eq!(result.validity, validity); + } +} diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 1c9e35e7e7f..5a69aabc690 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -1218,7 +1218,8 @@ mod test { let primitive_values = taken.values().to_primitive(); let primitive_indices = taken.indices().to_primitive(); assert_eq!(taken.array_len(), 2); - assert_eq!(primitive_values.as_slice::(), [44, 33]); + assert_eq!(primitive_values.scalar_at(0), Some(44i32).into()); + assert_eq!(primitive_values.scalar_at(1), Option::::None.into()); assert_eq!(primitive_indices.as_slice::(), [0, 1]); assert_eq!(