Skip to content

Commit

Permalink
fix: improve diagnostics for invalid for loop annotation (vyperlang#3721
Browse files Browse the repository at this point in the history
)

improves diagnostic messages for invalid for loop annotations by fixing
up the source location during `vyper/ast/parse.py`.

propagates full AnnAssign node from pre_parse.py to get better location information

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
tserg and charles-cooper authored Jan 11, 2024
1 parent 06fa46a commit a6dc432
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 36 deletions.
29 changes: 29 additions & 0 deletions tests/functional/codegen/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
NamespaceCollision,
StateAccessViolation,
StructureException,
SyntaxException,
TypeMismatch,
UnknownType,
)

BASIC_FOR_LOOP_CODE = [
Expand Down Expand Up @@ -803,6 +805,33 @@ def test_for() -> int128:
""",
TypeMismatch,
),
(
"""
@external
def foo():
for i in [1, 2, 3]:
pass
""",
SyntaxException,
),
(
"""
@external
def foo():
for i: $$$ in [1, 2, 3]:
pass
""",
SyntaxException,
),
(
"""
@external
def foo():
for i: uint9 in [1, 2, 3]:
pass
""",
UnknownType,
),
]

BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE]
Expand Down
12 changes: 12 additions & 0 deletions tests/functional/syntax/exceptions/test_syntax_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def f(a:uint256,/): # test posonlyargs blocked
def g():
self.f()
""",
"""
@external
def foo():
for i in range(0, 10):
pass
""",
"""
@external
def foo():
for i: $$$ in range(0, 10):
pass
""",
]


Expand Down
12 changes: 12 additions & 0 deletions tests/functional/syntax/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
StateAccessViolation,
StructureException,
TypeMismatch,
UnknownType,
)

fail_list = [
Expand Down Expand Up @@ -235,6 +236,17 @@ def foo():
"Bound must be at least 1",
"FOO",
),
(
"""
@external
def foo():
for i: DynArra[uint256, 3] in [1, 2, 3]:
pass
""",
UnknownType,
"No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?",
"DynArra",
),
]

for_code_regex = re.compile(r"for .+ in (.*):")
Expand Down
62 changes: 45 additions & 17 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast as python_ast
import string
import tokenize
from decimal import Decimal
from typing import Any, Dict, List, Optional, Union, cast
Expand Down Expand Up @@ -150,7 +151,9 @@ def generic_visit(self, node):
self.counter += 1

