diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index d09a2cea2c..29638eb821 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -232,6 +232,46 @@ prevent many types of unsupported structures, it is still the user's responsibility to use only "plain" structures that can be safely manipulated as raw memory without violating invariants. +Scalar types +============ + +In some cases we may want to accept or return NumPy scalar values such as +``np.float32`` or ``np.float64``. We hope to be able to handle single-precision +and double-precision on the C-side. However, both are bound to Python's +double-precision builtin float by default, so they cannot be processed separately. +We used the ``py::buffer`` trick to implement the previous approach, which +will cause the readability of the code to drop significantly. + +Luckily, there's a helper type for this occasion - ``py::numpy_scalar``: + +.. code-block:: cpp + + m.def("add", [](py::numpy_scalar a, py::numpy_scalar b) { + return py::make_scalar(a + b); + }); + m.def("add", [](py::numpy_scalar a, py::numpy_scalar b) { + return py::make_scalar(a + b); + }); + +This type is trivially convertible to and from the type it wraps; currently +supported scalar types are NumPy arithmetic types: ``bool_``, ``int8``, +``int16``, ``int32``, ``int64``, ``uint8``, ``uint16``, ``uint32``, +``uint64``, ``float32``, ``float64``, ``complex64``, ``complex128``, all of +them mapping to respective C++ counterparts. + +.. note:: + + This is a strict type, it will only allows to specify NumPy type as input + arguments, and does not allow other types of input parameters (e.g., + ``py::numpy_scalar`` will not accept Python's builtin ``int`` ). + +.. note:: + + Native C types are mapped to NumPy types in a platform specific way: for + instance, ``char`` may be mapped to either ``np.int8`` or ``np.uint8`` + and ``long`` may use 4 or 8 bytes depending on the platform. Unless you + clearly understand the difference and your needs, please use ````. + Vectorizing functions ===================== diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 09894cf74f..6359943558 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -49,6 +49,9 @@ PYBIND11_WARNING_DISABLE_MSVC(4127) class dtype; // Forward declaration class array; // Forward declaration +template +struct numpy_scalar; // Forward declaration + PYBIND11_NAMESPACE_BEGIN(detail) template <> @@ -200,15 +203,15 @@ struct same_size { using as = bool_constant; }; -template +template constexpr int platform_lookup() { return -1; } // Lookup a type according to its size, and return a value corresponding to the NumPy typenum. -template +template constexpr int platform_lookup(int I, Ints... Is) { - return sizeof(Concrete) == sizeof(T) ? I : platform_lookup(Is...); + return sizeof(size) == sizeof(T) ? I : platform_lookup(Is...); } struct npy_api { @@ -249,15 +252,23 @@ struct npy_api { // `npy_common.h` defines the integer aliases. In order, it checks: // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR // and assigns the alias to the first matching size, so we should check in this order. - NPY_INT32_ - = platform_lookup(NPY_LONG_, NPY_INT_, NPY_SHORT_), - NPY_UINT32_ = platform_lookup( + NPY_INT32_ = platform_lookup<4, long, int, short>(NPY_LONG_, NPY_INT_, NPY_SHORT_), + NPY_UINT32_ = platform_lookup<4, unsigned long, unsigned int, unsigned short>( NPY_ULONG_, NPY_UINT_, NPY_USHORT_), - NPY_INT64_ - = platform_lookup(NPY_LONG_, NPY_LONGLONG_, NPY_INT_), - NPY_UINT64_ - = platform_lookup( + NPY_INT64_ = platform_lookup<8, long, long long, int>(NPY_LONG_, NPY_LONGLONG_, NPY_INT_), + NPY_UINT64_ = platform_lookup<8, unsigned long, unsigned long long, unsigned int>( NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_), + NPY_FLOAT32_ + = platform_lookup<4, double, float, long double>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_FLOAT64_ + = platform_lookup<8, double, float, long double>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_COMPLEX64_ + = platform_lookup<8, std::complex, std::complex, std::complex>( + NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_COMPLEX128_ + = platform_lookup<8, std::complex, std::complex, std::complex>( + NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_CHAR_ = std::is_signed::value ? NPY_BYTE_ : NPY_UBYTE_, }; unsigned int PyArray_RUNTIME_VERSION_; @@ -281,6 +292,7 @@ struct npy_api { unsigned int (*PyArray_GetNDArrayCFeatureVersion_)(); PyObject *(*PyArray_DescrFromType_)(int); + PyObject *(*PyArray_TypeObjectFromType_)(int); PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *, PyObject *, int, @@ -297,6 +309,8 @@ struct npy_api { PyTypeObject *PyVoidArrType_Type_; PyTypeObject *PyArrayDescr_Type_; PyObject *(*PyArray_DescrFromScalar_)(PyObject *); + PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *); + void (*PyArray_ScalarAsCtype_)(PyObject *, void *); PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *); int (*PyArray_DescrConverter_)(PyObject *, PyObject **); bool (*PyArray_EquivTypes_)(PyObject *, PyObject *); @@ -324,7 +338,10 @@ struct npy_api { API_PyArrayDescr_Type = 3, API_PyVoidArrType_Type = 39, API_PyArray_DescrFromType = 45, + API_PyArray_TypeObjectFromType = 46, API_PyArray_DescrFromScalar = 57, + API_PyArray_Scalar = 60, + API_PyArray_ScalarAsCtype = 62, API_PyArray_FromAny = 69, API_PyArray_Resize = 80, // CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82. @@ -362,7 +379,10 @@ struct npy_api { DECL_NPY_API(PyVoidArrType_Type); DECL_NPY_API(PyArrayDescr_Type); DECL_NPY_API(PyArray_DescrFromType); + DECL_NPY_API(PyArray_TypeObjectFromType); DECL_NPY_API(PyArray_DescrFromScalar); + DECL_NPY_API(PyArray_Scalar); + DECL_NPY_API(PyArray_ScalarAsCtype); DECL_NPY_API(PyArray_FromAny); DECL_NPY_API(PyArray_Resize); DECL_NPY_API(PyArray_CopyInto); @@ -384,6 +404,88 @@ struct npy_api { } }; +template +struct is_complex : std::false_type {}; +template +struct is_complex> : std::true_type {}; + +template +struct npy_format_descriptor_name; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name = const_name::value>( + const_name("bool"), + const_name::value>("int", "uint") + const_name()); +}; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name + = const_name < std::is_same::value + || std::is_same::value + > (const_name("float") + const_name(), const_name("longdouble")); +}; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name + = const_name < std::is_same::value + || std::is_same::value + > (const_name("complex") + const_name(), + const_name("longcomplex")); +}; + +template +struct numpy_scalar_info {}; + +#define DECL_NPY_SCALAR(ctype_, typenum_) \ + template <> \ + struct numpy_scalar_info { \ + static constexpr auto name = npy_format_descriptor_name::name; \ + static constexpr int typenum = npy_api::typenum_##_; \ + } + +// boolean type +DECL_NPY_SCALAR(bool, NPY_BOOL); + +// character types +DECL_NPY_SCALAR(char, NPY_CHAR); +DECL_NPY_SCALAR(signed char, NPY_BYTE); +DECL_NPY_SCALAR(unsigned char, NPY_UBYTE); + +// signed integer types +DECL_NPY_SCALAR(std::int16_t, NPY_SHORT); +DECL_NPY_SCALAR(std::int32_t, NPY_INT); +DECL_NPY_SCALAR(std::int64_t, NPY_LONG); +#if defined(__linux__) +DECL_NPY_SCALAR(long long, NPY_LONG); +#else +DECL_NPY_SCALAR(long, NPY_LONG); +#endif + +// unsigned integer types +DECL_NPY_SCALAR(std::uint16_t, NPY_USHORT); +DECL_NPY_SCALAR(std::uint32_t, NPY_UINT); +DECL_NPY_SCALAR(std::uint64_t, NPY_ULONG); +#if defined(__linux__) +DECL_NPY_SCALAR(unsigned long long, NPY_ULONG); +#else +DECL_NPY_SCALAR(unsigned long, NPY_ULONG); +#endif + +// floating point types +DECL_NPY_SCALAR(float, NPY_FLOAT); +DECL_NPY_SCALAR(double, NPY_DOUBLE); +DECL_NPY_SCALAR(long double, NPY_LONGDOUBLE); + +// complex types +DECL_NPY_SCALAR(std::complex, NPY_CFLOAT); +DECL_NPY_SCALAR(std::complex, NPY_CDOUBLE); +DECL_NPY_SCALAR(std::complex, NPY_CLONGDOUBLE); + +#undef DECL_NPY_SCALAR + inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast(ptr); } inline const PyArray_Proxy *array_proxy(const void *ptr) { @@ -414,10 +516,6 @@ template struct is_std_array : std::false_type {}; template struct is_std_array> : std::true_type {}; -template -struct is_complex : std::false_type {}; -template -struct is_complex> : std::true_type {}; template struct array_info_scalar { @@ -631,8 +729,59 @@ template struct type_caster> : type_caster> {}; +template +struct type_caster> { + using value_type = T; + using type_info = numpy_scalar_info; + + PYBIND11_TYPE_CASTER(numpy_scalar, type_info::name); + + static handle &target_type() { + static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum); + return tp; + } + + static handle &target_dtype() { + static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum); + return tp; + } + + bool load(handle src, bool) { + if (isinstance(src, target_type())) { + npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value); + return true; + } + return false; + } + + static handle cast(numpy_scalar src, return_value_policy, handle) { + return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr); + } +}; + PYBIND11_NAMESPACE_END(detail) +template +struct numpy_scalar { + using value_type = T; + + value_type value; + + numpy_scalar() = default; + numpy_scalar(value_type value) : value(value) {} + + operator value_type() { return value; } + numpy_scalar &operator=(value_type value) { + this->value = value; + return *this; + } +}; + +template +numpy_scalar make_scalar(T value) { + return numpy_scalar(value); +} + class dtype : public object { public: PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_) @@ -1366,38 +1515,6 @@ struct compare_buffer_info::valu } }; -template -struct npy_format_descriptor_name; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = const_name::value>( - const_name("bool"), - const_name::value>("numpy.int", "numpy.uint") - + const_name()); -}; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = const_name < std::is_same::value - || std::is_same::value - || std::is_same::value - || std::is_same::value - > (const_name("numpy.float") + const_name(), - const_name("numpy.longdouble")); -}; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = const_name < std::is_same::value - || std::is_same::value - || std::is_same::value - || std::is_same::value - > (const_name("numpy.complex") - + const_name(), - const_name("numpy.longcomplex")); -}; - template struct npy_format_descriptor< T, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 03ff39138d..27aeddedef 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -137,6 +137,7 @@ set(PYBIND11_TEST_FILES test_multiple_inheritance test_numpy_array test_numpy_dtypes + test_numpy_scalars test_numpy_vectorize test_opaque_types test_operator_overloading diff --git a/tests/test_numpy_scalars.cpp b/tests/test_numpy_scalars.cpp new file mode 100644 index 0000000000..046a9c07a9 --- /dev/null +++ b/tests/test_numpy_scalars.cpp @@ -0,0 +1,52 @@ +/* + tests/test_numpy_scalars.cpp -- strict NumPy scalars + + Copyright (c) 2021 Steve R. Sun + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include + +#include "pybind11_tests.h" + +#include +#include + +namespace py = pybind11; + +template +struct add { + T x; + add(T x) : x(x) {} + T operator()(T y) const { return static_cast(x + y); } +}; + +template +void register_test(py::module &m, const char *name, F &&func) { + m.def((std::string("test_") + name).c_str(), + [=](py::numpy_scalar v) { + return std::make_tuple(name, py::make_scalar(static_cast(func(v.value)))); + }, + py::arg("x")); +} + +TEST_SUBMODULE(numpy_scalars, m) { + using cfloat = std::complex; + using cdouble = std::complex; + + register_test(m, "bool", [](bool x) { return !x; }); + register_test(m, "int8", add(-8)); + register_test(m, "int16", add(-16)); + register_test(m, "int32", add(-32)); + register_test(m, "int64", add(-64)); + register_test(m, "uint8", add(8)); + register_test(m, "uint16", add(16)); + register_test(m, "uint32", add(32)); + register_test(m, "uint64", add(64)); + register_test(m, "float32", add(0.125f)); + register_test(m, "float64", add(0.25f)); + register_test(m, "complex64", add({0, -0.125f})); + register_test(m, "complex128", add({0, -0.25f})); +} diff --git a/tests/test_numpy_scalars.py b/tests/test_numpy_scalars.py new file mode 100644 index 0000000000..52c2861a1c --- /dev/null +++ b/tests/test_numpy_scalars.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import sys + +import pytest + +from pybind11_tests import numpy_scalars as m + +np = pytest.importorskip("numpy") + +SCALAR_TYPES = { + np.bool_: False, + np.int8: -7, + np.int16: -15, + np.int32: -31, + np.int64: -63, + np.uint8: 9, + np.uint16: 17, + np.uint32: 33, + np.uint64: 65, + np.single: 1.125, + np.double: 1.25, + np.complex64: 1 - 0.125j, + np.complex128: 1 - 0.25j, +} +ALL_TYPES = [int, bool, float, bytes, str] + list(SCALAR_TYPES) + + +def type_name(tp): + try: + return tp.__name__.rstrip("_") + except BaseException: + # no numpy + return str(tp) + + +@pytest.fixture(scope="module", params=list(SCALAR_TYPES), ids=type_name) +def scalar_type(request): + return request.param + + +def expected_signature(tp): + s = "str" if sys.version_info[0] >= 3 else "unicode" + t = type_name(tp) + return f"test_{t}(x: {t}) -> tuple[{s}, {t}]\n" + + +def test_numpy_scalars(scalar_type): + expected = SCALAR_TYPES[scalar_type] + name = type_name(scalar_type) + func = getattr(m, "test_" + name) + assert func.__doc__ == expected_signature(scalar_type) + for tp in ALL_TYPES: + value = tp(1) + if tp is scalar_type: + result = func(value) + assert result[0] == name + assert isinstance(result[1], tp) + assert result[1] == tp(expected) + else: + with pytest.raises(TypeError): + func(value)