Skip to content

Commit

Permalink
AIP-72: Fix recursion bug with XComArg
Browse files Browse the repository at this point in the history
It fixes the following bug

```python
{"timestamp":"2024-12-20T10:38:56.890735","logger":"task","error_detail":
[{"exc_type":"RecursionError","exc_value":"maximum recursion depth exceeded in comparison","syntax_error":null,"is_cause":false,"frames":
[
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/execution_time/task_runner.py","lineno":382,"name":"main"},
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/execution_time/task_runner.py","lineno":317,"name":"run"},
	{"filename":"/opt/airflow/airflow/models/baseoperator.py","lineno":378,"name":"wrapper"},
	{"filename":"/opt/airflow/providers/src/airflow/providers/standard/operators/python.py","lineno":182,"name":"execute"},
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/definitions/baseoperator.py","lineno":660,"name":"__setattr__"},
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/definitions/baseoperator.py","lineno":1126,"name":"_set_xcomargs_dependency"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":132,"name":"apply_upstream_relationship"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":118,"name":"iter_xcom_references"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":121,"name":"iter_xcom_references"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":118,"name":"iter_xcom_references"},
	...
```

To reproduce just run `tutorial_dag` or the following minimal dag:

```python
import pendulum

from airflow.models.dag import DAG
from airflow.providers.standard.operators.python import PythonOperator

with DAG(
    "sdk_tutorial_dag",
    schedule=None,
    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
    catchup=False,
    tags=["example"],
) as dag:
    dag.doc_md = __doc__

    def extract(**kwargs):
        ti = kwargs["ti"]
        data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
        ti.xcom_push("order_data", data_string)

    extract_task = PythonOperator(
        task_id="extract",
        python_callable=extract,
    )

    extract_task
```
  • Loading branch information
kaxil committed Dec 20, 2024
1 parent f10f552 commit 1cf9ec3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 17 deletions.
7 changes: 0 additions & 7 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import collections.abc
import contextlib
import copy
import functools
import logging
from collections.abc import Collection, Iterable, Sequence
Expand Down Expand Up @@ -703,12 +702,6 @@ def get_outlet_defs(self):
extended/overridden by subclasses.
"""

def prepare_for_execution(self) -> BaseOperator:
"""Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
other = copy.copy(self)
other._lock_for_execution = True
return other

@prepare_lineage
def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
Expand Down
8 changes: 7 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def __deepcopy__(self, memo: dict[int, Any]):

def __getstate__(self):
state = dict(self.__dict__)
if self._log:
if "_log" in state:
del state["_log"]

return state
Expand Down Expand Up @@ -1219,6 +1219,12 @@ def get_serialized_fields(cls):

return cls.__serialized_fields

def prepare_for_execution(self) -> BaseOperator:
"""Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
other = copy.copy(self)
other._lock_for_execution = True
return other

def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Serialize; required by DAGNode."""
from airflow.serialization.enums import DagAttributeTypes
Expand Down
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
ti.task.execute(context) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
Expand Down
12 changes: 3 additions & 9 deletions task_sdk/tests/defintions/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,12 @@ def test_warnings_are_properly_propagated(self):
assert warning.filename == __file__

def test_setattr_performs_no_custom_action_at_execute_time(self, spy_agency):
from airflow.models.xcom_arg import XComArg

op = MockOperator(task_id="test_task")

op._lock_for_execution = True
# TODO: Task-SDK
# op_copy = op.prepare_for_execution()
op_copy = op

spy_agency.spy_on(XComArg.apply_upstream_relationship, call_original=False)
op_copy = op.prepare_for_execution()
spy_agency.spy_on(op._set_xcomargs_dependency, call_original=False)
op_copy.arg1 = "b"
assert XComArg.apply_upstream_relationship.called is False
assert op._set_xcomargs_dependency.called is False

def test_upstream_is_set_when_template_field_is_xcomarg(self):
with DAG("xcomargs_test", schedule=None):
Expand Down

0 comments on commit 1cf9ec3

Please sign in to comment.