@@ -16,6 +16,9 @@ use crate::{
16
16
17
17
/// Arrow schema interop trait for the fields of a struct array type.
18
18
pub trait StructArrayTypeFields {
19
+ /// The names of the fields.
20
+ const NAMES : & ' static [ & ' static str ] ;
21
+
19
22
/// Returns the fields of this struct array.
20
23
fn fields ( ) -> Fields ;
21
24
}
@@ -102,26 +105,48 @@ where
102
105
impl < T : StructArrayType , Buffer : BufferType > From < arrow_array:: StructArray >
103
106
for StructArray < T , false , Buffer >
104
107
where
105
- <T as StructArrayType >:: Array < Buffer > : From < Vec < Arc < dyn arrow_array:: Array > > > ,
108
+ <T as StructArrayType >:: Array < Buffer > :
109
+ From < Vec < Arc < dyn arrow_array:: Array > > > + StructArrayTypeFields ,
106
110
{
107
111
fn from ( value : arrow_array:: StructArray ) -> Self {
108
- let ( _fields, arrays, nulls_opt) = value. into_parts ( ) ;
112
+ let ( fields, arrays, nulls_opt) = value. into_parts ( ) ;
113
+ // Project
114
+ let projected = <<T as StructArrayType >:: Array < Buffer > as StructArrayTypeFields >:: NAMES
115
+ . iter ( )
116
+ . map ( |field| {
117
+ fields
118
+ . find ( field)
119
+ . unwrap_or_else ( || panic ! ( "expected struct array with field: {field}" ) )
120
+ } )
121
+ . map ( |( idx, _) | Arc :: clone ( & arrays[ idx] ) )
122
+ . collect :: < Vec < _ > > ( ) ;
109
123
match nulls_opt {
110
124
Some ( _) => panic ! ( "expected array without a null buffer" ) ,
111
- None => StructArray ( arrays . into ( ) ) ,
125
+ None => StructArray ( projected . into ( ) ) ,
112
126
}
113
127
}
114
128
}
115
129
116
130
impl < T : StructArrayType , Buffer : BufferType > From < arrow_array:: StructArray >
117
131
for StructArray < T , true , Buffer >
118
132
where
119
- <T as StructArrayType >:: Array < Buffer > : From < Vec < Arc < dyn arrow_array:: Array > > > + Length ,
133
+ <T as StructArrayType >:: Array < Buffer > :
134
+ From < Vec < Arc < dyn arrow_array:: Array > > > + Length + StructArrayTypeFields ,
120
135
Bitmap < Buffer > : From < NullBuffer > + FromIterator < bool > ,
121
136
{
122
137
fn from ( value : arrow_array:: StructArray ) -> Self {
123
- let ( _fields, arrays, nulls_opt) = value. into_parts ( ) ;
124
- let data = arrays. into ( ) ;
138
+ let ( fields, arrays, nulls_opt) = value. into_parts ( ) ;
139
+ // Project
140
+ let projected = <<T as StructArrayType >:: Array < Buffer > as StructArrayTypeFields >:: NAMES
141
+ . iter ( )
142
+ . map ( |field| {
143
+ fields
144
+ . find ( field)
145
+ . unwrap_or_else ( || panic ! ( "expected struct array with field: {field}" ) )
146
+ } )
147
+ . map ( |( idx, _) | Arc :: clone ( & arrays[ idx] ) )
148
+ . collect :: < Vec < _ > > ( ) ;
149
+ let data = projected. into ( ) ;
125
150
match nulls_opt {
126
151
Some ( null_buffer) => StructArray ( Nullable {
127
152
data,
@@ -264,6 +289,7 @@ mod tests {
264
289
type Array < Buffer : BufferType > = FooArray < Buffer > ;
265
290
}
266
291
impl < Buffer : BufferType > StructArrayTypeFields for FooArray < Buffer > {
292
+ const NAMES : & ' static [ & ' static str ] = & [ "a" ] ;
267
293
fn fields ( ) -> Fields {
268
294
Fields :: from ( vec ! [ Field :: new( "a" , DataType :: UInt32 , false ) ] )
269
295
}
@@ -437,4 +463,47 @@ mod tests {
437
463
) ) )
438
464
) ;
439
465
}
466
+
467
+ #[ test]
468
+ #[ should_panic( expected = "expected struct array with field: c" ) ]
469
+ #[ cfg( feature = "derive" ) ]
470
+ fn projected ( ) {
471
+ #[ derive( narrow_derive:: ArrayType ) ]
472
+ struct Foo {
473
+ a : u32 ,
474
+ b : bool ,
475
+ c : u64 ,
476
+ }
477
+
478
+ #[ derive( narrow_derive:: ArrayType , Debug , PartialEq ) ]
479
+ struct Bar {
480
+ b : bool ,
481
+ a : u32 ,
482
+ }
483
+
484
+ let foo_array = [
485
+ Foo {
486
+ a : 1 ,
487
+ b : false ,
488
+ c : 2 ,
489
+ } ,
490
+ Foo {
491
+ a : 2 ,
492
+ b : true ,
493
+ c : 3 ,
494
+ } ,
495
+ ]
496
+ . into_iter ( )
497
+ . collect :: < StructArray < Foo > > ( ) ;
498
+
499
+ let arrow_array = arrow_array:: StructArray :: from ( foo_array) ;
500
+ let bar_array = StructArray :: < Bar > :: from ( arrow_array) ;
501
+ assert_eq ! (
502
+ bar_array. clone( ) . into_iter( ) . collect:: <Vec <_>>( ) ,
503
+ [ Bar { b: false , a: 1 } , Bar { b: true , a: 2 } ]
504
+ ) ;
505
+
506
+ let bar_arrow_array = arrow_array:: StructArray :: from ( bar_array) ;
507
+ let _ = StructArray :: < Foo > :: from ( bar_arrow_array) ;
508
+ }
440
509
}
0 commit comments