Skip to content

Support table inputs for user defined table functions #18121

@alamb

Description

@alamb

Is your feature request related to a problem or challenge?

(based on discord thread with @timsaucer @pepijnve and @bmmeijers )

DataFusion has a TableFunction API:
https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableFunctionImpl.html
https://docs.rs/datafusion/latest/datafusion/catalog/struct.TableFunction.html

These functions take one or more arguments and produce a table as output. They are used for functions like generate_series which works like this

> select * from generate_series(1,2);
+-------+
| value |
+-------+
| 1     |
| 2     |
+-------+
2 row(s) fetched.
Elapsed 0.004 seconds.

However, As @bmmeijers reports, the TableFunction API does not allow access to data from columns from another table (e.g. LATERAL joins).

See commented lines 117--125 in https://pastebin.com/td76M8Fj

(alternate copy)

use datafusion::arrow::array::{ArrayRef, Int64Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::common::{Result, ScalarValue, plan_err};
use datafusion::datasource::memory::MemTable;
use datafusion::logical_expr::Expr;
use datafusion::prelude::SessionContext;
use std::sync::Arc;
 
#[derive(Debug)]
pub struct TransformFunction {}
 
impl TableFunctionImpl for TransformFunction {
    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
        if exprs.len() != 3 {
            return plan_err!(
                "Expected exactly three arguments: a, b, c, but got {}",
                exprs.len()
            );
        }
 
        println!("{:?}", exprs);
 
        let extract_int64 = |expr: &Expr, arg_name: &str| -> Result<i64> {
            match expr {
                Expr::Literal(ScalarValue::Int64(Some(val)), _) => Ok(*val),
                // Expr::Column()
                _ => plan_err!("Argument {} must be an Int64 literal", arg_name),
            }
        };
 
        let a = extract_int64(&exprs[0], "a")?;
        let b = extract_int64(&exprs[1], "b")?;
        let c = extract_int64(&exprs[2], "c")?;
 
        // Compute output columns: x = a + b, y = b * c
        let x = a + b;
        let y = b * c;
 
        // Define output schema
        let schema = Arc::new(Schema::new(vec![
            Field::new("x", DataType::Int64, false),
            Field::new("y", DataType::Int64, false),
        ]));
 
        // Create output arrays
        let x_array = Arc::new(Int64Array::from(vec![x])) as ArrayRef;
        let y_array = Arc::new(Int64Array::from(vec![y])) as ArrayRef;
 
        // Create a single RecordBatch
        let batch = RecordBatch::try_new(schema.clone(), vec![x_array, y_array])?;
 
        // Wrap in a MemTable
        let provider = MemTable::try_new(schema, vec![vec![batch]])?;
 
        Ok(Arc::new(provider))
    }
}
 
// --- Usage Example ---
 
// /// Registers the TransformFunction as a TableUDF in the SessionContext.
fn register_udtf(ctx: &mut SessionContext) -> Result<()> {
    // 1. Create the implementation instance
    let udtf = Arc::new(TransformFunction {});
    ctx.register_udtf("my_transform", udtf);
 
    Ok(())
}
 
/// Creates a small in-memory table for demonstration.
fn create_dummy_table(ctx: &mut SessionContext) -> Result<()> {
    let schema = Arc::new(Schema::new(vec![
        Field::new("id", DataType::Utf8, false),
        Field::new("a", DataType::Int64, false),
        Field::new("b", DataType::Int64, false),
        Field::new("c", DataType::Int64, false),
    ]));
 
    let batch = RecordBatch::try_new(
        schema.clone(),
        vec![
            Arc::new(StringArray::from(vec!["r1", "r2"])) as ArrayRef,
            Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef,
            Arc::new(Int64Array::from(vec![5, 6])) as ArrayRef,
            Arc::new(Int64Array::from(vec![2, 3])) as ArrayRef,
        ],
    )?;
 
    let provider = MemTable::try_new(schema, vec![vec![batch]])?;
    ctx.register_table("my_table", Arc::new(provider))?;
    Ok(())
}
 
