Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rath3t committed Aug 21, 2024
1 parent d21cee3 commit e0be5db
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
24 changes: 12 additions & 12 deletions include/pybind11/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "pybind11.h"

#include <functional>
#include <iostream>

PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail)
Expand Down Expand Up @@ -129,24 +130,23 @@ struct type_caster<std::function<Return(Args...)>> {
// See PR #1413 for full details
} else {
// Check number of arguments of Python function
auto argCountFromFuncCode = [&](handle &obj) {
// This is faster then doing import inspect and
// inspect.signature(obj).parameters

object argCount = obj.attr("co_argcount");
return argCount.template cast<size_t>();
auto get_argument_count = [](const handle &obj) -> size_t {
// Faster then `import inspect` and `inspect.signature(obj).parameters`
return obj.attr("co_argcount").cast<size_t>();
};
size_t argCount = 0;

handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__");
handle empty;
object codeAttr = getattr(src, "__code__", empty);

if (codeAttr) {
argCount = argCountFromFuncCode(codeAttr);
argCount = get_argument_count(codeAttr);
} else {
handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__");
object callAttr = getattr(src, "__call__", empty);

if (callAttr) {
handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__");
argCount = argCountFromFuncCode(codeAttr2)
- 1; // we have to remove the self argument
object codeAttr2 = getattr(callAttr, "__code__");
argCount = get_argument_count(codeAttr2) - 1; // removing the self argument
} else {
// No __code__ or __call__ attribute, this is not a proper Python function
return false;
Expand Down
3 changes: 2 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def __call__(self, a):
return a

assert m.dummy_function_overloaded_std_func_arg(f) == 9
assert m.dummy_function_overloaded_std_func_arg(A()) == 9
a = A()
assert m.dummy_function_overloaded_std_func_arg(a) == 9
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9

def f2(a, b):
Expand Down

0 comments on commit e0be5db

Please sign in to comment.