Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add detail::type_caster_std_function_specializations #4597

Merged
merged 9 commits into from
Aug 9, 2024
79 changes: 45 additions & 34 deletions include/pybind11/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,55 @@

#pragma once

#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS

#include "pybind11.h"

#include <functional>

PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail)
PYBIND11_NAMESPACE_BEGIN(type_caster_std_function_specializations)

// ensure GIL is held during functor destruction
struct func_handle {
function f;
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
// This triggers a syntax error under very special conditions (very weird indeed).
explicit
#endif
func_handle(function &&f_) noexcept
: f(std::move(f_)) {
}
func_handle(const func_handle &f_) { operator=(f_); }
func_handle &operator=(const func_handle &f_) {
gil_scoped_acquire acq;
f = f_.f;
return *this;
}
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};

// to emulate 'move initialization capture' in C++11
struct func_wrapper_base {
func_handle hfunc;
explicit func_wrapper_base(func_handle &&hf) noexcept : hfunc(hf) {}
};

template <typename Return, typename... Args>
struct func_wrapper : func_wrapper_base {
using func_wrapper_base::func_wrapper_base;
Return operator()(Args... args) const {
rwgk marked this conversation as resolved.
Show resolved Hide resolved
gil_scoped_acquire acq;
// casts the returned object as a rvalue to the return type
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
}
};

PYBIND11_NAMESPACE_END(type_caster_std_function_specializations)

template <typename Return, typename... Args>
struct type_caster<std::function<Return(Args...)>> {
Expand Down Expand Up @@ -77,40 +120,8 @@ struct type_caster<std::function<Return(Args...)>> {
// See PR #1413 for full details
}

// ensure GIL is held during functor destruction
struct func_handle {
function f;
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
// This triggers a syntax error under very special conditions (very weird indeed).
explicit
#endif
func_handle(function &&f_) noexcept
: f(std::move(f_)) {
}
func_handle(const func_handle &f_) { operator=(f_); }
func_handle &operator=(const func_handle &f_) {
gil_scoped_acquire acq;
f = f_.f;
return *this;
}
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};

// to emulate 'move initialization capture' in C++11
struct func_wrapper {
func_handle hfunc;
explicit func_wrapper(func_handle &&hf) noexcept : hfunc(std::move(hf)) {}
Return operator()(Args... args) const {
gil_scoped_acquire acq;
// casts the returned object as a rvalue to the return type
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
}
};

value = func_wrapper(func_handle(std::move(func)));
value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
type_caster_std_function_specializations::func_handle(std::move(func)));
return true;
}

Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ set(PYBIND11_TEST_FILES
test_tagbased_polymorphic
test_thread
test_type_caster_pyobject_ptr
test_type_caster_std_function_specializations
test_union
test_unnamed_namespace_a
test_unnamed_namespace_b
Expand Down
46 changes: 46 additions & 0 deletions tests/test_type_caster_std_function_specializations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>

#include "pybind11_tests.h"

namespace py = pybind11;

namespace {

struct SpecialReturn {
int value = 99;
};

} // namespace

namespace pybind11 {
namespace detail {
namespace type_caster_std_function_specializations {

template <typename... Args>
struct func_wrapper<SpecialReturn, Args...> : func_wrapper_base {
using func_wrapper_base::func_wrapper_base;
SpecialReturn operator()(Args... args) const {
gil_scoped_acquire acq;
SpecialReturn result;
try {
result = hfunc.f(std::forward<Args>(args)...).template cast<SpecialReturn>();
} catch (error_already_set &) {
result.value += 1;
}
result.value += 100;
return result;
}
};

} // namespace type_caster_std_function_specializations
} // namespace detail
} // namespace pybind11

TEST_SUBMODULE(type_caster_std_function_specializations, m) {
py::class_<SpecialReturn>(m, "SpecialReturn")
.def(py::init<>())
.def_readwrite("value", &SpecialReturn::value);
m.def("call_callback_with_special_return",
[](const std::function<SpecialReturn()> &func) { return func(); });
}
15 changes: 15 additions & 0 deletions tests/test_type_caster_std_function_specializations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

from pybind11_tests import type_caster_std_function_specializations as m


def test_callback_with_special_return():
def return_special():
return m.SpecialReturn()

def raise_exception():
raise ValueError("called raise_exception.")

assert return_special().value == 99
assert m.call_callback_with_special_return(return_special).value == 199
assert m.call_callback_with_special_return(raise_exception).value == 200