From a6dc432db2df1dfae5919b737b4dd1f55ace859b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 12 Jan 2024 06:42:28 +0800 Subject: [PATCH] fix: improve diagnostics for invalid for loop annotation (#3721) improves diagnostic messages for invalid for loop annotations by fixing up the source location during `vyper/ast/parse.py`. propagates full AnnAssign node from pre_parse.py to get better location information --------- Co-authored-by: Charles Cooper --- .../features/iteration/test_for_in_list.py | 29 +++++++++ .../exceptions/test_syntax_exception.py | 12 ++++ tests/functional/syntax/test_for_range.py | 12 ++++ vyper/ast/parse.py | 62 ++++++++++++++----- vyper/ast/pre_parser.py | 41 ++++++------ 5 files changed, 120 insertions(+), 36 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 5c7b5c6b1b..7f5658e485 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -11,7 +11,9 @@ NamespaceCollision, StateAccessViolation, StructureException, + SyntaxException, TypeMismatch, + UnknownType, ) BASIC_FOR_LOOP_CODE = [ @@ -803,6 +805,33 @@ def test_for() -> int128: """, TypeMismatch, ), + ( + """ +@external +def foo(): + for i in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: $$$ in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: uint9 in [1, 2, 3]: + pass + """, + UnknownType, + ), ] BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] diff --git a/tests/functional/syntax/exceptions/test_syntax_exception.py b/tests/functional/syntax/exceptions/test_syntax_exception.py index 9ab9b6c677..53a9550a7d 100644 --- a/tests/functional/syntax/exceptions/test_syntax_exception.py +++ b/tests/functional/syntax/exceptions/test_syntax_exception.py @@ -86,6 +86,18 @@ def f(a:uint256,/): # test posonlyargs blocked def g(): self.f() """, + """ +@external +def foo(): + for i in range(0, 10): + pass + """, + """ +@external +def foo(): + for i: $$$ in range(0, 10): + pass + """, ] diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 66981a90de..e807e12d41 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -8,6 +8,7 @@ StateAccessViolation, StructureException, TypeMismatch, + UnknownType, ) fail_list = [ @@ -235,6 +236,17 @@ def foo(): "Bound must be at least 1", "FOO", ), + ( + """ +@external +def foo(): + for i: DynArra[uint256, 3] in [1, 2, 3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?", + "DynArra", + ), ] for_code_regex = re.compile(r"for .+ in (.*):") diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index b657cf2245..b1b9a8d917 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -1,4 +1,5 @@ import ast as python_ast +import string import tokenize from decimal import Decimal from typing import Any, Dict, List, Optional, Union, cast @@ -150,7 +151,9 @@ def generic_visit(self, node): self.counter += 1 # Decorate every node with source end offsets - start = node.first_token.start if hasattr(node, "first_token") else (None, None) + start = (None, None) + if hasattr(node, "first_token"): + start = node.first_token.start end = (None, None) if hasattr(node, "last_token"): end = node.last_token.end @@ -224,9 +227,9 @@ def visit_For(self, node): Visit a For node, splicing in the loop variable annotation provided by the pre-parser """ - raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset)) - if not raw_annotation: + if not annotation_tokens: # a common case for people migrating to 0.4.0, provide a more # specific error message than "invalid type annotation" raise SyntaxException( @@ -238,25 +241,50 @@ def visit_For(self, node): node.col_offset, ) + self.generic_visit(node) + try: - annotation = python_ast.parse(raw_annotation, mode="eval") - # annotate with token and source code information. `first_token` - # and `last_token` attributes are accessed in `generic_visit`. - tokens = asttokens.ASTTokens(raw_annotation) - tokens.mark_tokens(annotation) + annotation_str = tokenize.untokenize(annotation_tokens).strip(string.whitespace + "\\") + annotation = python_ast.parse(annotation_str) except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset ) from e - assert isinstance(annotation, python_ast.Expression) - annotation = annotation.body - - old_target = node.target - new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) - node.target = new_target + annotation = annotation.body[0] + og_target = node.target + + # annotate with token and source code information. `first_token` + # and `last_token` attributes are accessed in `generic_visit`. + tokens = asttokens.ASTTokens(annotation_str) + tokens.mark_tokens(annotation) + + # decrease line offset by 1 because annotation is on the same line as `For` node + # but the spliced expression also starts at line 1 + adjustment = og_target.first_token.start[0] - 1, og_target.first_token.start[1] + + def _add_pair(x, y): + return x[0] + y[0], x[1] + y[1] + + for n in python_ast.walk(annotation): + # adjust all offsets + if hasattr(n, "first_token"): + n.first_token = n.first_token._replace( + start=_add_pair(n.first_token.start, adjustment), + end=_add_pair(n.first_token.end, adjustment), + startpos=n.first_token.startpos + og_target.first_token.startpos, + endpos=n.first_token.startpos + og_target.first_token.startpos, + ) + if hasattr(n, "last_token"): + n.last_token = n.last_token._replace( + start=_add_pair(n.last_token.start, adjustment), + end=_add_pair(n.last_token.end, adjustment), + startpos=n.last_token.startpos + og_target.first_token.startpos, + endpos=n.last_token.endpos + og_target.first_token.startpos, + ) - self.generic_visit(node) + node.target = annotation + node.target = self.generic_visit(node.target) return node @@ -418,8 +446,8 @@ def annotate_python_ast( source_code : str The originating source code of the AST. loop_var_annotations: dict, optional - A mapping of line numbers of `For` nodes to the type annotation of the iterator - extracted during pre-parsing. + A mapping of line numbers of `For` nodes to the tokens of the type annotation + of the iterator extracted during pre-parsing. modification_offsets : dict, optional A mapping of class names to their original class types. diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index c7e6f3698f..f7d2df208a 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -65,27 +65,32 @@ def __init__(self, code): def consume(self, token): # state machine: we can start slurping tokens soon if token.type == NAME and token.string == "for": - # note: self._state should be NOT_RUNNING here, but we don't sanity - # check here as that should be an error the parser will handle. + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + self._current_annotation = [] + + assert self._state == ForParserState.NOT_RUNNING self._state = ForParserState.START_SOON self._current_for_loop = token.start + return False if self._state == ForParserState.NOT_RUNNING: return False - # state machine: start slurping tokens - if token.type == OP and token.string == ":": - self._state = ForParserState.RUNNING + if self._state == ForParserState.START_SOON: + # state machine: start slurping tokens - # sanity check -- this should never really happen, but if it does, - # try to raise an exception which pinpoints the source. - if self._current_annotation is not None: - raise SyntaxException( - "for loop parse error", self._code, token.start[0], token.start[1] - ) + self._current_annotation.append(token) - self._current_annotation = [] - return True # do not add ":" to tokens. + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING + return True # do not add ":" to global tokens. + + return False # add everything before ":" to tokens # state machine: end slurping tokens if token.type == NAME and token.string == "in": @@ -136,8 +141,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: Compilation settings based on the directives in the source code ModificationOffsets A mapping of class names to their original class types. - dict[tuple[int, int], str] - A mapping of line/column offsets of `For` nodes to the annotation of the for loop target + dict[tuple[int, int], list[TokenInfo]] + A mapping of line/column offsets of `For` nodes to the tokens of the annotation of the + for loop target str Reformatted python source string. """ @@ -220,9 +226,6 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: for_loop_annotations = {} for k, v in for_parser.annotations.items(): - v_source = untokenize(v) - # untokenize adds backslashes and whitespace, strip them. - v_source = v_source.replace("\\", "").strip() - for_loop_annotations[k] = v_source + for_loop_annotations[k] = v.copy() return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8")