Skip to content

Commit 1d576a8

Browse files
authored
feat: Support Chart.transform_filter(*predicates, **constraints) (#3664)
1 parent b292ccf commit 1d576a8

File tree

5 files changed

+192
-47
lines changed

5 files changed

+192
-47
lines changed

altair/vegalite/v5/api.py

Lines changed: 110 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@
127127
NamedData,
128128
ParameterName,
129129
PointSelectionConfig,
130-
Predicate,
131130
PredicateComposition,
132131
ProjectionType,
133132
RepeatMapping,
@@ -542,12 +541,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
542541
"""
543542

544543

545-
_FieldEqualType: TypeAlias = Union[PrimitiveValue_T, Map, Parameter, SchemaBase]
546-
"""Permitted types for equality checks on field values:
544+
_FieldEqualType: TypeAlias = Union["IntoExpression", Parameter, SchemaBase]
545+
"""
546+
Permitted types for equality checks on field values.
547+
548+
Applies to the following context(s):
549+
550+
import altair as alt
547551
548-
- `datum.field == ...`
549-
- `FieldEqualPredicate(equal=...)`
550-
- `when(**constraints=...)`
552+
alt.datum.field == ...
553+
alt.FieldEqualPredicate(field="field", equal=...)
554+
alt.when(field=...)
555+
alt.when().then().when(field=...)
556+
alt.Chart.transform_filter(field=...)
551557
"""
552558

553559

@@ -2986,45 +2992,113 @@ def transform_extent(
29862992
"""
29872993
return self._add_transform(core.ExtentTransform(extent=extent, param=param))
29882994

2989-
# TODO: Update docstring
2990-
# # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])}
29912995
def transform_filter(
29922996
self,
2993-
filter: str
2994-
| Expr
2995-
| Expression
2996-
| Predicate
2997-
| Parameter
2998-
| PredicateComposition
2999-
| dict[str, Predicate | str | list | bool],
3000-
**kwargs: Any,
2997+
predicate: Optional[_PredicateType] = Undefined,
2998+
*more_predicates: _ComposablePredicateType,
2999+
empty: Optional[bool] = Undefined,
3000+
**constraints: _FieldEqualType,
30013001
) -> Self:
30023002
"""
3003-
Add a :class:`FilterTransform` to the schema.
3003+
Add a :class:`FilterTransform` to the spec.
3004+
3005+
The resulting predicate is an ``&`` reduction over ``predicate`` and optional ``*``, ``**``, arguments.
30043006
30053007
Parameters
30063008
----------
3007-
filter : a filter expression or :class:`PredicateComposition`
3008-
The `filter` property must be one of the predicate definitions:
3009-
(1) a string or alt.expr expression
3010-
(2) a range predicate
3011-
(3) a selection predicate
3012-
(4) a logical operand combining (1)-(3)
3013-
(5) a Selection object
3009+
predicate
3010+
A selection or test predicate. ``str`` input will be treated as a test operand.
3011+
*more_predicates
3012+
Additional predicates, restricted to types supporting ``&``.
3013+
empty
3014+
For selection parameters, the predicate of empty selections returns ``True`` by default.
3015+
Override this behavior, with ``empty=False``.
30143016
3015-
Returns
3016-
-------
3017-
self : Chart object
3018-
returns chart to allow for chaining
3017+
.. note::
3018+
When ``predicate`` is a ``Parameter`` that is used more than once,
3019+
``self.transform_filter(..., empty=...)`` provides granular control for each occurrence.
3020+
**constraints
3021+
Specify `Field Equal Predicate`_'s.
3022+
Shortcut for ``alt.datum.field_name == value``, see examples for usage.
3023+
3024+
Warns
3025+
-----
3026+
AltairDeprecationWarning
3027+
If called using ``filter`` as a keyword argument.
3028+
3029+
See Also
3030+
--------
3031+
alt.when : Uses a similar syntax for defining conditional values.
3032+
3033+
Notes
3034+
-----
3035+
- Directly inspired by the syntax used in `polars.DataFrame.filter`_.
3036+
3037+
.. _Field Equal Predicate:
3038+
https://vega.github.io/vega-lite/docs/predicate.html#equal-predicate
3039+
.. _polars.DataFrame.filter:
3040+
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.filter.html
3041+
3042+
Examples
3043+
--------
3044+
Setting up a common chart::
3045+
3046+
import altair as alt
3047+
from altair import datum
3048+
from vega_datasets import data
3049+
3050+
source = data.population.url
3051+
chart = (
3052+
alt.Chart(source)
3053+
.mark_line()
3054+
.encode(
3055+
x="age:O",
3056+
y="sum(people):Q",
3057+
color=alt.Color("year:O").legend(symbolType="square"),
3058+
)
3059+
)
3060+
chart
3061+
3062+
Singular predicates can be expressed via ``datum``::
3063+
3064+
chart.transform_filter(datum.year <= 1980)
3065+
3066+
We can also use selection parameters directly::
3067+
3068+
selection = alt.selection_point(encodings=["color"], bind="legend")
3069+
chart.transform_filter(selection).add_params(selection)
3070+
3071+
Or a field predicate::
3072+
3073+
between_1950_60 = alt.FieldRangePredicate(field="year", range=[1950, 1960])
3074+
chart.transform_filter(between_1950_60) | chart.transform_filter(~between_1950_60)
3075+
3076+
Predicates can be composed together using logical operands::
3077+
3078+
chart.transform_filter(between_1950_60 | (datum.year == 1850))
3079+
3080+
Predicates passed as positional arguments will be reduced with ``&``::
3081+
3082+
chart.transform_filter(datum.year > 1980, datum.age != 90)
3083+
3084+
Using keyword-argument ``constraints`` can simplify compositions like::
3085+
3086+
verbose_composition = chart.transform_filter((datum.year == 2000) & (datum.sex == 1))
3087+
chart.transform_filter(year=2000, sex=1)
30193088
"""
3020-
if isinstance(filter, Parameter):
3021-
new_filter: dict[str, Any] = {"param": filter.name}
3022-
if "empty" in kwargs:
3023-
new_filter["empty"] = kwargs.pop("empty")
3024-
elif isinstance(filter.empty, bool):
3025-
new_filter["empty"] = filter.empty
3026-
filter = new_filter
3027-
return self._add_transform(core.FilterTransform(filter=filter, **kwargs))
3089+
if depr_filter := t.cast(Any, constraints.pop("filter", None)):
3090+
utils.deprecated_warn(
3091+
"Passing `filter` as a keyword is ambiguous.\n\n"
3092+
"Use a positional argument for `<5.5.0` behavior.\n"
3093+
"Or, `alt.datum['filter'] == ...` if referring to a column named 'filter'.",
3094+
version="5.5.0",
3095+
)
3096+
if utils.is_undefined(predicate):
3097+
predicate = depr_filter
3098+
else:
3099+
more_predicates = *more_predicates, depr_filter
3100+
cond = _parse_when(predicate, *more_predicates, empty=empty, **constraints)
3101+
return self._add_transform(core.FilterTransform(filter=cond.get("test", cond)))
30283102

30293103
def transform_flatten(
30303104
self,

doc/user_guide/transform/filter.rst

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ expressions and objects:
2020

2121
We'll show a brief example of each of these in the following sections
2222

23+
.. _filter-expression:
24+
2325
Filter Expression
2426
^^^^^^^^^^^^^^^^^
2527
A filter expression uses the `Vega expression`_ language, either specified
@@ -189,12 +191,26 @@ Then, we can *invert* this selection using ``~``:
189191
chart.transform_filter(~between_1950_60)
190192

191193
We can further refine our filter by *composing* multiple predicates together.
192-
In this case, using ``alt.datum``:
194+
In this case, using ``datum``:
195+
196+
.. altair-plot::
197+
198+
chart.transform_filter(~between_1950_60 & (datum.age <= 70))
199+
200+
When passing multiple predicates they will be reduced with ``&``:
193201

194202
.. altair-plot::
195203

196-
chart.transform_filter(~between_1950_60 & (alt.datum.age <= 70))
204+
chart.transform_filter(datum.year > 1980, datum.age != 90)
197205

206+
Using keyword-argument ``constraints`` can simplify our first example in :ref:`filter-expression`:
207+
208+
.. altair-plot::
209+
210+
alt.Chart(source).mark_area().encode(
211+
x="age:O",
212+
y="people:Q",
213+
).transform_filter(year=2000, sex=1)
198214

199215
Transform Options
200216
^^^^^^^^^^^^^^^^^

tests/examples_arguments_syntax/line_chart_with_cumsum_faceted.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood']
1313

1414
alt.Chart(source).transform_filter(
15-
{'and': [
16-
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted
17-
alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century
18-
]}
15+
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted),
16+
alt.FieldRangePredicate(field='Year', range=[1900, 2000])
1917
).transform_window(
2018
cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity
2119
).mark_line().encode(

tests/examples_methods_syntax/line_chart_with_cumsum_faceted.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood']
1313

1414
alt.Chart(source).transform_filter(
15-
{'and': [
16-
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted
17-
alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century
18-
]}
15+
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted),
16+
alt.FieldRangePredicate(field='Year', range=[1900, 2000])
1917
).transform_window(
2018
cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity
2119
).mark_line().encode(

tests/vegalite/v5/test_api.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import re
1111
import sys
1212
import tempfile
13+
import warnings
1314
from collections.abc import Mapping
1415
from datetime import date, datetime
1516
from importlib.metadata import version as importlib_version
@@ -85,7 +86,7 @@ def _make_chart_type(chart_type):
8586

8687

8788
@pytest.fixture
88-
def basic_chart():
89+
def basic_chart() -> alt.Chart:
8990
data = pd.DataFrame(
9091
{
9192
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
@@ -1247,6 +1248,64 @@ def test_predicate_composition() -> None:
12471248
assert actual_multi == expected_multi
12481249

12491250

1251+
def test_filter_transform_predicates(basic_chart) -> None:
1252+
lhs, rhs = alt.datum["b"] >= 30, alt.datum["b"] < 60
1253+
expected = [{"filter": lhs & rhs}]
1254+
actual = basic_chart.transform_filter(lhs, rhs).to_dict()["transform"]
1255+
assert actual == expected
1256+
1257+
1258+
def test_filter_transform_constraints(basic_chart) -> None:
1259+
lhs, rhs = alt.datum["a"] == "A", alt.datum["b"] == 30
1260+
expected = [{"filter": lhs & rhs}]
1261+
actual = basic_chart.transform_filter(a="A", b=30).to_dict()["transform"]
1262+
assert actual == expected
1263+
1264+
1265+
def test_filter_transform_predicates_constraints(basic_chart) -> None:
1266+
from functools import reduce
1267+
from operator import and_
1268+
1269+
predicates = (
1270+
alt.datum["a"] != "A",
1271+
alt.datum["a"] != "B",
1272+
alt.datum["a"] != "C",
1273+
alt.datum["b"] > 1,
1274+
alt.datum["b"] < 99,
1275+
)
1276+
constraints = {"b": 30, "a": "D"}
1277+
pred_constraints = *predicates, alt.datum["b"] == 30, alt.datum["a"] != "D"
1278+
expected = [{"filter": reduce(and_, pred_constraints)}]
1279+
actual = basic_chart.transform_filter(*predicates, **constraints).to_dict()[
1280+
"transform"
1281+
]
1282+
assert actual == expected
1283+
1284+
1285+
def test_filter_transform_errors(basic_chart) -> None:
1286+
NO_ARGS = r"At least one.+Undefined"
1287+
FILTER_KWARGS = r"ambiguous"
1288+
1289+
depr_filter = {"field": "year", "oneOf": [1955, 2000]}
1290+
expected = [{"filter": depr_filter}]
1291+
1292+
with pytest.raises(TypeError, match=NO_ARGS):
1293+
basic_chart.transform_filter()
1294+
with pytest.raises(TypeError, match=NO_ARGS):
1295+
basic_chart.transform_filter(empty=True)
1296+
with pytest.raises(TypeError, match=NO_ARGS):
1297+
basic_chart.transform_filter(empty=False)
1298+
1299+
with pytest.warns(alt.AltairDeprecationWarning, match=FILTER_KWARGS):
1300+
basic_chart.transform_filter(filter=depr_filter)
1301+
1302+
with warnings.catch_warnings():
1303+
warnings.filterwarnings("ignore", category=alt.AltairDeprecationWarning)
1304+
actual = basic_chart.transform_filter(filter=depr_filter).to_dict()["transform"]
1305+
1306+
assert actual == expected
1307+
1308+
12501309
def test_resolve_methods():
12511310
chart = alt.LayerChart().resolve_axis(x="shared", y="independent")
12521311
assert chart.resolve == alt.Resolve(

0 commit comments

Comments
 (0)