Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions arrow-arith/src/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow_array::cast::AsArray;
use cast::as_primitive_array;
use chrono::{Datelike, TimeZone, Timelike, Utc};

use arrow_array::ree_map;
use arrow_array::temporal_conversions::{
MICROSECONDS, MICROSECONDS_IN_DAY, MILLISECONDS, MILLISECONDS_IN_DAY, NANOSECONDS,
NANOSECONDS_IN_DAY, SECONDS_IN_DAY, date32_to_datetime, date64_to_datetime,
Expand Down Expand Up @@ -249,15 +248,12 @@ pub fn date_part(array: &dyn Array, part: DatePart) -> Result<ArrayRef, ArrowErr
let new_array = array.with_values(values);
Ok(new_array)
}
DataType::RunEndEncoded(k, _) => match k.data_type() {
DataType::Int16 => ree_map!(array, Int16Type, |a| date_part(a, part)),
DataType::Int32 => ree_map!(array, Int32Type, |a| date_part(a, part)),
DataType::Int64 => ree_map!(array, Int64Type, |a| date_part(a, part)),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid run-end type: {:?}",
k.data_type()
))),
},
DataType::RunEndEncoded(_, _) => {
let array = array.as_any_ree();
let values = date_part(array.values(), part)?;
let new_array = array.with_values(values);
Ok(new_array)
}
t => return_compute_error_with!(format!("{part} does not support"), t),
)
}
Expand Down
57 changes: 53 additions & 4 deletions arrow-array/src/array/run_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,20 @@ impl<R: RunEndIndexType> RunArray<R> {
/// assert_eq!(new_run_array.run_ends().values(), &[2, 3, 5]);
/// ```
pub fn with_values(&self, values: ArrayRef) -> Self {
assert_eq!(values.len(), self.values().len());
assert_eq!(values.len(), self.values.len());
let (run_ends_field, values_field) = match &self.data_type {
DataType::RunEndEncoded(r, v) => (r, v),
DataType::RunEndEncoded(r, v) => {
let new_v = Arc::new(Field::new(
v.name(),
values.data_type().clone(),
v.is_nullable(),
));
(r, new_v)
}
_ => unreachable!("RunArray should have type RunEndEncoded"),
};
let data_type =
DataType::RunEndEncoded(Arc::clone(run_ends_field), Arc::clone(values_field));
let data_type = DataType::RunEndEncoded(Arc::clone(run_ends_field), values_field);

Self {
data_type,
run_ends: self.run_ends.clone(),
Expand Down Expand Up @@ -781,6 +788,28 @@ where
RunArrayIter::new(self)
}
}
/// An array that can be downcast to a [`RunArray`] of any run end type and any value type.
///
/// This can be used to efficiently implement kernels for all possible run end
/// types without needing to create specialized implementations for each key type.
pub trait AnyRunEndArray: Array {
/// Returns the values of this array.
fn values(&self) -> &Arc<dyn Array>;

/// Returns a new run-end encoded array with the given values, preserving the
/// existing run ends.
fn with_values(&self, values: ArrayRef) -> ArrayRef;
}

impl<R: RunEndIndexType> AnyRunEndArray for RunArray<R> {
fn values(&self) -> &Arc<dyn Array> {
&self.values
}

fn with_values(&self, values: ArrayRef) -> ArrayRef {
Comment thread
Rich-T-kid marked this conversation as resolved.
Comment thread
Rich-T-kid marked this conversation as resolved.
Arc::new(RunArray::<R>::with_values(self, values))
}
}

#[cfg(test)]
mod tests {
Expand All @@ -789,6 +818,7 @@ mod tests {
use rand::seq::SliceRandom;

use super::*;
use crate::Int64Array;
use crate::builder::PrimitiveRunBuilder;
use crate::cast::AsArray;
use crate::new_empty_array;
Expand Down Expand Up @@ -1055,6 +1085,25 @@ mod tests {
let expected = ArrowError::InvalidArgumentError("The run_ends array length should be the same as values array length. Run_ends array length is 3, values array length is 4".to_string());
assert_eq!(expected.to_string(), actual.err().unwrap().to_string());
}
#[test]
fn test_run_array_with_values_changes_value_type() {
let values = StringArray::from(vec!["foo", "bar", "baz"]);
let run_ends: Int32Array = [Some(1), Some(2), Some(3)].into_iter().collect();
let ree = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let new_values = Int64Array::from(vec![10, 20, 30]);
let result = ree.with_values(Arc::new(new_values));

match result.data_type() {
DataType::RunEndEncoded(_, v) => {
assert_eq!(v.data_type(), &DataType::Int64);
}
other => panic!("expected RunEndEncoded, got {other:?}"),
}

assert_eq!(result.values().data_type(), &DataType::Int64);
assert_eq!(result.values().len(), 3);
}

#[test]
fn test_run_array_run_ends_with_null() {
Expand Down
20 changes: 20 additions & 0 deletions arrow-array/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,14 @@ pub trait AsArray: private::Sealed {
fn as_any_dictionary(&self) -> &dyn AnyDictionaryArray {
self.as_any_dictionary_opt().expect("any dictionary array")
}

/// Downcasts this to a [`AnyRunEndArray`] returning `None` if not possible
fn as_any_ree_opt(&self) -> Option<&dyn AnyRunEndArray>;

/// Downcasts this to a [`AnyRunEndArray`] panicking if not possible
fn as_any_ree(&self) -> &dyn AnyRunEndArray {
self.as_any_ree_opt().expect("any run end array")
}
}

impl private::Sealed for dyn Array + '_ {}
Expand Down Expand Up @@ -1049,6 +1057,14 @@ impl AsArray for dyn Array + '_ {
_ => None
}
}

fn as_any_ree_opt(&self) -> Option<&dyn AnyRunEndArray> {
let array = self;
downcast_run_array! {
array => Some(array),
_ => None
}
}
}

impl private::Sealed for ArrayRef {}
Expand Down Expand Up @@ -1105,6 +1121,10 @@ impl AsArray for ArrayRef {
self.as_ref().as_any_dictionary_opt()
}

fn as_any_ree_opt(&self) -> Option<&dyn AnyRunEndArray> {
self.as_ref().as_any_ree_opt()
}

fn as_run_opt<K: RunEndIndexType>(&self) -> Option<&RunArray<K>> {
self.as_ref().as_run_opt()
}
Expand Down
27 changes: 8 additions & 19 deletions arrow-string/src/length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

//! Defines kernel for length of string arrays and binary arrays

use arrow_array::ree_map;
use arrow_array::*;
use arrow_array::{cast::AsArray, types::*};
use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer};
Expand Down Expand Up @@ -59,6 +58,10 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
let lengths = length(d.values().as_ref())?;
return Ok(d.with_values(lengths));
}
if let Some(ree) = array.as_any_ree_opt() {
let lengths = length(ree.values())?;
return Ok(ree.with_values(lengths));
}
match array.data_type() {
DataType::List(_) => {
let list = array.as_list::<i32>();
Expand Down Expand Up @@ -117,15 +120,6 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
list.nulls().cloned(),
)?))
}
DataType::RunEndEncoded(k, _) => match k.data_type() {
DataType::Int16 => ree_map!(array, Int16Type, length),
DataType::Int32 => ree_map!(array, Int32Type, length),
DataType::Int64 => ree_map!(array, Int64Type, length),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid run-end type: {:?}",
k.data_type()
))),
},
other => Err(ArrowError::ComputeError(format!(
"length not supported for {other:?}"
))),
Expand All @@ -144,6 +138,10 @@ pub fn bit_length(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
let lengths = bit_length(d.values().as_ref())?;
return Ok(d.with_values(lengths));
}
if let Some(ree) = array.as_any_ree_opt() {
let lengths = bit_length(ree.values())?;
return Ok(ree.with_values(lengths));
}

match array.data_type() {
DataType::Utf8 => {
Expand Down Expand Up @@ -190,15 +188,6 @@ pub fn bit_length(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
array.nulls().cloned(),
)?))
}
DataType::RunEndEncoded(k, _) => match k.data_type() {
DataType::Int16 => ree_map!(array, Int16Type, bit_length),
DataType::Int32 => ree_map!(array, Int32Type, bit_length),
DataType::Int64 => ree_map!(array, Int64Type, bit_length),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid run-end type: {:?}",
k.data_type()
))),
},
other => Err(ArrowError::ComputeError(format!(
"bit_length not supported for {other:?}"
))),
Expand Down
Loading