Skip to content

Commit

Permalink
refactor(utils/decorators): rewrite remove task decorator to use ast
Browse files Browse the repository at this point in the history
  • Loading branch information
josix committed Oct 25, 2024
1 parent 476e77d commit d5e6684
Showing 1 changed file with 29 additions and 36 deletions.
65 changes: 29 additions & 36 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,48 @@
# under the License.
from __future__ import annotations

import ast
import sys
from collections import deque
from typing import Callable, TypeVar

T = TypeVar("T", bound=Callable)


class _TaskDecoratorRemover(ast.NodeTransformer):
def __init__(self, task_decorator_name):
self.decorators_to_remove = {
"setup",
"teardown",
"task.skip_if",
"task.run_if",
task_decorator_name,
}

def visit_FunctionDef(self, node):
node.decorator_list = [d for d in node.decorator_list if not self._is_task_decorator(d)]
return self.generic_visit(node)

def _is_task_decorator(self, decorator):
if isinstance(decorator, ast.Name):
return decorator.id in self.decorators_to_remove
elif isinstance(decorator, ast.Attribute):
return f"{decorator.value.id}.{decorator.attr}" in self.decorators_to_remove
elif isinstance(decorator, ast.Call):
return self._is_task_decorator(decorator.func)
return False


def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
"""
Remove @task or similar decorators as well as @setup and @teardown.
:param python_source: The python source code
:param task_decorator_name: the decorator name
TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
"""

def _remove_task_decorator(py_source, decorator_name):
# if no line starts with @decorator_name, we can early exit
for line in py_source.split("\n"):
if line.startswith(decorator_name):
break
else:
return python_source
split = python_source.split(decorator_name, 1)
before_decorator, after_decorator = split[0], split[1]
if after_decorator[0] == "(":
after_decorator = _balance_parens(after_decorator)
if after_decorator[0] == "\n":
after_decorator = after_decorator[1:]
return before_decorator + after_decorator

decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name]
for decorator in decorators:
python_source = _remove_task_decorator(python_source, decorator)
return python_source


def _balance_parens(after_decorator):
num_paren = 1
after_decorator = deque(after_decorator)
after_decorator.popleft()
while num_paren:
current = after_decorator.popleft()
if current == "(":
num_paren = num_paren + 1
elif current == ")":
num_paren = num_paren - 1
return "".join(after_decorator)
tree = ast.parse(python_source)
remover = _TaskDecoratorRemover(task_decorator_name.strip("@"))
tree = remover.visit(tree)
return ast.unparse(tree)


class _autostacklevel_warn:
Expand Down

0 comments on commit d5e6684

Please sign in to comment.