Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smart Pointers for non custom types #1985

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 57 additions & 25 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "detail/internals.h"
#include <array>
#include <limits>
#include <memory>
#include <new>
#include <tuple>
#include <type_traits>

Expand Down Expand Up @@ -1451,29 +1453,58 @@ template <typename T1, typename T2> class type_caster<std::pair<T1, T2>>
template <typename... Ts> class type_caster<std::tuple<Ts...>>
: public tuple_caster<std::tuple, Ts...> {};

// When a value returned from a C++ function is being cast back to Python, we almost always want to
// force `policy = move`, regardless of the return value policy the function/method was declared
// with.
template <typename Return, typename SFINAE = void> struct return_value_policy_override {
static return_value_policy policy(return_value_policy p) { return p; }
};

template <typename Return> struct return_value_policy_override<Return,
detail::enable_if_t<std::is_base_of<type_caster_generic, make_caster<Return>>::value, void>> {
static return_value_policy policy(return_value_policy p) {
return !std::is_lvalue_reference<Return>::value &&
!std::is_pointer<Return>::value
? return_value_policy::move : p;
}
};

/// Helper class which abstracts away certain actions. Users can provide specializations for
/// custom holders, but it's only necessary if the type has a non-standard interface.
template <typename T>
template <typename type, typename holder_type>
struct holder_helper {
static auto get(const T &p) -> decltype(p.get()) { return p.get(); }
static auto get(const holder_type &p) -> decltype(p.get()) { return p.get(); }
static auto get(holder_type &p) -> decltype(p.get()) { return p.get(); }
static holder_type create(type && val) { return holder_type(new type(std::forward<type>(val))); }
};

