Skip to content

Commit

Permalink
feat[tool]: validate AST nodes early in the pipeline (vyperlang#3809)
Browse files Browse the repository at this point in the history
validate Vyper AST nodes as early as possible in the compilation
pipeline, during `parse_to_ast()`. this will make compilation fail
earlier in the pipeline for integrators which use the result of `-f ast`
but don't grab `-f annotated_ast`.
  • Loading branch information
charles-cooper authored Feb 27, 2024
1 parent f067a6d commit 391f3cc
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 36 deletions.
14 changes: 12 additions & 2 deletions tests/unit/ast/nodes/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@


def test_binary_becomes_bytes():
expected = vy_ast.parse_to_ast("foo: Bytes[1] = b'\x01'")
mutated = vy_ast.parse_to_ast("foo: Bytes[1] = 0b00000001")
expected = vy_ast.parse_to_ast(
"""
def x():
foo: Bytes[1] = b'\x01'
"""
)
mutated = vy_ast.parse_to_ast(
"""
def x():
foo: Bytes[1] = 0b00000001
"""
)

assert vy_ast.compare_nodes(expected, mutated)

Expand Down
4 changes: 1 addition & 3 deletions tests/unit/ast/nodes/test_hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def foo():

@pytest.mark.parametrize("code", code_invalid_checksum)
def test_invalid_checksum(code, dummy_input_bundle):
vyper_module = vy_ast.parse_to_ast(code)

with pytest.raises(InvalidLiteral):
vy_ast.validation.validate_literal_nodes(vyper_module)
vyper_module = vy_ast.parse_to_ast(code)
semantics.validate_semantics(vyper_module, dummy_input_bundle)
18 changes: 8 additions & 10 deletions tests/unit/semantics/types/test_pure_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,35 +76,33 @@ def test_valid_literals(build_node, type_):
sources = VALID_LITERALS[type_]
for source in sources:
node = build_node(source)

do_validate_node(type_, node)


@pytest.mark.parametrize("type_", TYPES.keys())
@pytest.mark.parametrize("source", INVALID_LITERALS)
def test_invalid_literals(build_node, type_, source):
node = build_node(source)
with pytest.raises((InvalidLiteral, OverflowException, UnexpectedNodeType)):
do_validate_node(type_, node)
# build_node throws; no need to run do_validate_node
build_node(source)


@pytest.mark.parametrize("type_,type_str", TYPES.items())
@pytest.mark.parametrize("source", INVALID_NODES + ["{}"])
def test_invalid_node(build_node, type_, type_str, source):
source = source.format(type_str)
node = build_node(source)

with pytest.raises((InvalidLiteral, UnexpectedNodeType)):
node = build_node(source)
do_validate_node(type_, node)


# no literal is a valid annotation
@pytest.mark.parametrize("type_", TYPES.keys())
@pytest.mark.parametrize("source", ALL_LITERALS)
def test_from_annotation_literal(build_node, type_, source):
node = build_node(source)

with pytest.raises(InvalidType):
with pytest.raises((InvalidType, InvalidLiteral, OverflowException)):
node = build_node(source)
type_from_annotation(node)


Expand All @@ -119,22 +117,22 @@ def _check_type_equals(type_, t):
@pytest.mark.parametrize("source", INVALID_NODES)
def test_invalid_annotations(build_node, type_, type_str, source):
source = source.format(type_str)
node = build_node(source)

with pytest.raises((StructureException, InvalidType)):
node = build_node(source)
t = type_from_annotation(node)
_check_type_equals(type_, t)


@pytest.mark.parametrize("type_", TYPES.keys())
@pytest.mark.parametrize("type_str", TYPES.values())
def test_from_annotation(build_node, type_, type_str):
node = build_node(type_str)

if type_str == TYPES[type_]:
node = build_node(type_str)
t = type_from_annotation(node)
_check_type_equals(type_, t)
else:
with pytest.raises(InvalidType):
node = build_node(type_str)
t = type_from_annotation(node)
_check_type_equals(type_, t)
6 changes: 5 additions & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_node(
ast_struct: Union[dict, python_ast.AST], parent: Optional["VyperNode"] = None
) -> "VyperNode":
"""
Convert an AST structure to a vyper AST node.
Convert an AST structure to a vyper AST node. Entry point to constructing
vyper AST nodes.
This is a recursive call, all child nodes of the input value are also
converted to Vyper nodes.
Expand Down Expand Up @@ -130,6 +131,9 @@ def get_node(
f"enum will be deprecated in a future release, use flag instead. {pretty_printed_node}",
stacklevel=2,
)

node.validate()

return node


Expand Down
16 changes: 0 additions & 16 deletions vyper/ast/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,3 @@ def validate_call_args(
if key.arg in kwargs_seen:
raise ArgumentException(f"Duplicate keyword argument '{key.arg}'", key)
kwargs_seen.add(key.arg)


def validate_literal_nodes(vyper_module: vy_ast.Module) -> None:
"""
Individually validate Vyper AST nodes.
Recursively calls the `validate` method of each node to verify that
literal nodes do not contain invalid values.
Arguments
---------
vyper_module : vy_ast.Module
Top level Vyper AST node.
"""
for node in vyper_module.get_descendants(include_self=True):
node.validate()
4 changes: 0 additions & 4 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import vyper.builtins.interfaces
from vyper import ast as vy_ast
from vyper.ast.validation import validate_literal_nodes
from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
Expand Down Expand Up @@ -67,9 +66,6 @@ def validate_module_semantics_r(
assert isinstance(module_ast._metadata["type"], ModuleT)
return module_ast._metadata["type"]

# TODO: move this to parser or VyperNode construction
validate_literal_nodes(module_ast)

# validate semantics and annotate AST with type/semantics information
namespace = get_namespace()

Expand Down

0 comments on commit 391f3cc

Please sign in to comment.