diff --git a/include/pybind11/detail/smart_holder_type_casters.h b/include/pybind11/detail/smart_holder_type_casters.h index bd9483a025..bfe54c847a 100644 --- a/include/pybind11/detail/smart_holder_type_casters.h +++ b/include/pybind11/detail/smart_holder_type_casters.h @@ -464,6 +464,27 @@ struct shared_ptr_trampoline_self_life_support { } }; +template ::value, int>::type = 0> +inline std::unique_ptr unique_with_deleter(T *raw_ptr, std::unique_ptr &&deleter) { + if (deleter == nullptr) { + return std::unique_ptr(raw_ptr); + } + return std::unique_ptr(raw_ptr, std::move(*deleter)); +} + +template ::value, int>::type = 0> +inline std::unique_ptr unique_with_deleter(T *raw_ptr, std::unique_ptr &&deleter) { + if (deleter == nullptr) { + pybind11_fail("smart_holder_type_casters: deleter is not default constructible and no" + " instance available to return."); + } + return std::unique_ptr(raw_ptr, std::move(*deleter)); +} + template struct smart_holder_type_caster_load { using holder_type = pybindit::memory::smart_holder; @@ -589,7 +610,7 @@ struct smart_holder_type_caster_load { " std::unique_ptr."); } if (!have_holder()) { - return nullptr; + return unique_with_deleter(nullptr, std::unique_ptr()); } throw_if_uninitialized_or_disowned_holder(); throw_if_instance_is_currently_owned_by_shared_ptr(); @@ -616,13 +637,30 @@ struct smart_holder_type_caster_load { "instance cannot safely be transferred to C++."); } + // Temporary variable to store the extracted deleter in. + std::unique_ptr extracted_deleter; + + auto *gd = std::get_deleter(holder().vptr); + if (gd && gd->use_del_fun) { // Note the ensure_compatible_rtti_uqp_del() call above. + // In smart_holder_poc, a custom deleter is always stored in a guarded delete. + // The guarded delete's std::function actually points at the + // custom_deleter type, so we can verify it is of the custom deleter type and + // finally extract its deleter. + using custom_deleter_D = pybindit::memory::custom_deleter; + const auto &custom_deleter_ptr = gd->del_fun.template target(); + assert(custom_deleter_ptr != nullptr); + // Now that we have confirmed the type of the deleter matches the desired return + // value we can extract the function. + extracted_deleter = std::unique_ptr(new D(std::move(custom_deleter_ptr->deleter))); + } + // Critical transfer-of-ownership section. This must stay together. if (self_life_support != nullptr) { holder().disown(); } else { holder().release_ownership(); } - auto result = std::unique_ptr(raw_type_ptr); + auto result = unique_with_deleter(raw_type_ptr, std::move(extracted_deleter)); if (self_life_support != nullptr) { self_life_support->activate_life_support(load_impl.loaded_v_h); } else { @@ -1056,7 +1094,8 @@ struct smart_holder_type_caster> static handle cast(std::unique_ptr &&src, return_value_policy policy, handle parent) { return smart_holder_type_caster>::cast( - std::unique_ptr(const_cast(src.release())), // Const2Mutbl + std::unique_ptr(const_cast(src.release()), + std::move(src.get_deleter())), // Const2Mutbl policy, parent); } diff --git a/tests/test_class_sh_basic.cpp b/tests/test_class_sh_basic.cpp index 7582cfa044..a294f7bbb2 100644 --- a/tests/test_class_sh_basic.cpp +++ b/tests/test_class_sh_basic.cpp @@ -28,6 +28,43 @@ struct uconsumer { // unique_ptr consumer const std::unique_ptr &rtrn_cref() const { return held; } }; +/// Custom deleter that is default constructible. +struct custom_deleter { + std::string trace_txt; + + custom_deleter() = default; + explicit custom_deleter(const std::string &trace_txt_) : trace_txt(trace_txt_) {} + + custom_deleter(const custom_deleter &other) { trace_txt = other.trace_txt + "_CpCtor"; } + + custom_deleter &operator=(const custom_deleter &rhs) { + trace_txt = rhs.trace_txt + "_CpLhs"; + return *this; + } + + custom_deleter(custom_deleter &&other) noexcept { + trace_txt = other.trace_txt + "_MvCtorTo"; + other.trace_txt += "_MvCtorFrom"; + } + + custom_deleter &operator=(custom_deleter &&rhs) noexcept { + trace_txt = rhs.trace_txt + "_MvLhs"; + rhs.trace_txt += "_MvRhs"; + return *this; + } + + void operator()(atyp *p) const { std::default_delete()(p); } + void operator()(const atyp *p) const { std::default_delete()(p); } +}; +static_assert(std::is_default_constructible::value, ""); + +/// Custom deleter that is not default constructible. +struct custom_deleter_nd : custom_deleter { + custom_deleter_nd() = delete; + explicit custom_deleter_nd(const std::string &trace_txt_) : custom_deleter(trace_txt_) {} +}; +static_assert(!std::is_default_constructible::value, ""); + // clang-format off atyp rtrn_valu() { atyp obj{"rtrn_valu"}; return obj; } @@ -64,6 +101,18 @@ std::unique_ptr rtrn_udcp() { return std::unique_ptr obj) { return "pass_udmp:" + obj->mtxt; } std::string pass_udcp(std::unique_ptr obj) { return "pass_udcp:" + obj->mtxt; } +std::unique_ptr rtrn_udmp_del() { return std::unique_ptr(new atyp{"rtrn_udmp_del"}, custom_deleter{"udmp_deleter"}); } +std::unique_ptr rtrn_udcp_del() { return std::unique_ptr(new atyp{"rtrn_udcp_del"}, custom_deleter{"udcp_deleter"}); } + +std::string pass_udmp_del(std::unique_ptr obj) { return "pass_udmp_del:" + obj->mtxt + "," + obj.get_deleter().trace_txt; } +std::string pass_udcp_del(std::unique_ptr obj) { return "pass_udcp_del:" + obj->mtxt + "," + obj.get_deleter().trace_txt; } + +std::unique_ptr rtrn_udmp_del_nd() { return std::unique_ptr(new atyp{"rtrn_udmp_del_nd"}, custom_deleter_nd{"udmp_deleter_nd"}); } +std::unique_ptr rtrn_udcp_del_nd() { return std::unique_ptr(new atyp{"rtrn_udcp_del_nd"}, custom_deleter_nd{"udcp_deleter_nd"}); } + +std::string pass_udmp_del_nd(std::unique_ptr obj) { return "pass_udmp_del_nd:" + obj->mtxt + "," + obj.get_deleter().trace_txt; } +std::string pass_udcp_del_nd(std::unique_ptr obj) { return "pass_udcp_del_nd:" + obj->mtxt + "," + obj.get_deleter().trace_txt; } + // clang-format on // Helpers for testing. @@ -130,6 +179,18 @@ TEST_SUBMODULE(class_sh_basic, m) { m.def("pass_udmp", pass_udmp); m.def("pass_udcp", pass_udcp); + m.def("rtrn_udmp_del", rtrn_udmp_del); + m.def("rtrn_udcp_del", rtrn_udcp_del); + + m.def("pass_udmp_del", pass_udmp_del); + m.def("pass_udcp_del", pass_udcp_del); + + m.def("rtrn_udmp_del_nd", rtrn_udmp_del_nd); + m.def("rtrn_udcp_del_nd", rtrn_udcp_del_nd); + + m.def("pass_udmp_del_nd", pass_udmp_del_nd); + m.def("pass_udcp_del_nd", pass_udcp_del_nd); + py::classh(m, "uconsumer") .def(py::init<>()) .def("valid", &uconsumer::valid) diff --git a/tests/test_class_sh_basic.py b/tests/test_class_sh_basic.py index 268e793985..7be0747c2a 100644 --- a/tests/test_class_sh_basic.py +++ b/tests/test_class_sh_basic.py @@ -65,6 +65,35 @@ def test_load_with_rtrn_f(pass_f, rtrn_f, expected): assert pass_f(rtrn_f()) == expected +@pytest.mark.parametrize( + ("pass_f", "rtrn_f", "regex_expected"), + [ + ( + m.pass_udmp_del, + m.rtrn_udmp_del, + "pass_udmp_del:rtrn_udmp_del,udmp_deleter(_MvCtorTo)*_MvCtorTo", + ), + ( + m.pass_udcp_del, + m.rtrn_udcp_del, + "pass_udcp_del:rtrn_udcp_del,udcp_deleter(_MvCtorTo)*_MvCtorTo", + ), + ( + m.pass_udmp_del_nd, + m.rtrn_udmp_del_nd, + "pass_udmp_del_nd:rtrn_udmp_del_nd,udmp_deleter_nd(_MvCtorTo)*_MvCtorTo", + ), + ( + m.pass_udcp_del_nd, + m.rtrn_udcp_del_nd, + "pass_udcp_del_nd:rtrn_udcp_del_nd,udcp_deleter_nd(_MvCtorTo)*_MvCtorTo", + ), + ], +) +def test_deleter_roundtrip(pass_f, rtrn_f, regex_expected): + assert re.match(regex_expected, pass_f(rtrn_f())) + + @pytest.mark.parametrize( ("pass_f", "rtrn_f", "expected"), [