Skip to content

Commit 22608e3

Browse files
authored
feat!: project struct array fields in arrow conversion (#254)
1 parent 60f078c commit 22608e3

File tree

3 files changed

+83
-6
lines changed

3 files changed

+83
-6
lines changed

narrow-derive/src/struct.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ impl Struct<'_> {
272272
let field_name = self.ident.to_string();
273273
let tokens = if matches!(self.fields, Fields::Unit) {
274274
quote!(impl #impl_generics #narrow::arrow::StructArrayTypeFields for #ident #ty_generics #where_clause {
275+
const NAMES: &'static [&'static str] = &[#field_name];
275276
fn fields() -> ::arrow_schema::Fields {
276277
::arrow_schema::Fields::from([
277278
::std::sync::Arc::new(::arrow_schema::Field::new(#field_name, ::arrow_schema::DataType::Null, true)),
@@ -281,6 +282,7 @@ impl Struct<'_> {
281282
} else {
282283
// Fields
283284
let field_ident = self.field_idents().map(|ident| ident.to_string());
285+
let field_name = field_ident.clone();
284286
let field_ty = self.field_types();
285287
let field_ty_drop = self.field_types_drop_option();
286288
let fields = quote!(
@@ -290,6 +292,11 @@ impl Struct<'_> {
290292
);
291293
quote! {
292294
impl #impl_generics #narrow::arrow::StructArrayTypeFields for #ident #ty_generics #where_clause {
295+
const NAMES: &'static [&'static str] = &[
296+
#(
297+
#field_name,
298+
)*
299+
];
293300
fn fields() -> ::arrow_schema::Fields {
294301
::arrow_schema::Fields::from([
295302
#fields

src/arrow/array/null.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ mod tests {
7272
const INPUT: [(); 4] = [(), (), (), ()];
7373

7474
#[test]
75+
#[cfg(feature = "derive")]
7576
fn derive() {
7677
#[derive(crate::ArrayType, Copy, Clone, Debug, Default)]
7778
struct Unit;

src/arrow/array/struct.rs

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ use crate::{
1616

1717
/// Arrow schema interop trait for the fields of a struct array type.
1818
pub trait StructArrayTypeFields {
19+
/// The names of the fields.
20+
const NAMES: &'static [&'static str];
21+
1922
/// Returns the fields of this struct array.
2023
fn fields() -> Fields;
2124
}
@@ -102,26 +105,48 @@ where
102105
impl<T: StructArrayType, Buffer: BufferType> From<arrow_array::StructArray>
103106
for StructArray<T, false, Buffer>
104107
where
105-
<T as StructArrayType>::Array<Buffer>: From<Vec<Arc<dyn arrow_array::Array>>>,
108+
<T as StructArrayType>::Array<Buffer>:
109+
From<Vec<Arc<dyn arrow_array::Array>>> + StructArrayTypeFields,
106110
{
107111
fn from(value: arrow_array::StructArray) -> Self {
108-
let (_fields, arrays, nulls_opt) = value.into_parts();
112+
let (fields, arrays, nulls_opt) = value.into_parts();
113+
// Project
114+
let projected = <<T as StructArrayType>::Array<Buffer> as StructArrayTypeFields>::NAMES
115+
.iter()
116+
.map(|field| {
117+
fields
118+
.find(field)
119+
.unwrap_or_else(|| panic!("expected struct array with field: {field}"))
120+
})
121+
.map(|(idx, _)| Arc::clone(&arrays[idx]))
122+
.collect::<Vec<_>>();
109123
match nulls_opt {
110124
Some(_) => panic!("expected array without a null buffer"),
111-
None => StructArray(arrays.into()),
125+
None => StructArray(projected.into()),
112126
}
113127
}
114128
}
115129

116130
impl<T: StructArrayType, Buffer: BufferType> From<arrow_array::StructArray>
117131
for StructArray<T, true, Buffer>
118132
where
119-
<T as StructArrayType>::Array<Buffer>: From<Vec<Arc<dyn arrow_array::Array>>> + Length,
133+
<T as StructArrayType>::Array<Buffer>:
134+
From<Vec<Arc<dyn arrow_array::Array>>> + Length + StructArrayTypeFields,
120135
Bitmap<Buffer>: From<NullBuffer> + FromIterator<bool>,
121136
{
122137
fn from(value: arrow_array::StructArray) -> Self {
123-
let (_fields, arrays, nulls_opt) = value.into_parts();
124-
let data = arrays.into();
138+
let (fields, arrays, nulls_opt) = value.into_parts();
139+
// Project
140+
let projected = <<T as StructArrayType>::Array<Buffer> as StructArrayTypeFields>::NAMES
141+
.iter()
142+
.map(|field| {
143+
fields
144+
.find(field)
145+
.unwrap_or_else(|| panic!("expected struct array with field: {field}"))
146+
})
147+
.map(|(idx, _)| Arc::clone(&arrays[idx]))
148+
.collect::<Vec<_>>();
149+
let data = projected.into();
125150
match nulls_opt {
126151
Some(null_buffer) => StructArray(Nullable {
127152
data,
@@ -264,6 +289,7 @@ mod tests {
264289
type Array<Buffer: BufferType> = FooArray<Buffer>;
265290
}
266291
impl<Buffer: BufferType> StructArrayTypeFields for FooArray<Buffer> {
292+
const NAMES: &'static [&'static str] = &["a"];
267293
fn fields() -> Fields {
268294
Fields::from(vec![Field::new("a", DataType::UInt32, false)])
269295
}
@@ -437,4 +463,47 @@ mod tests {
437463
)))
438464
);
439465
}
466+
467+
#[test]
468+
#[should_panic(expected = "expected struct array with field: c")]
469+
#[cfg(feature = "derive")]
470+
fn projected() {
471+
#[derive(narrow_derive::ArrayType)]
472+
struct Foo {
473+
a: u32,
474+
b: bool,
475+
c: u64,
476+
}
477+
478+
#[derive(narrow_derive::ArrayType, Debug, PartialEq)]
479+
struct Bar {
480+
b: bool,
481+
a: u32,
482+
}
483+
484+
let foo_array = [
485+
Foo {
486+
a: 1,
487+
b: false,
488+
c: 2,
489+
},
490+
Foo {
491+
a: 2,
492+
b: true,
493+
c: 3,
494+
},
495+
]
496+
.into_iter()
497+
.collect::<StructArray<Foo>>();
498+
499+
let arrow_array = arrow_array::StructArray::from(foo_array);
500+
let bar_array = StructArray::<Bar>::from(arrow_array);
501+
assert_eq!(
502+
bar_array.clone().into_iter().collect::<Vec<_>>(),
503+
[Bar { b: false, a: 1 }, Bar { b: true, a: 2 }]
504+
);
505+
506+
let bar_arrow_array = arrow_array::StructArray::from(bar_array);
507+
let _ = StructArray::<Foo>::from(bar_arrow_array);
508+
}
440509
}

0 commit comments

Comments
 (0)