diff --git a/docs/changelog.rst b/docs/changelog.rst index 2272a37e..afc4c74e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,6 +25,9 @@ Upcoming version (TBA) long-standing inconvenience. (PR `#778 `__). +- The type caster for ``std::function`` now properly identifies its type as + optional (the runtime behavior is unaffected; this only impacts stubs) + * ABI version 16. diff --git a/include/nanobind/stl/function.h b/include/nanobind/stl/function.h index b9623402..3a6b0a2f 100644 --- a/include/nanobind/stl/function.h +++ b/include/nanobind/stl/function.h @@ -50,9 +50,9 @@ struct type_caster> { std::conditional_t, void_type, Return>>; NB_TYPE_CASTER(std::function , - const_name(NB_TYPING_CALLABLE "[[") + + optional_name(const_name(NB_TYPING_CALLABLE "[[") + concat(make_caster::Name...) + const_name("], ") + - ReturnCaster::Name + const_name("]")) + ReturnCaster::Name + const_name("]"))) struct pyfunc_wrapper_t : pyfunc_wrapper { using pyfunc_wrapper::pyfunc_wrapper; diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index 3f29c604..5cc64208 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -182,7 +182,7 @@ def test_bytearray_resize(arg0: bytearray, arg1: int, /) -> None: ... def test_bytearray_size(arg: bytearray, /) -> int: ... -def test_call_1(arg: Callable[[int], int], /) -> object: ... +def test_call_1(arg: Callable[[int], int] | None, /) -> object: ... def test_call_2(arg: Callable[[int, int], None], /) -> object: ... diff --git a/tests/test_stl_ext.pyi.ref b/tests/test_stl_ext.pyi.ref index ac8f39f8..25a21feb 100644 --- a/tests/test_stl_ext.pyi.ref +++ b/tests/test_stl_ext.pyi.ref @@ -38,10 +38,10 @@ class FuncWrapper: def __init__(self) -> None: ... @property - def f(self) -> Callable[[], None]: ... + def f(self) -> Callable[[], None] | None: ... @f.setter - def f(self, arg: Callable[[], None], /) -> None: ... + def f(self, arg: Callable[[], None] | None, /) -> None: ... alive: int = ... """static read-only property""" @@ -87,7 +87,7 @@ def array_in(arg: Sequence[int], /) -> int: ... def array_out() -> list[int]: ... -def call_function(arg0: Callable[[int], int], arg1: int, /) -> int: ... +def call_function(arg0: Callable[[int], int] | None, arg1: int, /) -> int: ... def complex_array_double(arg: Sequence[complex], /) -> list[complex]: ... @@ -182,15 +182,15 @@ def return_copyable() -> Copyable: ... def return_copyable_ptr() -> Copyable: ... -def return_empty_function() -> Callable[[int], int]: ... +def return_empty_function() -> Callable[[int], int] | None: ... -def return_function() -> Callable[[int], int]: ... +def return_function() -> Callable[[int], int] | None: ... def return_movable() -> Movable: ... def return_movable_ptr() -> Movable: ... -def return_void_function(arg: Callable[[], None], /) -> Callable[[], None]: ... +def return_void_function(arg: Callable[[], None] | None, /) -> Callable[[], None] | None: ... def set_in_lvalue_ref(x: Set[str]) -> None: ...