From b9d956aebf14b3e80b18c32a0b84488ed8529a1e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 9 Apr 2026 11:44:56 -0400 Subject: [PATCH 01/11] feat: accept native Python types in function arguments instead of requiring lit() Update 47 functions in functions.py to accept native Python types (int, float, str) for arguments that are contextually literals, eliminating verbose lit() wrapping. For example, users can now write split_part(col("a"), ",", 2) instead of split_part(col("a"), lit(","), lit(2)). All changes are backward compatible. Co-Authored-By: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 302 ++++++++++++++++++++ python/datafusion/functions.py | 452 +++++++++++++++++++----------- 2 files changed, 588 insertions(+), 166 deletions(-) create mode 100644 .ai/skills/make-pythonic/SKILL.md diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md new file mode 100644 index 000000000..87cc8751c --- /dev/null +++ b/.ai/skills/make-pythonic/SKILL.md @@ -0,0 +1,302 @@ + + +--- +name: make-pythonic +description: Audit and improve datafusion-python functions to accept native Python types (int, float, str, bool) instead of requiring explicit lit() or col() wrapping. Analyzes function signatures, checks upstream Rust implementations for type constraints, and applies the appropriate coercion pattern. +argument-hint: [scope] (e.g., "string functions", "datetime functions", "array functions", "math functions", "all", or a specific function name like "split_part") +--- + +# Make Python API Functions More Pythonic + +You are improving the datafusion-python API to feel more natural to Python users. The goal is to allow functions to accept native Python types (int, float, str, bool, etc.) for arguments that are contextually always or typically literal values, instead of requiring users to manually wrap them in `lit()`. + +**Core principle:** A Python user should be able to write `split_part(col("a"), ",", 2)` instead of `split_part(col("a"), lit(","), lit(2))` when the arguments are contextually obvious literals. + +## How to Identify Candidates + +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions.py`. + +For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: + +### Signal 1: Contextual Understanding + +Some arguments are contextually always or almost always literal values based on what the function does: + +| Context | Typical Arguments | Examples | +|---------|------------------|----------| +| **String position/count** | Character counts, indices, repetition counts | `left(str, n)`, `right(str, n)`, `repeat(str, n)`, `lpad(str, count, ...)` | +| **Delimiters/separators** | Fixed separator characters | `split_part(str, delim, idx)`, `concat_ws(sep, ...)` | +| **Search/replace patterns** | Literal search strings, replacements | `replace(str, from, to)`, `regexp_replace(str, pattern, replacement, flags)` | +| **Date/time parts** | Part names from a fixed set | `date_part(part, date)`, `date_trunc(part, date)` | +| **Rounding precision** | Decimal place counts | `round(val, places)`, `trunc(val, places)` | +| **Fill characters** | Padding characters | `lpad(str, count, fill)`, `rpad(str, count, fill)` | + +### Signal 2: Upstream Rust Implementation + +Check the Rust binding in `crates/core/src/functions.rs` and the upstream DataFusion function implementation to determine type constraints. The upstream source is cached locally at: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-/src/ +``` + +Check the DataFusion version in `crates/core/Cargo.toml` to find the right directory. Key subdirectories: `string/`, `datetime/`, `math/`, `regex/`. + +There are three concrete techniques to check, in order of signal strength: + +#### Technique 1: Check `invoke_with_args()` for literal-only enforcement (strongest signal) + +Some functions pattern-match on `ColumnarValue::Scalar` in their `invoke_with_args()` method and **return an error** if the argument is a column/array. This means the argument **must** be a literal — passing a column expression will fail at runtime. + +Example from `date_trunc.rs`: +```rust +let granularity_str = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity { + v.to_lowercase() +} else { + return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); +}; +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type (e.g., `str`), not `Expr`. The function will error at runtime with a column expression anyway. + +#### Technique 2: Check the `Signature` for data type constraints + +Each function defines a `Signature::coercible(...)` that specifies what data types each argument accepts, using `Coercion` entries. This tells you the expected **data type** even if it doesn't enforce literal-only. + +Example from `repeat.rs`: +```rust +signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, +), +``` + +This tells you arg 2 (`n`) must be an integer type coerced to Int64. Use this to choose the correct Python type (e.g., `int` not `str` or `float`). + +Common mappings: +| Rust Type Constraint | Python Type | +|---------------------|-------------| +| `logical_int64()` / `TypeSignatureClass::Integer` | `int` | +| `logical_float64()` / `TypeSignatureClass::Numeric` | `int \| float` | +| `logical_string()` / `TypeSignatureClass::String` | `str` | +| `LogicalType::Boolean` | `bool` | + +#### Technique 3: Check `return_field_from_args()` for `scalar_arguments` usage + +Functions that inspect literal values at query planning time use `args.scalar_arguments.get(n)` in their `return_field_from_args()` method. This indicates the argument is **expected to be a literal** for optimal behavior (e.g., to determine output type precision), but may still work as a column. + +Example from `round.rs`: +```rust +let decimal_places: Option = match args.scalar_arguments.get(1) { + None => Some(0), + Some(None) => None, // argument is not a literal (column) + Some(Some(scalar)) if scalar.is_null() => Some(0), + Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?), +}; +``` + +**If you find this pattern:** The argument is **Category A** — accept native types AND `Expr`. It works as a column but is primarily used as a literal. + +#### Decision flow + +``` +Is argument rejected at runtime if not a literal? + (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) + → YES: Category B — accept only native type, no Expr + → NO: Does the Signature constrain it to a specific data type? + → YES: Category A — accept Expr | + → NO: Leave as Expr only +``` + +## Coercion Categories + +When making a function more pythonic, apply the correct coercion pattern based on **what the argument represents**: + +### Category A: Arguments That Should Accept Native Types AND Expr + +These are arguments that are *typically* literals but *could* be column references in advanced use cases. For these, accept a union type and coerce native types to `Expr.literal()`. + +**Type hint pattern:** `Expr | int`, `Expr | str`, `Expr | int | str`, etc. + +**When to use:** When the argument could plausibly come from a column in some use case (e.g., the repeat count might come from a column in a data-driven scenario). + +```python +def repeat(string: Expr, n: Expr | int) -> Expr: + """Repeats the ``string`` to ``n`` times. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["ha"]}) + >>> result = df.select( + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) + >>> result.collect_column("r")[0].as_py() + 'hahaha' + """ + if not isinstance(n, Expr): + n = Expr.literal(n) + return Expr(f.repeat(string.expr, n.expr)) +``` + +### Category B: Arguments That Should ONLY Accept Specific Native Types + +These are arguments where an `Expr` never makes sense because the value must be a fixed literal known at query-planning time (not a per-row value). For these, accept only the native type(s) and wrap internally. + +**Type hint pattern:** `str`, `int`, `list[str]`, etc. (no `Expr` in the union) + +**When to use:** When the argument is from a fixed enumeration or is always a compile-time constant: +- Date/time part names (`"year"`, `"month"`, `"day"`, etc.) +- Regex flags (`"g"`, `"i"`, etc.) +- Values that the Rust implementation already accepts as native types + +```python +def date_part(part: str, date: Expr) -> Expr: + """Extracts a subfield from the date. + + Args: + part: The part of the date to extract. Must be one of "year", "month", + "day", "hour", "minute", "second", etc. + date: The date expression to extract from. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) + >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) + >>> result = df.select( + ... dfn.functions.date_part("year", dfn.col("a")).alias("y")) + >>> result.collect_column("y")[0].as_py() + 2021 + """ + part = Expr.literal(part) + return Expr(f.date_part(part.expr, date.expr)) +``` + +### Category C: Arguments That Should Accept str as Column Name + +In some contexts a string argument naturally refers to a column name rather than a literal. This is the pattern used by DataFrame methods. + +**Type hint pattern:** `Expr | str` + +**When to use:** Only when the string contextually means a column name (rare in `functions.py`, more common in DataFrame methods). + +```python +# Use _to_raw_expr() from expr.py for this pattern +from datafusion.expr import _to_raw_expr + +def some_function(column: Expr | str) -> Expr: + raw = _to_raw_expr(column) # str -> col(str) + return Expr(f.some_function(raw)) +``` + +**IMPORTANT:** In `functions.py`, string arguments almost never mean column names. Functions operate on expressions, and column references should use `col()`. Category C applies mainly to DataFrame methods and context APIs, not to scalar/aggregate/window functions. Do NOT convert string arguments to column expressions in `functions.py` unless there is a very clear reason to do so. + +## Implementation Steps + +For each function being updated: + +### Step 1: Analyze the Function + +1. Read the current Python function signature in `python/datafusion/functions.py` +2. Read the Rust binding in `crates/core/src/functions.rs` +3. Optionally check the upstream DataFusion docs for the function +4. Determine which category (A, B, or C) applies to each parameter + +### Step 2: Update the Python Function + +1. **Change the type hints** to accept native types (e.g., `Expr` -> `Expr | int`) +2. **Add coercion logic** at the top of the function body +3. **Update the docstring** examples to use the simpler calling convention +4. **Preserve backward compatibility** — existing code using `Expr` must still work + +### Step 3: Update Docstring Examples + +Per the project's CLAUDE.md rules: +- Every function must have doctest-style examples +- Optional parameters need examples both without and with the optional args, using keyword argument syntax +- Reuse the same input data across examples where possible + +**Update examples to demonstrate the pythonic calling convention:** + +```python +# BEFORE (old style - still works but verbose) +dfn.functions.left(dfn.col("a"), dfn.lit(3)) + +# AFTER (new style - shown in examples) +dfn.functions.left(dfn.col("a"), 3) +``` + +### Step 4: Run Tests + +After making changes, run the doctests to verify: +```bash +python -m pytest --doctest-modules python/datafusion/functions.py -v +``` + +## Coercion Helper Pattern + +When adding coercion to a function, use this inline pattern: + +```python +if not isinstance(arg, Expr): + arg = Expr.literal(arg) +``` + +Do NOT create a new helper function for this — the inline check is clear and explicit. The project intentionally uses `ensure_expr()` to reject non-Expr values in contexts where coercion is not wanted; the pythonic coercion is the opposite pattern and should be visually distinct. + +## What NOT to Change + +- **Do not change arguments that represent data columns.** If an argument is the primary data being operated on (e.g., the `string` in `left(string, n)` or the `array` in `array_sort(array)`), it should remain `Expr` only. Users should use `col()` for column references. +- **Do not change variadic `*args: Expr` parameters.** These represent multiple expressions and should stay as `Expr`. +- **Do not change arguments where the coercion is ambiguous.** If it is unclear whether a string should be a column name or a literal, leave it as `Expr` and let the user be explicit. +- **Do not change functions that are simple aliases.** If a function is just `return other_function(...)`, update the primary function only. +- **Do not change the Rust bindings.** All coercion happens in the Python layer. The Rust functions continue to accept `PyExpr`. + +## Priority Order + +When auditing functions, process them in this order: + +1. **Date/time functions** — `date_part`, `date_trunc`, `date_bin` — these have the clearest literal arguments +2. **String functions** — `left`, `right`, `repeat`, `lpad`, `rpad`, `split_part`, `substring`, `replace`, `regexp_replace`, `regexp_match`, `regexp_count` — common and verbose without coercion +3. **Math functions** — `round`, `trunc`, `power` — numeric literal arguments +4. **Array functions** — `array_slice`, `array_position`, `array_remove_n`, `array_replace_n`, `array_resize`, `array_element` — index and count arguments +5. **Other functions** — any remaining functions with literal arguments + +## Output Format + +For each function analyzed, report: + +``` +## [Function Name] + +**Current signature:** `function(arg1: Expr, arg2: Expr) -> Expr` +**Proposed signature:** `function(arg1: Expr, arg2: Expr | int) -> Expr` +**Category:** A (accepts native + Expr) +**Arguments changed:** +- `arg2`: Expr -> Expr | int (always a literal count) +**Rust binding:** Takes PyExpr, wraps to literal internally +**Status:** [Changed / Skipped / Needs Discussion] +``` + +If asked to implement (not just audit), make the changes directly and show a summary of what was updated. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 841cd9c0b..feedfa7ec 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -363,49 +363,55 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr: return Expr(f.nullif(expr1.expr, expr2.expr)) -def encode(expr: Expr, encoding: Expr) -> Expr: +def encode(expr: Expr, encoding: Expr | str) -> Expr: """Encode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.encode(dfn.col("a"), dfn.lit("base64")).alias("enc")) + ... dfn.functions.encode(dfn.col("a"), "base64").alias("enc")) >>> result.collect_column("enc")[0].as_py() 'aGVsbG8' """ + if not isinstance(encoding, Expr): + encoding = Expr.literal(encoding) return Expr(f.encode(expr.expr, encoding.expr)) -def decode(expr: Expr, encoding: Expr) -> Expr: +def decode(expr: Expr, encoding: Expr | str) -> Expr: """Decode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["aGVsbG8="]}) >>> result = df.select( - ... dfn.functions.decode(dfn.col("a"), dfn.lit("base64")).alias("dec")) + ... dfn.functions.decode(dfn.col("a"), "base64").alias("dec")) >>> result.collect_column("dec")[0].as_py() b'hello' """ + if not isinstance(encoding, Expr): + encoding = Expr.literal(encoding) return Expr(f.decode(expr.expr, encoding.expr)) -def array_to_string(expr: Expr, delimiter: Expr) -> Expr: +def array_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) >>> result = df.select( - ... dfn.functions.array_to_string(dfn.col("a"), dfn.lit(",")).alias("s")) + ... dfn.functions.array_to_string(dfn.col("a"), ",").alias("s")) >>> result.collect_column("s")[0].as_py() '1,2,3' """ + if not isinstance(delimiter, Expr): + delimiter = Expr.literal(delimiter) return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string()))) -def array_join(expr: Expr, delimiter: Expr) -> Expr: +def array_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -414,7 +420,7 @@ def array_join(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_to_string(expr: Expr, delimiter: Expr) -> Expr: +def list_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -423,7 +429,7 @@ def list_to_string(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_join(expr: Expr, delimiter: Expr) -> Expr: +def list_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -459,7 +465,7 @@ def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: return Expr(f.in_list(arg.expr, values, negated)) -def digest(value: Expr, method: Expr) -> Expr: +def digest(value: Expr, method: Expr | str) -> Expr: """Computes the binary hash of an expression using the specified algorithm. Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, @@ -469,24 +475,28 @@ def digest(value: Expr, method: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.digest(dfn.col("a"), dfn.lit("md5")).alias("d")) + ... dfn.functions.digest(dfn.col("a"), "md5").alias("d")) >>> len(result.collect_column("d")[0].as_py()) > 0 True """ + if not isinstance(method, Expr): + method = Expr.literal(method) return Expr(f.digest(value.expr, method.expr)) -def contains(string: Expr, search_str: Expr) -> Expr: +def contains(string: Expr, search_str: Expr | str) -> Expr: """Returns true if ``search_str`` is found within ``string`` (case-sensitive). Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the quick brown fox"]}) >>> result = df.select( - ... dfn.functions.contains(dfn.col("a"), dfn.lit("brown")).alias("c")) + ... dfn.functions.contains(dfn.col("a"), "brown").alias("c")) >>> result.collect_column("c")[0].as_py() True """ + if not isinstance(search_str, Expr): + search_str = Expr.literal(search_str) return Expr(f.contains(string.expr, search_str.expr)) @@ -949,17 +959,19 @@ def degrees(arg: Expr) -> Expr: return Expr(f.degrees(arg.expr)) -def ends_with(arg: Expr, suffix: Expr) -> Expr: +def ends_with(arg: Expr, suffix: Expr | str) -> Expr: """Returns true if the ``string`` ends with the ``suffix``, false otherwise. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abc","b","c"]}) >>> ends_with_df = df.select( - ... dfn.functions.ends_with(dfn.col("a"), dfn.lit("c")).alias("ends_with")) + ... dfn.functions.ends_with(dfn.col("a"), "c").alias("ends_with")) >>> ends_with_df.collect_column("ends_with")[0].as_py() True """ + if not isinstance(suffix, Expr): + suffix = Expr.literal(suffix) return Expr(f.ends_with(arg.expr, suffix.expr)) @@ -991,7 +1003,7 @@ def factorial(arg: Expr) -> Expr: return Expr(f.factorial(arg.expr)) -def find_in_set(string: Expr, string_list: Expr) -> Expr: +def find_in_set(string: Expr, string_list: Expr | str) -> Expr: """Find a string in a list of strings. Returns a value in the range of 1 to N if the string is in the string list @@ -1003,10 +1015,12 @@ def find_in_set(string: Expr, string_list: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["b"]}) >>> result = df.select( - ... dfn.functions.find_in_set(dfn.col("a"), dfn.lit("a,b,c")).alias("pos")) + ... dfn.functions.find_in_set(dfn.col("a"), "a,b,c").alias("pos")) >>> result.collect_column("pos")[0].as_py() 2 """ + if not isinstance(string_list, Expr): + string_list = Expr.literal(string_list) return Expr(f.find_in_set(string.expr, string_list.expr)) @@ -1138,31 +1152,35 @@ def least(*args: Expr) -> Expr: return Expr(f.least(*exprs)) -def left(string: Expr, n: Expr) -> Expr: +def left(string: Expr, n: Expr | int) -> Expr: """Returns the first ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat"]}) >>> left_df = df.select( - ... dfn.functions.left(dfn.col("a"), dfn.lit(3)).alias("left")) + ... dfn.functions.left(dfn.col("a"), 3).alias("left")) >>> left_df.collect_column("left")[0].as_py() 'the' """ + if not isinstance(n, Expr): + n = Expr.literal(n) return Expr(f.left(string.expr, n.expr)) -def levenshtein(string1: Expr, string2: Expr) -> Expr: +def levenshtein(string1: Expr, string2: Expr | str) -> Expr: """Returns the Levenshtein distance between the two given strings. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["kitten"]}) >>> result = df.select( - ... dfn.functions.levenshtein(dfn.col("a"), dfn.lit("sitting")).alias("d")) + ... dfn.functions.levenshtein(dfn.col("a"), "sitting").alias("d")) >>> result.collect_column("d")[0].as_py() 3 """ + if not isinstance(string2, Expr): + string2 = Expr.literal(string2) return Expr(f.levenshtein(string1.expr, string2.expr)) @@ -1179,18 +1197,20 @@ def ln(arg: Expr) -> Expr: return Expr(f.ln(arg.expr)) -def log(base: Expr, num: Expr) -> Expr: +def log(base: Expr | float, num: Expr) -> Expr: """Returns the logarithm of a number for a particular ``base``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [100.0]}) >>> result = df.select( - ... dfn.functions.log(dfn.lit(10.0), dfn.col("a")).alias("log") + ... dfn.functions.log(10.0, dfn.col("a")).alias("log") ... ) >>> result.collect_column("log")[0].as_py() 2.0 """ + if not isinstance(base, Expr): + base = Expr.literal(base) return Expr(f.log(base.expr, num.expr)) @@ -1233,7 +1253,7 @@ def lower(arg: Expr) -> Expr: return Expr(f.lower(arg.expr)) -def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def lpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add left padding to a string. Extends the string to length length by prepending the characters fill (a @@ -1244,9 +1264,7 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat", "a hat"]}) >>> lpad_df = df.select( - ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(6) - ... ).alias("lpad")) + ... dfn.functions.lpad(dfn.col("a"), 6).alias("lpad")) >>> lpad_df.collect_column("lpad")[0].as_py() 'the ca' >>> lpad_df.collect_column("lpad")[1].as_py() @@ -1254,12 +1272,17 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(10), characters=dfn.lit(".") + ... dfn.col("a"), 10, characters="." ... ).alias("lpad")) >>> result.collect_column("lpad")[0].as_py() '...the cat' """ - characters = characters if characters is not None else Expr.literal(" ") + if not isinstance(count, Expr): + count = Expr.literal(count) + if characters is None: + characters = Expr.literal(" ") + elif not isinstance(characters, Expr): + characters = Expr.literal(characters) return Expr(f.lpad(string.expr, count.expr, characters.expr)) @@ -1354,7 +1377,10 @@ def octet_length(arg: Expr) -> Expr: def overlay( - string: Expr, substring: Expr, start: Expr, length: Expr | None = None + string: Expr, + substring: Expr | str, + start: Expr | int, + length: Expr | int | None = None, ) -> Expr: """Replace a substring with a new substring. @@ -1365,13 +1391,18 @@ def overlay( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcdef"]}) >>> result = df.select( - ... dfn.functions.overlay(dfn.col("a"), dfn.lit("XY"), dfn.lit(3), - ... dfn.lit(2)).alias("o")) + ... dfn.functions.overlay(dfn.col("a"), "XY", 3, 2).alias("o")) >>> result.collect_column("o")[0].as_py() 'abXYef' """ + if not isinstance(substring, Expr): + substring = Expr.literal(substring) + if not isinstance(start, Expr): + start = Expr.literal(start) if length is None: return Expr(f.overlay(string.expr, substring.expr, start.expr)) + if not isinstance(length, Expr): + length = Expr.literal(length) return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) @@ -1400,22 +1431,24 @@ def position(string: Expr, substring: Expr) -> Expr: return strpos(string, substring) -def power(base: Expr, exponent: Expr) -> Expr: +def power(base: Expr, exponent: Expr | float) -> Expr: """Returns ``base`` raised to the power of ``exponent``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [2.0]}) >>> result = df.select( - ... dfn.functions.power(dfn.col("a"), dfn.lit(3.0)).alias("pow") + ... dfn.functions.power(dfn.col("a"), 3.0).alias("pow") ... ) >>> result.collect_column("pow")[0].as_py() 8.0 """ + if not isinstance(exponent, Expr): + exponent = Expr.literal(exponent) return Expr(f.power(base.expr, exponent.expr)) -def pow(base: Expr, exponent: Expr) -> Expr: +def pow(base: Expr, exponent: Expr | float) -> Expr: """Returns ``base`` raised to the power of ``exponent``. See Also: @@ -1440,7 +1473,9 @@ def radians(arg: Expr) -> Expr: return Expr(f.radians(arg.expr)) -def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_like( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Find if any regular expression (regex) matches exist. Tests a string using a regular expression returning true if at least one match, @@ -1450,9 +1485,7 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello123"]}) >>> result = df.select( - ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("m") + ... dfn.functions.regexp_like(dfn.col("a"), "\\d+").alias("m") ... ) >>> result.collect_column("m")[0].as_py() True @@ -1461,19 +1494,24 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("HELLO"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "HELLO", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() True """ + if not isinstance(regex, Expr): + regex = Expr.literal(regex) if flags is not None: + if not isinstance(flags, Expr): + flags = Expr.literal(flags) flags = flags.expr return Expr(f.regexp_like(string.expr, regex.expr, flags)) -def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_match( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Perform regular expression (regex) matching. Returns an array with each element containing the leftmost-first match of the @@ -1483,9 +1521,7 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(\\d+)") - ... ).alias("m") + ... dfn.functions.regexp_match(dfn.col("a"), "(\\d+)").alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['42'] @@ -1494,20 +1530,26 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(HELLO)"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "(HELLO)", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['hello'] """ + if not isinstance(regex, Expr): + regex = Expr.literal(regex) if flags is not None: + if not isinstance(flags, Expr): + flags = Expr.literal(flags) flags = flags.expr return Expr(f.regexp_match(string.expr, regex.expr, flags)) def regexp_replace( - string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + replacement: Expr | str, + flags: Expr | str | None = None, ) -> Expr: r"""Replaces substring(s) matching a PCRE-like regular expression. @@ -1522,8 +1564,7 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["hello 42"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("XX") + ... dfn.col("a"), "\\d+", "XX" ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() @@ -1534,20 +1575,28 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["a1 b2 c3"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("X"), flags=dfn.lit("g"), + ... dfn.col("a"), "\\d+", "X", flags="g", ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() 'aX bX cX' """ + if not isinstance(pattern, Expr): + pattern = Expr.literal(pattern) + if not isinstance(replacement, Expr): + replacement = Expr.literal(replacement) if flags is not None: + if not isinstance(flags, Expr): + flags = Expr.literal(flags) flags = flags.expr return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) def regexp_count( - string: Expr, pattern: Expr, start: Expr | None = None, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, ) -> Expr: """Returns the number of matches in a string. @@ -1558,9 +1607,7 @@ def regexp_count( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcabc"]}) >>> result = df.select( - ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("abc") - ... ).alias("c")) + ... dfn.functions.regexp_count(dfn.col("a"), "abc").alias("c")) >>> result.collect_column("c")[0].as_py() 2 @@ -1569,25 +1616,31 @@ def regexp_count( >>> result = df.select( ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("ABC"), - ... start=dfn.lit(4), flags=dfn.lit("i"), + ... dfn.col("a"), "ABC", start=4, flags="i", ... ).alias("c")) >>> result.collect_column("c")[0].as_py() 1 """ + if not isinstance(pattern, Expr): + pattern = Expr.literal(pattern) + if start is not None: + if not isinstance(start, Expr): + start = Expr.literal(start) + start = start.expr if flags is not None: + if not isinstance(flags, Expr): + flags = Expr.literal(flags) flags = flags.expr - start = start.expr if start is not None else start return Expr(f.regexp_count(string.expr, pattern.expr, start, flags)) def regexp_instr( values: Expr, - regex: Expr, - start: Expr | None = None, - n: Expr | None = None, - flags: Expr | None = None, - sub_expr: Expr | None = None, + regex: Expr | str, + start: Expr | int | None = None, + n: Expr | int | None = None, + flags: Expr | str | None = None, + sub_expr: Expr | int | None = None, ) -> Expr: r"""Returns the position of a regular expression match in a string. @@ -1603,9 +1656,7 @@ def regexp_instr( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("pos") + ... dfn.functions.regexp_instr(dfn.col("a"), "\\d+").alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 7 @@ -1616,9 +1667,8 @@ def regexp_instr( >>> df = ctx.from_pydict({"a": ["abc ABC abc"]}) >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("abc"), - ... start=dfn.lit(2), n=dfn.lit(1), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "abc", + ... start=2, n=1, flags="i", ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() @@ -1628,56 +1678,65 @@ def regexp_instr( >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("(abc)"), - ... sub_expr=dfn.lit(1), + ... dfn.col("a"), "(abc)", sub_expr=1, ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 1 """ - start = start.expr if start is not None else None - n = n.expr if n is not None else None - flags = flags.expr if flags is not None else None - sub_expr = sub_expr.expr if sub_expr is not None else None + if not isinstance(regex, Expr): + regex = Expr.literal(regex) + + def _to_raw(val: Any) -> Any: + if val is None: + return None + if not isinstance(val, Expr): + val = Expr.literal(val) + return val.expr return Expr( f.regexp_instr( values.expr, regex.expr, - start, - n, - flags, - sub_expr, + _to_raw(start), + _to_raw(n), + _to_raw(flags), + _to_raw(sub_expr), ) ) -def repeat(string: Expr, n: Expr) -> Expr: +def repeat(string: Expr, n: Expr | int) -> Expr: """Repeats the ``string`` to ``n`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["ha"]}) >>> result = df.select( - ... dfn.functions.repeat(dfn.col("a"), dfn.lit(3)).alias("r")) + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'hahaha' """ + if not isinstance(n, Expr): + n = Expr.literal(n) return Expr(f.repeat(string.expr, n.expr)) -def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def replace(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces all occurrences of ``from_val`` with ``to_val`` in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.replace(dfn.col("a"), dfn.lit("world"), - ... dfn.lit("there")).alias("r")) + ... dfn.functions.replace(dfn.col("a"), "world", "there").alias("r")) >>> result.collect_column("r")[0].as_py() 'hello there' """ + if not isinstance(from_val, Expr): + from_val = Expr.literal(from_val) + if not isinstance(to_val, Expr): + to_val = Expr.literal(to_val) return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) @@ -1694,39 +1753,43 @@ def reverse(arg: Expr) -> Expr: return Expr(f.reverse(arg.expr)) -def right(string: Expr, n: Expr) -> Expr: +def right(string: Expr, n: Expr | int) -> Expr: """Returns the last ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) - >>> result = df.select(dfn.functions.right(dfn.col("a"), dfn.lit(3)).alias("r")) + >>> result = df.select(dfn.functions.right(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'llo' """ + if not isinstance(n, Expr): + n = Expr.literal(n) return Expr(f.right(string.expr, n.expr)) -def round(value: Expr, decimal_places: Expr | None = None) -> Expr: +def round(value: Expr, decimal_places: Expr | int | None = None) -> Expr: """Round the argument to the nearest integer. If the optional ``decimal_places`` is specified, round to the nearest number of decimal places. You can specify a negative number of decimal places. For example - ``round(lit(125.2345), lit(-2))`` would yield a value of ``100.0``. + ``round(lit(125.2345), -2)`` would yield a value of ``100.0``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) - >>> result = df.select(dfn.functions.round(dfn.col("a"), dfn.lit(2)).alias("r")) + >>> result = df.select(dfn.functions.round(dfn.col("a"), 2).alias("r")) >>> result.collect_column("r")[0].as_py() 1.57 """ if decimal_places is None: decimal_places = Expr.literal(0) + elif not isinstance(decimal_places, Expr): + decimal_places = Expr.literal(decimal_places) return Expr(f.round(value.expr, decimal_places.expr)) -def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def rpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add right padding to a string. Extends the string to length length by appending the characters fill (a space @@ -1736,11 +1799,16 @@ def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hi"]}) >>> result = df.select( - ... dfn.functions.rpad(dfn.col("a"), dfn.lit(5), dfn.lit("!")).alias("r")) + ... dfn.functions.rpad(dfn.col("a"), 5, "!").alias("r")) >>> result.collect_column("r")[0].as_py() 'hi!!!' """ - characters = characters if characters is not None else Expr.literal(" ") + if not isinstance(count, Expr): + count = Expr.literal(count) + if characters is None: + characters = Expr.literal(" ") + elif not isinstance(characters, Expr): + characters = Expr.literal(characters) return Expr(f.rpad(string.expr, count.expr, characters.expr)) @@ -1856,7 +1924,7 @@ def sinh(arg: Expr) -> Expr: return Expr(f.sinh(arg.expr)) -def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: +def split_part(string: Expr, delimiter: Expr | str, index: Expr | int) -> Expr: """Split a string and return one part. Splits a string based on a delimiter and picks out the desired field based @@ -1866,12 +1934,14 @@ def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a,b,c"]}) >>> result = df.select( - ... dfn.functions.split_part( - ... dfn.col("a"), dfn.lit(","), dfn.lit(2) - ... ).alias("s")) + ... dfn.functions.split_part(dfn.col("a"), ",", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'b' """ + if not isinstance(delimiter, Expr): + delimiter = Expr.literal(delimiter) + if not isinstance(index, Expr): + index = Expr.literal(index) return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) @@ -1888,49 +1958,55 @@ def sqrt(arg: Expr) -> Expr: return Expr(f.sqrt(arg.expr)) -def starts_with(string: Expr, prefix: Expr) -> Expr: +def starts_with(string: Expr, prefix: Expr | str) -> Expr: """Returns true if string starts with prefix. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello_from_datafusion"]}) >>> result = df.select( - ... dfn.functions.starts_with(dfn.col("a"), dfn.lit("hello")).alias("sw")) + ... dfn.functions.starts_with(dfn.col("a"), "hello").alias("sw")) >>> result.collect_column("sw")[0].as_py() True """ + if not isinstance(prefix, Expr): + prefix = Expr.literal(prefix) return Expr(f.starts_with(string.expr, prefix.expr)) -def strpos(string: Expr, substring: Expr) -> Expr: +def strpos(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.strpos(dfn.col("a"), dfn.lit("llo")).alias("pos")) + ... dfn.functions.strpos(dfn.col("a"), "llo").alias("pos")) >>> result.collect_column("pos")[0].as_py() 3 """ + if not isinstance(substring, Expr): + substring = Expr.literal(substring) return Expr(f.strpos(string.expr, substring.expr)) -def substr(string: Expr, position: Expr) -> Expr: +def substr(string: Expr, position: Expr | int) -> Expr: """Substring from the ``position`` to the end. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.substr(dfn.col("a"), dfn.lit(3)).alias("s")) + ... dfn.functions.substr(dfn.col("a"), 3).alias("s")) >>> result.collect_column("s")[0].as_py() 'llo' """ + if not isinstance(position, Expr): + position = Expr.literal(position) return Expr(f.substr(string.expr, position.expr)) -def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: +def substr_index(string: Expr, delimiter: Expr | str, count: Expr | int) -> Expr: """Returns an indexed substring. The return will be the ``string`` from before ``count`` occurrences of @@ -1940,27 +2016,32 @@ def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a.b.c"]}) >>> result = df.select( - ... dfn.functions.substr_index(dfn.col("a"), dfn.lit("."), - ... dfn.lit(2)).alias("s")) + ... dfn.functions.substr_index(dfn.col("a"), ".", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'a.b' """ + if not isinstance(delimiter, Expr): + delimiter = Expr.literal(delimiter) + if not isinstance(count, Expr): + count = Expr.literal(count) return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) -def substring(string: Expr, position: Expr, length: Expr) -> Expr: +def substring(string: Expr, position: Expr | int, length: Expr | int) -> Expr: """Substring from the ``position`` with ``length`` characters. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.substring( - ... dfn.col("a"), dfn.lit(1), dfn.lit(5) - ... ).alias("s")) + ... dfn.functions.substring(dfn.col("a"), 1, 5).alias("s")) >>> result.collect_column("s")[0].as_py() 'hello' """ + if not isinstance(position, Expr): + position = Expr.literal(position) + if not isinstance(length, Expr): + length = Expr.literal(length) return Expr(f.substring(string.expr, position.expr, length.expr)) @@ -2033,7 +2114,7 @@ def current_timestamp() -> Expr: return now() -def to_char(arg: Expr, formatter: Expr) -> Expr: +def to_char(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. For usage of ``formatter`` see the rust chrono package ``strftime`` package. @@ -2046,16 +2127,18 @@ def to_char(arg: Expr, formatter: Expr) -> Expr: >>> result = df.select( ... dfn.functions.to_char( ... dfn.functions.to_timestamp(dfn.col("a")), - ... dfn.lit("%Y/%m/%d"), + ... "%Y/%m/%d", ... ).alias("formatted") ... ) >>> result.collect_column("formatted")[0].as_py() '2021/01/01' """ + if not isinstance(formatter, Expr): + formatter = Expr.literal(formatter) return Expr(f.to_char(arg.expr, formatter.expr)) -def date_format(arg: Expr, formatter: Expr) -> Expr: +def date_format(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. See Also: @@ -2267,7 +2350,7 @@ def current_time() -> Expr: return Expr(f.current_time()) -def datepart(part: Expr, date: Expr) -> Expr: +def datepart(part: Expr | str, date: Expr) -> Expr: """Return a specified part of a date. See Also: @@ -2276,22 +2359,29 @@ def datepart(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_part(part: Expr, date: Expr) -> Expr: +def date_part(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. + Args: + part: The part of the date to extract. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to extract from. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_part(dfn.lit("year"), dfn.col("a")).alias("y")) + ... dfn.functions.date_part("year", dfn.col("a")).alias("y")) >>> result.collect_column("y")[0].as_py() 2021 """ + if not isinstance(part, Expr): + part = Expr.literal(part) return Expr(f.date_part(part.expr, date.expr)) -def extract(part: Expr, date: Expr) -> Expr: +def extract(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. See Also: @@ -2300,25 +2390,30 @@ def extract(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_trunc(part: Expr, date: Expr) -> Expr: +def date_trunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. + Args: + part: The precision to truncate to. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to truncate. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_trunc( - ... dfn.lit("month"), dfn.col("a") - ... ).alias("t") + ... dfn.functions.date_trunc("month", dfn.col("a")).alias("t") ... ) >>> str(result.collect_column("t")[0].as_py()) '2021-07-01 00:00:00' """ + if not isinstance(part, Expr): + part = Expr.literal(part) return Expr(f.date_trunc(part.expr, date.expr)) -def datetrunc(part: Expr, date: Expr) -> Expr: +def datetrunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. See Also: @@ -2379,18 +2474,21 @@ def make_time(hour: Expr, minute: Expr, second: Expr) -> Expr: return Expr(f.make_time(hour.expr, minute.expr, second.expr)) -def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def translate(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces the characters in ``from_val`` with the counterpart in ``to_val``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.translate(dfn.col("a"), dfn.lit("helo"), - ... dfn.lit("HELO")).alias("t")) + ... dfn.functions.translate(dfn.col("a"), "helo", "HELO").alias("t")) >>> result.collect_column("t")[0].as_py() 'HELLO' """ + if not isinstance(from_val, Expr): + from_val = Expr.literal(from_val) + if not isinstance(to_val, Expr): + to_val = Expr.literal(to_val) return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) @@ -2407,27 +2505,25 @@ def trim(arg: Expr) -> Expr: return Expr(f.trim(arg.expr)) -def trunc(num: Expr, precision: Expr | None = None) -> Expr: +def trunc(num: Expr, precision: Expr | int | None = None) -> Expr: """Truncate the number toward zero with optional precision. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a") - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a")).alias("t")) >>> result.collect_column("t")[0].as_py() 1.0 >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a"), precision=dfn.lit(2) - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a"), precision=2).alias("t")) >>> result.collect_column("t")[0].as_py() 1.56 """ if precision is not None: + if not isinstance(precision, Expr): + precision = Expr.literal(precision) return Expr(f.trunc(num.expr, precision.expr)) return Expr(f.trunc(num.expr)) @@ -2908,17 +3004,19 @@ def list_dims(array: Expr) -> Expr: return array_dims(array) -def array_element(array: Expr, n: Expr) -> Expr: +def array_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[10, 20, 30]]}) >>> result = df.select( - ... dfn.functions.array_element(dfn.col("a"), dfn.lit(2)).alias("result")) + ... dfn.functions.array_element(dfn.col("a"), 2).alias("result")) >>> result.collect_column("result")[0].as_py() 20 """ + if not isinstance(n, Expr): + n = Expr.literal(n) return Expr(f.array_element(array.expr, n.expr)) @@ -2944,7 +3042,7 @@ def list_empty(array: Expr) -> Expr: return array_empty(array) -def array_extract(array: Expr, n: Expr) -> Expr: +def array_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2953,7 +3051,7 @@ def array_extract(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_element(array: Expr, n: Expr) -> Expr: +def list_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2962,7 +3060,7 @@ def list_element(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_extract(array: Expr, n: Expr) -> Expr: +def list_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -3312,22 +3410,25 @@ def list_remove(array: Expr, element: Expr) -> Expr: return array_remove(array, element) -def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def array_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_remove_n(dfn.col("a"), dfn.lit(1), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_remove_n( + ... dfn.col("a"), dfn.lit(1), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 1] """ + if not isinstance(max, Expr): + max = Expr.literal(max) return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) -def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def list_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. See Also: @@ -3361,21 +3462,23 @@ def list_remove_all(array: Expr, element: Expr) -> Expr: return array_remove_all(array, element) -def array_repeat(element: Expr, count: Expr) -> Expr: +def array_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) >>> result = df.select( - ... dfn.functions.array_repeat(dfn.lit(3), dfn.lit(3)).alias("result")) + ... dfn.functions.array_repeat(dfn.lit(3), 3).alias("result")) >>> result.collect_column("result")[0].as_py() [3, 3, 3] """ + if not isinstance(count, Expr): + count = Expr.literal(count) return Expr(f.array_repeat(element.expr, count.expr)) -def list_repeat(element: Expr, count: Expr) -> Expr: +def list_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. See Also: @@ -3408,7 +3511,7 @@ def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: return array_replace(array, from_val, to_val) -def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3418,15 +3521,18 @@ def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Exp >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_replace_n(dfn.col("a"), dfn.lit(1), dfn.lit(9), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_replace_n( + ... dfn.col("a"), dfn.lit(1), dfn.lit(9), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [9, 2, 9, 1] """ + if not isinstance(max, Expr): + max = Expr.literal(max) return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) -def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3509,7 +3615,10 @@ def list_sort(array: Expr, descending: bool = False, null_first: bool = False) - def array_slice( - array: Expr, begin: Expr, end: Expr, stride: Expr | None = None + array: Expr, + begin: Expr | int, + end: Expr | int, + stride: Expr | int | None = None, ) -> Expr: """Returns a slice of the array. @@ -3517,9 +3626,7 @@ def array_slice( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) >>> result = df.select( - ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(2), dfn.lit(3) - ... ).alias("result")) + ... dfn.functions.array_slice(dfn.col("a"), 2, 3).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 3] @@ -3527,18 +3634,25 @@ def array_slice( >>> result = df.select( ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(1), dfn.lit(4), - ... stride=dfn.lit(2), + ... dfn.col("a"), 1, 4, stride=2, ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 3] """ + if not isinstance(begin, Expr): + begin = Expr.literal(begin) + if not isinstance(end, Expr): + end = Expr.literal(end) if stride is not None: + if not isinstance(stride, Expr): + stride = Expr.literal(stride) stride = stride.expr return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) -def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: +def list_slice( + array: Expr, begin: Expr | int, end: Expr | int, stride: Expr | int | None = None +) -> Expr: """Returns a slice of the array. See Also: @@ -3630,7 +3744,7 @@ def list_except(array1: Expr, array2: Expr) -> Expr: return array_except(array1, array2) -def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def array_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will @@ -3640,15 +3754,16 @@ def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2]]}) >>> result = df.select( - ... dfn.functions.array_resize(dfn.col("a"), dfn.lit(4), - ... dfn.lit(0)).alias("result")) + ... dfn.functions.array_resize(dfn.col("a"), 4, dfn.lit(0)).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 2, 0, 0] """ + if not isinstance(size, Expr): + size = Expr.literal(size) return Expr(f.array_resize(array.expr, size.expr, value.expr)) -def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def list_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will be @@ -3802,7 +3917,7 @@ def list_zip(*arrays: Expr) -> Expr: def string_to_array( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. @@ -3812,9 +3927,7 @@ def string_to_array( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello,world"]}) >>> result = df.select( - ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), - ... ).alias("result")) + ... dfn.functions.string_to_array(dfn.col("a"), ",").alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', 'world'] @@ -3822,17 +3935,24 @@ def string_to_array( >>> result = df.select( ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), null_string=dfn.lit("world"), + ... dfn.col("a"), ",", null_string="world", ... ).alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', None] """ - null_expr = null_string.expr if null_string is not None else None + if not isinstance(delimiter, Expr): + delimiter = Expr.literal(delimiter) + if null_string is not None: + if not isinstance(null_string, Expr): + null_string = Expr.literal(null_string) + null_expr = null_string.expr + else: + null_expr = None return Expr(f.string_to_array(string.expr, delimiter.expr, null_expr)) def string_to_list( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. From 3e50d86586a7e52c9f4f7247d63f727fb06af154 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 9 Apr 2026 11:49:20 -0400 Subject: [PATCH 02/11] fix: update alias function signatures to match pythonic primary functions Update instr and position (aliases of strpos) to accept Expr | str for the substring parameter, matching the updated primary function signature. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index feedfa7ec..a9fbbf29d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1096,7 +1096,7 @@ def initcap(string: Expr) -> Expr: return Expr(f.initcap(string.expr)) -def instr(string: Expr, substring: Expr) -> Expr: +def instr(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1422,7 +1422,7 @@ def pi() -> Expr: return Expr(f.pi()) -def position(string: Expr, substring: Expr) -> Expr: +def position(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: From 25e5ce34dedd932a0320eb0bd8b575a7dfa5c3c1 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 9 Apr 2026 11:50:19 -0400 Subject: [PATCH 03/11] docs: update make-pythonic skill to require alias type hint updates Alias functions that delegate to a primary function must have their type hints updated to match, even though coercion logic is only added to the primary. Added a new Step 3 to the implementation workflow for this. Co-Authored-By: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 87cc8751c..c65d4e2df 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -230,7 +230,11 @@ For each function being updated: 3. **Update the docstring** examples to use the simpler calling convention 4. **Preserve backward compatibility** — existing code using `Expr` must still work -### Step 3: Update Docstring Examples +### Step 3: Update Alias Type Hints + +After updating a primary function, find all alias functions that delegate to it (e.g., `instr` and `position` delegate to `strpos`). Update each alias's **parameter type hints** to match the primary function's new signature. Do not add coercion logic to aliases — the primary function handles that. + +### Step 4: Update Docstring Examples (primary functions only) Per the project's CLAUDE.md rules: - Every function must have doctest-style examples @@ -247,7 +251,7 @@ dfn.functions.left(dfn.col("a"), dfn.lit(3)) dfn.functions.left(dfn.col("a"), 3) ``` -### Step 4: Run Tests +### Step 5: Run Tests After making changes, run the doctests to verify: ```bash @@ -270,7 +274,7 @@ Do NOT create a new helper function for this — the inline check is clear and e - **Do not change arguments that represent data columns.** If an argument is the primary data being operated on (e.g., the `string` in `left(string, n)` or the `array` in `array_sort(array)`), it should remain `Expr` only. Users should use `col()` for column references. - **Do not change variadic `*args: Expr` parameters.** These represent multiple expressions and should stay as `Expr`. - **Do not change arguments where the coercion is ambiguous.** If it is unclear whether a string should be a column name or a literal, leave it as `Expr` and let the user be explicit. -- **Do not change functions that are simple aliases.** If a function is just `return other_function(...)`, update the primary function only. +- **Do not add coercion logic to simple aliases.** If a function is just `return other_function(...)`, the primary function handles coercion. However, you **must update the alias's type hints** to match the primary function's signature so that type checkers and documentation accurately reflect what the alias accepts. - **Do not change the Rust bindings.** All coercion happens in the Python layer. The Rust functions continue to accept `PyExpr`. ## Priority Order From 7515ebc029b6e5ae860f22dc0ef1e04bebbe99dc Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 9 Apr 2026 12:40:35 -0400 Subject: [PATCH 04/11] fix: address review feedback on pythonic skill and function signatures Update SKILL.md to prevent three classes of issues: clarify that float already accepts int per PEP 484 (avoiding redundant int | float that fails ruff PYI041), add backward-compat rule for Category B so existing Expr params aren't removed, and add guidance for inline coercion with many optional nullable params instead of local helpers. Replace regexp_instr's _to_raw() helper with inline coercion matching the pattern used throughout the rest of the file. Co-Authored-By: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 48 ++++++++++++++++++++----------- python/datafusion/functions.py | 23 ++++++++------- 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index c65d4e2df..4eefbdb3b 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -104,6 +104,8 @@ Common mappings: | `logical_string()` / `TypeSignatureClass::String` | `str` | | `LogicalType::Boolean` | `bool` | +**Important:** In Python's type system (PEP 484), `float` already accepts `int` values, so `int | float` is redundant and will fail the `ruff` linter (rule PYI041). Use `float` alone when the Rust side accepts a float/numeric type — Python users can still pass integer literals like `log(10, col("a"))` or `power(col("a"), 3)` without issue. Only use `int` when the Rust side strictly requires an integer (e.g., `logical_int64()`). + #### Technique 3: Check `return_field_from_args()` for `scalar_arguments` usage Functions that inspect literal values at query planning time use `args.scalar_arguments.get(n)` in their `return_field_from_args()` method. This indicates the argument is **expected to be a literal** for optimal behavior (e.g., to determine output type precision), but may still work as a column. @@ -166,31 +168,30 @@ These are arguments where an `Expr` never makes sense because the value must be **Type hint pattern:** `str`, `int`, `list[str]`, etc. (no `Expr` in the union) -**When to use:** When the argument is from a fixed enumeration or is always a compile-time constant: -- Date/time part names (`"year"`, `"month"`, `"day"`, etc.) -- Regex flags (`"g"`, `"i"`, etc.) +**When to use:** When the argument is from a fixed enumeration or is always a compile-time constant, **AND** the parameter was not previously typed as `Expr`: +- Separator in `concat_ws` (already typed as `str` in the Rust binding) +- Index in `array_position` (already typed as `int` in the Rust binding) - Values that the Rust implementation already accepts as native types +**Backward compatibility rule:** If a parameter was previously typed as `Expr`, you **must** keep `Expr` in the union even if the Rust side requires a literal. Removing `Expr` would break existing user code like `date_part(lit("year"), col("a"))`. Use **Category A** instead — accept `Expr | str` — and let users who pass column expressions discover the runtime error from the Rust side. Never silently break backward compatibility. + ```python -def date_part(part: str, date: Expr) -> Expr: - """Extracts a subfield from the date. +def concat_ws(separator: str, *args: Expr) -> Expr: + """Concatenates the list ``args`` with the separator. - Args: - part: The part of the date to extract. Must be one of "year", "month", - "day", "hour", "minute", "second", etc. - date: The date expression to extract from. + ``separator`` is already typed as ``str`` in the Rust binding, so + there is no backward-compatibility concern. Examples: >>> ctx = dfn.SessionContext() - >>> df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) - >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) + >>> df = ctx.from_pydict({"a": ["hello"], "b": ["world"]}) >>> result = df.select( - ... dfn.functions.date_part("year", dfn.col("a")).alias("y")) - >>> result.collect_column("y")[0].as_py() - 2021 + ... dfn.functions.concat_ws("-", dfn.col("a"), dfn.col("b")).alias("c")) + >>> result.collect_column("c")[0].as_py() + 'hello-world' """ - part = Expr.literal(part) - return Expr(f.date_part(part.expr, date.expr)) + args = [arg.expr for arg in args] + return Expr(f.concat_ws(separator, args)) ``` ### Category C: Arguments That Should Accept str as Column Name @@ -269,6 +270,21 @@ if not isinstance(arg, Expr): Do NOT create a new helper function for this — the inline check is clear and explicit. The project intentionally uses `ensure_expr()` to reject non-Expr values in contexts where coercion is not wanted; the pythonic coercion is the opposite pattern and should be visually distinct. +**Functions with many optional nullable parameters:** For parameters typed as `Expr | | None`, combine the `None` check with the `isinstance` check inline. Repeat this for each parameter — do not factor it into a local helper function, even if the repetition feels verbose. Consistency across the file is more important than DRY within a single function. + +```python +# For each optional parameter: +if start is not None and not isinstance(start, Expr): + start = Expr.literal(start) + +# When passing to the Rust binding, extract .expr or pass None: +f.some_func( + values.expr, + start.expr if start is not None else None, + n.expr if n is not None else None, +) +``` + ## What NOT to Change - **Do not change arguments that represent data columns.** If an argument is the primary data being operated on (e.g., the `string` in `left(string, n)` or the `array` in `array_sort(array)`), it should remain `Expr` only. Users should use `col()` for column references. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index a9fbbf29d..d2fcf8fa7 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1686,22 +1686,23 @@ def regexp_instr( """ if not isinstance(regex, Expr): regex = Expr.literal(regex) - - def _to_raw(val: Any) -> Any: - if val is None: - return None - if not isinstance(val, Expr): - val = Expr.literal(val) - return val.expr + if start is not None and not isinstance(start, Expr): + start = Expr.literal(start) + if n is not None and not isinstance(n, Expr): + n = Expr.literal(n) + if flags is not None and not isinstance(flags, Expr): + flags = Expr.literal(flags) + if sub_expr is not None and not isinstance(sub_expr, Expr): + sub_expr = Expr.literal(sub_expr) return Expr( f.regexp_instr( values.expr, regex.expr, - _to_raw(start), - _to_raw(n), - _to_raw(flags), - _to_raw(sub_expr), + start.expr if start is not None else None, + n.expr if n is not None else None, + flags.expr if flags is not None else None, + sub_expr.expr if sub_expr is not None else None, ) ) From 83ebb72b5adef880a0b02b61a7c4c1f17ef2208e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 07:14:01 -0400 Subject: [PATCH 05/11] refactor: add coerce_to_expr helpers and replace inline coercion patterns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce coerce_to_expr() and coerce_to_expr_or_none() in expr.py as the complement to ensure_expr() — where ensure_expr rejects non-Expr values, these helpers wrap them via Expr.literal(). Replaces ~60 inline isinstance checks in functions.py with single-line helper calls, and updates the make-pythonic skill to document the new pattern. Co-Authored-By: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 48 ++++-- python/datafusion/expr.py | 41 +++++ python/datafusion/functions.py | 247 +++++++++++------------------- 3 files changed, 167 insertions(+), 169 deletions(-) diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 4eefbdb3b..93fa9c2c9 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -261,28 +261,46 @@ python -m pytest --doctest-modules python/datafusion/functions.py -v ## Coercion Helper Pattern -When adding coercion to a function, use this inline pattern: +Use the coercion helpers from `datafusion.expr` to convert native Python values to `Expr`. These are the complement of `ensure_expr()` — where `ensure_expr` *rejects* non-`Expr` values, the coercion helpers *wrap* them via `Expr.literal()`. + +**For required parameters** use `coerce_to_expr`: ```python -if not isinstance(arg, Expr): - arg = Expr.literal(arg) +from datafusion.expr import coerce_to_expr + +def left(string: Expr, n: Expr | int) -> Expr: + n = coerce_to_expr(n) + return Expr(f.left(string.expr, n.expr)) ``` -Do NOT create a new helper function for this — the inline check is clear and explicit. The project intentionally uses `ensure_expr()` to reject non-Expr values in contexts where coercion is not wanted; the pythonic coercion is the opposite pattern and should be visually distinct. +**For optional nullable parameters** use `coerce_to_expr_or_none`: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none + +def regexp_count( + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, +) -> Expr: + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) +``` -**Functions with many optional nullable parameters:** For parameters typed as `Expr | | None`, combine the `None` check with the `isinstance` check inline. Repeat this for each parameter — do not factor it into a local helper function, even if the repetition feels verbose. Consistency across the file is more important than DRY within a single function. +Both helpers are defined in `python/datafusion/expr.py` alongside `ensure_expr`. Import them in `functions.py` via: ```python -# For each optional parameter: -if start is not None and not isinstance(start, Expr): - start = Expr.literal(start) - -# When passing to the Rust binding, extract .expr or pass None: -f.some_func( - values.expr, - start.expr if start is not None else None, - n.expr if n is not None else None, -) +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none ``` ## What NOT to Change diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 7cd74ecd5..e26bb85d8 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -221,6 +221,8 @@ "WindowExpr", "WindowFrame", "WindowFrameBound", + "coerce_to_expr", + "coerce_to_expr_or_none", "ensure_expr", "ensure_expr_list", ] @@ -233,6 +235,10 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr: higher level APIs consistently require explicit :func:`~datafusion.col` or :func:`~datafusion.lit` expressions. + See Also: + :func:`coerce_to_expr` — the opposite behavior: *wraps* non-``Expr`` + values as literals instead of rejecting them. + Args: value: Candidate expression or other object. @@ -277,6 +283,41 @@ def _iter( return list(_iter(exprs)) +def coerce_to_expr(value: Any) -> Expr: + """Coerce a native Python value to an ``Expr`` literal, passing ``Expr`` through. + + This is the complement of :func:`ensure_expr`: where ``ensure_expr`` + *rejects* non-``Expr`` values, ``coerce_to_expr`` *wraps* them via + :meth:`Expr.literal` so that functions can accept native Python types + (``int``, ``float``, ``str``, ``bool``, etc.) alongside ``Expr``. + + Args: + value: An ``Expr`` instance (returned as-is) or a Python literal to wrap. + + Returns: + An ``Expr`` representing the value. + """ + if isinstance(value, Expr): + return value + return Expr.literal(value) + + +def coerce_to_expr_or_none(value: Any | None) -> Expr | None: + """Coerce a value to ``Expr`` or pass ``None`` through unchanged. + + Same as :func:`coerce_to_expr` but accepts ``None`` for optional parameters. + + Args: + value: An ``Expr`` instance, a Python literal to wrap, or ``None``. + + Returns: + An ``Expr`` representing the value, or ``None``. + """ + if value is None: + return None + return coerce_to_expr(value) + + def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index d2fcf8fa7..91dc16b8a 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -29,6 +29,8 @@ Expr, SortExpr, SortKey, + coerce_to_expr, + coerce_to_expr_or_none, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, sort_or_default, @@ -374,8 +376,7 @@ def encode(expr: Expr, encoding: Expr | str) -> Expr: >>> result.collect_column("enc")[0].as_py() 'aGVsbG8' """ - if not isinstance(encoding, Expr): - encoding = Expr.literal(encoding) + encoding = coerce_to_expr(encoding) return Expr(f.encode(expr.expr, encoding.expr)) @@ -390,8 +391,7 @@ def decode(expr: Expr, encoding: Expr | str) -> Expr: >>> result.collect_column("dec")[0].as_py() b'hello' """ - if not isinstance(encoding, Expr): - encoding = Expr.literal(encoding) + encoding = coerce_to_expr(encoding) return Expr(f.decode(expr.expr, encoding.expr)) @@ -406,8 +406,7 @@ def array_to_string(expr: Expr, delimiter: Expr | str) -> Expr: >>> result.collect_column("s")[0].as_py() '1,2,3' """ - if not isinstance(delimiter, Expr): - delimiter = Expr.literal(delimiter) + delimiter = coerce_to_expr(delimiter) return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string()))) @@ -479,8 +478,7 @@ def digest(value: Expr, method: Expr | str) -> Expr: >>> len(result.collect_column("d")[0].as_py()) > 0 True """ - if not isinstance(method, Expr): - method = Expr.literal(method) + method = coerce_to_expr(method) return Expr(f.digest(value.expr, method.expr)) @@ -495,8 +493,7 @@ def contains(string: Expr, search_str: Expr | str) -> Expr: >>> result.collect_column("c")[0].as_py() True """ - if not isinstance(search_str, Expr): - search_str = Expr.literal(search_str) + search_str = coerce_to_expr(search_str) return Expr(f.contains(string.expr, search_str.expr)) @@ -970,8 +967,7 @@ def ends_with(arg: Expr, suffix: Expr | str) -> Expr: >>> ends_with_df.collect_column("ends_with")[0].as_py() True """ - if not isinstance(suffix, Expr): - suffix = Expr.literal(suffix) + suffix = coerce_to_expr(suffix) return Expr(f.ends_with(arg.expr, suffix.expr)) @@ -1019,8 +1015,7 @@ def find_in_set(string: Expr, string_list: Expr | str) -> Expr: >>> result.collect_column("pos")[0].as_py() 2 """ - if not isinstance(string_list, Expr): - string_list = Expr.literal(string_list) + string_list = coerce_to_expr(string_list) return Expr(f.find_in_set(string.expr, string_list.expr)) @@ -1163,8 +1158,7 @@ def left(string: Expr, n: Expr | int) -> Expr: >>> left_df.collect_column("left")[0].as_py() 'the' """ - if not isinstance(n, Expr): - n = Expr.literal(n) + n = coerce_to_expr(n) return Expr(f.left(string.expr, n.expr)) @@ -1179,8 +1173,7 @@ def levenshtein(string1: Expr, string2: Expr | str) -> Expr: >>> result.collect_column("d")[0].as_py() 3 """ - if not isinstance(string2, Expr): - string2 = Expr.literal(string2) + string2 = coerce_to_expr(string2) return Expr(f.levenshtein(string1.expr, string2.expr)) @@ -1209,8 +1202,7 @@ def log(base: Expr | float, num: Expr) -> Expr: >>> result.collect_column("log")[0].as_py() 2.0 """ - if not isinstance(base, Expr): - base = Expr.literal(base) + base = coerce_to_expr(base) return Expr(f.log(base.expr, num.expr)) @@ -1277,12 +1269,8 @@ def lpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) >>> result.collect_column("lpad")[0].as_py() '...the cat' """ - if not isinstance(count, Expr): - count = Expr.literal(count) - if characters is None: - characters = Expr.literal(" ") - elif not isinstance(characters, Expr): - characters = Expr.literal(characters) + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.lpad(string.expr, count.expr, characters.expr)) @@ -1395,14 +1383,11 @@ def overlay( >>> result.collect_column("o")[0].as_py() 'abXYef' """ - if not isinstance(substring, Expr): - substring = Expr.literal(substring) - if not isinstance(start, Expr): - start = Expr.literal(start) + substring = coerce_to_expr(substring) + start = coerce_to_expr(start) if length is None: return Expr(f.overlay(string.expr, substring.expr, start.expr)) - if not isinstance(length, Expr): - length = Expr.literal(length) + length = coerce_to_expr(length) return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) @@ -1443,8 +1428,7 @@ def power(base: Expr, exponent: Expr | float) -> Expr: >>> result.collect_column("pow")[0].as_py() 8.0 """ - if not isinstance(exponent, Expr): - exponent = Expr.literal(exponent) + exponent = coerce_to_expr(exponent) return Expr(f.power(base.expr, exponent.expr)) @@ -1500,13 +1484,9 @@ def regexp_like( >>> result.collect_column("m")[0].as_py() True """ - if not isinstance(regex, Expr): - regex = Expr.literal(regex) - if flags is not None: - if not isinstance(flags, Expr): - flags = Expr.literal(flags) - flags = flags.expr - return Expr(f.regexp_like(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr(f.regexp_like(string.expr, regex.expr, flags.expr if flags else None)) def regexp_match( @@ -1536,13 +1516,9 @@ def regexp_match( >>> result.collect_column("m")[0].as_py() ['hello'] """ - if not isinstance(regex, Expr): - regex = Expr.literal(regex) - if flags is not None: - if not isinstance(flags, Expr): - flags = Expr.literal(flags) - flags = flags.expr - return Expr(f.regexp_match(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr(f.regexp_match(string.expr, regex.expr, flags.expr if flags else None)) def regexp_replace( @@ -1581,15 +1557,17 @@ def regexp_replace( >>> result.collect_column("r")[0].as_py() 'aX bX cX' """ - if not isinstance(pattern, Expr): - pattern = Expr.literal(pattern) - if not isinstance(replacement, Expr): - replacement = Expr.literal(replacement) - if flags is not None: - if not isinstance(flags, Expr): - flags = Expr.literal(flags) - flags = flags.expr - return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + pattern = coerce_to_expr(pattern) + replacement = coerce_to_expr(replacement) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_replace( + string.expr, + pattern.expr, + replacement.expr, + flags.expr if flags else None, + ) + ) def regexp_count( @@ -1621,17 +1599,17 @@ def regexp_count( >>> result.collect_column("c")[0].as_py() 1 """ - if not isinstance(pattern, Expr): - pattern = Expr.literal(pattern) - if start is not None: - if not isinstance(start, Expr): - start = Expr.literal(start) - start = start.expr - if flags is not None: - if not isinstance(flags, Expr): - flags = Expr.literal(flags) - flags = flags.expr - return Expr(f.regexp_count(string.expr, pattern.expr, start, flags)) + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start else None, + flags.expr if flags else None, + ) + ) def regexp_instr( @@ -1684,16 +1662,11 @@ def regexp_instr( >>> result.collect_column("pos")[0].as_py() 1 """ - if not isinstance(regex, Expr): - regex = Expr.literal(regex) - if start is not None and not isinstance(start, Expr): - start = Expr.literal(start) - if n is not None and not isinstance(n, Expr): - n = Expr.literal(n) - if flags is not None and not isinstance(flags, Expr): - flags = Expr.literal(flags) - if sub_expr is not None and not isinstance(sub_expr, Expr): - sub_expr = Expr.literal(sub_expr) + regex = coerce_to_expr(regex) + start = coerce_to_expr_or_none(start) + n = coerce_to_expr_or_none(n) + flags = coerce_to_expr_or_none(flags) + sub_expr = coerce_to_expr_or_none(sub_expr) return Expr( f.regexp_instr( @@ -1718,8 +1691,7 @@ def repeat(string: Expr, n: Expr | int) -> Expr: >>> result.collect_column("r")[0].as_py() 'hahaha' """ - if not isinstance(n, Expr): - n = Expr.literal(n) + n = coerce_to_expr(n) return Expr(f.repeat(string.expr, n.expr)) @@ -1734,10 +1706,8 @@ def replace(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: >>> result.collect_column("r")[0].as_py() 'hello there' """ - if not isinstance(from_val, Expr): - from_val = Expr.literal(from_val) - if not isinstance(to_val, Expr): - to_val = Expr.literal(to_val) + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) @@ -1764,8 +1734,7 @@ def right(string: Expr, n: Expr | int) -> Expr: >>> result.collect_column("r")[0].as_py() 'llo' """ - if not isinstance(n, Expr): - n = Expr.literal(n) + n = coerce_to_expr(n) return Expr(f.right(string.expr, n.expr)) @@ -1783,10 +1752,7 @@ def round(value: Expr, decimal_places: Expr | int | None = None) -> Expr: >>> result.collect_column("r")[0].as_py() 1.57 """ - if decimal_places is None: - decimal_places = Expr.literal(0) - elif not isinstance(decimal_places, Expr): - decimal_places = Expr.literal(decimal_places) + decimal_places = coerce_to_expr(decimal_places if decimal_places is not None else 0) return Expr(f.round(value.expr, decimal_places.expr)) @@ -1804,12 +1770,8 @@ def rpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) >>> result.collect_column("r")[0].as_py() 'hi!!!' """ - if not isinstance(count, Expr): - count = Expr.literal(count) - if characters is None: - characters = Expr.literal(" ") - elif not isinstance(characters, Expr): - characters = Expr.literal(characters) + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.rpad(string.expr, count.expr, characters.expr)) @@ -1939,10 +1901,8 @@ def split_part(string: Expr, delimiter: Expr | str, index: Expr | int) -> Expr: >>> result.collect_column("s")[0].as_py() 'b' """ - if not isinstance(delimiter, Expr): - delimiter = Expr.literal(delimiter) - if not isinstance(index, Expr): - index = Expr.literal(index) + delimiter = coerce_to_expr(delimiter) + index = coerce_to_expr(index) return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) @@ -1970,8 +1930,7 @@ def starts_with(string: Expr, prefix: Expr | str) -> Expr: >>> result.collect_column("sw")[0].as_py() True """ - if not isinstance(prefix, Expr): - prefix = Expr.literal(prefix) + prefix = coerce_to_expr(prefix) return Expr(f.starts_with(string.expr, prefix.expr)) @@ -1986,8 +1945,7 @@ def strpos(string: Expr, substring: Expr | str) -> Expr: >>> result.collect_column("pos")[0].as_py() 3 """ - if not isinstance(substring, Expr): - substring = Expr.literal(substring) + substring = coerce_to_expr(substring) return Expr(f.strpos(string.expr, substring.expr)) @@ -2002,8 +1960,7 @@ def substr(string: Expr, position: Expr | int) -> Expr: >>> result.collect_column("s")[0].as_py() 'llo' """ - if not isinstance(position, Expr): - position = Expr.literal(position) + position = coerce_to_expr(position) return Expr(f.substr(string.expr, position.expr)) @@ -2021,10 +1978,8 @@ def substr_index(string: Expr, delimiter: Expr | str, count: Expr | int) -> Expr >>> result.collect_column("s")[0].as_py() 'a.b' """ - if not isinstance(delimiter, Expr): - delimiter = Expr.literal(delimiter) - if not isinstance(count, Expr): - count = Expr.literal(count) + delimiter = coerce_to_expr(delimiter) + count = coerce_to_expr(count) return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) @@ -2039,10 +1994,8 @@ def substring(string: Expr, position: Expr | int, length: Expr | int) -> Expr: >>> result.collect_column("s")[0].as_py() 'hello' """ - if not isinstance(position, Expr): - position = Expr.literal(position) - if not isinstance(length, Expr): - length = Expr.literal(length) + position = coerce_to_expr(position) + length = coerce_to_expr(length) return Expr(f.substring(string.expr, position.expr, length.expr)) @@ -2134,8 +2087,7 @@ def to_char(arg: Expr, formatter: Expr | str) -> Expr: >>> result.collect_column("formatted")[0].as_py() '2021/01/01' """ - if not isinstance(formatter, Expr): - formatter = Expr.literal(formatter) + formatter = coerce_to_expr(formatter) return Expr(f.to_char(arg.expr, formatter.expr)) @@ -2377,8 +2329,7 @@ def date_part(part: Expr | str, date: Expr) -> Expr: >>> result.collect_column("y")[0].as_py() 2021 """ - if not isinstance(part, Expr): - part = Expr.literal(part) + part = coerce_to_expr(part) return Expr(f.date_part(part.expr, date.expr)) @@ -2409,8 +2360,7 @@ def date_trunc(part: Expr | str, date: Expr) -> Expr: >>> str(result.collect_column("t")[0].as_py()) '2021-07-01 00:00:00' """ - if not isinstance(part, Expr): - part = Expr.literal(part) + part = coerce_to_expr(part) return Expr(f.date_trunc(part.expr, date.expr)) @@ -2486,10 +2436,8 @@ def translate(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: >>> result.collect_column("t")[0].as_py() 'HELLO' """ - if not isinstance(from_val, Expr): - from_val = Expr.literal(from_val) - if not isinstance(to_val, Expr): - to_val = Expr.literal(to_val) + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) @@ -2523,8 +2471,7 @@ def trunc(num: Expr, precision: Expr | int | None = None) -> Expr: 1.56 """ if precision is not None: - if not isinstance(precision, Expr): - precision = Expr.literal(precision) + precision = coerce_to_expr(precision) return Expr(f.trunc(num.expr, precision.expr)) return Expr(f.trunc(num.expr)) @@ -3016,8 +2963,7 @@ def array_element(array: Expr, n: Expr | int) -> Expr: >>> result.collect_column("result")[0].as_py() 20 """ - if not isinstance(n, Expr): - n = Expr.literal(n) + n = coerce_to_expr(n) return Expr(f.array_element(array.expr, n.expr)) @@ -3424,8 +3370,7 @@ def array_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: >>> result.collect_column("result")[0].as_py() [2, 1] """ - if not isinstance(max, Expr): - max = Expr.literal(max) + max = coerce_to_expr(max) return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) @@ -3474,8 +3419,7 @@ def array_repeat(element: Expr, count: Expr | int) -> Expr: >>> result.collect_column("result")[0].as_py() [3, 3, 3] """ - if not isinstance(count, Expr): - count = Expr.literal(count) + count = coerce_to_expr(count) return Expr(f.array_repeat(element.expr, count.expr)) @@ -3528,8 +3472,7 @@ def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) >>> result.collect_column("result")[0].as_py() [9, 2, 9, 1] """ - if not isinstance(max, Expr): - max = Expr.literal(max) + max = coerce_to_expr(max) return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) @@ -3640,15 +3583,12 @@ def array_slice( >>> result.collect_column("result")[0].as_py() [1, 3] """ - if not isinstance(begin, Expr): - begin = Expr.literal(begin) - if not isinstance(end, Expr): - end = Expr.literal(end) - if stride is not None: - if not isinstance(stride, Expr): - stride = Expr.literal(stride) - stride = stride.expr - return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) + begin = coerce_to_expr(begin) + end = coerce_to_expr(end) + stride = coerce_to_expr_or_none(stride) + return Expr( + f.array_slice(array.expr, begin.expr, end.expr, stride.expr if stride else None) + ) def list_slice( @@ -3759,8 +3699,7 @@ def array_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: >>> result.collect_column("result")[0].as_py() [1, 2, 0, 0] """ - if not isinstance(size, Expr): - size = Expr.literal(size) + size = coerce_to_expr(size) return Expr(f.array_resize(array.expr, size.expr, value.expr)) @@ -3941,15 +3880,15 @@ def string_to_array( >>> result.collect_column("result")[0].as_py() ['hello', None] """ - if not isinstance(delimiter, Expr): - delimiter = Expr.literal(delimiter) - if null_string is not None: - if not isinstance(null_string, Expr): - null_string = Expr.literal(null_string) - null_expr = null_string.expr - else: - null_expr = None - return Expr(f.string_to_array(string.expr, delimiter.expr, null_expr)) + delimiter = coerce_to_expr(delimiter) + null_string = coerce_to_expr_or_none(null_string) + return Expr( + f.string_to_array( + string.expr, + delimiter.expr, + null_string.expr if null_string else None, + ) + ) def string_to_list( From b2455152df7ac5f3db0cf284b46f9beeab1beef0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 07:36:46 -0400 Subject: [PATCH 06/11] docs: add aggregate function literal detection to make-pythonic skill Add Technique 1a to detect literal-only arguments in aggregate functions. Unlike scalar UDFs which enforce literals in invoke_with_args(), aggregate functions enforce them in accumulator() via get_scalar_value(), validate_percentile_expr(), or downcast_ref::(). Without this technique, the skill would incorrectly classify arguments like approx_percentile_cont's percentile as Category A (Expr | float) when they should be Category B (float only). Updates the decision flow to branch on scalar vs aggregate before checking for literal enforcement. Co-Authored-By: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 72 ++++++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 93fa9c2c9..78279dc9c 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -58,7 +58,13 @@ Check the Rust binding in `crates/core/src/functions.rs` and the upstream DataFu Check the DataFusion version in `crates/core/Cargo.toml` to find the right directory. Key subdirectories: `string/`, `datetime/`, `math/`, `regex/`. -There are three concrete techniques to check, in order of signal strength: +For **aggregate functions**, the upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-aggregate-/src/ +``` + +There are four concrete techniques to check, in order of signal strength: #### Technique 1: Check `invoke_with_args()` for literal-only enforcement (strongest signal) @@ -75,6 +81,48 @@ let granularity_str = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = **If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type (e.g., `str`), not `Expr`. The function will error at runtime with a column expression anyway. +#### Technique 1a: Check `accumulator()` for literal-only enforcement (aggregate functions) + +Technique 1 applies to scalar UDFs. Aggregate functions do not have `invoke_with_args()` — instead, they enforce literal-only arguments in their `accumulator()` (or `create_accumulator()`) method, which runs at planning time before any data is processed. + +Look for these patterns inside `accumulator()`: + +- `get_scalar_value(expr)` — evaluates the expression against an empty batch and errors if it's not a scalar +- `validate_percentile_expr(expr)` — specific helper used by percentile functions +- `downcast_ref::()` — checks that the physical expression is a literal constant + +Example from `approx_percentile_cont.rs`: +```rust +fn accumulator(&self, args: AccumulatorArgs) -> Result { + let percentile = + validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?; + // ... +} +``` + +Where `validate_percentile_expr` calls `get_scalar_value` and errors with `"must be a literal"`. + +Example from `string_agg.rs`: +```rust +fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { + return not_impl_err!( + "The second argument of the string_agg function must be a string literal" + ); + }; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +Known aggregate functions with literal-only arguments: +- `approx_percentile_cont` — `percentile` (float), `num_centroids` (int) +- `approx_percentile_cont_with_weight` — `percentile` (float), `num_centroids` (int) +- `percentile_cont` — `percentile` (float) +- `string_agg` — `delimiter` (str) +- `nth_value` — `n` (int) + #### Technique 2: Check the `Signature` for data type constraints Each function defines a `Signature::coercible(...)` that specifies what data types each argument accepts, using `Coercion` entries. This tells you the expected **data type** even if it doesn't enforce literal-only. @@ -125,12 +173,22 @@ let decimal_places: Option = match args.scalar_arguments.get(1) { #### Decision flow ``` -Is argument rejected at runtime if not a literal? - (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) - → YES: Category B — accept only native type, no Expr - → NO: Does the Signature constrain it to a specific data type? - → YES: Category A — accept Expr | - → NO: Leave as Expr only +Is the function a scalar UDF or an aggregate? + Scalar UDF: + Is argument rejected at runtime if not a literal? + (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Aggregate: + Is argument rejected at planning time if not a literal? + (check accumulator() for get_scalar_value / validate_percentile_expr / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + +Does the Signature constrain it to a specific data type? + → YES: Category A — accept Expr | + → NO: Leave as Expr only ``` ## Coercion Categories From 05f9ab91a3ef7aca59145a7027a87959cf449353 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 07:40:36 -0400 Subject: [PATCH 07/11] docs: add window function literal detection to make-pythonic skill Add Technique 1b to detect literal-only arguments in window functions. Window functions enforce literals in partition_evaluator() via get_scalar_value_from_args() / downcast_ref::(), not in invoke_with_args() (scalar) or accumulator() (aggregate). Updates the decision flow to branch on scalar vs aggregate vs window. Known window functions with literal-only arguments: ntile (n), lead/lag (offset, default_value), nth_value (n). Co-Authored-By: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 45 +++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 78279dc9c..a1bcb5895 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -64,7 +64,7 @@ For **aggregate functions**, the upstream source is in a separate crate: ~/.cargo/registry/src/index.crates.io-*/datafusion-functions-aggregate-/src/ ``` -There are four concrete techniques to check, in order of signal strength: +There are five concrete techniques to check, in order of signal strength: #### Technique 1: Check `invoke_with_args()` for literal-only enforcement (strongest signal) @@ -123,6 +123,41 @@ Known aggregate functions with literal-only arguments: - `string_agg` — `delimiter` (str) - `nth_value` — `n` (int) +#### Technique 1b: Check `partition_evaluator()` for literal-only enforcement (window functions) + +Window functions do not have `invoke_with_args()` or `accumulator()`. Instead, they enforce literal-only arguments in their `partition_evaluator()` method, which constructs the evaluator that processes each partition. + +The upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-window-/src/ +``` + +Look for `get_scalar_value_from_args()` calls inside `partition_evaluator()`. This helper (defined in the window crate's `utils.rs`) calls `downcast_ref::()` and errors with `"There is only support Literal types for field at idx: {index} in Window Function"`. + +Example from `ntile.rs`: +```rust +fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, +) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + exec_datafusion_err!("NTILE requires a positive integer") + })?; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +Known window functions with literal-only arguments: +- `ntile` — `n` (int) +- `lead` — `offset` (int), `default_value` (scalar) +- `lag` — `offset` (int), `default_value` (scalar) +- `nth_value` — `n` (int) + #### Technique 2: Check the `Signature` for data type constraints Each function defines a `Signature::coercible(...)` that specifies what data types each argument accepts, using `Coercion` entries. This tells you the expected **data type** even if it doesn't enforce literal-only. @@ -173,7 +208,7 @@ let decimal_places: Option = match args.scalar_arguments.get(1) { #### Decision flow ``` -Is the function a scalar UDF or an aggregate? +What kind of function is this? Scalar UDF: Is argument rejected at runtime if not a literal? (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) @@ -185,6 +220,12 @@ Is the function a scalar UDF or an aggregate? downcast_ref::() + error) → YES: Category B — accept only native type, no Expr → NO: continue below + Window: + Is argument rejected at planning time if not a literal? + (check partition_evaluator() for get_scalar_value_from_args / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below Does the Signature constrain it to a specific data type? → YES: Category A — accept Expr | From f6bed117abe2f85c8edbd31263a23b1ef63ea312 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 08:38:46 -0400 Subject: [PATCH 08/11] fix: use explicit None checks, widen numeric type hints, and add tests Replace 7 fragile truthiness checks (x.expr if x else None) with explicit is not None checks to prevent silent None when zero-valued literals are passed. Widen log/power/pow type hints to Expr | int | float with noqa: PYI041 for clarity. Add unit tests for coerce_to_expr helpers and integration tests for pythonic calling conventions. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 33 ++++++++---- python/tests/test_expr.py | 45 ++++++++++++++++ python/tests/test_functions.py | 93 ++++++++++++++++++++++++++++++++++ 3 files changed, 161 insertions(+), 10 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 91dc16b8a..162e70c95 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1190,7 +1190,7 @@ def ln(arg: Expr) -> Expr: return Expr(f.ln(arg.expr)) -def log(base: Expr | float, num: Expr) -> Expr: +def log(base: Expr | int | float, num: Expr) -> Expr: # noqa: PYI041 """Returns the logarithm of a number for a particular ``base``. Examples: @@ -1416,7 +1416,7 @@ def position(string: Expr, substring: Expr | str) -> Expr: return strpos(string, substring) -def power(base: Expr, exponent: Expr | float) -> Expr: +def power(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. Examples: @@ -1432,7 +1432,7 @@ def power(base: Expr, exponent: Expr | float) -> Expr: return Expr(f.power(base.expr, exponent.expr)) -def pow(base: Expr, exponent: Expr | float) -> Expr: +def pow(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. See Also: @@ -1486,7 +1486,11 @@ def regexp_like( """ regex = coerce_to_expr(regex) flags = coerce_to_expr_or_none(flags) - return Expr(f.regexp_like(string.expr, regex.expr, flags.expr if flags else None)) + return Expr( + f.regexp_like( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) def regexp_match( @@ -1518,7 +1522,11 @@ def regexp_match( """ regex = coerce_to_expr(regex) flags = coerce_to_expr_or_none(flags) - return Expr(f.regexp_match(string.expr, regex.expr, flags.expr if flags else None)) + return Expr( + f.regexp_match( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) def regexp_replace( @@ -1565,7 +1573,7 @@ def regexp_replace( string.expr, pattern.expr, replacement.expr, - flags.expr if flags else None, + flags.expr if flags is not None else None, ) ) @@ -1606,8 +1614,8 @@ def regexp_count( f.regexp_count( string.expr, pattern.expr, - start.expr if start else None, - flags.expr if flags else None, + start.expr if start is not None else None, + flags.expr if flags is not None else None, ) ) @@ -3587,7 +3595,12 @@ def array_slice( end = coerce_to_expr(end) stride = coerce_to_expr_or_none(stride) return Expr( - f.array_slice(array.expr, begin.expr, end.expr, stride.expr if stride else None) + f.array_slice( + array.expr, + begin.expr, + end.expr, + stride.expr if stride is not None else None, + ) ) @@ -3886,7 +3899,7 @@ def string_to_array( f.string_to_array( string.expr, delimiter.expr, - null_string.expr if null_string else None, + null_string.expr if null_string is not None else None, ) ) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 1cf824a15..3ff87ca52 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -53,6 +53,8 @@ TransactionEnd, TransactionStart, Values, + coerce_to_expr, + coerce_to_expr_or_none, ensure_expr, ensure_expr_list, ) @@ -1026,6 +1028,49 @@ def test_ensure_expr_list_bytearray(): ensure_expr_list(bytearray(b"a")) +def test_coerce_to_expr_passes_expr_through(): + e = col("a") + result = coerce_to_expr(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + +def test_coerce_to_expr_wraps_int(): + result = coerce_to_expr(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_wraps_str(): + result = coerce_to_expr("hello") + assert isinstance(result, type(lit("hello"))) + + +def test_coerce_to_expr_wraps_float(): + result = coerce_to_expr(3.14) + assert isinstance(result, type(lit(3.14))) + + +def test_coerce_to_expr_wraps_bool(): + result = coerce_to_expr(True) # noqa: FBT003 + assert isinstance(result, type(lit(True))) + + +def test_coerce_to_expr_or_none_returns_none(): + assert coerce_to_expr_or_none(None) is None + + +def test_coerce_to_expr_or_none_wraps_value(): + result = coerce_to_expr_or_none(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_or_none_passes_expr_through(): + e = col("a") + result = coerce_to_expr_or_none(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + @pytest.mark.parametrize( "value", [ diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 11e94af1c..d9781b1fb 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -2099,3 +2099,96 @@ def test_gen_series_with_step(): f.gen_series(literal(1), literal(10), literal(3)).alias("v") ).collect() assert result[0].column(0)[0].as_py() == [1, 4, 7, 10] + + +class TestPythonicNativeTypes: + """Tests for accepting native Python types instead of requiring lit().""" + + def test_split_part_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select(f.split_part(column("a"), ",", 2).alias("s")).collect() + assert result[0].column(0)[0].as_py() == "b" + + def test_encode_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello"]}) + result = df.select(f.encode(column("a"), "base64").alias("e")).collect() + assert result[0].column(0)[0].as_py() == "aGVsbG8" + + def test_date_part_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_part("year", column("a")).alias("y")).collect() + assert result[0].column(0)[0].as_py() == 2021 + + def test_date_trunc_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_trunc("month", column("a")).alias("t")).collect() + assert str(result[0].column(0)[0].as_py()) == "2021-07-01 00:00:00" + + def test_left_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["the cat"]}) + result = df.select(f.left(column("a"), 3).alias("l")).collect() + assert result[0].column(0)[0].as_py() == "the" + + def test_round_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.567]}) + result = df.select(f.round(column("a"), 2).alias("r")).collect() + assert result[0].column(0)[0].as_py() == 1.57 + + def test_regexp_count_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["abcabc"]}) + result = df.select( + f.regexp_count(column("a"), "abc", start=4, flags="i").alias("c") + ).collect() + assert result[0].column(0)[0].as_py() == 1 + + def test_log_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [100.0]}) + result = df.select(f.log(10, column("a")).alias("l")).collect() + assert result[0].column(0)[0].as_py() == 2.0 + + def test_power_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [2.0]}) + result = df.select(f.power(column("a"), 3).alias("p")).collect() + assert result[0].column(0)[0].as_py() == 8.0 + + def test_array_slice_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) + result = df.select(f.array_slice(column("a"), 2, 3).alias("s")).collect() + assert result[0].column(0)[0].as_py() == [2, 3] + + def test_string_to_array_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,NA,world"]}) + result = df.select( + f.string_to_array(column("a"), ",", null_string="NA").alias("v") + ).collect() + assert result[0].column(0)[0].as_py() == ["hello", None, "world"] + + def test_regexp_replace_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a1 b2 c3"]}) + result = df.select( + f.regexp_replace(column("a"), r"\d+", "X", flags="g").alias("r") + ).collect() + assert result[0].column(0)[0].as_py() == "aX bX cX" + + def test_backward_compat_with_lit(self): + """Verify that existing code using lit() still works.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select( + f.split_part(column("a"), literal(","), literal(2)).alias("s") + ).collect() + assert result[0].column(0)[0].as_py() == "b" From 4fede05d4e5a8867b6d43560a363ed967eb21fc2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 08:43:38 -0400 Subject: [PATCH 09/11] chore: suppress FBT003 in tests and remove redundant noqa comments Add FBT003 (boolean positional value) to the per-file-ignores for python/tests/* in pyproject.toml, and remove the 6 now-redundant inline noqa: FBT003 comments across test_expr.py and test_context.py. Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 1 + python/tests/test_context.py | 6 +++--- python/tests/test_expr.py | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 327199d1a..951f7adc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"] "ARG", "BLE001", "D", + "FBT003", "PD", "PLC0415", "PLR0913", diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 13c05a9e6..e0ebdbae5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -964,12 +964,12 @@ def test_csv_read_options_builder_pattern(): options = ( CsvReadOptions() - .with_has_header(False) # noqa: FBT003 + .with_has_header(False) .with_delimiter("|") .with_quote("'") .with_schema_infer_max_records(2000) - .with_truncated_rows(True) # noqa: FBT003 - .with_newlines_in_values(True) # noqa: FBT003 + .with_truncated_rows(True) + .with_newlines_in_values(True) .with_file_extension(".tsv") ) assert options.has_header is False diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 3ff87ca52..0361f4c79 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -1051,7 +1051,7 @@ def test_coerce_to_expr_wraps_float(): def test_coerce_to_expr_wraps_bool(): - result = coerce_to_expr(True) # noqa: FBT003 + result = coerce_to_expr(True) assert isinstance(result, type(lit(True))) @@ -1075,8 +1075,8 @@ def test_coerce_to_expr_or_none_passes_expr_through(): "value", [ # Boolean - pa.scalar(True, type=pa.bool_()), # noqa: FBT003 - pa.scalar(False, type=pa.bool_()), # noqa: FBT003 + pa.scalar(True, type=pa.bool_()), + pa.scalar(False, type=pa.bool_()), # Integers - signed pa.scalar(127, type=pa.int8()), pa.scalar(-128, type=pa.int8()), From fca012281eb73d351ba0cafd67cb0d99a6cf511a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 08:50:24 -0400 Subject: [PATCH 10/11] docs: replace static function lists with discovery instructions in skill Replace hardcoded "Known aggregate/window functions with literal-only arguments" lists with instructions to discover them dynamically by searching the upstream crate source. Keeps a few examples as validation anchors so the agent knows its search is working correctly. Co-Authored-By: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index a1bcb5895..57145ac6c 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -116,12 +116,7 @@ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> **If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. -Known aggregate functions with literal-only arguments: -- `approx_percentile_cont` — `percentile` (float), `num_centroids` (int) -- `approx_percentile_cont_with_weight` — `percentile` (float), `num_centroids` (int) -- `percentile_cont` — `percentile` (float) -- `string_agg` — `delimiter` (str) -- `nth_value` — `n` (int) +To discover which aggregate functions have literal-only arguments, search the upstream aggregate crate for `get_scalar_value`, `validate_percentile_expr`, and `downcast_ref::()` inside `accumulator()` methods. For example, you should expect to find `approx_percentile_cont` (percentile) and `string_agg` (delimiter) among the results. #### Technique 1b: Check `partition_evaluator()` for literal-only enforcement (window functions) @@ -152,11 +147,7 @@ fn partition_evaluator( **If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. -Known window functions with literal-only arguments: -- `ntile` — `n` (int) -- `lead` — `offset` (int), `default_value` (scalar) -- `lag` — `offset` (int), `default_value` (scalar) -- `nth_value` — `n` (int) +To discover which window functions have literal-only arguments, search the upstream window crate for `get_scalar_value_from_args` inside `partition_evaluator()` methods. For example, you should expect to find `ntile` (n) and `lead`/`lag` (offset, default_value) among the results. #### Technique 2: Check the `Signature` for data type constraints From e36eff29a6afb65df4aa2cf5a700f71887f2ca1b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 12:28:58 -0400 Subject: [PATCH 11/11] fix: make interrupt test reliable on Python 3.11 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PyThreadState_SetAsyncExc only delivers exceptions when the thread is executing Python bytecode, not while in native (Rust/C) code. The previous test had two issues causing flakiness on Python 3.11: 1. The interrupt fired before df.collect() entered the UDF, while the thread was still in native code where async exceptions are ignored. 2. time.sleep(2.0) is a single C call where async exceptions are not checked — they're only checked between bytecode instructions. Fix by adding a threading.Event so the interrupt waits until the UDF is actually executing Python code, and by sleeping in small increments so the eval loop has opportunities to check for pending exceptions. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/tests/test_dataframe.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index bb8e9685c..2b82e7ef9 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -3416,10 +3416,18 @@ def test_fill_null_all_null_column(ctx): assert result.column(1).to_pylist() == ["filled", "filled", "filled"] +_slow_udf_started = threading.Event() + + @udf([pa.int64()], pa.int64(), "immutable") def slow_udf(x: pa.Array) -> pa.Array: - # This must be longer than the check interval in wait_for_future - time.sleep(2.0) + _slow_udf_started.set() + # Sleep in small increments so Python's eval loop checks for pending + # async exceptions (like KeyboardInterrupt via PyThreadState_SetAsyncExc) + # between iterations. A single long time.sleep() is a C call where async + # exceptions are not checked on all Python versions (notably 3.11). + for _ in range(200): + time.sleep(0.01) return x @@ -3453,6 +3461,7 @@ def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 P if as_c_stream: reader = pa.RecordBatchReader.from_stream(df) + _slow_udf_started.clear() read_started = threading.Event() read_exception = [] read_thread_id = None @@ -3464,6 +3473,14 @@ def trigger_interrupt(): msg = f"Read operation did not start within {max_wait_time} seconds" raise RuntimeError(msg) + # For slow_query tests, wait until the UDF is actually executing Python + # bytecode before sending the interrupt. PyThreadState_SetAsyncExc only + # delivers exceptions when the thread is in the Python eval loop, not + # while in native (Rust/C) code. + if slow_query and not _slow_udf_started.wait(timeout=max_wait_time): + msg = f"UDF did not start within {max_wait_time} seconds" + raise RuntimeError(msg) + if read_thread_id is None: msg = "Cannot get read thread ID" raise RuntimeError(msg)