@@ -28,6 +28,7 @@ use std::sync::Arc;
2828use std:: task:: { Context , Poll } ;
2929
3030use crate :: error:: { DataFusionError , Result } ;
31+ use crate :: logical_plan:: { Subquery , SubqueryType } ;
3132use crate :: physical_plan:: { DisplayFormatType , ExecutionPlan , Partitioning } ;
3233use arrow:: array:: new_null_array;
3334use arrow:: datatypes:: { Schema , SchemaRef } ;
@@ -46,7 +47,7 @@ use futures::stream::StreamExt;
4647#[ derive( Debug ) ]
4748pub struct SubqueryExec {
4849 /// Sub queries
49- subqueries : Vec < Arc < dyn ExecutionPlan > > ,
50+ subqueries : Vec < ( Arc < dyn ExecutionPlan > , SubqueryType ) > ,
5051 /// Merged schema
5152 schema : SchemaRef ,
5253 /// The input plan
@@ -58,15 +59,22 @@ pub struct SubqueryExec {
5859impl SubqueryExec {
5960 /// Create a projection on an input
6061 pub fn try_new (
61- subqueries : Vec < Arc < dyn ExecutionPlan > > ,
62+ subqueries : Vec < ( Arc < dyn ExecutionPlan > , SubqueryType ) > ,
6263 input : Arc < dyn ExecutionPlan > ,
6364 cursor : Arc < OuterQueryCursor > ,
6465 ) -> Result < Self > {
6566 let input_schema = input. schema ( ) ;
6667
6768 let mut total_fields = input_schema. fields ( ) . clone ( ) ;
68- for q in subqueries. iter ( ) {
69- total_fields. append ( & mut q. schema ( ) . fields ( ) . clone ( ) ) ;
69+ for ( q, t) in subqueries. iter ( ) {
70+ total_fields. append (
71+ & mut q
72+ . schema ( )
73+ . fields ( )
74+ . iter ( )
75+ . map ( |f| Subquery :: transform_field ( f, * t) )
76+ . collect ( ) ,
77+ ) ;
7078 }
7179
7280 let merged_schema = Schema :: new_with_metadata ( total_fields, HashMap :: new ( ) ) ;
@@ -100,7 +108,7 @@ impl ExecutionPlan for SubqueryExec {
100108
101109 fn children ( & self ) -> Vec < Arc < dyn ExecutionPlan > > {
102110 let mut res = vec ! [ self . input. clone( ) ] ;
103- res. extend ( self . subqueries . iter ( ) . cloned ( ) ) ;
111+ res. extend ( self . subqueries . iter ( ) . map ( | ( i , _ ) | i ) . cloned ( ) ) ;
104112 res
105113 }
106114
@@ -134,7 +142,13 @@ impl ExecutionPlan for SubqueryExec {
134142 }
135143
136144 Ok ( Arc :: new ( SubqueryExec :: try_new (
137- children. iter ( ) . skip ( 1 ) . cloned ( ) . collect ( ) ,
145+ children
146+ . iter ( )
147+ . skip ( 1 )
148+ . cloned ( )
149+ . zip ( self . subqueries . iter ( ) )
150+ . map ( |( p, ( _, t) ) | ( p, * t) )
151+ . collect ( ) ,
138152 children[ 0 ] . clone ( ) ,
139153 self . cursor . clone ( ) ,
140154 ) ?) )
@@ -151,71 +165,78 @@ impl ExecutionPlan for SubqueryExec {
151165 let context = context. clone ( ) ;
152166 let size_hint = stream. size_hint ( ) ;
153167 let schema = self . schema . clone ( ) ;
154- let res_stream =
155- stream. then ( move |batch| {
156- let cursor = cursor. clone ( ) ;
157- let context = context. clone ( ) ;
158- let subqueries = subqueries. clone ( ) ;
159- let schema = schema. clone ( ) ;
160- async move {
161- let batch = batch?;
162- let b = Arc :: new ( batch. clone ( ) ) ;
163- cursor. set_batch ( b) ?;
164- let mut subquery_arrays = vec ! [ Vec :: new( ) ; subqueries. len( ) ] ;
165- for i in 0 ..batch. num_rows ( ) {
166- cursor. set_position ( i) ?;
167- for ( subquery_i, subquery) in subqueries. iter ( ) . enumerate ( ) {
168- let null_array = || {
169- let schema = subquery. schema ( ) ;
170- let fields = schema. fields ( ) ;
171- if fields. len ( ) != 1 {
172- return Err ( ArrowError :: ComputeError ( format ! (
173- "Sub query should have only one column but got {}" ,
174- fields. len( )
175- ) ) ) ;
176- }
177-
178- let data_type = fields. get ( 0 ) . unwrap ( ) . data_type ( ) ;
179- Ok ( new_null_array ( data_type, 1 ) )
180- } ;
168+ let res_stream = stream. then ( move |batch| {
169+ let cursor = cursor. clone ( ) ;
170+ let context = context. clone ( ) ;
171+ let subqueries = subqueries. clone ( ) ;
172+ let schema = schema. clone ( ) ;
173+ async move {
174+ let batch = batch?;
175+ let b = Arc :: new ( batch. clone ( ) ) ;
176+ cursor. set_batch ( b) ?;
177+ let mut subquery_arrays = vec ! [ Vec :: new( ) ; subqueries. len( ) ] ;
178+ for i in 0 ..batch. num_rows ( ) {
179+ cursor. set_position ( i) ?;
180+ for ( subquery_i, ( subquery, subquery_type) ) in
181+ subqueries. iter ( ) . enumerate ( )
182+ {
183+ let schema = subquery. schema ( ) ;
184+ let fields = schema. fields ( ) ;
185+ if fields. len ( ) != 1 {
186+ return Err ( ArrowError :: ComputeError ( format ! (
187+ "Sub query should have only one column but got {}" ,
188+ fields. len( )
189+ ) ) ) ;
190+ }
191+ let data_type = fields. get ( 0 ) . unwrap ( ) . data_type ( ) ;
192+ let null_array = || new_null_array ( data_type, 1 ) ;
181193
182- if subquery. output_partitioning ( ) . partition_count ( ) != 1 {
183- return Err ( ArrowError :: ComputeError ( format ! (
184- "Sub query should have only one partition but got {}" ,
185- subquery. output_partitioning( ) . partition_count( )
186- ) ) ) ;
187- }
188- let mut stream = subquery. execute ( 0 , context. clone ( ) ) . await ?;
189- let res = stream. next ( ) . await ;
190- if let Some ( subquery_batch) = res {
191- let subquery_batch = subquery_batch?;
192- match subquery_batch. column ( 0 ) . len ( ) {
193- 0 => subquery_arrays[ subquery_i] . push ( null_array ( ) ?) ,
194+ if subquery. output_partitioning ( ) . partition_count ( ) != 1 {
195+ return Err ( ArrowError :: ComputeError ( format ! (
196+ "Sub query should have only one partition but got {}" ,
197+ subquery. output_partitioning( ) . partition_count( )
198+ ) ) ) ;
199+ }
200+ let mut stream = subquery. execute ( 0 , context. clone ( ) ) . await ?;
201+ let res = stream. next ( ) . await ;
202+ if let Some ( subquery_batch) = res {
203+ let subquery_batch = subquery_batch?;
204+ match subquery_type {
205+ SubqueryType :: Scalar => match subquery_batch
206+ . column ( 0 )
207+ . len ( )
208+ {
209+ 0 => subquery_arrays[ subquery_i] . push ( null_array ( ) ) ,
194210 1 => subquery_arrays[ subquery_i]
195211 . push ( subquery_batch. column ( 0 ) . clone ( ) ) ,
196212 _ => return Err ( ArrowError :: ComputeError (
197213 "Sub query should return no more than one row"
198214 . to_string ( ) ,
199215 ) ) ,
200- } ;
201- } else {
202- subquery_arrays[ subquery_i] . push ( null_array ( ) ?) ;
203- }
216+ } ,
217+ } ;
218+ } else {
219+ match subquery_type {
220+ SubqueryType :: Scalar => {
221+ subquery_arrays[ subquery_i] . push ( null_array ( ) )
222+ }
223+ } ;
204224 }
205225 }
206- let mut new_columns = batch. columns ( ) . to_vec ( ) ;
207- for subquery_array in subquery_arrays {
208- new_columns. push ( concat (
209- subquery_array
210- . iter ( )
211- . map ( |a| a. as_ref ( ) )
212- . collect :: < Vec < _ > > ( )
213- . as_slice ( ) ,
214- ) ?) ;
215- }
216- RecordBatch :: try_new ( schema. clone ( ) , new_columns)
217226 }
218- } ) ;
227+ let mut new_columns = batch. columns ( ) . to_vec ( ) ;
228+ for subquery_array in subquery_arrays {
229+ new_columns. push ( concat (
230+ subquery_array
231+ . iter ( )
232+ . map ( |a| a. as_ref ( ) )
233+ . collect :: < Vec < _ > > ( )
234+ . as_slice ( ) ,
235+ ) ?) ;
236+ }
237+ RecordBatch :: try_new ( schema. clone ( ) , new_columns)
238+ }
239+ } ) ;
219240 Ok ( Box :: pin ( SubQueryStream {
220241 schema : self . schema . clone ( ) ,
221242 stream : Box :: pin ( res_stream) ,
0 commit comments