diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py new file mode 100644 index 0000000000..084ad26e6b --- /dev/null +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -0,0 +1,31 @@ +def test_import_interface_types(make_input_bundle, get_contract): + ifaces = """ +interface IFoo: + def foo() -> uint256: nonpayable + """ + + foo_impl = """ +import ifaces + +implements: ifaces.IFoo + +@external +def foo() -> uint256: + return block.number + """ + + contract = """ +import ifaces + +@external +def test_foo(s: ifaces.IFoo) -> bool: + assert s.foo() == block.number + return True + """ + + input_bundle = make_input_bundle({"ifaces.vy": ifaces}) + + foo = get_contract(foo_impl, input_bundle=input_bundle) + c = get_contract(contract, input_bundle=input_bundle) + + assert c.test_foo(foo.address) is True diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/modules/test_stateless_functions.py similarity index 100% rename from tests/functional/codegen/test_stateless_modules.py rename to tests/functional/codegen/modules/test_stateless_functions.py diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index a672ed7b88..ca96adca91 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -90,7 +90,7 @@ def foo(): nonpayable """ implements: self.x """, - StructureException, + InvalidType, ), ( """ diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 90365c63d5..fa1fb63673 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1372,8 +1372,8 @@ class ImplementsDecl(Stmt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not isinstance(self.annotation, Name): - raise StructureException("not an identifier", self.annotation) + if not isinstance(self.annotation, (Name, Attribute)): + raise StructureException("invalid implements", self.annotation) class If(Stmt): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 4a7e33e848..2972ed2917 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -383,8 +383,9 @@ def visit_ImportFrom(self, node): self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): - obj = InterfaceT.from_InterfaceDef(node) - self.namespace[node.name] = obj + interface_t = InterfaceT.from_InterfaceDef(node) + node._metadata["interface_type"] = interface_t + self.namespace[node.name] = interface_t def visit_StructDef(self, node): struct_t = StructT.from_StructDef(node) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 8f1a5cc0dc..f2c3d74525 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -39,6 +39,7 @@ def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> No self._helper = VyperType(events | structs) self._id = _id + self._helper._id = _id self.functions = functions self.events = events self.structs = structs @@ -267,6 +268,8 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): + _attribute_in_annotation = True + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -276,7 +279,10 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): # compute the interface, note this has the side effect of checking # for function collisions - self._helper = self.interface + _ = self.interface + + self._helper = VyperType() + self._helper._id = self._id for f in self.function_defs: # note: this checks for collisions @@ -289,6 +295,12 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for s in self.struct_defs: # add the type of the struct so it can be used in call position self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + self._helper.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + + for i in self.interface_defs: + # add the type of the interface so it can be used in call position + self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore + self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -322,6 +334,10 @@ def event_defs(self): def struct_defs(self): return self._module.get_children(vy_ast.StructDef) + @property + def interface_defs(self): + return self._module.get_children(vy_ast.InterfaceDef) + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index eb96375404..c82eb73afc 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -127,14 +127,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: except UndeclaredDefinition: raise InvalidType(err_msg, node) from None - interface = module_or_interface if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo - interface = module_or_interface.module_t.interface + module_or_interface = module_or_interface.module_t - if not interface._attribute_in_annotation: + if not isinstance(module_or_interface, VyperType): raise InvalidType(err_msg, node) - type_t = interface.get_type_member(node.attr, node) + if not module_or_interface._attribute_in_annotation: + raise InvalidType(err_msg, node) + + type_t = module_or_interface.get_type_member(node.attr, node) # type: ignore assert isinstance(type_t, TYPE_T) # sanity check return type_t.typedef