Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 111 additions & 149 deletions datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::hash::Hash;
use std::sync::Arc;

use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
use arrow::datatypes::{DataType, FieldRef, SchemaRef};
use datafusion_common::{
DataFusionError, Result, ScalarValue, exec_err,
metadata::FieldMetadata,
Expand All @@ -34,11 +34,10 @@ use datafusion_common::{
};
use datafusion_functions::core::getfield::GetFieldFunc;
use datafusion_physical_expr::PhysicalExprSimplifier;
use datafusion_physical_expr::expressions::CastColumnExpr;
use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
use datafusion_physical_expr::{
ScalarFunctionExpr,
expressions::{self, Column},
expressions::{self, CastExpr, Column},
};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use itertools::Itertools;
Expand Down Expand Up @@ -423,13 +422,12 @@ impl DefaultPhysicalExprAdapterRewriter {
)));
};

if resolved_column.index() == column.index()
&& logical_field == physical_field.as_ref()
{
return Ok(Transformed::no(expr));
}
let fields_match = logical_field == physical_field.as_ref();
if fields_match {
if resolved_column.index() == column.index() {
return Ok(Transformed::no(expr));
}
Comment on lines -426 to +429
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just to create the named variable fields_match? It otherwise seems equivalent (just changing the order of evaluation).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that’s right. This part is just a readability refactor: fields_match avoids repeating the field comparison and groups the “same field” fast path together. The behavior should be unchanged here: if both field and index match we keep the original expr, and if only the index changed we return the resolved column.


if logical_field == physical_field.as_ref() {
// If the fields match (including metadata/nullability), we can use the column as is
return Ok(Transformed::yes(Arc::new(resolved_column)));
}
Expand All @@ -439,7 +437,25 @@ impl DefaultPhysicalExprAdapterRewriter {
// TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
// since that's much cheaper to evalaute.
// See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
self.create_cast_column_expr(resolved_column, physical_field, logical_field)
validate_data_type_compatibility(
resolved_column.name(),
physical_field.data_type(),
logical_field.data_type(),
)
.map_err(|e| {
DataFusionError::Execution(format!(
"Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}",
resolved_column.name(),
physical_field.data_type(),
logical_field.data_type()
))
})?;

Ok(Transformed::yes(Arc::new(CastExpr::new_with_target_field(
Arc::new(resolved_column),
Arc::new(logical_field.clone()),
None,
))))
}

/// Resolves a logical column to the corresponding physical column and field.
Expand All @@ -465,48 +481,13 @@ impl DefaultPhysicalExprAdapterRewriter {
Column::new_with_schema(column.name(), self.physical_file_schema.as_ref())?
};

Ok(Some((
column,
Arc::new(
self.physical_file_schema
.field(physical_column_index)
.clone(),
),
)))
}

/// Validates type compatibility and creates a CastColumnExpr if needed.
///
/// Checks whether the physical field can be cast to the logical field type,
/// handling both struct and scalar types. Returns a CastColumnExpr with the
/// appropriate configuration.
fn create_cast_column_expr(
&self,
column: Column,
physical_field: FieldRef,
logical_field: &Field,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
validate_data_type_compatibility(
column.name(),
physical_field.data_type(),
logical_field.data_type(),
)
.map_err(|e|
DataFusionError::Execution(format!(
"Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}",
column.name(),
physical_field.data_type(),
logical_field.data_type()
)))?;

let cast_expr = Arc::new(CastColumnExpr::new(
Arc::new(column),
physical_field,
Arc::new(logical_field.clone()),
None,
));
let physical_field = Arc::new(
self.physical_file_schema
.field(physical_column_index)
.clone(),
);

Ok(Transformed::yes(cast_expr))
Ok(Some((column, physical_field)))
}
}

