Skip to content

Commit

Permalink
Fix Python-based decorators templating (apache#36103)
Browse files Browse the repository at this point in the history
Templating of Python-based decorators has been broken since
implementation. The decorators used template_fields definition
as defined originally in PythonOperator rather than the ones from
virtualenv because template fields were redefined in
_PythonDecoratedOperator class and they took precedence (MRU).

This PR add explicit copying of template_fields from the operators
that they are decorating.

Fixes: apache#36102
  • Loading branch information
potiuk authored Dec 7, 2023
1 parent 76d26f4 commit 3904206
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 4 deletions.
1 change: 1 addition & 0 deletions airflow/decorators/branch_external_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class _BranchExternalPythonDecoratedOperator(_PythonDecoratedOperator, BranchExternalPythonOperator):
"""Wraps a Python callable and captures args/kwargs when called for execution."""

template_fields = BranchExternalPythonOperator.template_fields
custom_operator_name: str = "@task.branch_external_python"


Expand Down
1 change: 1 addition & 0 deletions airflow/decorators/branch_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class _BranchPythonDecoratedOperator(_PythonDecoratedOperator, BranchPythonOperator):
"""Wraps a Python callable and captures args/kwargs when called for execution."""

template_fields = BranchPythonOperator.template_fields
custom_operator_name: str = "@task.branch"


Expand Down
1 change: 1 addition & 0 deletions airflow/decorators/branch_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class _BranchPythonVirtualenvDecoratedOperator(_PythonDecoratedOperator, BranchPythonVirtualenvOperator):
"""Wraps a Python callable and captures args/kwargs when called for execution."""

template_fields = BranchPythonVirtualenvOperator.template_fields
custom_operator_name: str = "@task.branch_virtualenv"


Expand Down
1 change: 1 addition & 0 deletions airflow/decorators/external_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class _PythonExternalDecoratedOperator(_PythonDecoratedOperator, ExternalPythonOperator):
"""Wraps a Python callable and captures args/kwargs when called for execution."""

template_fields = ExternalPythonOperator.template_fields
custom_operator_name: str = "@task.external_python"


Expand Down
1 change: 1 addition & 0 deletions airflow/decorators/python_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class _PythonVirtualenvDecoratedOperator(_PythonDecoratedOperator, PythonVirtualenvOperator):
"""Wraps a Python callable and captures args/kwargs when called for execution."""

template_fields = PythonVirtualenvOperator.template_fields
custom_operator_name: str = "@task.virtualenv"


Expand Down
1 change: 1 addition & 0 deletions airflow/decorators/short_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class _ShortCircuitDecoratedOperator(_PythonDecoratedOperator, ShortCircuitOperator):
"""Wraps a Python callable and captures args/kwargs when called for execution."""

template_fields = ShortCircuitOperator.template_fields
custom_operator_name: str = "@task.short_circuit"


Expand Down
1 change: 0 additions & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ def _do_render_template_fields(
f"{attr_name!r} is configured as a template field "
f"but {parent.task_type} does not have this attribute."
)

try:
if not value:
continue
Expand Down
9 changes: 6 additions & 3 deletions tests/decorators/test_branch_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class Test_BranchPythonVirtualenvDecoratedOperator:
# possibilities. So we are increasing the timeout for this test to 3x of the default timeout
@pytest.mark.execution_timeout(180)
@pytest.mark.parametrize("branch_task_name", ["task_1", "task_2"])
def test_branch_one(self, dag_maker, branch_task_name):
def test_branch_one(self, dag_maker, branch_task_name, tmp_path):
requirements_file = tmp_path / "requirements.txt"
requirements_file.write_text("funcsigs==0.4")

@task
def dummy_f():
pass
Expand All @@ -57,14 +60,14 @@ def branch_operator():

else:

@task.branch_virtualenv(task_id="branching", requirements=["funcsigs"])
@task.branch_virtualenv(task_id="branching", requirements="requirements.txt")
def branch_operator():
import funcsigs

print(f"We successfully imported funcsigs version {funcsigs.__version__}")
return "task_2"

with dag_maker():
with dag_maker(template_searchpath=tmp_path.as_posix()):
branchoperator = branch_operator()
df = dummy_f()
task_1 = task_1()
Expand Down
14 changes: 14 additions & 0 deletions tests/decorators/test_external_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ def f():

ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_with_templated_python(self, dag_maker, venv_python_with_dill):
# add template that produces empty string when rendered
templated_python_with_dill = venv_python_with_dill.as_posix() + "{{ '' }}"

@task.external_python(python=templated_python_with_dill, use_dill=True)
def f():
"""Import dill to double-check it is installed ."""
import dill # noqa: F401

with dag_maker():
ret = f()

ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_no_dill_installed_raises_exception_when_use_dill(self, dag_maker, venv_python):
@task.external_python(python=venv_python, use_dill=True)
def f():
Expand Down
26 changes: 26 additions & 0 deletions tests/decorators/test_python_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,32 @@ def f():

ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_with_requirements_file(self, dag_maker, tmp_path):
requirements_file = tmp_path / "requirements.txt"
requirements_file.write_text("funcsigs==0.4\nattrs==23.1.0")

@task.virtualenv(
system_site_packages=False,
requirements="requirements.txt",
python_version=PYTHON_VERSION,
use_dill=True,
)
def f():
import funcsigs

if funcsigs.__version__ != "0.4":
raise Exception

import attrs

if attrs.__version__ != "23.1.0":
raise Exception

with dag_maker(template_searchpath=tmp_path.as_posix()):
ret = f()

ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_unpinned_requirements(self, dag_maker):
@task.virtualenv(
system_site_packages=False,
Expand Down

0 comments on commit 3904206

Please sign in to comment.