@@ -20,6 +20,7 @@ use std::sync::Arc;
20
20
21
21
use crate :: sql:: db_connection_pool:: {
22
22
dbconnection:: { self , AsyncDbConnection , DbConnection , GenericError } ,
23
+ runtime:: run_async_with_tokio,
23
24
DbConnectionPool ,
24
25
} ;
25
26
use arrow_odbc:: arrow_schema_from;
@@ -42,7 +43,7 @@ use odbc_api::handles::StatementImpl;
42
43
use odbc_api:: parameter:: InputParameter ;
43
44
use odbc_api:: Cursor ;
44
45
use odbc_api:: CursorImpl ;
45
- use secrecy:: { SecretBox , ExposeSecret , SecretString } ;
46
+ use secrecy:: { ExposeSecret , SecretBox , SecretString } ;
46
47
use snafu:: prelude:: * ;
47
48
use snafu:: Snafu ;
48
49
use tokio:: runtime:: Handle ;
@@ -184,69 +185,71 @@ where
184
185
let params = params. iter ( ) . map ( dyn_clone:: clone) . collect :: < Vec < _ > > ( ) ;
185
186
let secrets = Arc :: clone ( & self . params ) ;
186
187
187
- let join_handle = tokio:: task:: spawn_blocking ( move || {
188
- let handle = Handle :: current ( ) ;
189
- let cxn = handle. block_on ( async { conn. lock ( ) . await } ) ;
188
+ let create_stream = async || -> Result < SendableRecordBatchStream > {
189
+ let join_handle = tokio:: task:: spawn_blocking ( move || {
190
+ let handle = Handle :: current ( ) ;
191
+ let cxn = handle. block_on ( async { conn. lock ( ) . await } ) ;
190
192
191
- let mut prepared = cxn. prepare ( & sql) ?;
192
- let schema = Arc :: new ( arrow_schema_from ( & mut prepared, false ) ?) ;
193
- blocking_channel_send ( & schema_tx, Arc :: clone ( & schema) ) ?;
193
+ let mut prepared = cxn. prepare ( & sql) ?;
194
+ let schema = Arc :: new ( arrow_schema_from ( & mut prepared, false ) ?) ;
195
+ blocking_channel_send ( & schema_tx, Arc :: clone ( & schema) ) ?;
194
196
195
- let mut statement = prepared. into_statement ( ) ;
197
+ let mut statement = prepared. into_statement ( ) ;
196
198
197
- bind_parameters ( & mut statement, & params) ?;
199
+ bind_parameters ( & mut statement, & params) ?;
198
200
199
- // StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
200
- let cursor = unsafe {
201
- if let SqlResult :: Error { function } = statement. execute ( ) {
202
- return Err ( Error :: ODBCAPIErrorNoSource {
203
- message : function. to_string ( ) ,
201
+ // StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
202
+ let cursor = unsafe {
203
+ if let SqlResult :: Error { function } = statement. execute ( ) {
204
+ return Err ( Error :: ODBCAPIErrorNoSource {
205
+ message : function. to_string ( ) ,
206
+ }
207
+ . into ( ) ) ;
204
208
}
205
- . into ( ) ) ;
206
- }
207
209
208
- Ok :: < _ , GenericError > ( CursorImpl :: new ( statement. as_stmt_ref ( ) ) )
209
- } ?;
210
+ Ok :: < _ , GenericError > ( CursorImpl :: new ( statement. as_stmt_ref ( ) ) )
211
+ } ?;
210
212
211
- let reader = build_odbc_reader ( cursor, & schema, & secrets) ?;
212
- for batch in reader {
213
- blocking_channel_send ( & batch_tx, batch. context ( ArrowSnafu ) ?) ?;
214
- }
213
+ let reader = build_odbc_reader ( cursor, & schema, & secrets) ?;
214
+ for batch in reader {
215
+ blocking_channel_send ( & batch_tx, batch. context ( ArrowSnafu ) ?) ?;
216
+ }
215
217
216
- Ok :: < _ , GenericError > ( ( ) )
217
- } ) ;
218
+ Ok :: < _ , GenericError > ( ( ) )
219
+ } ) ;
218
220
219
- // we need to wait for the schema first before we can build our RecordBatchStreamAdapter
220
- let Some ( schema) = schema_rx. recv ( ) . await else {
221
- // if the channel drops, the task errored
222
- if !join_handle. is_finished ( ) {
223
- unreachable ! ( "Schema channel should not have dropped before the task finished" ) ;
224
- }
221
+ // we need to wait for the schema first before we can build our RecordBatchStreamAdapter
222
+ let Some ( schema) = schema_rx. recv ( ) . await else {
223
+ // if the channel drops, the task errored
224
+ if !join_handle. is_finished ( ) {
225
+ unreachable ! ( "Schema channel should not have dropped before the task finished" ) ;
226
+ }
225
227
226
- let result = join_handle. await ?;
227
- let Err ( err) = result else {
228
- unreachable ! ( "Task should have errored" ) ;
228
+ let result = join_handle. await ?;
229
+ let Err ( err) = result else {
230
+ unreachable ! ( "Task should have errored" ) ;
231
+ } ;
232
+
233
+ return Err ( err) ;
229
234
} ;
230
235
231
- return Err ( err) ;
232
- } ;
236
+ let output_stream = stream ! {
237
+ while let Some ( batch) = batch_rx. recv( ) . await {
238
+ yield Ok ( batch) ;
239
+ }
233
240
234
- let output_stream = stream ! {
235
- while let Some ( batch) = batch_rx. recv( ) . await {
236
- yield Ok ( batch) ;
237
- }
241
+ if let Err ( e) = join_handle. await {
242
+ yield Err ( DataFusionError :: Execution ( format!(
243
+ "Failed to execute ODBC query: {e}"
244
+ ) ) )
245
+ }
246
+ } ;
238
247
239
- if let Err ( e) = join_handle. await {
240
- yield Err ( DataFusionError :: Execution ( format!(
241
- "Failed to execute ODBC query: {e}"
242
- ) ) )
243
- }
248
+ let result: SendableRecordBatchStream =
249
+ Box :: pin ( RecordBatchStreamAdapter :: new ( schema, output_stream) ) ;
250
+ Ok ( result)
244
251
} ;
245
-
246
- Ok ( Box :: pin ( RecordBatchStreamAdapter :: new (
247
- schema,
248
- output_stream,
249
- ) ) )
252
+ run_async_with_tokio ( create_stream) . await
250
253
}
251
254
252
255
async fn execute ( & self , query : & str , params : & [ ODBCParameter ] ) -> Result < u64 > {
0 commit comments