Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[smart_holder] Unique ptr deleter roundtrip tests and fix #4921

Merged
merged 13 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions include/pybind11/detail/smart_holder_type_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,24 @@ struct shared_ptr_trampoline_self_life_support {
}
};

template <typename T,
typename D,
typename std::enable_if<std::is_default_constructible<D>::value, int>::type = 0>
inline std::unique_ptr<T, D> unique_with_deleter(T *raw_ptr, std::unique_ptr<D> &&deleter) {
if (deleter != nullptr) {
return std::unique_ptr<T, D>(raw_ptr, std::move(*deleter));
iwanders marked this conversation as resolved.
Show resolved Hide resolved
}
return std::unique_ptr<T, D>(raw_ptr);
}

template <typename T,
typename D,
typename std::enable_if<!std::is_default_constructible<D>::value, int>::type = 0>
inline std::unique_ptr<T, D> unique_with_deleter(T *raw_ptr, std::unique_ptr<D> &&deleter) {
assert(deleter != nullptr);
iwanders marked this conversation as resolved.
Show resolved Hide resolved
return std::unique_ptr<T, D>(raw_ptr, std::move(*deleter));
}

template <typename T>
struct smart_holder_type_caster_load {
using holder_type = pybindit::memory::smart_holder;
Expand Down Expand Up @@ -589,7 +607,7 @@ struct smart_holder_type_caster_load {
" std::unique_ptr.");
}
if (!have_holder()) {
return nullptr;
return unique_with_deleter<T, D>(nullptr, std::unique_ptr<D>());
}
throw_if_uninitialized_or_disowned_holder();
throw_if_instance_is_currently_owned_by_shared_ptr();
Expand All @@ -616,13 +634,30 @@ struct smart_holder_type_caster_load {
"instance cannot safely be transferred to C++.");
}

// Temporary variable to store the extracted deleter in.
std::unique_ptr<D> extracted_deleter;

auto *gd = std::get_deleter<pybindit::memory::guarded_delete>(holder().vptr);
if (gd && gd->use_del_fun) { // Note the ensure_compatible_rtti_uqp_del<T, D>() call above.
// In smart_holder_poc, a custom deleter is always stored in a guarded delete.
// The guarded delete's std::function<void(void*)> actually points at the
// custom_deleter type, so we can verify it is of the custom deleter type and
// finally extract its deleter.
using custom_deleter_D = pybindit::memory::custom_deleter<T, D>;
const auto &custom_deleter_ptr = gd->del_fun.template target<custom_deleter_D>();
assert(custom_deleter_ptr != nullptr);
// Now that we have confirmed the type of the deleter matches the desired return
// value we can extract the function.
extracted_deleter = std::unique_ptr<D>(new D(std::move(custom_deleter_ptr->deleter)));
}

// Critical transfer-of-ownership section. This must stay together.
if (self_life_support != nullptr) {
holder().disown();
} else {
holder().release_ownership();
}
auto result = std::unique_ptr<T, D>(raw_type_ptr);
auto result = unique_with_deleter<T, D>(raw_type_ptr, std::move(extracted_deleter));
if (self_life_support != nullptr) {
self_life_support->activate_life_support(load_impl.loaded_v_h);
} else {
Expand Down Expand Up @@ -1056,7 +1091,8 @@ struct smart_holder_type_caster<std::unique_ptr<T const, D>>
static handle
cast(std::unique_ptr<T const, D> &&src, return_value_policy policy, handle parent) {
return smart_holder_type_caster<std::unique_ptr<T, D>>::cast(
std::unique_ptr<T, D>(const_cast<T *>(src.release())), // Const2Mutbl
std::unique_ptr<T, D>(const_cast<T *>(src.release()),
std::move(src.get_deleter())), // Const2Mutbl
policy,
parent);
}
Expand Down
28 changes: 28 additions & 0 deletions tests/test_class_sh_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ std::unique_ptr<atyp const, sddc> rtrn_udcp() { return std::unique_ptr<atyp cons
std::string pass_udmp(std::unique_ptr<atyp, sddm> obj) { return "pass_udmp:" + obj->mtxt; }
std::string pass_udcp(std::unique_ptr<atyp const, sddc> obj) { return "pass_udcp:" + obj->mtxt; }


struct custom_deleter
{
std::string delete_txt;
custom_deleter() = delete;
explicit custom_deleter(const std::string &delete_txt_) : delete_txt(delete_txt_) {}
iwanders marked this conversation as resolved.
Show resolved Hide resolved
void operator()(atyp* p) const
{
std::default_delete<atyp>()(p);
}
void operator()(const atyp* p) const
{
std::default_delete<const atyp>()(p);
}
};

std::unique_ptr<atyp, custom_deleter> rtrn_udmp_del() { return std::unique_ptr<atyp, custom_deleter>(new atyp{"rtrn_udmp_del"}, custom_deleter{"udmp_deleter"}); }
std::unique_ptr<atyp const, custom_deleter> rtrn_udcp_del() { return std::unique_ptr<atyp const, custom_deleter>(new atyp{"rtrn_udcp_del"}, custom_deleter{"udcp_deleter"}); }

std::string pass_udmp_del(std::unique_ptr<atyp, custom_deleter> obj) { return "pass_udmp_del:" + obj->mtxt + "," + obj.get_deleter().delete_txt; }
std::string pass_udcp_del(std::unique_ptr<atyp const, custom_deleter> obj) { return "pass_udcp_del:" + obj->mtxt + "," + obj.get_deleter().delete_txt; }

// clang-format on

// Helpers for testing.
Expand Down Expand Up @@ -130,6 +152,12 @@ TEST_SUBMODULE(class_sh_basic, m) {
m.def("pass_udmp", pass_udmp);
m.def("pass_udcp", pass_udcp);

m.def("rtrn_udmp_del", rtrn_udmp_del);
m.def("rtrn_udcp_del", rtrn_udcp_del);

m.def("pass_udmp_del", pass_udmp_del);
m.def("pass_udcp_del", pass_udcp_del);

py::classh<uconsumer>(m, "uconsumer")
.def(py::init<>())
.def("valid", &uconsumer::valid)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_class_sh_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def test_load_with_rtrn_f(pass_f, rtrn_f, expected):
assert pass_f(rtrn_f()) == expected


@pytest.mark.parametrize(
("pass_f", "rtrn_f", "expected"),
[
(m.pass_udmp_del, m.rtrn_udmp_del, "pass_udmp_del:rtrn_udmp_del,udmp_deleter"),
(m.pass_udcp_del, m.rtrn_udcp_del, "pass_udcp_del:rtrn_udcp_del,udcp_deleter"),
],
)
def test_deleter_roundtrip(pass_f, rtrn_f, expected):
assert pass_f(rtrn_f()) == expected


@pytest.mark.parametrize(
("pass_f", "rtrn_f", "expected"),
[
Expand Down