Expand Down Expand Up @@ -652,10 +633,40 @@ mod tests {
Array, BooleanArray, GenericListArray, Int32Array, Int64Array, RecordBatch,
RecordBatchOptions, StringArray, StringViewArray, StructArray,
};
use arrow::datatypes::{Fields, Schema};
use arrow::datatypes::{Field, Fields, Schema};
use datafusion_common::{assert_contains, record_batch};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{Column, Literal, col, lit};
use datafusion_physical_expr::expressions::{Column, Literal, col};

fn assert_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> &CastExpr {
expr.as_any()
.downcast_ref::<CastExpr>()
.expect("Expected CastExpr")
}

fn assert_cast_column(cast_expr: &CastExpr, name: &str, index: usize) {
let inner_col = cast_expr
.expr()
.as_any()
.downcast_ref::<Column>()
.expect("Expected inner Column");
assert_eq!(inner_col.name(), name);
assert_eq!(inner_col.index(), index);
}

fn stale_index_cast_schemas() -> (SchemaRef, SchemaRef) {
let physical_schema = Arc::new(Schema::new(vec![
Field::new("b", DataType::Binary, true),
Field::new("a", DataType::Int32, false),
]));

let logical_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Binary, true),
]));

(logical_schema, physical_schema)
}

fn create_test_schema() -> (Schema, Schema) {
let physical_schema = Schema::new(vec![
Expand Down Expand Up @@ -685,7 +696,7 @@ mod tests {
let result = adapter.rewrite(column_expr).unwrap();

// Should be wrapped in a cast expression
assert!(result.as_any().downcast_ref::<CastColumnExpr>().is_some());
assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
}

#[test]
Expand All @@ -702,24 +713,19 @@ mod tests {
.unwrap();

let result = adapter.rewrite(Arc::new(Column::new("a", 0)))?;
let cast = result
.as_any()
.downcast_ref::<CastColumnExpr>()
.expect("Expected CastColumnExpr");

assert_eq!(cast.target_field().data_type(), &DataType::Int64);
assert!(!cast.target_field().is_nullable());
// Ensure the expression preserves the logical field nullability/metadata.
let return_field = result.return_field(physical_schema.as_ref())?;
assert_eq!(return_field.data_type(), &DataType::Int64);
assert!(!return_field.is_nullable());
assert_eq!(
cast.target_field()
return_field
.metadata()
.get("logical_meta")
.map(String::as_str),
Some("1")
);

// Ensure the expression reports the logical nullability regardless of input schema
assert!(!result.nullable(physical_schema.as_ref())?);

Ok(())
}

Expand Down Expand Up @@ -750,33 +756,35 @@ mod tests {
);

let result = adapter.rewrite(Arc::new(expr)).unwrap();
println!("Rewritten expression: {result}");

let expected = expressions::BinaryExpr::new(
Arc::new(CastColumnExpr::new(
Arc::new(Column::new("a", 0)),
Arc::new(Field::new("a", DataType::Int32, false)),
Arc::new(Field::new("a", DataType::Int64, false)),
None,
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int64(Some(5)))),
);
let expected = Arc::new(expressions::BinaryExpr::new(
Arc::new(expected),
Operator::Or,
Arc::new(expressions::BinaryExpr::new(
lit(ScalarValue::Float64(None)), // c is missing, so it becomes null
Operator::Gt,
Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))),
)),
)) as Arc<dyn PhysicalExpr>;
let outer = result
.as_any()
.downcast_ref::<expressions::BinaryExpr>()
.expect("Expected outer BinaryExpr");
assert_eq!(*outer.op(), Operator::Or);

assert_eq!(
result.to_string(),
expected.to_string(),
"The rewritten expression did not match the expected output"
);
let left = outer
.left()
.as_any()
.downcast_ref::<expressions::BinaryExpr>()
.expect("Expected left BinaryExpr");
assert_eq!(*left.op(), Operator::Plus);

let left_cast = assert_cast_expr(left.left());
assert_eq!(left_cast.target_field().data_type(), &DataType::Int64);
assert_cast_column(left_cast, "a", 0);

