Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions datafusion-examples/examples/struct_cast_reorder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use arrow::array::{Int64Array, RecordBatch, StructArray};
use arrow::datatypes::{DataType, Field, Fields, Schema};
use datafusion::execution::context::SessionContext;
use datafusion::logical_expr::{cast, col};
use std::sync::Arc;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let ctx = SessionContext::new();

// Source: struct with fields [b=3, a=4]
let source_fields = Fields::from(vec![
Field::new("b", DataType::Int64, false),
Field::new("a", DataType::Int64, false),
]);

let source_struct = StructArray::new(
source_fields.clone(),
vec![
Arc::new(Int64Array::from(vec![3i64])), // b = 3
Arc::new(Int64Array::from(vec![4i64])), // a = 4
],
None,
);

let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"s",
DataType::Struct(source_fields),
false,
)])),
vec![Arc::new(source_struct)],
)?;

let table = datafusion::datasource::memory::MemTable::try_new(
batch.schema(),
vec![vec![batch]],
)?;

ctx.register_table("t", Arc::new(table))?;

// Validate source data: should be b=3, a=4
let source_data = ctx.table("t").await?.collect().await?;
use arrow::array::AsArray;
let src_struct = source_data[0].column(0).as_struct();
let src_a = src_struct
.column_by_name("a")
.unwrap()
.as_primitive::<arrow::array::types::Int64Type>()
.value(0);
let src_b = src_struct
.column_by_name("b")
.unwrap()
.as_primitive::<arrow::array::types::Int64Type>()
.value(0);
assert_eq!(src_a, 4, "Source field 'a' should be 4");
assert_eq!(src_b, 3, "Source field 'b' should be 3");
println!("✓ Source validation passed: b={}, a={}", src_b, src_a);

// Target: reorder fields to [a, b]
let target_type = DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int64, false),
]));

// Execute cast
let result = ctx
.table("t")
.await?
.select(vec![cast(col("s"), target_type)])?
.collect()
.await?;

// Validate result
let res_struct = result[0].column(0).as_struct();
let res_a = res_struct
.column_by_name("a")
.unwrap()
.as_primitive::<arrow::array::types::Int64Type>()
.value(0);
let res_b = res_struct
.column_by_name("b")
.unwrap()
.as_primitive::<arrow::array::types::Int64Type>()
.value(0);

if res_a == 4 && res_b == 3 {
println!("✓ Cast result passed: a={}, b={}", res_a, res_b);
} else {
println!(
"✗ Bug: Cast maps by position, not name. Expected a=4,b=3 but got a={}, b={}",
res_a, res_b
);
}

Ok(())
}
26 changes: 26 additions & 0 deletions datafusion/common/src/nested_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,32 @@ pub fn cast_column(
}
}

/// Cast a struct array to another struct type by aligning child arrays using
/// field names instead of their physical order.
///
/// This is a convenience wrapper around [`cast_struct_column`] that accepts
/// `Fields` directly instead of requiring a `Field` wrapper.
///
/// See [`cast_column`] for detailed documentation on the casting behavior.
///
/// # Arguments
/// * `array` - The source array to cast (must be a struct array)
/// * `target_fields` - The target struct field definitions
/// * `cast_options` - Options controlling cast behavior (strictness, formatting)
///
/// # Returns
/// A `Result<ArrayRef>` containing the cast struct array
///
/// # Errors
/// Returns an error if the source is not a struct array or if field casting fails
pub fn cast_struct_array_by_name(
array: &ArrayRef,
target_fields: &arrow::datatypes::Fields,
cast_options: &CastOptions,
) -> Result<ArrayRef> {
cast_struct_column(array, target_fields.as_ref(), cast_options)
}

/// Validates compatibility between source and target struct fields for casting operations.
///
/// This function implements comprehensive struct compatibility checking by examining:
Expand Down
14 changes: 13 additions & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3704,7 +3704,19 @@ impl ScalarValue {
}

let scalar_array = self.to_array()?;
let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?;

// Use name-based struct casting for struct types
let cast_arr = match (scalar_array.data_type(), target_type) {
(DataType::Struct(_), DataType::Struct(target_fields)) => {
crate::nested_struct::cast_struct_array_by_name(
&scalar_array,
target_fields,
cast_options,
)?
}
_ => cast_with_options(&scalar_array, target_type, cast_options)?,
};

ScalarValue::try_from_array(&cast_arr, 0)
}

Expand Down
167 changes: 156 additions & 11 deletions datafusion/expr-common/src/columnar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,23 @@ impl ColumnarValue {
}

