Skip to content

Commit 98d2c6e

Browse files
Handle (partially) dictionary values in ScalarValue serde (#243)
1 parent e35bb28 commit 98d2c6e

File tree

7 files changed

+258
-3
lines changed

7 files changed

+258
-3
lines changed

datafusion/proto/proto/datafusion.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,9 +948,15 @@ message Union{
948948

949949
// Used for List/FixedSizeList/LargeList/Struct
950950
message ScalarNestedValue {
951+
message Dictionary {
952+
bytes ipc_message = 1;
953+
bytes arrow_data = 2;
954+
}
955+
951956
bytes ipc_message = 1;
952957
bytes arrow_data = 2;
953958
Schema schema = 3;
959+
repeated Dictionary dictionaries = 4;
954960
}
955961

956962
message ScalarTime32Value {

datafusion/proto/src/generated/pbjson.rs

Lines changed: 133 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::collections::HashMap;
1819
use std::sync::Arc;
1920

2021
use crate::protobuf::{
@@ -29,6 +30,7 @@ use crate::protobuf::{
2930
OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
3031
};
3132

33+
use arrow::array::ArrayRef;
3234
use arrow::{
3335
array::AsArray,
3436
buffer::Buffer,
@@ -587,6 +589,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
587589
let protobuf::ScalarNestedValue {
588590
ipc_message,
589591
arrow_data,
592+
dictionaries,
590593
schema,
591594
} = &v;
592595

@@ -613,11 +616,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
613616
)
614617
})?;
615618

619+
let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
620+
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
621+
Error::General(format!(
622+
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
623+
))
624+
})?;
625+
let buffer = Buffer::from(arrow_data);
626+
627+
let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| {
628+
Error::General(
629+
"Unexpected message type deserializing ScalarValue::List dictionary message"
630+
.to_string(),
631+
)
632+
})?;
633+
634+
let id = dict_batch.id();
635+
636+
let fields_using_this_dictionary = schema.fields_with_dict_id(id);
637+
let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
638+
Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())
639+
})?;
640+
641+
let values: ArrayRef = match first_field.data_type() {
642+
DataType::Dictionary(_, ref value_type) => {
643+
// Make a fake schema for the dictionary batch.
644+
let value = value_type.as_ref().clone();
645+
let schema = Schema::new(vec![Field::new("", value, true)]);
646+
// Read a single column
647+
let record_batch = read_record_batch(
648+
&buffer,
649+
dict_batch.data().unwrap(),
650+
Arc::new(schema),
651+
&Default::default(),
652+
None,
653+
&message.version(),
654+
)?;
655+
Ok(record_batch.column(0).clone())
656+
}
657+
_ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())),
658+
}?;
659+
660+
Ok((id,values))
661+
}).collect::<Result<HashMap<_,_>>>()?;
662+
616663
let record_batch = read_record_batch(
617664
&buffer,
618665
ipc_batch,
619666
Arc::new(schema),
620-
&Default::default(),
667+
&dict_by_id,
621668
None,
622669
&message.version(),
623670
)

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1604,7 +1604,7 @@ fn encode_scalar_nested_value(
16041604

16051605
let gen = IpcDataGenerator {};
16061606
let mut dict_tracker = DictionaryTracker::new(false);
1607-
let (_, encoded_message) = gen
1607+
let (encoded_dictionaries, encoded_message) = gen
16081608
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
16091609
.map_err(|e| {
16101610
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
@@ -1615,6 +1615,13 @@ fn encode_scalar_nested_value(
16151615
let scalar_list_value = protobuf::ScalarNestedValue {
16161616
ipc_message: encoded_message.ipc_message,
16171617
arrow_data: encoded_message.arrow_data,
1618+
dictionaries: encoded_dictionaries
1619+
.into_iter()
1620+
.map(|data| protobuf::scalar_nested_value::Dictionary {
1621+
ipc_message: data.ipc_message,
1622+
arrow_data: data.arrow_data,
1623+
})
1624+
.collect(),
16181625
schema: Some(schema),
16191626
};
16201627

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,11 +1092,60 @@ fn round_trip_scalar_values() {
10921092
)
10931093
.build()
10941094
.unwrap(),
1095+
ScalarStructBuilder::new()
1096+
.with_scalar(
1097+
Field::new("a", DataType::Int32, true),
1098+
ScalarValue::from(23i32),
1099+
)
1100+
.with_scalar(
1101+
Field::new("b", DataType::Boolean, false),
1102+
ScalarValue::from(false),
1103+
)
1104+
.with_scalar(
1105+
Field::new(
1106+
"c",
1107+
DataType::Dictionary(
1108+
Box::new(DataType::UInt16),
1109+
Box::new(DataType::Utf8),
1110+
),
1111+
false,
1112+
),
1113+
ScalarValue::Dictionary(
1114+
Box::new(DataType::UInt16),
1115+
Box::new("value".into()),
1116+
),
1117+
)
1118+
.build()
1119+
.unwrap(),
10951120
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
10961121
Field::new("a", DataType::Int32, true),
10971122
Field::new("b", DataType::Boolean, false),
10981123
])))
10991124
.unwrap(),
1125+
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
1126+
Field::new("a", DataType::Int32, true),
1127+
Field::new("b", DataType::Boolean, false),
1128+
Field::new(
1129+
"c",
1130+
DataType::Dictionary(
1131+
Box::new(DataType::UInt16),
1132+
Box::new(DataType::Binary),
1133+
),
1134+
false,
1135+
),
1136+
Field::new(
1137+
"d",
1138+
DataType::new_list(
1139+
DataType::Dictionary(
1140+
Box::new(DataType::UInt16),
1141+
Box::new(DataType::Binary),
1142+
),
1143+
false,
1144+
),
1145+
false,
1146+
),
1147+
])))
1148+
.unwrap(),
11001149
ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())),
11011150
ScalarValue::FixedSizeBinary(0, None),
11021151
ScalarValue::FixedSizeBinary(5, None),

0 commit comments

Comments
 (0)