-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Is your feature request related to a problem or challenge?
(based on discord thread with @timsaucer @pepijnve and @bmmeijers )
DataFusion has a TableFunction API:
https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableFunctionImpl.html
https://docs.rs/datafusion/latest/datafusion/catalog/struct.TableFunction.html
These functions take one or more arguments and produce a table as output. They are used for functions like generate_series which works like this
> select * from generate_series(1,2);
+-------+
| value |
+-------+
| 1 |
| 2 |
+-------+
2 row(s) fetched.
Elapsed 0.004 seconds.However, As @bmmeijers reports, the TableFunction API does not allow access to data from columns from another table (e.g. LATERAL joins).
See commented lines 117--125 in https://pastebin.com/td76M8Fj
(alternate copy)
use datafusion::arrow::array::{ArrayRef, Int64Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::common::{Result, ScalarValue, plan_err};
use datafusion::datasource::memory::MemTable;
use datafusion::logical_expr::Expr;
use datafusion::prelude::SessionContext;
use std::sync::Arc;
#[derive(Debug)]
pub struct TransformFunction {}
impl TableFunctionImpl for TransformFunction {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
if exprs.len() != 3 {
return plan_err!(
"Expected exactly three arguments: a, b, c, but got {}",
exprs.len()
);
}
println!("{:?}", exprs);
let extract_int64 = |expr: &Expr, arg_name: &str| -> Result<i64> {
match expr {
Expr::Literal(ScalarValue::Int64(Some(val)), _) => Ok(*val),
// Expr::Column()
_ => plan_err!("Argument {} must be an Int64 literal", arg_name),
}
};
let a = extract_int64(&exprs[0], "a")?;
let b = extract_int64(&exprs[1], "b")?;
let c = extract_int64(&exprs[2], "c")?;
// Compute output columns: x = a + b, y = b * c
let x = a + b;
let y = b * c;
// Define output schema
let schema = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Int64, false),
]));
// Create output arrays
let x_array = Arc::new(Int64Array::from(vec![x])) as ArrayRef;
let y_array = Arc::new(Int64Array::from(vec![y])) as ArrayRef;
// Create a single RecordBatch
let batch = RecordBatch::try_new(schema.clone(), vec![x_array, y_array])?;
// Wrap in a MemTable
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
Ok(Arc::new(provider))
}
}
// --- Usage Example ---
// /// Registers the TransformFunction as a TableUDF in the SessionContext.
fn register_udtf(ctx: &mut SessionContext) -> Result<()> {
// 1. Create the implementation instance
let udtf = Arc::new(TransformFunction {});
ctx.register_udtf("my_transform", udtf);
Ok(())
}
/// Creates a small in-memory table for demonstration.
fn create_dummy_table(ctx: &mut SessionContext) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int64, false),
Field::new("c", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["r1", "r2"])) as ArrayRef,
Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef,
Arc::new(Int64Array::from(vec![5, 6])) as ArrayRef,
Arc::new(Int64Array::from(vec![2, 3])) as ArrayRef,
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
ctx.register_table("my_table", Arc::new(provider))?;
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
let mut ctx = SessionContext::new();
// 1. Register the custom UDTF
register_udtf(&mut ctx)?;
// 2. Register a dummy table
create_dummy_table(&mut ctx)?;
// 3. Define and execute the SQL query
let sql = r#"
SELECT
t1.id,
t2.x AS a_plus_b,
t2.y AS b_times_c
FROM
my_table AS t1,
LATERAL my_transform(1, 2, 3) AS t2(x, y)
"#;
// let sql = r#"
// SELECT
// t1.id,
// t2.x AS a_plus_b,
// t2.y AS b_times_c
// FROM
// my_table AS t1,
// LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)
// "#;
println!("Executing SQL:\n{}", sql);
let df = ctx.sql(sql).await?;
println!("\nQuery Result:");
df.show().await?;
Ok(())
}Specifically, it is possible to pass in columns for args in call(&self, args: &[Expr]) which works, for example this works (because the arguments to my_transform are scalar values 1,2,3:
SELECT
t1.id,
t2.x AS a_plus_b,
t2.y AS b_times_c
FROM
my_table AS t1,
LATERAL my_transform(1, 2, 3) AS t2(x, y)However, this doesn't work:
SELECT
t1.id,
t2.x AS a_plus_b,
t2.y AS b_times_c
FROM
my_table AS t1,
LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)And you get an error like This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(
Describe the solution you'd like
I would like to be able to access the values from another table from a TableFunction
Here is a simple example of LATERAL and the range function that works today :
> SELECT *
FROM range(3) t(i), LATERAL range(3) t2(j);
+---+---+
| i | j |
+---+---+
| 0 | 0 |
| 0 | 1 |
| 0 | 2 |
| 1 | 0 |
| 1 | 1 |
| 1 | 2 |
| 2 | 0 |
| 2 | 1 |
| 2 | 2 |
+---+---+
9 row(s) fetched.
Elapsed 0.002 seconds.However, you can't get refer to the previous subquery in the argument to range,
> SELECT *
FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);
This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "i", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "t" }), name: "i" })I want some way to refer to other inputs in table expressions. Here is how it works in DuckDB
D SELECT *
FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);
┌───────┬───────┐
│ i │ j │
│ int64 │ int64 │
├───────┼───────┤
│ 0 │ 1 │
│ 1 │ 2 │
│ 2 │ 3 │
└───────┴───────┘Describe alternatives you've considered
Here is the duckdb explain plan
Details
D explain SELECT *
FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);
┌─────────────────────────────┐
│┌───────────────────────────┐│
││ Physical Plan ││
│└───────────────────────────┘│
└─────────────────────────────┘
┌───────────────────────────┐
│ LEFT_DELIM_JOIN │
│ ──────────────────── │
│ Join Type: INNER │
│ │
│ Conditions: ├──────────────┬─────────────────────────────────────────────────────────┐
│ i IS NOT DISTINCT FROM i │ │ │
│ │ │ │
│ ~3 rows │ │ │
└─────────────┬─────────────┘ │ │
┌─────────────┴─────────────┐┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐
│ RANGE ││ HASH_JOIN │ │ HASH_GROUP_BY │
│ ──────────────────── ││ ──────────────────── │ │ ──────────────────── │
│ Function: RANGE ││ Join Type: INNER │ │ Groups: #0 │
│ ││ │ │ │
│ ││ Conditions: ├──────────────┐ │ │
│ ││ i IS NOT DISTINCT FROM i │ │ │ │
│ ││ │ │ │ │
│ ~3 rows ││ ~3 rows │ │ │ ~1 row │
└───────────────────────────┘└─────────────┬─────────────┘ │ └───────────────────────────┘
┌─────────────┴─────────────┐┌─────────────┴─────────────┐
│ COLUMN_DATA_SCAN ││ PROJECTION │
│ ──────────────────── ││ ──────────────────── │
│ ││ (i + 1) │
│ ││ i │
│ ││ │
│ ~3 rows ││ ~1 row │
└───────────────────────────┘└─────────────┬─────────────┘
┌─────────────┴─────────────┐
│ DELIM_SCAN │
│ ──────────────────── │
│ Delim Index: 1 │
│ │
│ ~1 row │
└───────────────────────────┘Here are some good documents about user defined table functions in the Snowflake documentation:
Here are the docs about LATERAL join from DuckDB: