|
17 | 17 | # under the License.
|
18 | 18 | from __future__ import annotations
|
19 | 19 |
|
| 20 | +import ast |
20 | 21 | import sys
|
21 |
| -from collections import deque |
22 | 22 | from typing import Callable, TypeVar
|
23 | 23 |
|
24 | 24 | T = TypeVar("T", bound=Callable)
|
25 | 25 |
|
26 | 26 |
|
| 27 | +class _TaskDecoratorRemover(ast.NodeTransformer): |
| 28 | + def __init__(self, task_decorator_name): |
| 29 | + self.decorators_to_remove = { |
| 30 | + "setup", |
| 31 | + "teardown", |
| 32 | + "task.skip_if", |
| 33 | + "task.run_if", |
| 34 | + task_decorator_name, |
| 35 | + } |
| 36 | + |
| 37 | + def visit_FunctionDef(self, node): |
| 38 | + node.decorator_list = [ |
| 39 | + decorator for decorator in node.decorator_list if not self._is_task_decorator(decorator) |
| 40 | + ] |
| 41 | + return self.generic_visit(node) |
| 42 | + |
| 43 | + def _is_task_decorator(self, decorator): |
| 44 | + if isinstance(decorator, ast.Name): |
| 45 | + return decorator.id in self.decorators_to_remove |
| 46 | + elif isinstance(decorator, ast.Attribute): |
| 47 | + return f"{decorator.value.id}.{decorator.attr}" in self.decorators_to_remove |
| 48 | + elif isinstance(decorator, ast.Call): |
| 49 | + return self._is_task_decorator(decorator.func) |
| 50 | + return False |
| 51 | + |
| 52 | + |
27 | 53 | def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
|
28 | 54 | """
|
29 | 55 | Remove @task or similar decorators as well as @setup and @teardown.
|
30 | 56 |
|
31 | 57 | :param python_source: The python source code
|
32 | 58 | :param task_decorator_name: the decorator name
|
33 |
| -
|
34 |
| - TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse |
35 | 59 | """
|
36 |
| - |
37 |
| - def _remove_task_decorator(py_source, decorator_name): |
38 |
| - # if no line starts with @decorator_name, we can early exit |
39 |
| - for line in py_source.split("\n"): |
40 |
| - if line.startswith(decorator_name): |
41 |
| - break |
42 |
| - else: |
43 |
| - return python_source |
44 |
| - split = python_source.split(decorator_name, 1) |
45 |
| - before_decorator, after_decorator = split[0], split[1] |
46 |
| - if after_decorator[0] == "(": |
47 |
| - after_decorator = _balance_parens(after_decorator) |
48 |
| - if after_decorator[0] == "\n": |
49 |
| - after_decorator = after_decorator[1:] |
50 |
| - return before_decorator + after_decorator |
51 |
| - |
52 |
| - decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] |
53 |
| - for decorator in decorators: |
54 |
| - python_source = _remove_task_decorator(python_source, decorator) |
55 |
| - return python_source |
56 |
| - |
57 |
| - |
58 |
| -def _balance_parens(after_decorator): |
59 |
| - num_paren = 1 |
60 |
| - after_decorator = deque(after_decorator) |
61 |
| - after_decorator.popleft() |
62 |
| - while num_paren: |
63 |
| - current = after_decorator.popleft() |
64 |
| - if current == "(": |
65 |
| - num_paren = num_paren + 1 |
66 |
| - elif current == ")": |
67 |
| - num_paren = num_paren - 1 |
68 |
| - return "".join(after_decorator) |
| 60 | + tree = ast.parse(python_source) |
| 61 | + remover = _TaskDecoratorRemover(task_decorator_name.strip("@")) |
| 62 | + mutated_tree = remover.visit(tree) |
| 63 | + return ast.unparse(mutated_tree) |
69 | 64 |
|
70 | 65 |
|
71 | 66 | class _autostacklevel_warn:
|
|
0 commit comments