Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .ai/skills/check-upstream/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,18 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all"
- Python API: `python/datafusion/functions.py` (aggregate functions are mixed in with scalar functions)
- Rust bindings: `crates/core/src/functions.rs`

**Evaluated and not requiring separate Python exposure:**
- `count_distinct` — covered by `count(expr, distinct=True)`. Both forms call
`count_udaf` with `distinct: bool = true` and produce the same logical plan.
- `sum_distinct` — covered by `sum(expr, distinct=True)`.
- `avg_distinct` — covered by `avg(expr, distinct=True)`.

**How to check:**
1. Fetch the upstream aggregate function documentation page
2. Compare against aggregate functions in `python/datafusion/functions.py` (check `__all__` list and function definitions)
3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding
4. Report only functions missing from the Python API
4. Check against the "evaluated and not requiring exposure" list before flagging as a gap
5. Report only functions missing from the Python API

### 3. Window Functions

Expand Down
30 changes: 26 additions & 4 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4521,17 +4521,20 @@ def grouping(

def avg(
expression: Expr,
distinct: bool = False,
filter: Expr | None = None,
) -> Expr:
"""Returns the average value.

This aggregate function expects a numeric expression and will return a float.

If using the builder functions described in ref:`_aggregation` this function ignores
the options ``order_by``, ``null_treatment``, and ``distinct``.
the options ``order_by`` and ``null_treatment``.

Args:
expression: Values to combine into an array
distinct: If True, only distinct values are averaged. Equivalent to the
upstream ``avg_distinct`` shortcut.
filter: If provided, only compute against rows for which the filter is True

Examples:
Expand All @@ -4551,9 +4554,17 @@ def avg(
... ).alias("v")])
>>> result.collect_column("v")[0].as_py()
2.5

>>> df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]})
>>> result = df.aggregate(
... [], [dfn.functions.avg(
... dfn.col("a"), distinct=True,
... ).alias("v")])
>>> result.collect_column("v")[0].as_py()
2.0
"""
filter_raw = filter.expr if filter is not None else None
return Expr(f.avg(expression.expr, filter=filter_raw))
return Expr(f.avg(expression.expr, distinct=distinct, filter=filter_raw))


def corr(value_y: Expr, value_x: Expr, filter: Expr | None = None) -> Expr:
Expand Down Expand Up @@ -4838,17 +4849,20 @@ def min(expression: Expr, filter: Expr | None = None) -> Expr:

def sum(
expression: Expr,
distinct: bool = False,
filter: Expr | None = None,
) -> Expr:
"""Computes the sum of a set of numbers.

This aggregate function expects a numeric expression.

If using the builder functions described in ref:`_aggregation` this function ignores
the options ``order_by``, ``null_treatment``, and ``distinct``.
the options ``order_by`` and ``null_treatment``.

Args:
expression: Values to combine into an array
distinct: If True, only distinct values are summed. Equivalent to the
upstream ``sum_distinct`` shortcut.
filter: If provided, only compute against rows for which the filter is True

Examples:
Expand All @@ -4868,9 +4882,17 @@ def sum(
... ).alias("v")])
>>> result.collect_column("v")[0].as_py()
5

>>> df = ctx.from_pydict({"a": [1, 1, 2, 3]})
>>> result = df.aggregate(
... [], [dfn.functions.sum(
... dfn.col("a"), distinct=True,
... ).alias("v")])
>>> result.collect_column("v")[0].as_py()
6
"""
filter_raw = filter.expr if filter is not None else None
return Expr(f.sum(expression.expr, filter=filter_raw))
return Expr(f.sum(expression.expr, distinct=distinct, filter=filter_raw))


def stddev(expression: Expr, filter: Expr | None = None) -> Expr:
Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,36 @@ def test_get_field(df):
assert result.column(1) == pa.array([4, 5, 6])


def test_sum_distinct_kwarg():
ctx = SessionContext()
df = ctx.from_pydict({"a": [1, 1, 2, 3]})
distinct = (
df.aggregate([], [f.sum(column("a"), distinct=True).alias("v")])
.collect_column("v")[0]
.as_py()
)
total = (
df.aggregate([], [f.sum(column("a")).alias("v")]).collect_column("v")[0].as_py()
)
assert distinct == 6
assert total == 7


def test_avg_distinct_kwarg():
ctx = SessionContext()
df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]})
distinct = (
df.aggregate([], [f.avg(column("a"), distinct=True).alias("v")])
.collect_column("v")[0]
.as_py()
)
mean = (
df.aggregate([], [f.avg(column("a")).alias("v")]).collect_column("v")[0].as_py()
)
assert distinct == 2.0
assert mean == 1.75


def test_arrow_metadata():
ctx = SessionContext()
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})
Expand Down
Loading