Skip to content

Commit 4a3a6de

Browse files
wangxf123456rwgkpre-commit-ci[bot]
authored andcommitted
Add type_caster_std_function_specializations feature. (#4597)
* Allow specializations based on callback function return values. * clang-tidy auto fix * Add a test case for function specialization. * Add test for callback function that raises Python exception. * Fix test failures. * style: pre-commit fixes * Add `#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS` --------- Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7140586 commit 4a3a6de

File tree

4 files changed

+107
-34
lines changed

4 files changed

+107
-34
lines changed

include/pybind11/functional.h

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,55 @@
99

1010
#pragma once
1111

12+
#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS
13+
1214
#include "pybind11.h"
1315

1416
#include <functional>
1517

1618
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
1719
PYBIND11_NAMESPACE_BEGIN(detail)
20+
PYBIND11_NAMESPACE_BEGIN(type_caster_std_function_specializations)
21+
22+
// ensure GIL is held during functor destruction
23+
struct func_handle {
24+
function f;
25+
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
26+
// This triggers a syntax error under very special conditions (very weird indeed).
27+
explicit
28+
#endif
29+
func_handle(function &&f_) noexcept
30+
: f(std::move(f_)) {
31+
}
32+
func_handle(const func_handle &f_) { operator=(f_); }
33+
func_handle &operator=(const func_handle &f_) {
34+
gil_scoped_acquire acq;
35+
f = f_.f;
36+
return *this;
37+
}
38+
~func_handle() {
39+
gil_scoped_acquire acq;
40+
function kill_f(std::move(f));
41+
}
42+
};
43+
44+
// to emulate 'move initialization capture' in C++11
45+
struct func_wrapper_base {
46+
func_handle hfunc;
47+
explicit func_wrapper_base(func_handle &&hf) noexcept : hfunc(hf) {}
48+
};
49+
50+
template <typename Return, typename... Args>
51+
struct func_wrapper : func_wrapper_base {
52+
using func_wrapper_base::func_wrapper_base;
53+
Return operator()(Args... args) const {
54+
gil_scoped_acquire acq;
55+
// casts the returned object as a rvalue to the return type
56+
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
57+
}
58+
};
59+
60+
PYBIND11_NAMESPACE_END(type_caster_std_function_specializations)
1861

1962
template <typename Return, typename... Args>
2063
struct type_caster<std::function<Return(Args...)>> {
@@ -77,40 +120,8 @@ struct type_caster<std::function<Return(Args...)>> {
77120
// See PR #1413 for full details
78121
}
79122

80-
// ensure GIL is held during functor destruction
81-
struct func_handle {
82-
function f;
83-
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
84-
// This triggers a syntax error under very special conditions (very weird indeed).
85-
explicit
86-
#endif
87-
func_handle(function &&f_) noexcept
88-
: f(std::move(f_)) {
89-
}
90-
func_handle(const func_handle &f_) { operator=(f_); }
91-
func_handle &operator=(const func_handle &f_) {
92-
gil_scoped_acquire acq;
93-
f = f_.f;
94-
return *this;
95-
}
96-
~func_handle() {
97-
gil_scoped_acquire acq;
98-
function kill_f(std::move(f));
99-
}
100-
};
101-
102-
// to emulate 'move initialization capture' in C++11
103-
struct func_wrapper {
104-
func_handle hfunc;
105-
explicit func_wrapper(func_handle &&hf) noexcept : hfunc(std::move(hf)) {}
106-
Return operator()(Args... args) const {
107-
gil_scoped_acquire acq;
108-
// casts the returned object as a rvalue to the return type
109-
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
110-
}
111-
};
112-
113-
value = func_wrapper(func_handle(std::move(func)));
123+
value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
124+
type_caster_std_function_specializations::func_handle(std::move(func)));
114125
return true;
115126
}
116127

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ set(PYBIND11_TEST_FILES
158158
test_tagbased_polymorphic
159159
test_thread
160160
test_type_caster_pyobject_ptr
161+
test_type_caster_std_function_specializations
161162
test_union
162163
test_unnamed_namespace_a
163164
test_unnamed_namespace_b
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <pybind11/functional.h>
2+
#include <pybind11/pybind11.h>
3+
4+
#include "pybind11_tests.h"
5+
6+
namespace py = pybind11;
7+
8+
namespace {
9+
10+
struct SpecialReturn {
11+
int value = 99;
12+
};
13+
14+
} // namespace
15+
16+
namespace pybind11 {
17+
namespace detail {
18+
namespace type_caster_std_function_specializations {
19+
20+
template <typename... Args>
21+
struct func_wrapper<SpecialReturn, Args...> : func_wrapper_base {
22+
using func_wrapper_base::func_wrapper_base;
23+
SpecialReturn operator()(Args... args) const {
24+
gil_scoped_acquire acq;
25+
SpecialReturn result;
26+
try {
27+
result = hfunc.f(std::forward<Args>(args)...).template cast<SpecialReturn>();
28+
} catch (error_already_set &) {
29+
result.value += 1;
30+
}
31+
result.value += 100;
32+
return result;
33+
}
34+
};
35+
36+
} // namespace type_caster_std_function_specializations
37+
} // namespace detail
38+
} // namespace pybind11
39+
40+
TEST_SUBMODULE(type_caster_std_function_specializations, m) {
41+
py::class_<SpecialReturn>(m, "SpecialReturn")
42+
.def(py::init<>())
43+
.def_readwrite("value", &SpecialReturn::value);
44+
m.def("call_callback_with_special_return",
45+
[](const std::function<SpecialReturn()> &func) { return func(); });
46+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from __future__ import annotations
2+
3+
from pybind11_tests import type_caster_std_function_specializations as m
4+
5+
6+
def test_callback_with_special_return():
7+
def return_special():
8+
return m.SpecialReturn()
9+
10+
def raise_exception():
11+
raise ValueError("called raise_exception.")
12+
13+
assert return_special().value == 99
14+
assert m.call_callback_with_special_return(return_special).value == 199
15+
assert m.call_callback_with_special_return(raise_exception).value == 200

0 commit comments

Comments
 (0)