let right = outer
.right()
.as_any()
.downcast_ref::<expressions::BinaryExpr>()
.expect("Expected right BinaryExpr");
assert_eq!(*right.op(), Operator::Gt);
let null_literal = right
.left()
.as_any()
.downcast_ref::<Literal>()
.expect("Expected null literal");
assert_eq!(*null_literal.value(), ScalarValue::Float64(None));
}

#[test]
Expand Down Expand Up @@ -841,17 +849,6 @@ mod tests {

let result = adapter.rewrite(column_expr).unwrap();

let physical_struct_fields: Fields = vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into();
let physical_field = Arc::new(Field::new(
"data",
DataType::Struct(physical_struct_fields),
false,
));

let logical_struct_fields: Fields = vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8View, true),
Expand All @@ -863,9 +860,8 @@ mod tests {
false,
));

let expected = Arc::new(CastColumnExpr::new(
let expected = Arc::new(CastExpr::new_with_target_field(
Arc::new(Column::new("data", 0)),
physical_field,
logical_field,
None,
)) as Arc<dyn PhysicalExpr>;
Expand Down Expand Up @@ -1663,8 +1659,7 @@ mod tests {
Field::new("b", DataType::Utf8, true),
]);

let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory
let adapter = DefaultPhysicalExprAdapterFactory
.create(Arc::new(logical_schema), Arc::new(physical_schema))
.unwrap();

Expand All @@ -1673,20 +1668,11 @@ mod tests {

let result = adapter.rewrite(column_expr).unwrap();

// Should be a CastColumnExpr
let cast_expr = result
.as_any()
.downcast_ref::<CastColumnExpr>()
.expect("Expected CastColumnExpr");
// Should be a CastExpr
let cast_expr = assert_cast_expr(&result);

// Verify the inner column points to the correct physical index (1)
let inner_col = cast_expr
.expr()
.as_any()
.downcast_ref::<Column>()
.expect("Expected inner Column");
assert_eq!(inner_col.name(), "a");
assert_eq!(inner_col.index(), 1); // Physical index is 1
assert_cast_column(cast_expr, "a", 1);

// Verify cast types
assert_eq!(
Expand All @@ -1696,41 +1682,17 @@ mod tests {
}

#[test]
fn test_create_cast_column_expr_uses_name_lookup_not_column_index() {
// Physical schema has column `a` at index 1; index 0 is an incompatible type.
let physical_schema = Arc::new(Schema::new(vec![
Field::new("b", DataType::Binary, true),
Field::new("a", DataType::Int32, false),
]));

let logical_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Binary, true),
]));

let rewriter = DefaultPhysicalExprAdapterRewriter {
logical_file_schema: Arc::clone(&logical_schema),
physical_file_schema: Arc::clone(&physical_schema),
};
fn test_rewrite_resolves_physical_column_by_name_before_casting() {
let (logical_schema, physical_schema) = stale_index_cast_schemas();
let adapter = DefaultPhysicalExprAdapterFactory
.create(logical_schema, physical_schema)
.unwrap();

// Deliberately provide the wrong index for column `a`.
// Regression: this must still resolve against physical field `a` by name.
let transformed = rewriter
.create_cast_column_expr(
Column::new("a", 0),
Arc::new(physical_schema.field_with_name("a").unwrap().clone()),
logical_schema.field_with_name("a").unwrap(),
)
.unwrap();

let cast_expr = transformed
.data
.as_any()
.downcast_ref::<CastColumnExpr>()
.expect("Expected CastColumnExpr");

assert_eq!(cast_expr.input_field().name(), "a");
assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32);
let rewritten = adapter.rewrite(Arc::new(Column::new("a", 0))).unwrap();
let cast_expr = assert_cast_expr(&rewritten);
assert_cast_column(cast_expr, "a", 1);
assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64);
}
}
Loading