#[tokio::main]
async fn main() -> Result<()> {
    let mut ctx = SessionContext::new();
 
    // 1. Register the custom UDTF
    register_udtf(&mut ctx)?;
 
    // 2. Register a dummy table
    create_dummy_table(&mut ctx)?;
 
    // 3. Define and execute the SQL query
    let sql = r#"
        SELECT 
            t1.id, 
            t2.x AS a_plus_b, 
            t2.y AS b_times_c
        FROM 
            my_table AS t1,
            LATERAL my_transform(1, 2, 3) AS t2(x, y)
    "#;
 
    // let sql = r#"
    //     SELECT 
    //         t1.id, 
    //         t2.x AS a_plus_b, 
    //         t2.y AS b_times_c
    //     FROM 
    //         my_table AS t1,
    //         LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)
    // "#;
 
 
    println!("Executing SQL:\n{}", sql);
 
    let df = ctx.sql(sql).await?;
 
    println!("\nQuery Result:");
    df.show().await?;
 
    Ok(())
}

Specifically, it is possible to pass in columns for args in call(&self, args: &[Expr]) which works, for example this works (because the arguments to my_transform are scalar values 1,2,3:

        SELECT 
            t1.id, 
            t2.x AS a_plus_b, 
            t2.y AS b_times_c
        FROM 
            my_table AS t1,
            LATERAL my_transform(1, 2, 3) AS t2(x, y)

However, this doesn't work:

         SELECT 
             t1.id, 
             t2.x AS a_plus_b, 
             t2.y AS b_times_c
         FROM 
             my_table AS t1,
             LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)

And you get an error like This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(

Describe the solution you'd like

I would like to be able to access the values from another table from a TableFunction

Here is a simple example of LATERAL and the range function that works today :

> SELECT *
  FROM range(3) t(i), LATERAL range(3) t2(j);
+---+---+
| i | j |
+---+---+
| 0 | 0 |
| 0 | 1 |
| 0 | 2 |
| 1 | 0 |
| 1 | 1 |
| 1 | 2 |
| 2 | 0 |
| 2 | 1 |
| 2 | 2 |
+---+---+
9 row(s) fetched.
Elapsed 0.002 seconds.

However, you can't get refer to the previous subquery in the argument to range,

> SELECT *
  FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);
This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "i", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "t" }), name: "i" })

I want some way to refer to other inputs in table expressions. Here is how it works in DuckDB

D SELECT *
  FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);
┌───────┬───────┐
│   i   │   j   │
│ int64 │ int64 │
├───────┼───────┤
│     01 │
│     12 │
│     23 │
└───────┴───────┘

Describe alternatives you've considered

Here is the duckdb explain plan

Details

D explain SELECT *
  FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);

┌─────────────────────────────┐
│┌───────────────────────────┐│
││       Physical Plan       ││
│└───────────────────────────┘│
└─────────────────────────────┘
┌───────────────────────────┐
│      LEFT_DELIM_JOIN      │
│    ────────────────────   │
│      Join Type: INNER     │
│                           │
│        Conditions:        ├──────────────┬─────────────────────────────────────────────────────────┐
│  i IS NOT DISTINCT FROM i │              │                                                         │
│                           │              │                                                         │
│          ~3 rows          │              │                                                         │
└─────────────┬─────────────┘              │                                                         │
┌─────────────┴─────────────┐┌─────────────┴─────────────┐                             ┌─────────────┴─────────────┐
│           RANGE           ││         HASH_JOIN         │                             │       HASH_GROUP_BY       │
│    ────────────────────   ││    ────────────────────   │                             │    ────────────────────   │
│      Function: RANGE      ││      Join Type: INNER     │                             │         Groups: #0        │
│                           ││                           │                             │                           │
│                           ││        Conditions:        ├──────────────┐              │                           │
│                           ││  i IS NOT DISTINCT FROM i │              │              │                           │
│                           ││                           │              │              │                           │
│          ~3 rows          ││          ~3 rows          │              │              │           ~1 row          │
└───────────────────────────┘└─────────────┬─────────────┘              │              └───────────────────────────┘
                             ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
                             │      COLUMN_DATA_SCAN     ││         PROJECTION        │
                             │    ────────────────────   ││    ────────────────────   │
                             │                           ││          (i + 1)          │
                             │                           ││             i             │
                             │                           ││                           │
                             │          ~3 rows          ││           ~1 row          │
                             └───────────────────────────┘└─────────────┬─────────────┘
                                                          ┌─────────────┴─────────────┐
                                                          │         DELIM_SCAN        │
                                                          │    ────────────────────   │
                                                          │       Delim Index: 1      │
                                                          │                           │
                                                          │           ~1 row          │
                                                          └───────────────────────────┘

### Additional context

Here are some good documents about user defined table functions in the Snowflake documentation:

Here are the docs about LATERAL join from DuckDB:

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions