Skip to content

Commit

Permalink
Support Union types in ScalarValue (apache#9683)
Browse files Browse the repository at this point in the history
  • Loading branch information
avantgardnerio committed Mar 27, 2024
1 parent fce8bf8 commit d88e414
Show file tree
Hide file tree
Showing 7 changed files with 503 additions and 12 deletions.
104 changes: 93 additions & 11 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ use arrow::{
use arrow_array::cast::as_list_array;
use arrow_array::types::ArrowTimestampType;
use arrow_array::{ArrowNativeTypeOp, Scalar};
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};

/// A dynamically typed, nullable single value, (the single-valued counter-part
/// to arrow's [`Array`])
Expand Down Expand Up @@ -187,6 +189,11 @@ pub enum ScalarValue {
DurationNanosecond(Option<i64>),
/// struct of nested ScalarValue
Struct(Option<Vec<ScalarValue>>, Fields),
/// A nested datatype that can represent slots of differing types. Components:
/// `.0`: a tuple of union `type_id` and the single value held by this Scalar
/// `.1`: the list of fields, zero-to-one of which will by set in `.0`
/// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came
Union(Option<(i8, Box<ScalarValue>)>, UnionFields, UnionMode),
/// Dictionary type: index type and value
Dictionary(Box<DataType>, Box<ScalarValue>),
}
Expand Down Expand Up @@ -287,6 +294,10 @@ impl PartialEq for ScalarValue {
(IntervalMonthDayNano(_), _) => false,
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
(Struct(_, _), _) => false,
(Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => {
val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2)
}
(Union(_, _, _), _) => false,
(Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2),
(Dictionary(_, _), _) => false,
(Null, Null) => true,
Expand Down Expand Up @@ -448,6 +459,14 @@ impl PartialOrd for ScalarValue {
}
}
(Struct(_, _), _) => None,
(Union(v1, t1, m1), Union(v2, t2, m2)) => {
if t1.eq(t2) && m1.eq(m2) {
v1.partial_cmp(v2)
} else {
None
}
}
(Union(_, _, _), _) => None,
(Dictionary(k1, v1), Dictionary(k2, v2)) => {
// Don't compare if the key types don't match (it is effectively a different datatype)
if k1 == k2 {
Expand Down Expand Up @@ -546,6 +565,11 @@ impl std::hash::Hash for ScalarValue {
v.hash(state);
t.hash(state);
}
Union(v, t, m) => {
v.hash(state);
t.hash(state);
m.hash(state);
}
Dictionary(k, v) => {
k.hash(state);
v.hash(state);
Expand Down Expand Up @@ -968,6 +992,7 @@ impl ScalarValue {
DataType::Duration(TimeUnit::Nanosecond)
}
ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()),
ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode),
ScalarValue::Dictionary(k, v) => {
DataType::Dictionary(k.clone(), Box::new(v.data_type()))
}
Expand Down Expand Up @@ -1167,6 +1192,7 @@ impl ScalarValue {
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
ScalarValue::Struct(v, _) => v.is_none(),
ScalarValue::Union(v, _, _) => v.is_none(),
ScalarValue::Dictionary(_, v) => v.is_null(),
}
}
Expand Down Expand Up @@ -1992,6 +2018,39 @@ impl ScalarValue {
new_null_array(&dt, size)
}
},
ScalarValue::Union(value, fields, _mode) => match value {
Some((v_id, value)) => {
let mut field_type_ids = Vec::<i8>::with_capacity(fields.len());
let mut child_arrays =
Vec::<(Field, ArrayRef)>::with_capacity(fields.len());
for (f_id, field) in fields.iter() {
let ar = if f_id == *v_id {
value.to_array_of_size(size)?
} else {
let dt = field.data_type();
new_null_array(dt, size)
};
let field = (**field).clone();
child_arrays.push((field, ar));
field_type_ids.push(f_id);
}
let type_ids = repeat(*v_id).take(size).collect::<Vec<_>>();
let type_ids = Buffer::from_slice_ref(type_ids);
let value_offsets: Option<Buffer> = None;
let ar = UnionArray::try_new(
field_type_ids.as_slice(),
type_ids,
value_offsets,
child_arrays,
)
.map_err(|e| DataFusionError::ArrowError(e))?;
Arc::new(ar)
}
None => {
let dt = self.data_type();
new_null_array(&dt, size)
}
},
ScalarValue::Dictionary(key_type, v) => {
// values array is one element long (the value)
match key_type.as_ref() {
Expand Down Expand Up @@ -2492,6 +2551,9 @@ impl ScalarValue {
ScalarValue::Struct(_, _) => {
return _not_impl_err!("Struct is not supported yet")
}
ScalarValue::Union(_, _, _) => {
return _not_impl_err!("Union is not supported yet")
}
ScalarValue::Dictionary(key_type, v) => {
let (values_array, values_index) = match key_type.as_ref() {
DataType::Int8 => get_dict_value::<Int8Type>(array, index)?,
Expand Down Expand Up @@ -2560,22 +2622,31 @@ impl ScalarValue {
| ScalarValue::LargeBinary(b) => {
b.as_ref().map(|b| b.capacity()).unwrap_or_default()
}
ScalarValue::List(arr)
| ScalarValue::LargeList(arr)
| ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
ScalarValue::Struct(vals, fields) => {
ScalarValue::List(arr)
| ScalarValue::LargeList(arr)
| ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
ScalarValue::Struct(vals, fields) => {
vals.as_ref()
.map(|vals| {
vals.iter()
.map(|sv| sv.size() - std::mem::size_of_val(sv))
.sum::<usize>()
+ (std::mem::size_of::<ScalarValue>() * vals.capacity())
})
.unwrap_or_default()
// `fields` is boxed, so it is NOT already included in `self`
+ std::mem::size_of_val(fields)
+ (std::mem::size_of::<Field>() * fields.len())
+ fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::<usize>()
}
ScalarValue::Union(vals, fields, _mode) => {
vals.as_ref()
.map(|vals| {
vals.iter()
.map(|sv| sv.size() - std::mem::size_of_val(sv))
.sum::<usize>()
+ (std::mem::size_of::<ScalarValue>() * vals.capacity())
})
.map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv))
.unwrap_or_default()
// `fields` is boxed, so it is NOT already included in `self`
+ std::mem::size_of_val(fields)
+ (std::mem::size_of::<Field>() * fields.len())
+ fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::<usize>()
+ fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::<usize>()
}
ScalarValue::Dictionary(dt, sv) => {
// `dt` and `sv` are boxed, so they are NOT already included in `self`
Expand Down Expand Up @@ -2873,6 +2944,9 @@ impl TryFrom<&DataType> for ScalarValue {
1,
)),
DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()),
DataType::Union(fields, mode) => {
ScalarValue::Union(None, fields.clone(), *mode)
}
DataType::Null => ScalarValue::Null,
_ => {
return _not_impl_err!(
Expand Down Expand Up @@ -2971,6 +3045,10 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "{}:{}", id, val)?,
None => write!(f, "NULL")?,
},
ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?,
ScalarValue::Null => write!(f, "NULL")?,
};
Expand Down Expand Up @@ -3069,6 +3147,10 @@ impl fmt::Debug for ScalarValue {
None => write!(f, "Struct(NULL)"),
}
}
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "Union {}:{}", id, val),
None => write!(f, "Union(NULL)"),
},
ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"),
ScalarValue::Null => write!(f, "NULL"),
}
Expand Down
35 changes: 35 additions & 0 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ mod tests {
use crate::test::exec::StatisticsExec;
use crate::ExecutionPlan;

use crate::empty::EmptyExec;
use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::{UnionFields, UnionMode};
use datafusion_common::{ColumnStatistics, ScalarValue};
use datafusion_expr::Operator;

Expand Down Expand Up @@ -1057,4 +1059,37 @@ mod tests {
assert_eq!(statistics.total_byte_size, Precision::Inexact(1600));
Ok(())
}

#[test]
fn test_equivalence_properties_union_type() -> Result<()> {
let union_type = DataType::Union(
UnionFields::new(
vec![0, 1],
vec![
Field::new("f1", DataType::Int32, true),
Field::new("f2", DataType::Utf8, true),
],
),
UnionMode::Sparse,
);

let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", union_type, true),
]));

let exec = FilterExec::try_new(
binary(
binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?,
Operator::And,
binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?,
&schema,
)?,
Arc::new(EmptyExec::new(schema.clone())),
)?;

exec.statistics().unwrap();

Ok(())
}
}
15 changes: 15 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,20 @@ message StructValue {
repeated Field fields = 3;
}

message UnionField {
int32 field_id = 1;
Field field = 2;
}

message UnionValue {
// Note that a null union value must have one or more fields, so we
// encode a null UnionValue as one with value_id == 128
int32 value_id = 1;
ScalarValue value = 2;
repeated UnionField fields = 3;
UnionMode mode = 4;
}

message ScalarFixedSizeBinary{
bytes values = 1;
int32 length = 2;
Expand Down Expand Up @@ -1015,6 +1029,7 @@ message ScalarValue{
IntervalMonthDayNanoValue interval_month_day_nano = 31;
StructValue struct_value = 32;
ScalarFixedSizeBinary fixed_size_binary_value = 34;
UnionValue union_value = 42;
}
}

Expand Down
Loading

0 comments on commit d88e414

Please sign in to comment.