/// Cast's this [ColumnarValue] to the specified `DataType`
///
/// # Struct Casting Behavior
///
/// When casting struct types, fields are matched **by name** rather than position:
/// - Source fields are matched to target fields using case-sensitive name comparison
/// - Fields are reordered to match the target schema
/// - Missing target fields are filled with null arrays
/// - Extra source fields are ignored
///
/// # Example
/// ```text
/// Source: {"b": 3, "a": 4} (schema: {b: Int32, a: Int32})
/// Target: {"a": Int32, "b": Int32}
/// Result: {"a": 4, "b": 3} (values matched by field name)
/// ```
///
/// For non-struct types, uses Arrow's standard positional casting.
pub fn cast_to(
&self,
cast_type: &DataType,
Expand All @@ -283,16 +300,44 @@ impl ColumnarValue {
let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS);
match self {
ColumnarValue::Array(array) => {
ensure_date_array_timestamp_bounds(array, cast_type)?;
Ok(ColumnarValue::Array(kernels::cast::cast_with_options(
array,
cast_type,
&cast_options,
)?))
let casted = cast_array_by_name(array, cast_type, &cast_options)?;
Ok(ColumnarValue::Array(casted))
}
ColumnarValue::Scalar(scalar) => {
// For scalars, use ScalarValue's cast which now supports name-based struct casting
Ok(ColumnarValue::Scalar(
scalar.cast_to_with_options(cast_type, &cast_options)?,
))
}
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(
scalar.cast_to_with_options(cast_type, &cast_options)?,
)),
}
}
}

fn cast_array_by_name(
array: &ArrayRef,
cast_type: &DataType,
cast_options: &CastOptions<'static>,
) -> Result<ArrayRef> {
// If types are already equal, no cast needed
if array.data_type() == cast_type {
return Ok(Arc::clone(array));
}

match (array.data_type(), cast_type) {
(DataType::Struct(_source_fields), DataType::Struct(target_fields)) => {
datafusion_common::nested_struct::cast_struct_array_by_name(
array,
target_fields,
cast_options,
)
}
_ => {
ensure_date_array_timestamp_bounds(array, cast_type)?;
Ok(kernels::cast::cast_with_options(
array,
cast_type,
cast_options,
)?)
}
}
}
Expand Down Expand Up @@ -378,8 +423,8 @@ impl fmt::Display for ColumnarValue {
mod tests {
use super::*;
use arrow::{
array::{Date64Array, Int32Array},
datatypes::TimeUnit,
array::{Date64Array, Int32Array, StructArray},
datatypes::{Fields, TimeUnit},
};

#[test]
Expand Down Expand Up @@ -553,6 +598,106 @@ mod tests {
);
}

#[test]
fn cast_struct_by_field_name() {
use arrow::datatypes::Field;

let source_fields = Fields::from(vec![
Field::new("b", DataType::Int32, true),
Field::new("a", DataType::Int32, true),
]);

let target_fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);

let struct_array = StructArray::new(
source_fields,
vec![
Arc::new(Int32Array::from(vec![Some(3)])),
Arc::new(Int32Array::from(vec![Some(4)])),
],
None,
);

let value = ColumnarValue::Array(Arc::new(struct_array));
let casted = value
.cast_to(&DataType::Struct(target_fields.clone()), None)
.expect("struct cast should succeed");

let ColumnarValue::Array(arr) = casted else {
panic!("expected array after cast");
};

let struct_array = arr
.as_any()
.downcast_ref::<StructArray>()
.expect("expected StructArray");

let field_a = struct_array
.column_by_name("a")
.expect("expected field a in cast result");
let field_b = struct_array
.column_by_name("b")
.expect("expected field b in cast result");

assert_eq!(
field_a
.as_any()
.downcast_ref::<Int32Array>()
.expect("expected Int32 array")
.value(0),
4
);
assert_eq!(
field_b
.as_any()
.downcast_ref::<Int32Array>()
.expect("expected Int32 array")
.value(0),
3
);
}

#[test]
fn cast_struct_missing_field_inserts_nulls() {
use arrow::datatypes::Field;

let source_fields = Fields::from(vec![Field::new("a", DataType::Int32, true)]);

let target_fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);

let struct_array = StructArray::new(
source_fields,
vec![Arc::new(Int32Array::from(vec![Some(5)]))],
None,
);

let value = ColumnarValue::Array(Arc::new(struct_array));
let casted = value
.cast_to(&DataType::Struct(target_fields.clone()), None)
.expect("struct cast should succeed");

let ColumnarValue::Array(arr) = casted else {
panic!("expected array after cast");
};

let struct_array = arr
.as_any()
.downcast_ref::<StructArray>()
.expect("expected StructArray");

let field_b = struct_array
.column_by_name("b")
.expect("expected missing field to be added");

assert!(field_b.is_null(0));
}

#[test]
fn cast_date64_array_to_timestamp_overflow() {
let overflow_value = i64::MAX / 1_000_000 + 1;
Expand Down
Loading
Loading