diff --git a/CHANGELOG.md b/CHANGELOG.md index 51652009..b876101a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ # Changelog *[CalVer, YY.month.patch](https://calver.org/)* +## 24.5.1 +- Add ASYNC912: no checkpoints in with statement are guaranteed to run. +- ASYNC100 now properly treats async for comprehensions as checkpoints. +- ASYNC100 now supports autofixing on asyncio. + ## 24.4.2 - Add ASYNC119: yield in contextmanager in async generator. diff --git a/README.md b/README.md index 654d4067..751ada11 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ Note: 22X, 23X and 24X has not had asyncio-specific suggestions written. - **ASYNC910**: Exit or `return` from async function with no guaranteed checkpoint or exception since function definition. You might want to enable this on a codebase to make it easier to reason about checkpoints, and make the logic of ASYNC911 correct. - **ASYNC911**: Exit, `yield` or `return` from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition) Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit). +- **ASYNC912**: Timeout/Cancelscope has no awaits that are guaranteed to run. If the scope has no checkpoints at all, then `ASYNC100` will be raised instead. ### Removed Warnings - **TRIOxxx**: All error codes are now renamed ASYNCxxx diff --git a/docs/rules.rst b/docs/rules.rst index f1eea69a..9fa52506 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -55,6 +55,7 @@ Optional rules disabled by default - **ASYNC910**: Exit or ``return`` from async function with no guaranteed checkpoint or exception since function definition. You might want to enable this on a codebase to make it easier to reason about checkpoints, and make the logic of ASYNC911 correct. - **ASYNC911**: Exit, ``yield`` or ``return`` from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition) Checkpoints are ``await``, ``async for``, and ``async with`` (on one of enter/exit). +-- **ASYNC912**: A timeout/cancelscope has checkpoints, but they're not guaranteed to run. Similar to ASYNC100, but it does not warn on trivial cases where there is no checkpoint at all. It instead shares logic with ASYNC910 and ASYNC911 for parsing conditionals and branches. Removed rules ================ diff --git a/flake8_async/__init__.py b/flake8_async/__init__.py index 5b4235a7..b808ba5d 100644 --- a/flake8_async/__init__.py +++ b/flake8_async/__init__.py @@ -37,7 +37,7 @@ # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" -__version__ = "24.4.2" +__version__ = "24.5.1" # taken from https://github.com/Zac-HD/shed diff --git a/flake8_async/visitors/__init__.py b/flake8_async/visitors/__init__.py index bd858eca..0b05011f 100644 --- a/flake8_async/visitors/__init__.py +++ b/flake8_async/visitors/__init__.py @@ -30,7 +30,6 @@ from . import ( visitor2xx, visitor91x, - visitor100, visitor101, visitor102, visitor103_104, diff --git a/flake8_async/visitors/flake8asyncvisitor.py b/flake8_async/visitors/flake8asyncvisitor.py index 160bedf9..7bce4a9a 100644 --- a/flake8_async/visitors/flake8asyncvisitor.py +++ b/flake8_async/visitors/flake8asyncvisitor.py @@ -98,7 +98,11 @@ def error( ), "No error code defined, but class has multiple codes" error_code = next(iter(self.error_codes)) # don't emit an error if this code is disabled in a multi-code visitor - elif strip_error_subidentifier(error_code) not in self.options.enabled_codes: + elif ( + (ec_no_sub := strip_error_subidentifier(error_code)) + not in self.options.enabled_codes + and ec_no_sub not in self.options.autofix_codes + ): return self.__state.problems.append( @@ -217,7 +221,11 @@ def error( error_code = next(iter(self.error_codes)) # don't emit an error if this code is disabled in a multi-code visitor # TODO: write test for only one of 910/911 enabled/autofixed - elif strip_error_subidentifier(error_code) not in self.options.enabled_codes: + elif ( + (ec_no_sub := strip_error_subidentifier(error_code)) + not in self.options.enabled_codes + and ec_no_sub not in self.options.autofix_codes + ): return False # pragma: no cover if self.is_noqa(node, error_code): @@ -237,7 +245,7 @@ def error( return True def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: - if code is None: + if code is None: # pragma: no cover assert len(self.error_codes) == 1 code = next(iter(self.error_codes)) # this does not currently need to check for `noqa`s, as error() does that diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index f8521b3b..d33e0992 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -51,6 +51,7 @@ def error_class_cst(error_class: type[T_CST]) -> type[T_CST]: def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]: + """Default-disables all error codes in a class.""" assert error_class.error_codes # type: ignore[attr-defined] default_disabled_error_codes.extend( error_class.error_codes # type: ignore[attr-defined] @@ -58,6 +59,11 @@ def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]: return error_class +def disable_codes_by_default(*codes: str) -> None: + """Default-disables only specified codes.""" + default_disabled_error_codes.extend(codes) + + def utility_visitor(c: type[T]) -> type[T]: assert not hasattr(c, "error_codes") c.error_codes = {} @@ -317,30 +323,68 @@ class AttributeCall(NamedTuple): function: str +# the custom __or__ in libcst breaks pyright type checking. It's possible to use +# `Union` as a workaround ... except pyupgrade will automatically replace that. +# So we have to resort to specifying one of the base classes. +# See https://github.com/Instagram/LibCST/issues/1143 +def build_cst_matcher(attr: str) -> m.BaseExpression: + """Build a cst matcher structure with attributes&names matching a string `a.b.c`.""" + if "." not in attr: + return m.Name(value=attr) + body, tail = attr.rsplit(".") + return m.Attribute(value=build_cst_matcher(body), attr=m.Name(value=tail)) + + +def identifier_to_string(attr: cst.Name | cst.Attribute) -> str: + if isinstance(attr, cst.Name): + return attr.value + assert isinstance(attr.value, (cst.Attribute, cst.Name)) + return identifier_to_string(attr.value) + "." + attr.attr.value + + def with_has_call( node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio") ) -> list[AttributeCall]: + """Check if a with statement has a matching call, returning a list with matches. + + `names` specify the names of functions to match, `base` specifies the + library/module(s) the function must be in. + The list elements in the return value are named tuples with the matched node, + base and function. + + Examples_ + + `with_has_call(node, "bar", base="foo")` matches foo.bar. + `with_has_call(node, "bar", "bee", base=("foo", "a.b.c")` matches + `foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`. + + """ + if isinstance(base, str): + base = (base,) # pragma: no cover + + # build matcher, using SaveMatchedNode to save the base and the function name. + matcher = m.Call( + func=m.Attribute( + value=m.SaveMatchedNode( + m.OneOf(*(build_cst_matcher(b) for b in base)), name="base" + ), + attr=m.SaveMatchedNode( + oneof_names(*names), + name="function", + ), + ) + ) + res_list: list[AttributeCall] = [] for item in node.items: - if res := m.extract( - item.item, - m.Call( - func=m.Attribute( - value=m.SaveMatchedNode(m.Name(), name="library"), - attr=m.SaveMatchedNode( - oneof_names(*names), - name="function", - ), - ) - ), - ): + if res := m.extract(item.item, matcher): assert isinstance(item.item, cst.Call) - assert isinstance(res["library"], cst.Name) + assert isinstance(res["base"], (cst.Name, cst.Attribute)) assert isinstance(res["function"], cst.Name) - if res["library"].value not in base: - continue res_list.append( - AttributeCall(item.item, res["library"].value, res["function"].value) + AttributeCall( + item.item, identifier_to_string(res["base"]), res["function"].value + ) ) return res_list diff --git a/flake8_async/visitors/visitor100.py b/flake8_async/visitors/visitor100.py deleted file mode 100644 index 345f8926..00000000 --- a/flake8_async/visitors/visitor100.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Contains visitor for ASYNC100. - -A `with trio.fail_after(...):` or `with trio.move_on_after(...):` -context does not contain any `await` statements. This makes it pointless, as -the timeout can only be triggered by a checkpoint. -Checkpoints on Await, Async For and Async With -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import libcst as cst -import libcst.matchers as m - -from .flake8asyncvisitor import Flake8AsyncVisitor_cst -from .helpers import ( - AttributeCall, - error_class_cst, - flatten_preserving_comments, - with_has_call, -) - -if TYPE_CHECKING: - from collections.abc import Mapping - - -@error_class_cst -class Visitor100_libcst(Flake8AsyncVisitor_cst): - error_codes: Mapping[str, str] = { - "ASYNC100": ( - "{0}.{1} context contains no checkpoints, remove the context or add" - " `await {0}.lowlevel.checkpoint()`." - ), - } - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self.has_checkpoint_stack: list[bool] = [] - self.node_dict: dict[cst.With, list[AttributeCall]] = {} - - def checkpoint(self) -> None: - # Set the whole stack to True. - self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) - - def visit_With(self, node: cst.With) -> None: - if m.matches(node, m.With(asynchronous=m.Asynchronous())): - self.checkpoint() - if res := with_has_call( - node, "fail_after", "fail_at", "move_on_after", "move_on_at", "CancelScope" - ): - self.node_dict[node] = res - - self.has_checkpoint_stack.append(False) - else: - self.has_checkpoint_stack.append(True) - - def leave_With( - self, original_node: cst.With, updated_node: cst.With - ) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]: - if not self.has_checkpoint_stack.pop(): - autofix = len(updated_node.items) == 1 - for res in self.node_dict[original_node]: - autofix &= self.error( - res.node, res.base, res.function - ) and self.should_autofix(res.node) - - if autofix: - return flatten_preserving_comments(updated_node) - - return updated_node - - def visit_For(self, node: cst.For): - if node.asynchronous is not None: - self.checkpoint() - - def visit_Await(self, node: cst.Await | cst.Yield): - self.checkpoint() - - visit_Yield = visit_Await - - def visit_FunctionDef(self, node: cst.FunctionDef): - self.save_state(node, "has_checkpoint_stack", copy=True) - self.has_checkpoint_stack = [] - - def leave_FunctionDef( - self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef - ) -> cst.FunctionDef: - self.restore_state(original_node) - return updated_node diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 988bc3a9..bbfd696b 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -18,19 +18,32 @@ from ..base import Statement from .flake8asyncvisitor import Flake8AsyncVisitor_cst from .helpers import ( - disabled_by_default, + AttributeCall, + cancel_scope_names, + disable_codes_by_default, error_class_cst, + flatten_preserving_comments, fnmatch_qualified_name_cst, func_has_decorator, iter_guaranteed_once_cst, + with_has_call, ) if TYPE_CHECKING: from collections.abc import Mapping, Sequence +class ArtificialStatement(Statement): + """Statement that should not trigger 910/911 on function exit. + + Used by loops and `with` statements. + """ + + # Statement injected at the start of loops to track missed checkpoints. -ARTIFICIAL_STATEMENT = Statement("artificial", -1) +ARTIFICIAL_STATEMENT = ArtificialStatement("artificial", -1) +# There's no particular reason why loops use a globally instanced statement, but +# `with` does not - mostly just an artifact of them being implemented at different times. def func_empty_body(node: cst.FunctionDef) -> bool: @@ -231,8 +244,10 @@ def leave_Yield( leave_Return = leave_Yield # type: ignore +disable_codes_by_default("ASYNC910", "ASYNC911", "ASYNC912") + + @error_class_cst -@disabled_by_default class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): error_codes: Mapping[str, str] = { "ASYNC910": ( @@ -243,6 +258,14 @@ class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): "{0} from async iterable with no guaranteed checkpoint since {1.name} " "on line {1.lineno}." ), + "ASYNC912": ( + "CancelScope with no guaranteed checkpoint. This makes it potentially " + "impossible to cancel." + ), + "ASYNC100": ( + "{0}.{1} context contains no checkpoints, remove the context or add" + " `await {0}.lowlevel.checkpoint()`." + ), } def __init__(self, *args: Any, **kwargs: Any): @@ -256,15 +279,24 @@ def __init__(self, *args: Any, **kwargs: Any): self.loop_state = LoopState() self.try_state = TryState() + # ASYNC100 + self.has_checkpoint_stack: list[bool] = [] + self.node_dict: dict[cst.With, list[AttributeCall]] = {} + def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: + if code is None: # pragma: no branch + code = "ASYNC911" if self.has_yield else "ASYNC910" + return ( not self.noautofix - and super().should_autofix( - node, "ASYNC911" if self.has_yield else "ASYNC910" - ) + and super().should_autofix(node, code) and self.library != ("asyncio",) ) + def checkpoint(self) -> None: + self.uncheckpointed_statements = set() + self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) + def checkpoint_statement(self) -> cst.SimpleStatementLine: return checkpoint_statement(self.library[0]) @@ -283,9 +315,11 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: "uncheckpointed_statements", "loop_state", "try_state", + "has_checkpoint_stack", copy=True, ) self.uncheckpointed_statements = set() + self.has_checkpoint_stack = [] self.has_yield = self.safe_decorator = False self.loop_state = LoopState() @@ -359,7 +393,7 @@ def check_function_exit( any_errors = False # raise the actual errors for statement in self.uncheckpointed_statements: - if statement == ARTIFICIAL_STATEMENT: + if isinstance(statement, ArtificialStatement): continue any_errors |= self.error_91x(original_node, statement) @@ -376,6 +410,7 @@ def leave_Return( self.add_statement = self.checkpoint_statement() # avoid duplicate error messages self.uncheckpointed_statements = set() + # we don't treat it as a checkpoint for ASYNC100 # return original node to avoid problems with identity equality assert original_node.deep_equals(updated_node) @@ -386,7 +421,7 @@ def error_91x( node: cst.Return | cst.FunctionDef | cst.Yield, statement: Statement, ) -> bool: - assert statement != ARTIFICIAL_STATEMENT + assert not isinstance(statement, ArtificialStatement) if isinstance(node, cst.FunctionDef): msg = "exit" @@ -407,7 +442,7 @@ def leave_Await( # so only set checkpoint after the await node # all nodes are now checkpointed - self.uncheckpointed_statements = set() + self.checkpoint() return updated_node # raising exception means we don't need to checkpoint so we can treat it as one @@ -419,9 +454,54 @@ def leave_Await( # missing-checkpoint warning when there might in fact be one (i.e. a false alarm). def visit_With_body(self, node: cst.With): if getattr(node, "asynchronous", None): - self.uncheckpointed_statements = set() - - leave_With_body = visit_With_body + self.checkpoint() + if res := ( + with_has_call(node, *cancel_scope_names) + or with_has_call( + node, "timeout", "timeout_at", base=("asyncio", "asyncio.timeouts") + ) + ): + pos = self.get_metadata(PositionProvider, node).start # pyright: ignore + line: int = pos.line # pyright: ignore + column: int = pos.column # pyright: ignore + self.uncheckpointed_statements.add( + ArtificialStatement("with", line, column) + ) + self.node_dict[node] = res + self.has_checkpoint_stack.append(False) + else: + self.has_checkpoint_stack.append(True) + + def leave_With(self, original_node: cst.With, updated_node: cst.With): + # Uses leave_With instead of leave_With_body because we need access to both + # original and updated node + # ASYNC100 + if not self.has_checkpoint_stack.pop(): + autofix = len(updated_node.items) == 1 + for res in self.node_dict[original_node]: + # bypass 910 & 911's should_autofix logic, which excludes asyncio + # (TODO: and uses self.noautofix ... which I don't remember what it's for) + autofix &= self.error( + res.node, res.base, res.function, error_code="ASYNC100" + ) and super().should_autofix(res.node, code="ASYNC100") + + if autofix: + return flatten_preserving_comments(updated_node) + # ASYNC912 + else: + pos = self.get_metadata( # pyright: ignore + PositionProvider, original_node + ).start # pyright: ignore + line: int = pos.line # pyright: ignore + column: int = pos.column # pyright: ignore + s = ArtificialStatement("with", line, column) + if s in self.uncheckpointed_statements: + self.uncheckpointed_statements.remove(s) + for res in self.node_dict[original_node]: + self.error(res.node, error_code="ASYNC912") + if getattr(original_node, "asynchronous", None): + self.checkpoint() + return updated_node # error if no checkpoint since earlier yield or function entry def leave_Yield( @@ -431,6 +511,10 @@ def leave_Yield( return updated_node self.has_yield = True + # Treat as a checkpoint for ASYNC100, since the context we yield to + # may checkpoint. + self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) + if self.check_function_exit(original_node) and self.should_autofix( original_node ): @@ -605,7 +689,7 @@ def visit_While_body(self, node: cst.For | cst.While): # appropriate errors if the loop doesn't checkpoint if getattr(node, "asynchronous", None): - self.uncheckpointed_statements = set() + self.checkpoint() else: self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT} @@ -651,7 +735,7 @@ def leave_While_body(self, node: cst.For | cst.While): # AsyncFor guarantees checkpoint on running out of iterable # so reset checkpoint state at end of loop. (but not state at break) if getattr(node, "asynchronous", None): - self.uncheckpointed_statements = set() + self.checkpoint() else: # enter orelse with worst case: # loop body might execute fully before entering orelse @@ -780,7 +864,7 @@ def visit_CompFor(self, node: cst.CompFor): # if async comprehension, checkpoint if node.asynchronous: - self.uncheckpointed_statements = set() + self.checkpoint() self.comp_unknown = False return False diff --git a/tests/autofix_files/async100.py b/tests/autofix_files/async100.py index db6733b5..bd3d2809 100644 --- a/tests/autofix_files/async100.py +++ b/tests/autofix_files/async100.py @@ -71,10 +71,11 @@ async def foo(): ... -# Seems like the inner context manager 'hides' the checkpoint. +# The outer cancelscope can get triggered in more complex cases, so +# to avoid false positives we don't raise a warning. async def does_contain_checkpoints(): - with trio.fail_after(1): # false-alarm ASYNC100 - with trio.CancelScope(): # or any other context manager + with trio.fail_after(1): + with trio.CancelScope(): await trio.sleep_forever() diff --git a/tests/autofix_files/async100_asyncio.py b/tests/autofix_files/async100_asyncio.py new file mode 100644 index 00000000..cfd0121e --- /dev/null +++ b/tests/autofix_files/async100_asyncio.py @@ -0,0 +1,22 @@ +# TRIO_NO_ERROR +# ANYIO_NO_ERROR +# BASE_LIBRARY asyncio + +# timeout[_at] re-exported in the main asyncio namespace in py3.11 +# mypy: disable-error-code=attr-defined +# AUTOFIX + +import asyncio +import asyncio.timeouts + + +async def foo(): + # error: 9, "asyncio", "timeout_at" + ... + # error: 9, "asyncio", "timeout" + ... + + # error: 9, "asyncio.timeouts", "timeout_at" + ... + # error: 9, "asyncio.timeouts", "timeout" + ... diff --git a/tests/autofix_files/async100_asyncio.py.diff b/tests/autofix_files/async100_asyncio.py.diff new file mode 100644 index 00000000..f083238a --- /dev/null +++ b/tests/autofix_files/async100_asyncio.py.diff @@ -0,0 +1,23 @@ +--- ++++ +@@ x,12 x,12 @@ + + + async def foo(): +- with asyncio.timeout_at(10): # error: 9, "asyncio", "timeout_at" +- ... +- with asyncio.timeout(10): # error: 9, "asyncio", "timeout" +- ... ++ # error: 9, "asyncio", "timeout_at" ++ ... ++ # error: 9, "asyncio", "timeout" ++ ... + +- with asyncio.timeouts.timeout_at(10): # error: 9, "asyncio.timeouts", "timeout_at" +- ... +- with asyncio.timeouts.timeout(10): # error: 9, "asyncio.timeouts", "timeout" +- ... ++ # error: 9, "asyncio.timeouts", "timeout_at" ++ ... ++ # error: 9, "asyncio.timeouts", "timeout" ++ ... diff --git a/tests/autofix_files/async910.py b/tests/autofix_files/async910.py index 0415a7a4..95922ab7 100644 --- a/tests/autofix_files/async910.py +++ b/tests/autofix_files/async910.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX # mypy: disable-error-code="unreachable" from __future__ import annotations diff --git a/tests/autofix_files/async911.py b/tests/autofix_files/async911.py index 720a9811..a91a322a 100644 --- a/tests/autofix_files/async911.py +++ b/tests/autofix_files/async911.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX from typing import Any import pytest diff --git a/tests/autofix_files/async91x_autofix.py b/tests/autofix_files/async91x_autofix.py index d7d99b3a..35fe6ff2 100644 --- a/tests/autofix_files/async91x_autofix.py +++ b/tests/autofix_files/async91x_autofix.py @@ -1,4 +1,6 @@ # AUTOFIX +# asyncio will raise the same errors, but does not have autofix available +# ASYNCIO_NO_AUTOFIX from __future__ import annotations """Docstring for file @@ -124,3 +126,21 @@ async def async_func(): ... break [... for i in range(5)] return + + +# TODO: issue 240 +async def livelocks(): + while True: + ... + + +# this will autofix 910 by adding a checkpoint outside the loop, which doesn't actually +# help, and the method still isn't guaranteed to checkpoint in case bar() always returns +# True. +async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno) + while bar(): + try: + await foo("1") # type: ignore[call-arg] + except TypeError: + ... + await trio.lowlevel.checkpoint() diff --git a/tests/autofix_files/async91x_autofix.py.diff b/tests/autofix_files/async91x_autofix.py.diff index 2c84b107..e6e6625d 100644 --- a/tests/autofix_files/async91x_autofix.py.diff +++ b/tests/autofix_files/async91x_autofix.py.diff @@ -78,3 +78,8 @@ yield # ASYNC911: 8, "yield", Statement("function definition", lineno-2) # ASYNC911: 8, "yield", Statement("yield", lineno) async def bar(): +@@ x,3 x,4 @@ + await foo("1") # type: ignore[call-arg] + except TypeError: + ... ++ await trio.lowlevel.checkpoint() diff --git a/tests/autofix_files/noqa_testing.py b/tests/autofix_files/noqa_testing.py index 9bb4e456..b55942c4 100644 --- a/tests/autofix_files/noqa_testing.py +++ b/tests/autofix_files/noqa_testing.py @@ -1,4 +1,7 @@ +# TODO: When was this file added? Why? + # AUTOFIX +# ASYNCIO_NO_AUTOFIX # ARG --enable=ASYNC911 import trio diff --git a/tests/eval_files/async100.py b/tests/eval_files/async100.py index c9a00b53..226e34d1 100644 --- a/tests/eval_files/async100.py +++ b/tests/eval_files/async100.py @@ -71,10 +71,11 @@ async def foo(): ... -# Seems like the inner context manager 'hides' the checkpoint. +# The outer cancelscope can get triggered in more complex cases, so +# to avoid false positives we don't raise a warning. async def does_contain_checkpoints(): - with trio.fail_after(1): # false-alarm ASYNC100 - with trio.CancelScope(): # or any other context manager + with trio.fail_after(1): + with trio.CancelScope(): await trio.sleep_forever() diff --git a/tests/eval_files/async100_asyncio.py b/tests/eval_files/async100_asyncio.py index 9dd743c8..494803ab 100644 --- a/tests/eval_files/async100_asyncio.py +++ b/tests/eval_files/async100_asyncio.py @@ -1,23 +1,22 @@ # TRIO_NO_ERROR # ANYIO_NO_ERROR # BASE_LIBRARY asyncio -# ASYNCIO_NO_ERROR # TODO + +# timeout[_at] re-exported in the main asyncio namespace in py3.11 +# mypy: disable-error-code=attr-defined +# AUTOFIX import asyncio import asyncio.timeouts async def foo(): - # py>=3.11 re-exports these in the main asyncio namespace - with asyncio.timeout_at(10): # type: ignore[attr-defined] - ... - with asyncio.timeout_at(10): # type: ignore[attr-defined] + with asyncio.timeout_at(10): # error: 9, "asyncio", "timeout_at" ... - with asyncio.timeout(10): # type: ignore[attr-defined] + with asyncio.timeout(10): # error: 9, "asyncio", "timeout" ... - with asyncio.timeouts.timeout_at(10): - ... - with asyncio.timeouts.timeout_at(10): + + with asyncio.timeouts.timeout_at(10): # error: 9, "asyncio.timeouts", "timeout_at" ... - with asyncio.timeouts.timeout(10): + with asyncio.timeouts.timeout(10): # error: 9, "asyncio.timeouts", "timeout" ... diff --git a/tests/eval_files/async910.py b/tests/eval_files/async910.py index 11600888..68aee89f 100644 --- a/tests/eval_files/async910.py +++ b/tests/eval_files/async910.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX # mypy: disable-error-code="unreachable" from __future__ import annotations diff --git a/tests/eval_files/async911.py b/tests/eval_files/async911.py index 8a19c525..b6d256de 100644 --- a/tests/eval_files/async911.py +++ b/tests/eval_files/async911.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX from typing import Any import pytest diff --git a/tests/eval_files/async912.py b/tests/eval_files/async912.py new file mode 100644 index 00000000..c2abf045 --- /dev/null +++ b/tests/eval_files/async912.py @@ -0,0 +1,184 @@ +# ASYNCIO_NO_ERROR +# ARG --enable=ASYNC100,ASYNC912 +# asyncio is tested in async912_asyncio. Cancelscopes in anyio are named the same +# as in trio, so they're also tested with this file. + +# ASYNC100 has autofixes, but ASYNC912 does not. This leaves us with the option +# of not testing both in the same file, or running with NOAUTOFIX. +# NOAUTOFIX + +from typing import TypeVar + +import trio + + +def bar() -> bool: + return False + + +async def foo(): + # trivial cases where there is absolutely no `await` only triggers ASYNC100 + with trio.move_on_after(0.1): # ASYNC100: 9, "trio", "move_on_after" + ... + with trio.move_on_at(0.1): # ASYNC100: 9, "trio", "move_on_at" + ... + with trio.fail_after(0.1): # ASYNC100: 9, "trio", "fail_after" + ... + with trio.fail_at(0.1): # ASYNC100: 9, "trio", "fail_at" + ... + with trio.CancelScope(0.1): # ASYNC100: 9, "trio", "CancelScope" + ... + + # conditional cases trigger ASYNC912 + with trio.move_on_after(0.1): # ASYNC912: 9 + if bar(): + await trio.lowlevel.checkpoint() + with trio.move_on_at(0.1): # ASYNC912: 9 + while bar(): + await trio.lowlevel.checkpoint() + with trio.fail_after(0.1): # ASYNC912: 9 + try: + await trio.lowlevel.checkpoint() + except: + ... + with trio.fail_at(0.1): # ASYNC912: 9 + if bar(): + await trio.lowlevel.checkpoint() + with trio.CancelScope(0.1): # ASYNC912: 9 + if bar(): + await trio.lowlevel.checkpoint() + # ASYNC912 generally shares the same logic as other 91x codes, check respective + # eval files for more comprehensive tests. + + # check we don't trigger on all context managers + with open(""): + ... + + # don't error with guaranteed checkpoint + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + with trio.move_on_after(0.1): + if bar(): + await trio.lowlevel.checkpoint() + else: + await trio.lowlevel.checkpoint() + + # both scopes error in nested cases + with trio.move_on_after(0.1): # ASYNC912: 9 + with trio.move_on_after(0.1): # ASYNC912: 13 + if bar(): + await trio.lowlevel.checkpoint() + + # We don't know which cancelscope will trigger first, so to avoid false + # alarms on tricky-but-valid cases we don't raise any error for the outer one. + with trio.move_on_after(0.1): + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + + with trio.move_on_after(0.1): + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + await trio.lowlevel.checkpoint() + + # check correct line gives error + # fmt: off + with ( + # a + # b + trio.move_on_after(0.1) # ASYNC912: 12 + # c + ): + if bar(): + await trio.lowlevel.checkpoint() + + with ( + open(""), + trio.move_on_at(5), # ASYNC912: 12 + open(""), + ): + if bar(): + await trio.lowlevel.checkpoint() + # fmt: on + + # error on each call with multiple matching calls in the same with + with ( + trio.move_on_after(0.1), # ASYNC912: 8 + trio.fail_at(5), # ASYNC912: 8 + ): + if bar(): + await trio.lowlevel.checkpoint() + + # wrapped calls do not raise errors + T = TypeVar("T") + + def customWrapper(a: T) -> T: + return a + + with customWrapper(trio.fail_at(10)): + ... + with (res := trio.fail_at(10)): + ... + # but saving with `as` does + with trio.fail_at(10) as res: # ASYNC912: 9 + if bar(): + await trio.lowlevel.checkpoint() + + +# TODO: issue #240 +async def livelocks(): + with trio.move_on_after(0.1): # should error + while True: + try: + await trio.sleep("1") # type: ignore + except TypeError: + pass + + +def condition() -> bool: + return True + + +async def livelocks_2(): + with trio.move_on_after(0.1): # ASYNC912: 9 + while condition(): + try: + await trio.sleep("1") # type: ignore + except TypeError: + pass + + +# TODO: add --async912-context-managers= +async def livelocks_3(): + import contextlib + + with trio.move_on_after(0.1): # should error + while True: + with contextlib.suppress(TypeError): + await trio.sleep("1") # type: ignore + + +# raises an error...? +with trio.move_on_after(10): # ASYNC100: 5, "trio", "move_on_after" + ... + + +# completely sync function ... is this something we care about? +def sync_func(): + with trio.move_on_after(10): + ... + + +async def check_yield_logic(): + # Does not raise any of async100 or async912, as the yield is treated + # as a checkpoint because the parent context may checkpoint. + with trio.move_on_after(1): + yield + with trio.move_on_after(1): + if bar(): + await trio.lowlevel.checkpoint() + yield diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py new file mode 100644 index 00000000..ef9200bf --- /dev/null +++ b/tests/eval_files/async912_asyncio.py @@ -0,0 +1,76 @@ +# ARG --enable=ASYNC100,ASYNC912 +# BASE_LIBRARY asyncio +# ANYIO_NO_ERROR +# TRIO_NO_ERROR + +# ASYNC100 supports autofix, but ASYNC912 doesn't, so we must run with NOAUTOFIX +# NOAUTOFIX + +# timeout[_at] re-exported in the main asyncio namespace in py3.11 +# mypy: disable-error-code=attr-defined + +import asyncio + +from typing import Any + + +def bar() -> bool: + return False + + +def customWrapper(a: object) -> object: ... + + +async def foo(): + # async100 + async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout" + ... + async with asyncio.timeout_at(10): # ASYNC100: 15, "asyncio", "timeout_at" + ... + async with asyncio.timeouts.timeout( # ASYNC100: 15, "asyncio.timeouts", "timeout" + 10 + ): + ... + async with asyncio.timeouts.timeout_at( # ASYNC100: 15, "asyncio.timeouts", "timeout_at" + 10 + ): + ... + + # no errors + async with asyncio.timeout(10): + await foo() + async with asyncio.timeout_at(10): + await foo() + + # async912 + async with asyncio.timeout_at(10): # ASYNC912: 15 + if bar(): + await foo() + async with asyncio.timeout(10): # ASYNC912: 15 + if bar(): + await foo() + + async with asyncio.timeouts.timeout(10): # ASYNC912: 15 + if bar(): + await foo() + async with asyncio.timeouts.timeout_at(10): # ASYNC912: 15 + if bar(): + await foo() + + # double check that helper methods used by visitor don't trigger erroneously + timeouts: Any + timeout_at: Any + async with asyncio.timeout_at.timeouts(10): + ... + async with timeouts.asyncio.timeout_at(10): + ... + async with timeouts.timeout_at.asyncio(10): + ... + async with timeout_at.asyncio.timeouts(10): + ... + async with timeout_at.timeouts.asyncio(10): + ... + async with foo.timeout(10): + ... + async with asyncio.timeouts(10): + ... diff --git a/tests/eval_files/async91x_autofix.py b/tests/eval_files/async91x_autofix.py index de650311..7ce0a359 100644 --- a/tests/eval_files/async91x_autofix.py +++ b/tests/eval_files/async91x_autofix.py @@ -1,4 +1,6 @@ # AUTOFIX +# asyncio will raise the same errors, but does not have autofix available +# ASYNCIO_NO_AUTOFIX from __future__ import annotations """Docstring for file @@ -109,3 +111,20 @@ async def async_func(): ... break [... for i in range(5)] return + + +# TODO: issue 240 +async def livelocks(): + while True: + ... + + +# this will autofix 910 by adding a checkpoint outside the loop, which doesn't actually +# help, and the method still isn't guaranteed to checkpoint in case bar() always returns +# True. +async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno) + while bar(): + try: + await foo("1") # type: ignore[call-arg] + except TypeError: + ... diff --git a/tests/eval_files/noqa_testing.py b/tests/eval_files/noqa_testing.py index 1c6ea8f5..1a1a3440 100644 --- a/tests/eval_files/noqa_testing.py +++ b/tests/eval_files/noqa_testing.py @@ -1,4 +1,7 @@ +# TODO: When was this file added? Why? + # AUTOFIX +# ASYNCIO_NO_AUTOFIX # ARG --enable=ASYNC911 import trio diff --git a/tests/test_flake8_async.py b/tests/test_flake8_async.py index a345a5ed..cad9d2a2 100644 --- a/tests/test_flake8_async.py +++ b/tests/test_flake8_async.py @@ -109,20 +109,32 @@ def check_autofix( plugin: Plugin, unfixed_code: str, generate_autofix: bool, + magic_markers: MagicMarkers, library: str = "trio", - base_library: str = "trio", ): + base_library = magic_markers.BASE_LIBRARY # the source code after it's been visited by current transformers visited_code = plugin.module.code - if "# AUTOFIX" not in unfixed_code: - # if the file is specifically marked with NOAUTOFIX, that means it has visitors - # that will autofix with --autofix, but the file explicitly doesn't want to check - # the result of doing that. THIS IS DANGEROUS - if "# NOAUTOFIX" in unfixed_code: - print(f"eval file {test} marked with dangerous marker NOAUTOFIX") - else: - assert unfixed_code == visited_code + # if the file is specifically marked with NOAUTOFIX, that means it has visitors + # that will autofix with --autofix, but the file explicitly doesn't want to check + # the result of doing that. THIS IS DANGEROUS + assert not (magic_markers.AUTOFIX and magic_markers.NOAUTOFIX) + if magic_markers.NOAUTOFIX: + print(f"eval file {test} marked with dangerous marker NOAUTOFIX") + return + + if ( + # not marked for autofixing + not magic_markers.AUTOFIX + # file+library does not raise errors + or magic_markers.library_no_error(library) + # code raises errors on asyncio, but does not support autofixing for it + or (library == "asyncio" and magic_markers.ASYNCIO_NO_AUTOFIX) + ): + assert ( + unfixed_code == visited_code + ), "Code changed after visiting, but magic markers say it shouldn't change." return # the full generated source code, saved from a previous run @@ -196,9 +208,22 @@ class MagicMarkers: ANYIO_NO_ERROR: bool = False TRIO_NO_ERROR: bool = False ASYNCIO_NO_ERROR: bool = False + + AUTOFIX: bool = False + NOAUTOFIX: bool = False + + # File should not get modified when running with asyncio+autofix + ASYNCIO_NO_AUTOFIX: bool = False # eval file is written using this library, so no substitution is required BASE_LIBRARY: str = "trio" + def library_no_error(self, library: str) -> bool: + return { + "anyio": self.ANYIO_NO_ERROR, + "asyncio": self.ASYNCIO_NO_ERROR, + "trio": self.TRIO_NO_ERROR, + }[library] + def find_magic_markers( content: str, @@ -300,15 +325,14 @@ def test_eval( lib in message for lib in ("anyio", "asyncio", "trio") ) - # asyncio does not support autofix atm, so should not modify content - if autofix and not noqa and library != "asyncio": + if autofix and not noqa: check_autofix( test, plugin, content, generate_autofix, library=library, - base_library=magic_markers.BASE_LIBRARY, + magic_markers=magic_markers, ) else: # make sure content isn't modified @@ -452,6 +476,7 @@ def _parse_eval_file( "ASYNC116", "ASYNC117", "ASYNC118", + "ASYNC912", } @@ -479,7 +504,7 @@ def visit_AsyncFor(self, node: ast.AsyncFor): return self.replace_async(node, ast.For, node.target, node.iter) -@pytest.mark.parametrize(("test", "path"), test_files) +@pytest.mark.parametrize(("test", "path"), test_files, ids=[f[0] for f in test_files]) def test_noerror_on_sync_code(test: str, path: Path): if any(e in test for e in error_codes_ignored_when_checking_transformed_sync_code): return