Skip to content

Commit 606b2ea

Browse files
committed
refactor(utils/decorators): rewrite remove task decorator to use ast
1 parent 476e77d commit 606b2ea

File tree

1 file changed

+31
-36
lines changed

1 file changed

+31
-36
lines changed

airflow/utils/decorators.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,55 +17,50 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20+
import ast
2021
import sys
21-
from collections import deque
2222
from typing import Callable, TypeVar
2323

2424
T = TypeVar("T", bound=Callable)
2525

2626

27+
class _TaskDecoratorRemover(ast.NodeTransformer):
28+
def __init__(self, task_decorator_name):
29+
self.decorators_to_remove = {
30+
"setup",
31+
"teardown",
32+
"task.skip_if",
33+
"task.run_if",
34+
task_decorator_name,
35+
}
36+
37+
def visit_FunctionDef(self, node):
38+
node.decorator_list = [
39+
decorator for decorator in node.decorator_list if not self._is_task_decorator(decorator)
40+
]
41+
return self.generic_visit(node)
42+
43+
def _is_task_decorator(self, decorator):
44+
if isinstance(decorator, ast.Name):
45+
return decorator.id in self.decorators_to_remove
46+
elif isinstance(decorator, ast.Attribute):
47+
return f"{decorator.value.id}.{decorator.attr}" in self.decorators_to_remove
48+
elif isinstance(decorator, ast.Call):
49+
return self._is_task_decorator(decorator.func)
50+
return False
51+
52+
2753
def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
2854
"""
2955
Remove @task or similar decorators as well as @setup and @teardown.
3056
3157
:param python_source: The python source code
3258
:param task_decorator_name: the decorator name
33-
34-
TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
3559
"""
36-
37-
def _remove_task_decorator(py_source, decorator_name):
38-
# if no line starts with @decorator_name, we can early exit
39-
for line in py_source.split("\n"):
40-
if line.startswith(decorator_name):
41-
break
42-
else:
43-
return python_source
44-
split = python_source.split(decorator_name, 1)
45-
before_decorator, after_decorator = split[0], split[1]
46-
if after_decorator[0] == "(":
47-
after_decorator = _balance_parens(after_decorator)
48-
if after_decorator[0] == "\n":
49-
after_decorator = after_decorator[1:]
50-
return before_decorator + after_decorator
51-
52-
decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name]
53-
for decorator in decorators:
54-
python_source = _remove_task_decorator(python_source, decorator)
55-
return python_source
56-
57-
58-
def _balance_parens(after_decorator):
59-
num_paren = 1
60-
after_decorator = deque(after_decorator)
61-
after_decorator.popleft()
62-
while num_paren:
63-
current = after_decorator.popleft()
64-
if current == "(":
65-
num_paren = num_paren + 1
66-
elif current == ")":
67-
num_paren = num_paren - 1
68-
return "".join(after_decorator)
60+
tree = ast.parse(python_source)
61+
remover = _TaskDecoratorRemover(task_decorator_name.strip("@"))
62+
mutated_tree = remover.visit(tree)
63+
return ast.unparse(mutated_tree)
6964

7065

7166
class _autostacklevel_warn:

0 commit comments

Comments
 (0)