# Decorate every node with source end offsets
start = node.first_token.start if hasattr(node, "first_token") else (None, None)
start = (None, None)
if hasattr(node, "first_token"):
start = node.first_token.start
end = (None, None)
if hasattr(node, "last_token"):
end = node.last_token.end
Expand Down Expand Up @@ -224,9 +227,9 @@ def visit_For(self, node):
Visit a For node, splicing in the loop variable annotation provided by
the pre-parser
"""
raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset))
annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset))

if not raw_annotation:
if not annotation_tokens:
# a common case for people migrating to 0.4.0, provide a more
# specific error message than "invalid type annotation"
raise SyntaxException(
Expand All @@ -238,25 +241,50 @@ def visit_For(self, node):
node.col_offset,
)

self.generic_visit(node)

try:
annotation = python_ast.parse(raw_annotation, mode="eval")
# annotate with token and source code information. `first_token`
# and `last_token` attributes are accessed in `generic_visit`.
tokens = asttokens.ASTTokens(raw_annotation)
tokens.mark_tokens(annotation)
annotation_str = tokenize.untokenize(annotation_tokens).strip(string.whitespace + "\\")
annotation = python_ast.parse(annotation_str)
except SyntaxError as e:
raise SyntaxException(
"invalid type annotation", self._source_code, node.lineno, node.col_offset
) from e

assert isinstance(annotation, python_ast.Expression)
annotation = annotation.body

old_target = node.target
new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1)
node.target = new_target
annotation = annotation.body[0]
og_target = node.target

# annotate with token and source code information. `first_token`
# and `last_token` attributes are accessed in `generic_visit`.
tokens = asttokens.ASTTokens(annotation_str)
tokens.mark_tokens(annotation)

# decrease line offset by 1 because annotation is on the same line as `For` node
# but the spliced expression also starts at line 1
adjustment = og_target.first_token.start[0] - 1, og_target.first_token.start[1]

def _add_pair(x, y):
return x[0] + y[0], x[1] + y[1]

for n in python_ast.walk(annotation):
# adjust all offsets
if hasattr(n, "first_token"):
n.first_token = n.first_token._replace(
start=_add_pair(n.first_token.start, adjustment),
end=_add_pair(n.first_token.end, adjustment),
startpos=n.first_token.startpos + og_target.first_token.startpos,
endpos=n.first_token.startpos + og_target.first_token.startpos,
)
if hasattr(n, "last_token"):
n.last_token = n.last_token._replace(
start=_add_pair(n.last_token.start, adjustment),
end=_add_pair(n.last_token.end, adjustment),
startpos=n.last_token.startpos + og_target.first_token.startpos,
endpos=n.last_token.endpos + og_target.first_token.startpos,
)

self.generic_visit(node)
node.target = annotation
node.target = self.generic_visit(node.target)

return node

Expand Down Expand Up @@ -418,8 +446,8 @@ def annotate_python_ast(
source_code : str
The originating source code of the AST.
loop_var_annotations: dict, optional
A mapping of line numbers of `For` nodes to the type annotation of the iterator
extracted during pre-parsing.
A mapping of line numbers of `For` nodes to the tokens of the type annotation
of the iterator extracted during pre-parsing.
modification_offsets : dict, optional
A mapping of class names to their original class types.
Expand Down
41 changes: 22 additions & 19 deletions vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,32 @@ def __init__(self, code):
def consume(self, token):
# state machine: we can start slurping tokens soon
if token.type == NAME and token.string == "for":
# note: self._state should be NOT_RUNNING here, but we don't sanity
# check here as that should be an error the parser will handle.
# sanity check -- this should never really happen, but if it does,
# try to raise an exception which pinpoints the source.
if self._current_annotation is not None:
raise SyntaxException(
"for loop parse error", self._code, token.start[0], token.start[1]
)
self._current_annotation = []

assert self._state == ForParserState.NOT_RUNNING
self._state = ForParserState.START_SOON
self._current_for_loop = token.start
return False

if self._state == ForParserState.NOT_RUNNING:
return False

# state machine: start slurping tokens
if token.type == OP and token.string == ":":
self._state = ForParserState.RUNNING
if self._state == ForParserState.START_SOON:
# state machine: start slurping tokens

# sanity check -- this should never really happen, but if it does,
# try to raise an exception which pinpoints the source.
if self._current_annotation is not None:
raise SyntaxException(
"for loop parse error", self._code, token.start[0], token.start[1]
)
self._current_annotation.append(token)

self._current_annotation = []
return True # do not add ":" to tokens.
if token.type == OP and token.string == ":":
self._state = ForParserState.RUNNING
return True # do not add ":" to global tokens.

return False # add everything before ":" to tokens

# state machine: end slurping tokens
if token.type == NAME and token.string == "in":
Expand Down Expand Up @@ -136,8 +141,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]:
Compilation settings based on the directives in the source code
ModificationOffsets
A mapping of class names to their original class types.
dict[tuple[int, int], str]
A mapping of line/column offsets of `For` nodes to the annotation of the for loop target
dict[tuple[int, int], list[TokenInfo]]
A mapping of line/column offsets of `For` nodes to the tokens of the annotation of the
for loop target
str
Reformatted python source string.
"""
Expand Down Expand Up @@ -220,9 +226,6 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]:

for_loop_annotations = {}
for k, v in for_parser.annotations.items():
v_source = untokenize(v)
# untokenize adds backslashes and whitespace, strip them.
v_source = v_source.replace("\\", "").strip()
for_loop_annotations[k] = v_source
for_loop_annotations[k] = v.copy()

return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8")

0 comments on commit a6dc432

Please sign in to comment.