/// Type caster for holder types like std::shared_ptr, etc.
template <typename type, typename holder_type>
struct copyable_holder_caster : public type_caster_base<type> {
public:
using base = type_caster_base<type>;
static_assert(std::is_base_of<base, type_caster<type>>::value,
"Holder classes are only supported for custom types");
using base::base;
using base::cast;
using base::typeinfo;
using base::value;

template <typename T = type, detail::enable_if_t<std::is_base_of<type_caster_base<T>, type_caster<T>>::value, int> = 0>
bool load(handle src, bool convert) {
return base::template load_impl<copyable_holder_caster<type, holder_type>>(src, convert);
}

template <typename T = type, detail::enable_if_t<!std::is_base_of<type_caster_base<T>, type_caster<T>>::value, int> = 0>
bool load(handle src, bool convert) {
using value_conv = make_caster<type>;
value_conv caster;
if (!caster.load(src, convert)) {
return false;
}
holder = holder_helper<type, holder_type>::create(std::forward<type>(cast_op<type&&>(std::move(caster))));
value = reinterpret_cast<void*>(holder_helper<type, holder_type>::get(holder));
return true;
}

explicit operator type*() { return this->value; }
explicit operator type&() { return *(this->value); }
explicit operator holder_type*() { return std::addressof(holder); }
Expand All @@ -1486,9 +1517,18 @@ struct copyable_holder_caster : public type_caster_base<type> {
explicit operator holder_type&() { return holder; }
#endif

template <typename T = type, detail::enable_if_t<std::is_base_of<type_caster_base<T>, type_caster<T>>::value, int> = 0>
static handle cast(const holder_type &src, return_value_policy, handle) {
const auto *ptr = holder_helper<holder_type>::get(src);
return type_caster_base<type>::cast_holder(ptr, &src);
const auto *ptr = holder_helper<type, holder_type>::get(src);
return type_caster_base<type>::cast_holder(ptr, std::addressof(src));
}

template <typename T = type, detail::enable_if_t<!std::is_base_of<type_caster_base<T>, type_caster<T>>::value, int> = 0>
static handle cast(const holder_type &src, return_value_policy policy, handle parent) {
policy = return_value_policy_override<type>::policy(policy);
using value_conv = make_caster<type>;
const auto *ptr = holder_helper<type, holder_type>::get(src);
return value_conv::cast(*ptr, policy, parent);
}

protected:
Expand Down Expand Up @@ -1541,13 +1581,21 @@ class type_caster<std::shared_ptr<T>> : public copyable_holder_caster<T, std::sh

template <typename type, typename holder_type>
struct move_only_holder_caster {
static_assert(std::is_base_of<type_caster_base<type>, type_caster<type>>::value,
"Holder classes are only supported for custom types");

template <typename T = type, detail::enable_if_t<std::is_base_of<type_caster_base<T>, type_caster<T>>::value, int> = 0>
static handle cast(holder_type &&src, return_value_policy, handle) {
auto *ptr = holder_helper<holder_type>::get(src);
auto *ptr = holder_helper<type, holder_type>::get(src);
return type_caster_base<type>::cast_holder(ptr, std::addressof(src));
}

template <typename T = type, detail::enable_if_t<!std::is_base_of<type_caster_base<T>, type_caster<T>>::value, int> = 0>
static handle cast(const holder_type &src, return_value_policy policy, handle parent) {
policy = return_value_policy_override<type>::policy(policy);
using value_conv = make_caster<type>;
const auto *ptr = holder_helper<type, holder_type>::get(src);
return value_conv::cast(*ptr, policy, parent);
}

static constexpr auto name = type_caster_base<type>::name;
};

Expand Down Expand Up @@ -1644,22 +1692,6 @@ template <typename type> using cast_is_temporary_value_reference = bool_constant
!std::is_same<intrinsic_t<type>, void>::value
>;

// When a value returned from a C++ function is being cast back to Python, we almost always want to
// force `policy = move`, regardless of the return value policy the function/method was declared
// with.
template <typename Return, typename SFINAE = void> struct return_value_policy_override {
static return_value_policy policy(return_value_policy p) { return p; }
};

template <typename Return> struct return_value_policy_override<Return,
detail::enable_if_t<std::is_base_of<type_caster_generic, make_caster<Return>>::value, void>> {
static return_value_policy policy(return_value_policy p) {
return !std::is_lvalue_reference<Return>::value &&
!std::is_pointer<Return>::value
? return_value_policy::move : p;
}
};

// Basic python -> C++ casting; throws if casting fails
template <typename T, typename SFINAE> type_caster<T, SFINAE> &load_type(type_caster<T, SFINAE> &conv, const handle &handle) {
if (!conv.load(handle, true)) {
Expand Down
2 changes: 1 addition & 1 deletion include/pybind11/detail/init.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
// derived type (through those holder's implicit conversion from derived class holder constructors).
template <typename Class>
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
auto *ptr = holder_helper<Holder<Class>>::get(holder);
auto *ptr = holder_helper<Class, Holder<Class>>::get(holder);
// If we need an alias, check that the held pointer is actually an alias instance
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
Expand Down
7 changes: 6 additions & 1 deletion tests/test_smart_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ PYBIND11_DECLARE_HOLDER_TYPE(T, ref<T>, true);
// Make pybind11 aware of the non-standard getter member function
namespace pybind11 { namespace detail {
template <typename T>
struct holder_helper<ref<T>> {
struct holder_helper<T, ref<T>> {
static const T *get(const ref<T> &p) { return p.get_ptr(); }
static T *get(ref<T> &p) { return p.get_ptr(); }
static ref<T> create(T && val) { return ref<T>(&val); }
};
}}

Expand Down Expand Up @@ -88,6 +90,9 @@ PYBIND11_DECLARE_HOLDER_TYPE(T, unique_ptr_with_addressof_operator<T>);
TEST_SUBMODULE(smart_ptr, m) {

// test_smart_ptr
m.def("cast_shared_int", []() { return std::make_shared<int>(1); });
m.def("cast_unique_int", []() { return std::unique_ptr<int>(new int(1)); });
m.def("load_shared_int", [](std::shared_ptr<int> x) { return *x == 1; });

// Object implementation in `object.h`
py::class_<Object, ref<Object>> obj(m, "Object");
Expand Down
11 changes: 11 additions & 0 deletions tests/test_smart_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ def test_smart_ptr(capture):
assert cstats.move_assignments == 0


def test_shared_ptr():
v_int = m.cast_shared_int()
assert v_int == 1
assert m.load_shared_int(v_int)


def test_unique_ptr():
v_int = m.cast_unique_int()
assert v_int == 1


def test_smart_ptr_refcounting():
assert m.test_object1_refcounting()

Expand Down
11 changes: 11 additions & 0 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ TEST_SUBMODULE(stl, m) {
static std::vector<RValueCaster> lvv{2};
m.def("cast_ptr_vector", []() { return &lvv; });

// test_vector_shared
m.def("cast_vector_shared", []() { return std::make_shared<std::vector<int>>(std::vector<int>{1}); });
m.def("load_vector_shared", [](std::shared_ptr<std::vector<int>> v) { return v->at(0) == 1 && v->at(1) == 2; });

// test_vector_unique
m.def("cast_vector_unique", []() { return std::unique_ptr<std::vector<int>>(new std::vector<int>{1}); });

// test_deque
m.def("cast_deque", []() { return std::deque<int>{1}; });
m.def("load_deque", [](const std::deque<int> &v) { return v.at(0) == 1 && v.at(1) == 2; });
Expand Down Expand Up @@ -208,6 +215,7 @@ TEST_SUBMODULE(stl, m) {
result_type operator()(std::string) { return "std::string"; }
result_type operator()(double) { return "double"; }
result_type operator()(std::nullptr_t) { return "std::nullptr_t"; }
result_type operator()(std::shared_ptr<std::vector<int>>) { return "std::shared_ptr<std::vector<int>>"; }
};

// test_variant
Expand All @@ -221,6 +229,9 @@ TEST_SUBMODULE(stl, m) {
using V = variant<int, std::string>;
return py::make_tuple(V(5), V("Hello"));
});
m.def("load_variant_with_shared", [](variant<int, std::shared_ptr<std::vector<int>>> v) {
return py::detail::visit_helper<variant>::call(visitor(), v);
});
#endif

// #528: templated constructor
Expand Down
18 changes: 18 additions & 0 deletions tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ def test_vector(doc):
assert m.cast_ptr_vector() == ["lvalue", "lvalue"]


def test_vector_shared(doc):
"""std::shared_ptr<std::vector> <-> list"""
lst = m.cast_vector_shared()
assert lst == [1]
lst.append(2)
assert m.load_vector_shared(lst)
assert m.load_vector_shared(tuple(lst))


def test_vector_unique(doc):
"""std::unique_ptr<std::vector> -> list"""
lst = m.cast_vector_unique()
assert lst == [1]


def test_deque(doc):
"""std::deque <-> list"""
lst = m.cast_deque()
Expand Down Expand Up @@ -159,6 +174,9 @@ def test_variant(doc):
assert m.load_variant_2pass(1) == "int"
assert m.load_variant_2pass(1.0) == "double"

assert m.load_variant_with_shared(1) == "int"
assert m.load_variant_with_shared([1, 2]) == "std::shared_ptr<std::vector<int>>"

assert m.cast_variant() == (5, "Hello")

assert doc(m.load_variant) == "load_variant(arg0: Union[int, str, float, None]) -> str"
Expand Down
3 changes: 3 additions & 0 deletions tests/test_stl_binders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ TEST_SUBMODULE(stl_binders, m) {
// test_vector_int
py::bind_vector<std::vector<unsigned int>>(m, "VectorInt", py::buffer_protocol());

// test_vector_double_shared
py::bind_vector<std::vector<double>, std::shared_ptr<std::vector<double>>>(m, "VectorDoubleShared");

// test_vector_custom
py::class_<El>(m, "El")
.def(py::init<int>());
Expand Down
59 changes: 59 additions & 0 deletions tests/test_stl_binders.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,65 @@ def test_vector_int():
del v_int2[-1]
assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88])


def test_vector_double_shared():
v_dbl = m.VectorDoubleShared([0.0, 0.0])
assert len(v_dbl) == 2
assert bool(v_dbl) is True

# test construction from a generator
v_dbl1 = m.VectorDoubleShared(x for x in range(5))
assert v_dbl1 == m.VectorDoubleShared([0.0, 1.0, 2.0, 3.0, 4.0])

v_dbl2 = m.VectorDoubleShared([0.0, 0.0])
assert v_dbl == v_dbl2
v_dbl2[1] = 1
assert v_dbl != v_dbl2

v_dbl2.append(2)
v_dbl2.insert(0, 1.0)
v_dbl2.insert(0, 2.0)
v_dbl2.insert(0, 3.0)
v_dbl2.insert(6, 3.0)
assert str(v_dbl2) == "VectorDoubleShared[3, 2, 1, 0, 1, 2, 3]"
with pytest.raises(IndexError):
v_dbl2.insert(8, 4)

v_dbl.append(99.2)
v_dbl2[2:-2] = v_dbl
assert v_dbl2 == m.VectorDoubleShared([3, 2, 0, 0, 99.2, 2, 3])
del v_dbl2[1:3]
assert v_dbl2 == m.VectorDoubleShared([3, 0, 99.2, 2, 3])
del v_dbl2[0]
assert v_dbl2 == m.VectorDoubleShared([0, 99.2, 2, 3])

v_dbl2.extend(m.VectorDoubleShared([4, 5]))
assert v_dbl2 == m.VectorDoubleShared([0, 99.2, 2, 3, 4, 5])

v_dbl2.extend([6, 7])
assert v_dbl2 == m.VectorDoubleShared([0, 99.2, 2, 3, 4, 5, 6, 7])

# test error handling, and that the vector is unchanged
with pytest.raises(RuntimeError):
v_dbl2.extend([8, 'a'])

assert v_dbl2 == m.VectorDoubleShared([0, 99.2, 2, 3, 4, 5, 6, 7])

# test extending from a generator
v_dbl2.extend(x for x in range(5))
assert v_dbl2 == m.VectorDoubleShared([0, 99.2, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4])

# test negative indexing
assert v_dbl2[-1] == 4

# insert with negative index
v_dbl2.insert(-1, 88)
assert v_dbl2 == m.VectorDoubleShared([0, 99.2, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88, 4])

# delete negative index
del v_dbl2[-1]
assert v_dbl2 == m.VectorDoubleShared([0, 99.2, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88])

# related to the PyPy's buffer protocol.
@pytest.unsupported_on_pypy
def test_vector_buffer():
Expand Down