diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py index 5b484cd3..9ea8e2cd 100644 --- a/python/tvm_ffi/registry.py +++ b/python/tvm_ffi/registry.py @@ -275,7 +275,8 @@ def get_global_func_metadata(name: str) -> dict[str, Any]: Register a Python callable as a global FFI function. """ - return json.loads(get_global_func("ffi.GetGlobalFuncMetadata")(name) or "{}") + metadata_json = get_global_func("ffi.GetGlobalFuncMetadata")(name) + return json.loads(metadata_json) if metadata_json else {} def init_ffi_api(namespace: str, target_module_name: str | None = None) -> None: diff --git a/tests/python/test_metadata.py b/tests/python/test_metadata.py index 116afd3d..4d91462d 100644 --- a/tests/python/test_metadata.py +++ b/tests/python/test_metadata.py @@ -17,7 +17,7 @@ from typing import Any import pytest -from tvm_ffi import get_global_func_metadata +from tvm_ffi import get_global_func_metadata, register_global_func, remove_global_func from tvm_ffi.core import TypeInfo, TypeSchema, _lookup_type_attr from tvm_ffi.testing import _SchemaAllTypes @@ -163,6 +163,16 @@ def test_metadata_global_func() -> None: assert metadata["str_attr"] == "hello" +def test_metadata_empty_for_python_func() -> None: + @register_global_func("test.python_func_no_metadata") + def simple_func(x: int) -> int: + return x + 1 + + metadata = get_global_func_metadata("test.python_func_no_metadata") + assert metadata == {} + remove_global_func("test.python_func_no_metadata") + + def test_metadata_field() -> None: type_info: TypeInfo = getattr(_SchemaAllTypes, "__tvm_ffi_type_info__") for field in type_info.fields: