diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index a2348528313..e70969085a6 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -718,7 +718,7 @@ mod tests { ); if let Some(expected_val) = expected_value { - let buf = result_primitive.buffer::(); + let buf = result_primitive.to_buffer::(); let result_val = buf.as_slice()[idx]; assert_eq!(result_val, expected_val, "Value mismatch at idx={idx}",); } diff --git a/encodings/bytebool/src/array.rs b/encodings/bytebool/src/array.rs index 0227c285da3..aa09130874e 100644 --- a/encodings/bytebool/src/array.rs +++ b/encodings/bytebool/src/array.rs @@ -97,7 +97,7 @@ impl VTable for ByteBoolVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let buffer = buffers[0].clone().try_to_bytes()?; + let buffer = buffers[0].clone().try_to_host()?; Ok(ByteBoolArray::new(buffer, validity)) } diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs index 462fb2aefc1..2232d9a86bc 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs @@ -238,7 +238,7 @@ impl CanonicalVTable for DecimalBytePartsVTable { // The decimal dtype matches the array's dtype, and validity is preserved. Canonical::Decimal(unsafe { DecimalArray::new_unchecked( - prim.buffer::

(), + prim.to_buffer::

(), *array.decimal_dtype(), prim.validity().clone(), ) diff --git a/encodings/fastlanes/src/bitpacking/vtable/mod.rs b/encodings/fastlanes/src/bitpacking/vtable/mod.rs index 4b7ebe8fb91..e47a6d92f0d 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/mod.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/mod.rs @@ -176,7 +176,7 @@ impl VTable for BitPackedVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let packed = buffers[0].clone().try_to_bytes()?; + let packed = buffers[0].clone().try_to_host()?; let load_validity = |child_idx: usize| { if children.len() == child_idx { diff --git a/encodings/fsst/src/array.rs b/encodings/fsst/src/array.rs index 68752e09e72..78973336fef 100644 --- a/encodings/fsst/src/array.rs +++ b/encodings/fsst/src/array.rs @@ -116,8 +116,8 @@ impl VTable for FSSTVTable { if buffers.len() != 2 { vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len()); } - let symbols = Buffer::::from_byte_buffer(buffers[0].clone().try_to_bytes()?); - let symbol_lengths = Buffer::::from_byte_buffer(buffers[1].clone().try_to_bytes()?); + let symbols = Buffer::::from_byte_buffer(buffers[0].clone().try_to_host()?); + let symbol_lengths = Buffer::::from_byte_buffer(buffers[1].clone().try_to_host()?); if children.len() != 2 { vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len()); diff --git a/encodings/fsst/src/kernel.rs b/encodings/fsst/src/kernel.rs index 3fd0f4bdadc..93a334bdd26 100644 --- a/encodings/fsst/src/kernel.rs +++ b/encodings/fsst/src/kernel.rs @@ -86,7 +86,7 @@ impl ExecuteParentKernel for FSSTFilterKernel { .cast(DType::Primitive(PType::U32, Nullability::NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::(); + .to_buffer::(); let decompressor = array.decompressor(); @@ -97,7 +97,7 @@ impl ExecuteParentKernel for FSSTFilterKernel { &codes_offsets, mask_values, &validity, - &uncompressed_lens.buffer::(), + &uncompressed_lens.to_buffer::(), ) }); diff --git a/encodings/pco/src/array.rs b/encodings/pco/src/array.rs index 6d265901995..399f6a6fe0d 100644 --- a/encodings/pco/src/array.rs +++ b/encodings/pco/src/array.rs @@ -141,11 +141,11 @@ impl VTable for PcoVTable { vortex_ensure!(buffers.len() >= metadata.0.chunks.len()); let chunk_metas = buffers[..metadata.0.chunks.len()] .iter() - .map(|b| b.clone().try_to_bytes()) + .map(|b| b.clone().try_to_host()) .collect::>>()?; let pages = buffers[metadata.0.chunks.len()..] .iter() - .map(|b| b.clone().try_to_bytes()) + .map(|b| b.clone().try_to_host()) .collect::>>()?; let expected_n_pages = metadata @@ -293,7 +293,7 @@ impl PcoArray { number_type, NumberType => { let chunk_end = cmp::min(n_values, chunk_start + values_per_chunk); - let values = values.buffer::(); + let values = values.to_buffer::(); let chunk = &values.as_slice()[chunk_start..chunk_end]; fc .chunk_compressor(chunk, &chunk_config) diff --git a/encodings/sparse/src/canonical.rs b/encodings/sparse/src/canonical.rs index 1b5bd301b30..ad50877db88 100644 --- a/encodings/sparse/src/canonical.rs +++ b/encodings/sparse/src/canonical.rs @@ -449,7 +449,7 @@ fn canonicalize_varbin( let len = array.len(); match_each_integer_ptype!(indices.ptype(), |I| { - let indices = indices.buffer::(); + let indices = indices.to_buffer::(); canonicalize_varbin_inner::(fill_value, indices, values, dtype, validity, len) }) } diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 1cbaa5a992b..88337677f50 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -137,7 +137,7 @@ impl VTable for SparseVTable { } let fill_value = Scalar::new( dtype.clone(), - ScalarValue::from_protobytes(&buffers[0].clone().try_to_bytes()?)?, + ScalarValue::from_protobytes(&buffers[0].clone().try_to_host()?)?, ); SparseArray::try_new(patch_indices, patch_values, len, fill_value) diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index 3b9a7c27623..88a9571775d 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -142,16 +142,16 @@ impl VTable for ZstdVTable { None, buffers .iter() - .map(|b| b.clone().try_to_bytes()) + .map(|b| b.clone().try_to_host()) .collect::>>()?, ) } else { // with dictionary ( - Some(buffers[0].clone().try_to_bytes()?), + Some(buffers[0].clone().try_to_host()?), buffers[1..] .iter() - .map(|b| b.clone().try_to_bytes()) + .map(|b| b.clone().try_to_host()) .collect::>>()?, ) }; @@ -366,7 +366,7 @@ impl ZstdArray { n_values }; - let value_bytes = values.byte_buffer(); + let value_bytes = values.buffer_handle().try_to_host()?; // Align frames to buffer alignment. This is necessary for overaligned buffers. let alignment = *value_bytes.alignment(); let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment; @@ -379,7 +379,7 @@ impl ZstdArray { frames, frame_metas, } = Self::compress_values( - value_bytes, + &value_bytes, &frame_byte_starts, level, values_per_frame, diff --git a/fuzz/src/array/fill_null.rs b/fuzz/src/array/fill_null.rs index adb52d844c4..e56f14c6a37 100644 --- a/fuzz/src/array/fill_null.rs +++ b/fuzz/src/array/fill_null.rs @@ -90,12 +90,9 @@ fn fill_primitive_array( .vortex_expect("fill value conversion should succeed in fuzz test"); match array.validity() { - Validity::NonNullable | Validity::AllValid => PrimitiveArray::from_byte_buffer( - array.byte_buffer().clone(), - array.ptype(), - result_nullability.into(), - ) - .into_array(), + Validity::NonNullable | Validity::AllValid => { + PrimitiveArray::new(array.to_buffer::(), result_nullability.into()).into_array() + } Validity::AllInvalid => { ConstantArray::new(fill_value.clone(), array.len()).into_array() } diff --git a/fuzz/src/array/mask.rs b/fuzz/src/array/mask.rs index fcff725c954..0d05c53b23e 100644 --- a/fuzz/src/array/mask.rs +++ b/fuzz/src/array/mask.rs @@ -36,8 +36,8 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult { let new_validity = array.validity().mask(mask); - PrimitiveArray::from_byte_buffer( - array.byte_buffer().clone(), + PrimitiveArray::from_buffer_handle( + array.buffer_handle().clone(), array.ptype(), new_validity, ) diff --git a/fuzz/src/array/slice.rs b/fuzz/src/array/slice.rs index eaa22a5dcf1..b3061312358 100644 --- a/fuzz/src/array/slice.rs +++ b/fuzz/src/array/slice.rs @@ -41,10 +41,11 @@ pub fn slice_canonical_array( DType::Primitive(p, _) => { let primitive_array = array.to_primitive(); match_each_native_ptype!(p, |P| { - Ok( - PrimitiveArray::new(primitive_array.buffer::

().slice(start..stop), validity) - .into_array(), + Ok(PrimitiveArray::new( + primitive_array.to_buffer::

().slice(start..stop), + validity, ) + .into_array()) }) } DType::Utf8(_) | DType::Binary(_) => { diff --git a/vortex-array/src/array/visitor.rs b/vortex-array/src/array/visitor.rs index 5bc398db6b5..9f3c4454b9b 100644 --- a/vortex-array/src/array/visitor.rs +++ b/vortex-array/src/array/visitor.rs @@ -11,6 +11,7 @@ use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; +use crate::buffer::BufferHandle; use crate::patches::Patches; use crate::validity::Validity; @@ -113,6 +114,10 @@ pub trait ArrayVisitorExt: Array { impl ArrayVisitorExt for A {} pub trait ArrayBufferVisitor { + fn visit_buffer_handle(&mut self, handle: &BufferHandle) -> VortexResult<()> { + self.visit_buffer(&handle.clone().try_to_host()?); + Ok(()) + } fn visit_buffer(&mut self, buffer: &ByteBuffer); } diff --git a/vortex-array/src/arrays/bool/vtable/mod.rs b/vortex-array/src/arrays/bool/vtable/mod.rs index 2dc17114c93..b29dad2e7da 100644 --- a/vortex-array/src/arrays/bool/vtable/mod.rs +++ b/vortex-array/src/arrays/bool/vtable/mod.rs @@ -103,7 +103,7 @@ impl VTable for BoolVTable { vortex_bail!("Expected 0 or 1 child, got {}", children.len()); }; - let buffer = buffers[0].clone().try_to_bytes()?; + let buffer = buffers[0].clone().try_to_host()?; let bits = BitBuffer::new_with_offset(buffer, len, metadata.offset as usize); BoolArray::try_new(bits, validity) diff --git a/vortex-array/src/arrays/chunked/array.rs b/vortex-array/src/arrays/chunked/array.rs index 9f41b7e6847..63af487a89c 100644 --- a/vortex-array/src/arrays/chunked/array.rs +++ b/vortex-array/src/arrays/chunked/array.rs @@ -117,7 +117,7 @@ impl ChunkedArray { #[inline] pub fn chunk_offsets(&self) -> Buffer { - self.chunk_offsets.buffer() + self.chunk_offsets.to_buffer() } pub(crate) fn find_chunk_idx(&self, index: usize) -> (usize, usize) { diff --git a/vortex-array/src/arrays/chunked/vtable/mod.rs b/vortex-array/src/arrays/chunked/vtable/mod.rs index 7331561c00a..9c4ecfce7a2 100644 --- a/vortex-array/src/arrays/chunked/vtable/mod.rs +++ b/vortex-array/src/arrays/chunked/vtable/mod.rs @@ -98,7 +98,7 @@ impl VTable for ChunkedVTable { )? .to_primitive(); - let chunk_offsets_buf = chunk_offsets_array.buffer::(); + let chunk_offsets_buf = chunk_offsets_array.to_buffer::(); // The remaining children contain the actual data of the chunks let chunks = chunk_offsets_buf @@ -139,7 +139,7 @@ impl VTable for ChunkedVTable { let nchunks = children.len() - 1; let chunk_offsets_array = children[0].to_primitive(); - let chunk_offsets_buf = chunk_offsets_array.buffer::(); + let chunk_offsets_buf = chunk_offsets_array.to_buffer::(); vortex_ensure!( chunk_offsets_buf.len() == nchunks + 1, diff --git a/vortex-array/src/arrays/constant/vtable/mod.rs b/vortex-array/src/arrays/constant/vtable/mod.rs index 3671dacacba..8e8e3e7fd73 100644 --- a/vortex-array/src/arrays/constant/vtable/mod.rs +++ b/vortex-array/src/arrays/constant/vtable/mod.rs @@ -81,7 +81,7 @@ impl VTable for ConstantVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let buffer = buffers[0].clone().try_to_bytes()?; + let buffer = buffers[0].clone().try_to_host()?; let sv = ScalarValue::from_protobytes(&buffer)?; let scalar = Scalar::new(dtype.clone(), sv); Ok(ConstantArray::new(scalar, len)) diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index 1f47438efd3..1163c67595b 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -95,7 +95,7 @@ impl VTable for DecimalVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let buffer = buffers[0].clone().try_to_bytes()?; + let buffer = buffers[0].clone().try_to_host()?; let validity = if children.is_empty() { Validity::from(dtype.nullability()) diff --git a/vortex-array/src/arrays/extension/vtable/rules.rs b/vortex-array/src/arrays/extension/vtable/rules.rs index be9aa9f5612..7801ae1b94d 100644 --- a/vortex-array/src/arrays/extension/vtable/rules.rs +++ b/vortex-array/src/arrays/extension/vtable/rules.rs @@ -104,7 +104,7 @@ mod tests { assert_eq!(ext_result.ext_dtype().as_ref(), ext_dtype.as_ref()); // Check the storage values - let storage_result: &[i64] = &ext_result.storage().to_primitive().buffer::(); + let storage_result: &[i64] = &ext_result.storage().to_primitive().to_buffer::(); assert_eq!(storage_result, &[1, 3, 5]); } diff --git a/vortex-array/src/arrays/masked/execute.rs b/vortex-array/src/arrays/masked/execute.rs index e6d5bf456b3..02d7a387215 100644 --- a/vortex-array/src/arrays/masked/execute.rs +++ b/vortex-array/src/arrays/masked/execute.rs @@ -66,7 +66,14 @@ fn mask_validity_primitive(array: PrimitiveArray, mask: &Mask) -> PrimitiveArray let len = array.len(); let ptype = array.ptype(); let new_validity = combine_validity(array.validity(), mask, len); - PrimitiveArray::from_byte_buffer(array.into_byte_buffer(), ptype, new_validity) + // SAFETY: validity has same length as values + unsafe { + PrimitiveArray::new_unchecked_from_handle( + array.buffer_handle().clone(), + ptype, + new_validity, + ) + } } fn mask_validity_decimal(array: DecimalArray, mask: &Mask) -> DecimalArray { diff --git a/vortex-array/src/arrays/primitive/array/cast.rs b/vortex-array/src/arrays/primitive/array/cast.rs index c559f0e7a8d..74051542c79 100644 --- a/vortex-array/src/arrays/primitive/array/cast.rs +++ b/vortex-array/src/arrays/primitive/array/cast.rs @@ -5,6 +5,7 @@ use vortex_buffer::Buffer; use vortex_dtype::DType; use vortex_dtype::NativePType; use vortex_dtype::PType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_panic; @@ -18,6 +19,10 @@ impl PrimitiveArray { /// Return a slice of the array's buffer. /// /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null. + /// + /// # Panic + /// + /// This operation will panic if the array is not backed by host memory. pub fn as_slice(&self) -> &[T] { if T::PTYPE != self.ptype() { vortex_panic!( @@ -26,11 +31,15 @@ impl PrimitiveArray { self.ptype() ) } - let raw_slice = self.byte_buffer().as_ptr(); + + let byte_buffer = self + .buffer + .as_host_opt() + .vortex_expect("as_slice must be called on host buffer"); + let raw_slice = byte_buffer.as_ptr(); + // SAFETY: alignment of Buffer is checked on construction - unsafe { - std::slice::from_raw_parts(raw_slice.cast(), self.byte_buffer().len() / size_of::()) - } + unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::()) } } pub fn reinterpret_cast(&self, ptype: PType) -> Self { @@ -44,7 +53,11 @@ impl PrimitiveArray { "can't reinterpret cast between integers of two different widths" ); - PrimitiveArray::from_byte_buffer(self.byte_buffer().clone(), ptype, self.validity().clone()) + PrimitiveArray::from_buffer_handle( + self.buffer_handle().clone(), + ptype, + self.validity().clone(), + ) } /// Narrow the array to the smallest possible integer type that can represent all values. diff --git a/vortex-array/src/arrays/primitive/array/conversion.rs b/vortex-array/src/arrays/primitive/array/conversion.rs index 50db4459edc..8bfbd8a52d5 100644 --- a/vortex-array/src/arrays/primitive/array/conversion.rs +++ b/vortex-array/src/arrays/primitive/array/conversion.rs @@ -19,7 +19,6 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::PrimitiveArray; use crate::validity::Validity; -use crate::vtable::ValidityHelper; impl PrimitiveArray { /// Attempts to create a `PrimitiveArray` from a [`PrimitiveVector`] given a [`Nullability`]. @@ -72,7 +71,10 @@ impl PrimitiveArray { Self::new(values.freeze(), Validity::from(validity.freeze())) } - pub fn buffer(&self) -> Buffer { + /// Get a buffer in host memory holding all the values. + /// + /// NOTE: some values may be nonsense if the validity buffer indicates that the value is null. + pub fn to_buffer(&self) -> Buffer { if T::PTYPE != self.ptype() { vortex_panic!( "Attempted to get buffer of type {} from array of type {}", @@ -80,9 +82,11 @@ impl PrimitiveArray { self.ptype() ) } - Buffer::from_byte_buffer(self.byte_buffer().clone()) + + Buffer::from_byte_buffer(self.buffer_handle().to_host()) } + /// Consume the array and get a host Buffer containing the data values. pub fn into_buffer(self) -> Buffer { if T::PTYPE != self.ptype() { vortex_panic!( @@ -91,26 +95,19 @@ impl PrimitiveArray { self.ptype() ) } - Buffer::from_byte_buffer(self.buffer) + + Buffer::from_byte_buffer(self.buffer.into_host()) } /// Extract a mutable buffer from the PrimitiveArray. Attempts to do this with zero-copy /// if the buffer is uniquely owned, otherwise will make a copy. pub fn into_buffer_mut(self) -> BufferMut { - if T::PTYPE != self.ptype() { - vortex_panic!( - "Attempted to get buffer_mut of type {} from array of type {}", - T::PTYPE, - self.ptype() - ) - } - self.into_buffer() - .try_into_mut() + self.try_into_buffer_mut() .unwrap_or_else(|buffer| BufferMut::::copy_from(&buffer)) } /// Try to extract a mutable buffer from the PrimitiveArray with zero copy. - pub fn try_into_buffer_mut(self) -> Result, PrimitiveArray> { + pub fn try_into_buffer_mut(self) -> Result, Buffer> { if T::PTYPE != self.ptype() { vortex_panic!( "Attempted to get buffer_mut of type {} from array of type {}", @@ -118,10 +115,8 @@ impl PrimitiveArray { self.ptype() ) } - let validity = self.validity().clone(); - Buffer::::from_byte_buffer(self.into_byte_buffer()) - .try_into_mut() - .map_err(|buffer| PrimitiveArray::new(buffer, validity)) + let buffer = Buffer::::from_byte_buffer(self.buffer.into_host()); + buffer.try_into_mut() } } diff --git a/vortex-array/src/arrays/primitive/array/mod.rs b/vortex-array/src/arrays/primitive/array/mod.rs index f1f4234eae4..2cb835f4101 100644 --- a/vortex-array/src/arrays/primitive/array/mod.rs +++ b/vortex-array/src/arrays/primitive/array/mod.rs @@ -31,6 +31,8 @@ mod top_value; pub use patch::chunk_range; pub use patch::patch_chunk; +use crate::buffer::BufferHandle; + /// A primitive array that stores [native types][vortex_dtype::NativePType] in a contiguous buffer /// of memory, along with an optional validity child. /// @@ -65,13 +67,32 @@ pub use patch::patch_chunk; #[derive(Clone, Debug)] pub struct PrimitiveArray { pub(super) dtype: DType, - pub(super) buffer: ByteBuffer, + pub(super) buffer: BufferHandle, pub(super) validity: Validity, pub(super) stats_set: ArrayStats, } // TODO(connor): There are a lot of places where we could be using `new_unchecked` in the codebase. impl PrimitiveArray { + /// Create a new array from a buffer handle. + /// + /// # Safety + /// + /// Should ensure that the provided BufferHandle points at sufficiently large region of aligned + /// memory to hold the `ptype` values. + pub unsafe fn new_unchecked_from_handle( + handle: BufferHandle, + ptype: PType, + validity: Validity, + ) -> Self { + Self { + buffer: handle, + dtype: DType::Primitive(ptype, validity.nullability()), + validity, + stats_set: ArrayStats::default(), + } + } + /// Creates a new [`PrimitiveArray`]. /// /// # Panics @@ -119,7 +140,7 @@ impl PrimitiveArray { Self { dtype: DType::Primitive(T::PTYPE, validity.nullability()), - buffer: buffer.into_byte_buffer(), + buffer: BufferHandle::new_host(buffer.into_byte_buffer()), validity, stats_set: Default::default(), } @@ -145,17 +166,33 @@ impl PrimitiveArray { pub fn empty(nullability: Nullability) -> Self { Self::new(Buffer::::empty(), nullability.into()) } +} + +impl PrimitiveArray { + /// Consume the primitive array and returns its component parts. + pub fn into_parts(self) -> (DType, BufferHandle, Validity, ArrayStats) { + (self.dtype, self.buffer, self.validity, self.stats_set) + } +} +impl PrimitiveArray { pub fn ptype(&self) -> PType { self.dtype().as_ptype() } - pub fn byte_buffer(&self) -> &ByteBuffer { + /// Get access to the buffer handle backing the array. + pub fn buffer_handle(&self) -> &BufferHandle { &self.buffer } - pub fn into_byte_buffer(self) -> ByteBuffer { - self.buffer + pub fn from_buffer_handle(handle: BufferHandle, ptype: PType, validity: Validity) -> Self { + let dtype = DType::Primitive(ptype, validity.nullability()); + Self { + buffer: handle, + dtype, + validity, + stats_set: ArrayStats::default(), + } } pub fn from_byte_buffer(buffer: ByteBuffer, ptype: PType, validity: Validity) -> Self { @@ -206,7 +243,7 @@ impl PrimitiveArray { let validity = self.validity().clone(); let buffer = match self.try_into_buffer_mut() { Ok(buffer_mut) => buffer_mut.map_each_in_place(f), - Err(parray) => BufferMut::::from_iter(parray.buffer::().iter().copied().map(f)), + Err(buffer) => BufferMut::from_iter(buffer.iter().copied().map(f)), }; PrimitiveArray::new(buffer.freeze(), validity) } @@ -223,7 +260,7 @@ impl PrimitiveArray { { let validity = self.validity(); - let buf_iter = self.buffer::().into_iter(); + let buf_iter = self.to_buffer::().into_iter(); let buffer = match &validity { Validity::NonNullable | Validity::AllValid => { diff --git a/vortex-array/src/arrays/primitive/compute/cast.rs b/vortex-array/src/arrays/primitive/compute/cast.rs index cd2f82446e6..ceadfbea78b 100644 --- a/vortex-array/src/arrays/primitive/compute/cast.rs +++ b/vortex-array/src/arrays/primitive/compute/cast.rs @@ -35,14 +35,15 @@ impl CastKernel for PrimitiveVTable { // If the bit width is the same, we can short-circuit and simply update the validity if array.ptype() == new_ptype { - return Ok(Some( - PrimitiveArray::from_byte_buffer( - array.byte_buffer().clone(), + // SAFETY: validity and data buffer still have same length + return Ok(Some(unsafe { + PrimitiveArray::new_unchecked_from_handle( + array.buffer_handle().clone(), array.ptype(), new_validity, ) - .into_array(), - )); + .into_array() + })); } let mask = array.validity_mask(); diff --git a/vortex-array/src/arrays/primitive/compute/fill_null.rs b/vortex-array/src/arrays/primitive/compute/fill_null.rs index 4b23de81bcb..8a3f5013e66 100644 --- a/vortex-array/src/arrays/primitive/compute/fill_null.rs +++ b/vortex-array/src/arrays/primitive/compute/fill_null.rs @@ -27,7 +27,7 @@ impl FillNullKernel for PrimitiveVTable { Validity::Array(is_valid) => { let is_invalid = is_valid.to_bool().bit_buffer().not(); match_each_native_ptype!(array.ptype(), |T| { - let mut buffer = array.buffer::().into_mut(); + let mut buffer = array.to_buffer::().into_mut(); let fill_value = fill_value .as_primitive() .typed_value::() diff --git a/vortex-array/src/arrays/primitive/compute/mask.rs b/vortex-array/src/arrays/primitive/compute/mask.rs index 8a701ce72a8..545d0847618 100644 --- a/vortex-array/src/arrays/primitive/compute/mask.rs +++ b/vortex-array/src/arrays/primitive/compute/mask.rs @@ -16,10 +16,16 @@ use crate::vtable::ValidityHelper; impl MaskKernel for PrimitiveVTable { fn mask(&self, array: &PrimitiveArray, mask: &Mask) -> VortexResult { let validity = array.validity().mask(mask); - Ok( - PrimitiveArray::from_byte_buffer(array.byte_buffer().clone(), array.ptype(), validity) - .into_array(), - ) + + // SAFETY: validity and data buffer still have same length + Ok(unsafe { + PrimitiveArray::new_unchecked_from_handle( + array.buffer_handle().clone(), + array.ptype(), + validity, + ) + .into_array() + }) } } diff --git a/vortex-array/src/arrays/primitive/vtable/array.rs b/vortex-array/src/arrays/primitive/vtable/array.rs index 3b0b5aafbb6..15a749ad793 100644 --- a/vortex-array/src/arrays/primitive/vtable/array.rs +++ b/vortex-array/src/arrays/primitive/vtable/array.rs @@ -16,7 +16,7 @@ use crate::vtable::BaseArrayVTable; impl BaseArrayVTable for PrimitiveVTable { fn len(array: &PrimitiveArray) -> usize { - array.byte_buffer().len() / array.ptype().byte_width() + array.buffer_handle().len() / array.ptype().byte_width() } fn dtype(array: &PrimitiveArray) -> &DType { diff --git a/vortex-array/src/arrays/primitive/vtable/mod.rs b/vortex-array/src/arrays/primitive/vtable/mod.rs index b79e04cc37a..5093bcef0c1 100644 --- a/vortex-array/src/arrays/primitive/vtable/mod.rs +++ b/vortex-array/src/arrays/primitive/vtable/mod.rs @@ -1,11 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_buffer::Alignment; -use vortex_buffer::Buffer; use vortex_dtype::DType; use vortex_dtype::PType; -use vortex_dtype::match_each_native_ptype; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -31,6 +28,7 @@ mod validity; mod visitor; pub use rules::PrimitiveMaskedValidityRule; +use vortex_buffer::Alignment; use crate::arrays::primitive::vtable::rules::RULES; use crate::vtable::ArrayId; @@ -82,7 +80,7 @@ impl VTable for PrimitiveVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let buffer = buffers[0].clone().try_to_bytes()?; + let buffer = buffers[0].clone(); let validity = if children.is_empty() { Validity::from(dtype.nullability()) @@ -95,12 +93,6 @@ impl VTable for PrimitiveVTable { let ptype = PType::try_from(dtype)?; - if !buffer.is_aligned(Alignment::new(ptype.byte_width())) { - vortex_bail!( - "Buffer is not aligned to {}-byte boundary", - ptype.byte_width() - ); - } if buffer.len() != ptype.byte_width() * len { vortex_bail!( "Buffer length {} does not match expected length {} for {}, {}", @@ -111,10 +103,23 @@ impl VTable for PrimitiveVTable { ); } - match_each_native_ptype!(ptype, |P| { - let buffer = Buffer::

::from_byte_buffer(buffer); - Ok(PrimitiveArray::new(buffer, validity)) - }) + // For host buffers, we eagerly check alignment on construction. + // TODO(aduffy): check for device buffers. CUDA buffers are generally 256-byte aligned, + // but not sure about other devices. + if let Some(host_buf) = buffer.as_host_opt() { + vortex_ensure!( + host_buf.is_aligned(Alignment::new(ptype.byte_width())), + "PrimitiveArray::build: Buffer must be aligned to {}", + ptype.byte_width() + ); + } + + // SAFETY: checked ahead of time + unsafe { + Ok(PrimitiveArray::new_unchecked_from_handle( + buffer, ptype, validity, + )) + } } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { diff --git a/vortex-array/src/arrays/primitive/vtable/operations.rs b/vortex-array/src/arrays/primitive/vtable/operations.rs index e0374acdbaa..9cd846ca460 100644 --- a/vortex-array/src/arrays/primitive/vtable/operations.rs +++ b/vortex-array/src/arrays/primitive/vtable/operations.rs @@ -15,9 +15,10 @@ use crate::vtable::ValidityHelper; impl OperationsVTable for PrimitiveVTable { fn slice(array: &PrimitiveArray, range: Range) -> ArrayRef { - match_each_native_ptype!(array.ptype(), |T| { - PrimitiveArray::new( - array.buffer::().slice(range.clone()), + match_each_native_ptype!(array.ptype(), |P| { + PrimitiveArray::from_buffer_handle( + array.buffer.slice_typed::

(range.clone()), + array.ptype(), array.validity().slice(range), ) .into_array() diff --git a/vortex-array/src/arrays/primitive/vtable/rules.rs b/vortex-array/src/arrays/primitive/vtable/rules.rs index 3971453491e..3b736889f8d 100644 --- a/vortex-array/src/arrays/primitive/vtable/rules.rs +++ b/vortex-array/src/arrays/primitive/vtable/rules.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_buffer::Buffer; use vortex_dtype::match_each_native_ptype; use vortex_error::VortexResult; @@ -42,11 +41,11 @@ impl ArrayParentReduceRule for PrimitiveMaskedValidityRule { // Merge the parent's validity mask into the child's validity // TODO(joe): make this lazy let masked_array = match_each_native_ptype!(array.ptype(), |T| { - // SAFETY: Since we are only flipping some bits in the validity, all invariants that - // were upheld are still upheld. + // SAFETY: masking validity does not change PrimitiveArray invariants unsafe { - PrimitiveArray::new_unchecked( - Buffer::::from_byte_buffer(array.byte_buffer().clone()), + PrimitiveArray::new_unchecked_from_handle( + array.buffer_handle().clone(), + array.ptype(), array.validity().clone().and(parent.validity().clone()), ) } diff --git a/vortex-array/src/arrays/primitive/vtable/visitor.rs b/vortex-array/src/arrays/primitive/vtable/visitor.rs index eaad838dd1e..65910a17658 100644 --- a/vortex-array/src/arrays/primitive/vtable/visitor.rs +++ b/vortex-array/src/arrays/primitive/vtable/visitor.rs @@ -10,7 +10,7 @@ use crate::vtable::VisitorVTable; impl VisitorVTable for PrimitiveVTable { fn visit_buffers(array: &PrimitiveArray, visitor: &mut dyn ArrayBufferVisitor) { - visitor.visit_buffer(array.byte_buffer()); + visitor.visit_buffer(&array.buffer_handle().to_host()); } fn visit_children(array: &PrimitiveArray, visitor: &mut dyn ArrayChildVisitor) { diff --git a/vortex-array/src/arrays/varbin/vtable/mod.rs b/vortex-array/src/arrays/varbin/vtable/mod.rs index 5e15bccabd4..f679569760f 100644 --- a/vortex-array/src/arrays/varbin/vtable/mod.rs +++ b/vortex-array/src/arrays/varbin/vtable/mod.rs @@ -103,7 +103,7 @@ impl VTable for VarBinVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let bytes = buffers[0].clone().try_to_bytes()?; + let bytes = buffers[0].clone().try_to_host()?; VarBinArray::try_new(offsets, bytes, dtype.clone(), validity) } diff --git a/vortex-array/src/arrays/varbinview/vtable/mod.rs b/vortex-array/src/arrays/varbinview/vtable/mod.rs index 701e891d315..bc5f086d38a 100644 --- a/vortex-array/src/arrays/varbinview/vtable/mod.rs +++ b/vortex-array/src/arrays/varbinview/vtable/mod.rs @@ -83,7 +83,7 @@ impl VTable for VarBinViewVTable { } let mut buffers: Vec = buffers .iter() - .map(|b| b.clone().try_to_bytes()) + .map(|b| b.clone().try_to_host()) .collect::>>()?; let views = buffers.pop().vortex_expect("buffers non-empty"); diff --git a/vortex-array/src/arrow/executor/byte.rs b/vortex-array/src/arrow/executor/byte.rs index 0908deb7868..bfa0d102b1d 100644 --- a/vortex-array/src/arrow/executor/byte.rs +++ b/vortex-array/src/arrow/executor/byte.rs @@ -55,7 +55,7 @@ where .cast(DType::Primitive(T::Offset::PTYPE, Nullability::NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::() + .to_buffer::() .into_arrow_offset_buffer(); let data = array.bytes().clone().into_arrow_buffer(); diff --git a/vortex-array/src/arrow/executor/list.rs b/vortex-array/src/arrow/executor/list.rs index c8e5493484c..dc20a25878d 100644 --- a/vortex-array/src/arrow/executor/list.rs +++ b/vortex-array/src/arrow/executor/list.rs @@ -104,7 +104,7 @@ fn list_to_list( .cast(DType::Primitive(O::PTYPE, Nullability::NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::() + .to_buffer::() .into_arrow_offset_buffer(); let elements = array @@ -149,7 +149,7 @@ fn list_view_zctl( .cast(DType::Primitive(O::PTYPE, Nullability::NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::(); + .to_buffer::(); // List arrays need one extra element in the offsets buffer to signify the end of the last list. // If the offsets original came from a list, chances are there is already capacity for this! @@ -194,12 +194,12 @@ fn list_view_to_list( .cast(DType::Primitive(O::PTYPE, Nullability::NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::(); + .to_buffer::(); let sizes = sizes .cast(DType::Primitive(O::PTYPE, Nullability::NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::(); + .to_buffer::(); // We create a new offsets buffer for the final list array. // And we also create an `indices` buffer for taking the elements. diff --git a/vortex-array/src/arrow/executor/list_view.rs b/vortex-array/src/arrow/executor/list_view.rs index de171ad1899..b95af1fa5d0 100644 --- a/vortex-array/src/arrow/executor/list_view.rs +++ b/vortex-array/src/arrow/executor/list_view.rs @@ -75,13 +75,13 @@ fn list_view_to_list_view( .cast(DType::Primitive(O::PTYPE, NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::() + .to_buffer::() .into_arrow_scalar_buffer(); let sizes = sizes .cast(DType::Primitive(O::PTYPE, NonNullable))? .execute::(ctx)? .into_primitive() - .buffer::() + .to_buffer::() .into_arrow_scalar_buffer(); let null_buffer = to_arrow_null_buffer(&validity, offsets.len(), ctx)?; diff --git a/vortex-array/src/arrow/executor/temporal.rs b/vortex-array/src/arrow/executor/temporal.rs index 3c621c47a37..4677a62f48a 100644 --- a/vortex-array/src/arrow/executor/temporal.rs +++ b/vortex-array/src/arrow/executor/temporal.rs @@ -136,7 +136,7 @@ where ); let validity = primitive.validity_mask(); - let buffer = primitive.buffer::(); + let buffer = primitive.to_buffer::(); let values = buffer.into_arrow_scalar_buffer(); let nulls = to_null_buffer(validity); diff --git a/vortex-array/src/buffer.rs b/vortex-array/src/buffer.rs index 370a2501eb9..313a4359805 100644 --- a/vortex-array/src/buffer.rs +++ b/vortex-array/src/buffer.rs @@ -3,18 +3,35 @@ use std::fmt::Debug; use std::hash::Hash; +use std::hash::Hasher; +use std::ops::Range; use std::sync::Arc; use vortex_buffer::ByteBuffer; +use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_utils::dyn_traits::DynEq; use vortex_utils::dyn_traits::DynHash; -/// A buffer can be either on the CPU or on an attached device (e.g. GPU). -/// The Device implementation will come later. +use crate::ArrayEq; +use crate::ArrayHash; +use crate::Precision; + +/// A handle to a buffer allocation. +/// +/// There are two kinds of buffer allocations supported: +/// +/// * **host** allocations, which were allocated by the global allocator and reside in main memory +/// * **device** allocations, which are remote to the CPU and live on a GPU or other external +/// device. +/// +/// A device allocation can be copied to the host, yielding a new [`ByteBuffer`] containing the +/// copied data. Copying can fail at runtime, error recovery is system-dependent. +#[derive(Debug, Clone)] +pub struct BufferHandle(Inner); + #[derive(Debug, Clone)] -pub enum BufferHandle { +enum Inner { /// On the host/cpu. Host(ByteBuffer), /// On the device. @@ -32,11 +49,19 @@ pub trait DeviceBuffer: 'static + Send + Sync + Debug + DynEq + DynHash { } /// Attempts to copy the device buffer to a host ByteBuffer. - fn to_host(self: Arc) -> VortexResult; + /// + /// # Errors + /// + /// This operation may fail, depending on the device implementation and the underlying hardware. + fn copy_to_host(&self) -> VortexResult; + + /// Create a new buffer that references a subrange of this buffer at the given + /// slice indices. + fn slice(&self, range: Range) -> Arc; } impl Hash for dyn DeviceBuffer { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.dyn_hash(state); } } @@ -49,28 +74,199 @@ impl PartialEq for dyn DeviceBuffer { impl Eq for dyn DeviceBuffer {} impl BufferHandle { - /// Fetches the cpu buffer and fails otherwise. - pub fn bytes(&self) -> &ByteBuffer { - match self { - BufferHandle::Host(b) => b, - BufferHandle::Device(_) => todo!(), + /// Create a new handle to a host [`ByteBuffer`]. + pub fn new_host(byte_buffer: ByteBuffer) -> Self { + BufferHandle(Inner::Host(byte_buffer)) + } + + /// Create a new handle to a memory allocation that exists on an external device. + /// + /// Allocations on external devices are not cheaply accessible from the CPU and most be copied + /// into new memory when we read them. + pub fn new_device(device: Arc) -> Self { + BufferHandle(Inner::Device(device)) + } +} + +impl BufferHandle { + /// Gets the size of the buffer, in bytes. + pub fn len(&self) -> usize { + match &self.0 { + Inner::Host(bytes) => bytes.len(), + Inner::Device(device) => device.len(), + } + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Creates a new handle to a subrange of memory at the given bind range. + /// + /// + /// # Example + /// + /// ``` + /// # use vortex_array::buffer::BufferHandle; + /// # use vortex_buffer::buffer; + /// let handle1 = BufferHandle::new_host(buffer![1u8,2,3,4]); + /// let handle2 = handle1.slice(1..4); + /// assert_eq!(handle2.unwrap_host(), buffer![2u8,3,4]); + /// ``` + pub fn slice(&self, range: Range) -> Self { + match &self.0 { + Inner::Host(host) => BufferHandle::new_host(host.slice(range)), + Inner::Device(device) => BufferHandle::new_device(device.slice(range)), + } + } + + /// Reinterpret the pointee as a buffer of `T` and slice the provided element range. + /// + /// # Example + /// + /// ``` + /// # use vortex_array::buffer::BufferHandle; + /// # use vortex_buffer::{buffer, Buffer}; + /// let values = buffer![1u32, 2u32, 3u32, 4u32]; + /// let handle = BufferHandle::new_host(values.into_byte_buffer()); + /// let sliced = handle.slice_typed::(1..4); + /// let result = Buffer::::from_byte_buffer(sliced.to_host()); + /// assert_eq!(result, buffer![2, 3, 4]); + /// ``` + pub fn slice_typed(&self, range: Range) -> Self { + let start = range.start * size_of::(); + let end = range.end * size_of::(); + + self.slice(start..end) + } + + #[allow(clippy::panic)] + /// Unwraps the handle as host memory. + /// + /// # Panics + /// + /// This will panic if the handle points to device memory. + pub fn unwrap_host(self) -> ByteBuffer { + match self.0 { + Inner::Host(b) => b, + Inner::Device(_) => panic!("unwrap_host called for Device allocation"), } } - /// Fetches the cpu buffer and fails otherwise. - pub fn into_bytes(self) -> ByteBuffer { - match self { - BufferHandle::Host(b) => b, - BufferHandle::Device(_) => todo!(), + #[allow(clippy::panic)] + /// Unwraps the handle as device memory. + /// + /// # Panics + /// + /// This will panic if the handle points to host memory. + pub fn unwrap_device(self) -> Arc { + match self.0 { + Inner::Device(b) => b, + Inner::Host(_) => panic!("unwrap_device called for Host allocation"), } } - /// Attempts to convert this handle into a CPU ByteBuffer. - /// Returns an error if the buffer is on a device. - pub fn try_to_bytes(self) -> VortexResult { - match self { - BufferHandle::Host(b) => Ok(b), - BufferHandle::Device(_) => vortex_bail!("cannot move device_buffer to buffer"), + /// Downcast this handle as a handle to a host-resident buffer, or `None`. + pub fn as_host_opt(&self) -> Option<&ByteBuffer> { + match &self.0 { + Inner::Host(buffer) => Some(buffer), + Inner::Device(_) => None, + } + } + + /// Downcast this handle as a handle to a device buffer, or `None`. + pub fn as_device_opt(&self) -> Option<&Arc> { + match &self.0 { + Inner::Host(_) => None, + Inner::Device(device) => Some(device), + } + } + + /// Returns a host-resident copy of the data in the buffer. + /// + /// If the data was already host-resident, this is trivial. + /// + /// If the data was device-resident, data will be copied from the device to a new allocation + /// on the host. + /// + /// # Panics + /// + /// This function will never panic if the data is already host-resident. + /// + /// For a device-resident handle, any errors triggered by the copying from device to host will + /// result in a panic. + /// + /// See also: [`try_to_host`][Self::try_to_host]. + pub fn to_host(&self) -> ByteBuffer { + self.try_to_host() + .vortex_expect("to_host: copy from device to host failed") + } + + /// Returns a host-resident copy of the data behind the handle, consuming the handle. + /// + /// If the data was already host-resident, this completes trivially. + /// + /// See also [`to_host`][Self::to_host]. + /// + /// # Panics + /// + /// See the panic documentation on [`to_host`][Self::to_host]. + pub fn into_host(self) -> ByteBuffer { + self.try_into_host() + .vortex_expect("into_host: copy from device to host failed") + } + + /// Attempts to load this buffer into a host-resident allocation. + /// + /// If the allocation is already host-resident, this trivially completes with success. + /// + /// If it is a device allocation, then this issues an operation that attempts to copy the data + /// from the device into a host-resident buffer, and returns a handle to that buffer. + pub fn try_to_host(&self) -> VortexResult { + match &self.0 { + Inner::Host(b) => Ok(b.clone()), + Inner::Device(device) => device.copy_to_host(), + } + } + + /// Attempts to load this buffer into a host-resident allocation, consuming the handle. + /// + /// See also [`try_to_host`][Self::try_to_host]. + pub fn try_into_host(self) -> VortexResult { + match self.0 { + Inner::Host(b) => Ok(b), + Inner::Device(device) => device.copy_to_host(), + } + } +} + +impl ArrayHash for BufferHandle { + // TODO(aduffy): implement for array hash + fn array_hash(&self, state: &mut H, precision: Precision) { + match &self.0 { + Inner::Host(host) => host.array_hash(state, precision), + Inner::Device(dev) => match precision { + Precision::Ptr => { + Arc::as_ptr(dev).hash(state); + } + Precision::Value => { + dev.hash(state); + } + }, + } + } +} + +impl ArrayEq for BufferHandle { + fn array_eq(&self, other: &Self, precision: Precision) -> bool { + match (&self.0, &other.0) { + (Inner::Host(b), Inner::Host(b2)) => b.array_eq(b2, precision), + (Inner::Device(b), Inner::Device(b2)) => match precision { + Precision::Ptr => Arc::ptr_eq(b, b2), + Precision::Value => b.eq(b2), + }, + _ => false, } } } diff --git a/vortex-array/src/canonical_to_vector.rs b/vortex-array/src/canonical_to_vector.rs index 4a0c36a15ea..c6cf44da215 100644 --- a/vortex-array/src/canonical_to_vector.rs +++ b/vortex-array/src/canonical_to_vector.rs @@ -54,8 +54,8 @@ impl Canonical { let ptype = a.ptype(); let validity = a.validity_mask(); match_each_native_ptype!(ptype, |T| { - let buffer = a.as_slice::(); - Vector::Primitive(PVector::::new(buffer.to_vec().into(), validity).into()) + let buffer = a.to_buffer::(); + Vector::Primitive(PVector::::new(buffer, validity).into()) }) } Canonical::Decimal(a) => { @@ -127,14 +127,10 @@ impl Canonical { match_each_native_ptype!(offsets_ptype, |O| { match_each_native_ptype!(sizes_ptype, |S| { - let offsets_vec = PVector::::new( - offsets.as_slice::().to_vec().into(), - offsets.validity_mask(), - ); - let sizes_vec = PVector::::new( - sizes.as_slice::().to_vec().into(), - sizes.validity_mask(), - ); + let offsets_vec = + PVector::::new(offsets.to_buffer::(), offsets.validity_mask()); + let sizes_vec = + PVector::::new(sizes.to_buffer::(), sizes.validity_mask()); Vector::List(unsafe { ListViewVector::new_unchecked( Arc::new(elements_vector), diff --git a/vortex-array/src/compute/conformance/take.rs b/vortex-array/src/compute/conformance/take.rs index dd82e9ccc70..886cbffb77c 100644 --- a/vortex-array/src/compute/conformance/take.rs +++ b/vortex-array/src/compute/conformance/take.rs @@ -64,7 +64,10 @@ fn test_take_all(array: &dyn Array) { // Verify elements match match (&array.to_canonical(), &result.to_canonical()) { (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => { - assert_eq!(orig_prim.byte_buffer(), result_prim.byte_buffer()); + assert_eq!( + orig_prim.buffer_handle().to_host(), + result_prim.buffer_handle().to_host() + ); } _ => { // For non-primitive types, check scalar values diff --git a/vortex-array/src/hash.rs b/vortex-array/src/hash.rs index b88a92212d7..2ac27bd8d05 100644 --- a/vortex-array/src/hash.rs +++ b/vortex-array/src/hash.rs @@ -4,7 +4,6 @@ use std::any::Any; use std::hash::Hash; use std::hash::Hasher; -use std::sync::Arc; use vortex_buffer::BitBuffer; use vortex_buffer::Buffer; @@ -12,7 +11,6 @@ use vortex_mask::Mask; use crate::Array; use crate::ArrayRef; -use crate::buffer::BufferHandle; use crate::patches::Patches; use crate::validity::Validity; @@ -255,25 +253,3 @@ impl ArrayEq for Patches { && self.values().array_eq(other.values(), precision) } } - -impl ArrayHash for BufferHandle { - fn array_hash(&self, state: &mut H, precision: Precision) { - match self { - BufferHandle::Host(b) => b.array_hash(state, precision), - BufferHandle::Device(_) => (), - } - } -} - -impl ArrayEq for BufferHandle { - fn array_eq(&self, other: &Self, precision: Precision) -> bool { - match (self, other) { - (BufferHandle::Host(b), BufferHandle::Host(b2)) => b.array_eq(b2, precision), - (BufferHandle::Device(b), BufferHandle::Device(b2)) => match precision { - Precision::Ptr => Arc::ptr_eq(b, b2), - Precision::Value => b.eq(b2), - }, - _ => false, - } - } -} diff --git a/vortex-array/src/serde.rs b/vortex-array/src/serde.rs index 854378385ea..77e9554dfe5 100644 --- a/vortex-array/src/serde.rs +++ b/vortex-array/src/serde.rs @@ -465,7 +465,7 @@ impl ArrayParts { segment: BufferHandle, ) -> VortexResult { // TODO: this can also work with device buffers. - let segment = segment.try_to_bytes()?; + let segment = segment.try_to_host()?; // We align each buffer individually, so we remove alignment requirements on the buffer. let segment = segment.aligned(Alignment::none()); @@ -494,7 +494,7 @@ impl ArrayParts { .aligned(Alignment::from_exponent(fb_buf.alignment_exponent())); offset += buffer_len; - BufferHandle::Host(buffer) + BufferHandle::new_host(buffer) }) .collect(); @@ -565,7 +565,7 @@ impl TryFrom for ArrayParts { .aligned(Alignment::from_exponent(fb_buffer.alignment_exponent())); offset += buffer_len; - BufferHandle::Host(buffer) + BufferHandle::new_host(buffer) }) .collect(); @@ -581,6 +581,6 @@ impl TryFrom for ArrayParts { type Error = VortexError; fn try_from(value: BufferHandle) -> Result { - Self::try_from(value.try_to_bytes()?) + Self::try_from(value.try_to_host()?) } } diff --git a/vortex-btrblocks/src/float/stats.rs b/vortex-btrblocks/src/float/stats.rs index 312bc2570e2..af46197bfd3 100644 --- a/vortex-btrblocks/src/float/stats.rs +++ b/vortex-btrblocks/src/float/stats.rs @@ -155,7 +155,7 @@ where let head_idx = validity .first() .vortex_expect("All null masks have been handled before"); - let buff = array.buffer::(); + let buff = array.to_buffer::(); let mut prev = buff[head_idx]; let first_valid_buff = buff.slice(head_idx..array.len()); diff --git a/vortex-btrblocks/src/integer/stats.rs b/vortex-btrblocks/src/integer/stats.rs index c8365a8cef8..0e3aa3d7b68 100644 --- a/vortex-btrblocks/src/integer/stats.rs +++ b/vortex-btrblocks/src/integer/stats.rs @@ -222,7 +222,7 @@ where let head_idx = validity .first() .vortex_expect("All null masks have been handled before"); - let buffer = array.buffer::(); + let buffer = array.to_buffer::(); let head = buffer[head_idx]; let mut loop_state = LoopState { diff --git a/vortex-buffer/src/buffer_mut.rs b/vortex-buffer/src/buffer_mut.rs index 35f785f1e2b..51e6ca2cbb4 100644 --- a/vortex-buffer/src/buffer_mut.rs +++ b/vortex-buffer/src/buffer_mut.rs @@ -436,10 +436,19 @@ impl BufferMut { buf } - /// Return a `BufferMut` with the given alignment. Where possible, this will be zero-copy. + /// Return a `BufferMut` with the same data as this one with the given alignment. + /// + /// If the data is already properly aligned, this is a metadata-only operation. + /// + /// If the data is not aligned, we copy it into a new allocation. pub fn aligned(self, alignment: Alignment) -> Self { if self.as_ptr().align_offset(*alignment) == 0 { - self + Self { + bytes: self.bytes, + length: self.length, + alignment, + _marker: std::marker::PhantomData, + } } else { Self::copy_from_aligned(self, alignment) } diff --git a/vortex-compute/src/take/slice/avx2.rs b/vortex-compute/src/take/slice/avx2.rs index 6826acc37a8..d660a84ade6 100644 --- a/vortex-compute/src/take/slice/avx2.rs +++ b/vortex-compute/src/take/slice/avx2.rs @@ -445,6 +445,12 @@ where // SAFETY: all elements have been initialized. unsafe { buffer.set_len(indices_len) }; + // Reset the buffer alignment to the Value type + // NOTE: if we don't do this, we pass back a Buffer which is over-aligned to the + // SIMD register width. The caller expects that this memory should be aligned to the value + // type so that we can slice it at value boundaries. + buffer = buffer.aligned(Alignment::of::()); + buffer.freeze() } diff --git a/vortex-compute/src/take/slice/portable.rs b/vortex-compute/src/take/slice/portable.rs index 38e5f804cd3..bc24e488870 100644 --- a/vortex-compute/src/take/slice/portable.rs +++ b/vortex-compute/src/take/slice/portable.rs @@ -124,6 +124,11 @@ where buffer.set_len(indices_len); } + // NOTE: if we don't do this, we pass back a Buffer which is over-aligned to the + // SIMD register width. The caller expects that this memory should be aligned to the value + // type so that we can slice it at value boundaries. + buffer = buffer.aligned(Alignment::of::()); + buffer.freeze() } diff --git a/vortex-duckdb/src/exporter/primitive.rs b/vortex-duckdb/src/exporter/primitive.rs index e3132e0d8a1..2beb61cdea7 100644 --- a/vortex-duckdb/src/exporter/primitive.rs +++ b/vortex-duckdb/src/exporter/primitive.rs @@ -22,7 +22,7 @@ struct PrimitiveExporter { pub fn new_exporter(array: PrimitiveArray) -> VortexResult> { match_each_native_ptype!(array.ptype(), |T| { - let buffer = array.buffer::(); + let buffer = array.to_buffer::(); let prim = Box::new(PrimitiveExporter { len: buffer.len(), start: buffer.as_ptr(), diff --git a/vortex-file/src/segments/source.rs b/vortex-file/src/segments/source.rs index 58517505ce1..32cd58f77da 100644 --- a/vortex-file/src/segments/source.rs +++ b/vortex-file/src/segments/source.rs @@ -41,7 +41,7 @@ impl SegmentSource for FileSegmentSource { maybe_fut .ok_or_else(|| vortex_err!("Missing segment: {}", id))? .await - .map(BufferHandle::Host) + .map(BufferHandle::new_host) } .boxed() } diff --git a/vortex-layout/src/segments/cache.rs b/vortex-layout/src/segments/cache.rs index 155ed038ec7..9bcb4db32a3 100644 --- a/vortex-layout/src/segments/cache.rs +++ b/vortex-layout/src/segments/cache.rs @@ -130,11 +130,11 @@ impl SegmentSource for SegmentCacheSourceAdapter { async move { if let Ok(Some(segment)) = cache.get(id).await { tracing::debug!("Resolved segment {} from cache", id); - return Ok(BufferHandle::Host(segment)); + return Ok(BufferHandle::new_host(segment)); } let result = delegate.await?; // Cache only CPU buffers; device buffers are not cached. - if let BufferHandle::Host(ref buffer) = result + if let Some(buffer) = result.as_host_opt() && let Err(e) = cache.put(id, buffer.clone()).await { tracing::warn!("Failed to store segment {} in cache: {}", id, e); diff --git a/vortex-layout/src/segments/shared.rs b/vortex-layout/src/segments/shared.rs index f0a321e035f..c794daf608e 100644 --- a/vortex-layout/src/segments/shared.rs +++ b/vortex-layout/src/segments/shared.rs @@ -111,8 +111,8 @@ mod tests { // Both futures should resolve to the same data let (result1, result2) = futures::join!(future1, future2); - assert_eq!(*result1.unwrap().bytes(), data); - assert_eq!(*result2.unwrap().bytes(), data); + assert_eq!(result1.unwrap().unwrap_host(), data); + assert_eq!(result2.unwrap().unwrap_host(), data); // The inner source should have been called only once assert_eq!(source.request_count.load(Ordering::Relaxed), 1); @@ -142,7 +142,7 @@ mod tests { // A new request should still work correctly let result = shared_source.request(id).await; - assert_eq!(*result.unwrap().bytes(), data); + assert_eq!(result.unwrap().unwrap_host(), data); // Should have made 2 requests since the first was dropped before completion assert_eq!(source.request_count.load(Ordering::Relaxed), 2); diff --git a/vortex-layout/src/segments/test.rs b/vortex-layout/src/segments/test.rs index e22deec994d..d880d15cc1a 100644 --- a/vortex-layout/src/segments/test.rs +++ b/vortex-layout/src/segments/test.rs @@ -30,7 +30,7 @@ impl SegmentSource for TestSegments { let buffer = self.segments.lock().get(*id as usize).cloned(); async move { buffer - .map(BufferHandle::Host) + .map(BufferHandle::new_host) .ok_or_else(|| vortex_err!("Segment not found")) } .boxed() diff --git a/vortex-python/src/serde/parts.rs b/vortex-python/src/serde/parts.rs index b5a27f8a872..ab856ecbf3a 100644 --- a/vortex-python/src/serde/parts.rs +++ b/vortex-python/src/serde/parts.rs @@ -81,7 +81,7 @@ impl PyArrayParts { let mut buffers = Vec::with_capacity(slf.nbuffers()); for buffer in (0..slf.nbuffers()).map(|i| slf.buffer(i)) { - let buffer: ByteBuffer = buffer.map(|b| b.into_bytes())?; + let buffer: ByteBuffer = buffer.and_then(|b| b.try_to_host())?; let addr = buffer.as_ptr() as usize; let size = buffer.len(); diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index 1fa11ac14da..112cae0a494 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -333,7 +333,10 @@ mod test { assert_eq!(recovered_array.len(), array.len()); let recovered_primitive = recovered_array.to_primitive(); assert_eq!(recovered_primitive.validity(), array.validity()); - assert_eq!(recovered_primitive.buffer::(), array.buffer::()); + assert_eq!( + recovered_primitive.to_buffer::(), + array.to_buffer::() + ); std::fs::remove_file(&path)?;