Skip to content

Commit d8276cd

Browse files
feat: Implement Python ODBC, MySQL, Postgres, Flight table providers (#281)
* Implement ODBC python table provider * Update ODBC documentation * Update API * Remove unused imports * Update comments * Update comments * Fix a typo * Implement mysql and postgresql, and refactor tokio runtime * Enable flight * Fix styling and documentation * Update core/src/sql/db_connection_pool/runtime.rs Co-authored-by: Phillip LeBlanc <[email protected]> * Update python/python/datafusion_table_providers/odbc.py Co-authored-by: Phillip LeBlanc <[email protected]> * Fix formatting and address comments * Update core/src/sql/db_connection_pool/runtime.rs --------- Co-authored-by: Phillip LeBlanc <[email protected]>
1 parent bc7c00c commit d8276cd

29 files changed

+627
-92
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let ctx = SessionContext::with_state(state);
2727
- Flight SQL
2828
- ODBC
2929

30-
## Examples
30+
## Examples (in Rust)
3131

3232
Run the included examples to see how to use the table providers:
3333

@@ -45,6 +45,7 @@ cargo run --example duckdb_function --features duckdb
4545
### SQLite
4646

4747
```bash
48+
# Run from repo folder
4849
cargo run --example sqlite --features sqlite
4950
```
5051

@@ -80,7 +81,9 @@ EOF
8081
```
8182

8283
```bash
83-
cargo run --example postgres --features postgres
84+
# Run from repo folder
85+
cargo run -p datafusion-table-providers --example postgres --features postgres
86+
8487
```
8588

8689
### MySQL
@@ -104,7 +107,8 @@ EOF
104107
```
105108

106109
```bash
107-
cargo run --example mysql --features mysql
110+
# Run from repo folder
111+
cargo run -p datafusion-table-providers --example mysql --features mysql
108112
```
109113

110114
### Flight SQL
@@ -115,16 +119,37 @@ brew install roapi
115119
# cargo install --locked --git https://github.com/roapi/roapi --branch main --bins roapi
116120
roapi -t taxi=https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-01.parquet &
117121

118-
cargo run --example flight-sql --features flight
122+
# Run from repo folder
123+
cargo run -p datafusion-table-providers --example flight-sql --features flight
119124
```
120125

121126
### ODBC
122-
123127
```bash
124128
apt-get install unixodbc-dev libsqliteodbc
125129
# or
126130
# brew install unixodbc & brew install sqliteodbc
127-
# If you use ARM Mac, please see https://github.com/pacman82/odbc-api#os-x-arm--mac-m1
128131

129132
cargo run --example odbc_sqlite --features odbc
130133
```
134+
135+
#### ARM Mac
136+
137+
Please see https://github.com/pacman82/odbc-api#os-x-arm--mac-m1 for reference.
138+
139+
Steps:
140+
1. Install unixodbc and sqliteodbc by `brew install unixodbc sqliteodbc`.
141+
2. Find local sqliteodbc driver path by running `brew info sqliteodbc`. The path might look like `/opt/homebrew/Cellar/sqliteodbc/0.99991`.
142+
3. Set up odbc config file at `~/.odbcinst.ini` with your local sqliteodbc path.
143+
Example config file:
144+
```
145+
[SQLite3]
146+
Description = SQLite3 ODBC Driver
147+
Driver = /opt/homebrew/Cellar/sqliteodbc/0.99991/lib/libsqlite3odbc.dylib
148+
```
149+
4. Test configuration by running `odbcinst -q -d -n SQLite3`. If the path is printed out correctly, then you are all set.
150+
151+
## Examples (in Python)
152+
1. Start a Python venv
153+
2. Enter into venv
154+
3. Inside python/ folder, run `maturin develop`.
155+
4. Inside python/examples/ folder, run the corresponding test using `python3 [file_name]`.

core/examples/duckdb.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async fn main() {
1919
// Opening in ReadOnly mode allows multiple reader processes to access
2020
// the database at the same time
2121
let duckdb_pool = Arc::new(
22-
DuckDbConnectionPool::new_file("examples/duckdb_example.db", &AccessMode::ReadOnly)
22+
DuckDbConnectionPool::new_file("core/examples/duckdb_example.db", &AccessMode::ReadOnly)
2323
.expect("unable to create DuckDB connection pool"),
2424
);
2525

core/examples/odbc_sqlite.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async fn main() {
1717
// Create SQLite ODBC connection pool
1818
let params = to_secret_map(HashMap::from([(
1919
"connection_string".to_owned(),
20-
"driver=SQLite3;database=examples/sqlite_example.db;".to_owned(),
20+
"driver=SQLite3;database=core/examples/sqlite_example.db;".to_owned(),
2121
)]));
2222
let odbc_pool =
2323
Arc::new(ODBCPool::new(params).expect("unable to create SQLite ODBC connection pool"));

core/examples/sqlite.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async fn main() {
2020
// - arg3: Connection timeout duration
2121
let sqlite_pool = Arc::new(
2222
SqliteConnectionPoolFactory::new(
23-
"examples/sqlite_example.db",
23+
"core/examples/sqlite_example.db",
2424
Mode::File,
2525
Duration::from_millis(5000),
2626
)

core/src/flight/exec.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::str::FromStr;
2424
use std::sync::Arc;
2525

2626
use crate::flight::{flight_channel, to_df_err, FlightMetadata, FlightProperties, SizeLimits};
27+
use crate::sql::db_connection_pool::runtime::run_async_with_tokio;
2728
use arrow_flight::error::FlightError;
2829
use arrow_flight::flight_service_client::FlightServiceClient;
2930
use arrow_flight::{FlightClient, FlightEndpoint, Ticket};
@@ -190,7 +191,8 @@ async fn flight_stream(
190191
) -> Result<SendableRecordBatchStream> {
191192
let mut errors: Vec<Box<dyn Error + Send + Sync>> = vec![];
192193
for loc in partition.locations.iter() {
193-
let client = flight_client(loc, grpc_headers.as_ref(), &size_limits).await?;
194+
let get_client = || async { flight_client(loc, grpc_headers.as_ref(), &size_limits).await };
195+
let client = run_async_with_tokio(get_client).await?;
194196
match try_fetch_stream(client, &partition.ticket, schema.clone()).await {
195197
Ok(stream) => return Ok(stream),
196198
Err(e) => errors.push(Box::new(e)),

core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::any::Any;
2-
use std::collections::HashSet;
32
use std::sync::Arc;
4-
use std::sync::OnceLock;
3+
use std::collections::HashSet;
54

65
use arrow::array::RecordBatch;
76
use arrow_schema::{DataType, Field};
@@ -20,9 +19,9 @@ use duckdb::{Connection, DuckdbConnectionManager};
2019
use dyn_clone::DynClone;
2120
use rand::distr::{Alphanumeric, SampleString};
2221
use snafu::{prelude::*, ResultExt};
23-
use tokio::runtime::{Handle, Runtime};
2422
use tokio::sync::mpsc::Sender;
2523

24+
use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
2625
use crate::util::schema::SchemaValidator;
2726
use crate::UnsupportedTypeAction;
2827

@@ -283,13 +282,6 @@ impl DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParamet
283282
}
284283
}
285284

286-
fn get_tokio_runtime() -> &'static Runtime {
287-
// TODO: this function is a repetition of python/src/utils.rs::get_tokio_runtime.
288-
// Think about how to refactor it
289-
static RUNTIME: OnceLock<Runtime> = OnceLock::new();
290-
RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create Tokio runtime"))
291-
}
292-
293285
impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
294286
for DuckDbConnection
295287
{
@@ -449,13 +441,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
449441
)))
450442
};
451443

452-
// If calling directly from Rust, there is already tokio runtime so no
453-
// additional work is needed. If calling from Python FFI, there's no existing
454-
// tokio runtime, so we need to start a new one.
455-
match Handle::try_current() {
456-
Ok(_) => create_stream(),
457-
Err(_) => get_tokio_runtime().block_on(async { create_stream() }),
458-
}
444+
run_sync_with_tokio(create_stream)
459445
}
460446

461447
fn execute(&self, sql: &str, params: &[DuckDBParameter]) -> Result<u64> {

core/src/sql/db_connection_pool/dbconnection/odbcconn.rs

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use std::sync::Arc;
2020

2121
use crate::sql::db_connection_pool::{
2222
dbconnection::{self, AsyncDbConnection, DbConnection, GenericError},
23+
runtime::run_async_with_tokio,
2324
DbConnectionPool,
2425
};
2526
use arrow_odbc::arrow_schema_from;
@@ -42,7 +43,7 @@ use odbc_api::handles::StatementImpl;
4243
use odbc_api::parameter::InputParameter;
4344
use odbc_api::Cursor;
4445
use odbc_api::CursorImpl;
45-
use secrecy::{SecretBox, ExposeSecret, SecretString};
46+
use secrecy::{ExposeSecret, SecretBox, SecretString};
4647
use snafu::prelude::*;
4748
use snafu::Snafu;
4849
use tokio::runtime::Handle;
@@ -184,69 +185,71 @@ where
184185
let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
185186
let secrets = Arc::clone(&self.params);
186187

187-
let join_handle = tokio::task::spawn_blocking(move || {
188-
let handle = Handle::current();
189-
let cxn = handle.block_on(async { conn.lock().await });
188+
let create_stream = async || -> Result<SendableRecordBatchStream> {
189+
let join_handle = tokio::task::spawn_blocking(move || {
190+
let handle = Handle::current();
191+
let cxn = handle.block_on(async { conn.lock().await });
190192

191-
let mut prepared = cxn.prepare(&sql)?;
192-
let schema = Arc::new(arrow_schema_from(&mut prepared, false)?);
193-
blocking_channel_send(&schema_tx, Arc::clone(&schema))?;
193+
let mut prepared = cxn.prepare(&sql)?;
194+
let schema = Arc::new(arrow_schema_from(&mut prepared, false)?);
195+
blocking_channel_send(&schema_tx, Arc::clone(&schema))?;
194196

195-
let mut statement = prepared.into_statement();
197+
let mut statement = prepared.into_statement();
196198

197-
bind_parameters(&mut statement, &params)?;
199+
bind_parameters(&mut statement, &params)?;
198200

199-
// StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
200-
let cursor = unsafe {
201-
if let SqlResult::Error { function } = statement.execute() {
202-
return Err(Error::ODBCAPIErrorNoSource {
203-
message: function.to_string(),
201+
// StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
202+
let cursor = unsafe {
203+
if let SqlResult::Error { function } = statement.execute() {
204+
return Err(Error::ODBCAPIErrorNoSource {
205+
message: function.to_string(),
206+
}
207+
.into());
204208
}
205-
.into());
206-
}
207209

208-
Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
209-
}?;
210+
Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
211+
}?;
210212

211-
let reader = build_odbc_reader(cursor, &schema, &secrets)?;
212-
for batch in reader {
213-
blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
214-
}
213+
let reader = build_odbc_reader(cursor, &schema, &secrets)?;
214+
for batch in reader {
215+
blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
216+
}
215217

216-
Ok::<_, GenericError>(())
217-
});
218+
Ok::<_, GenericError>(())
219+
});
218220

219-
// we need to wait for the schema first before we can build our RecordBatchStreamAdapter
220-
let Some(schema) = schema_rx.recv().await else {
221-
// if the channel drops, the task errored
222-
if !join_handle.is_finished() {
223-
unreachable!("Schema channel should not have dropped before the task finished");
224-
}
221+
// we need to wait for the schema first before we can build our RecordBatchStreamAdapter
222+
let Some(schema) = schema_rx.recv().await else {
223+
// if the channel drops, the task errored
224+
if !join_handle.is_finished() {
225+
unreachable!("Schema channel should not have dropped before the task finished");
226+
}
225227

226-
let result = join_handle.await?;
227-
let Err(err) = result else {
228-
unreachable!("Task should have errored");
228+
let result = join_handle.await?;
229+
let Err(err) = result else {
230+
unreachable!("Task should have errored");
231+
};
232+
233+
return Err(err);
229234
};
230235

231-
return Err(err);
232-
};
236+
let output_stream = stream! {
237+
while let Some(batch) = batch_rx.recv().await {
238+
yield Ok(batch);
239+
}
233240

234-
let output_stream = stream! {
235-
while let Some(batch) = batch_rx.recv().await {
236-
yield Ok(batch);
237-
}
241+
if let Err(e) = join_handle.await {
242+
yield Err(DataFusionError::Execution(format!(
243+
"Failed to execute ODBC query: {e}"
244+
)))
245+
}
246+
};
238247

239-
if let Err(e) = join_handle.await {
240-
yield Err(DataFusionError::Execution(format!(
241-
"Failed to execute ODBC query: {e}"
242-
)))
243-
}
248+
let result: SendableRecordBatchStream =
249+
Box::pin(RecordBatchStreamAdapter::new(schema, output_stream));
250+
Ok(result)
244251
};
245-
246-
Ok(Box::pin(RecordBatchStreamAdapter::new(
247-
schema,
248-
output_stream,
249-
)))
252+
run_async_with_tokio(create_stream).await
250253
}
251254

252255
async fn execute(&self, query: &str, params: &[ODBCParameter]) -> Result<u64> {

core/src/sql/db_connection_pool/duckdbpool.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,6 @@ mod test {
360360

361361
use super::*;
362362
use crate::sql::db_connection_pool::DbConnectionPool;
363-
use std::sync::Arc;
364363

365364
fn random_db_name() -> String {
366365
let mut rng = rand::rng();

core/src/sql/db_connection_pool/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub mod mysqlpool;
1111
pub mod odbcpool;
1212
#[cfg(feature = "postgres")]
1313
pub mod postgrespool;
14+
pub mod runtime;
1415
#[cfg(feature = "sqlite")]
1516
pub mod sqlitepool;
1617

0 commit comments

Comments
 (0)