Skip to content

Commit

Permalink
Add support for unary functions (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Oct 31, 2023
1 parent f88bfba commit e6842e5
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 11 deletions.
11 changes: 10 additions & 1 deletion docs/guide/classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ Also note the shorthand signatures:

```zig
const binaryfunc = fn(*Self, object) !object;
const unaryfunc = fn(*Self) !object;
const inquiry = fn(*Self) !bool;
```

### Type Methods
Expand Down Expand Up @@ -210,7 +212,6 @@ The remaining mapping methods are yet to be implemented.
| `__gt__` | `#!zig fn(*Self, object) !bool` |
| `__ge__` | `#!zig fn(*Self, object) !bool` |


!!! note

By default, `__ne__` will delegate to the negation of `__eq__` if it is defined.
Expand Down Expand Up @@ -258,6 +259,14 @@ to implement the full comparison logic in a single `__richcompare__` function.
| `__ifloordiv__` | `binaryfunc` |
| `__matmul__` | `binaryfunc` |
| `__imatmul__` | `binaryfunc` |
| `__neg__` | `unaryfunc` |
| `__pos__` | `unaryfunc` |
| `__abs__` | `unaryfunc` |
| `__invert__` | `unaryfunc` |
| `__int__` | `unaryfunc` |
| `__float__` | `unaryfunc` |
| `__index__` | `unaryfunc` |
| `__bool__` | `inquiry` |

!!! note

Expand Down
45 changes: 45 additions & 0 deletions example/operators.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,48 @@ class Ops:
"""
...
def num(self, /): ...

class UnaryOps:
def __init__(self, num, /):
pass
def __neg__(self, /):
"""
-self
"""
...
def __pos__(self, /):
"""
+self
"""
...
def __abs__(self, /):
"""
abs(self)
"""
...
def __bool__(self, /):
"""
True if self else False
"""
...
def __invert__(self, /):
"""
~self
"""
...
def __int__(self, /):
"""
int(self)
"""
...
def __float__(self, /):
"""
float(self)
"""
...
def __index__(self, /):
"""
Return self converted to an integer, if self is suitable for use as an index into a list.
"""
...
def num(self, /): ...
47 changes: 47 additions & 0 deletions example/operators.zig
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,53 @@ pub const Ops = py.class(struct {
});
// --8<-- [end:all]

pub const UnaryOps = py.class(struct {
const Self = @This();

num: i64,

pub fn __init__(self: *Self, args: struct { num: i64 }) !void {
self.num = args.num;
}

pub fn num(self: *const Self) i64 {
return self.num;
}

pub fn __neg__(self: *Self) !py.PyLong {
return py.PyLong.create(-self.num);
}

pub fn __pos__(self: *Self) !*Self {
py.incref(self);
return self;
}

pub fn __abs__(self: *Self) !*Self {
return py.init(Self, .{ .num = @as(i64, @intCast(std.math.absCast(self.num))) });
}

pub fn __invert__(self: *Self) !*Self {
return py.init(Self, .{ .num = ~self.num });
}

pub fn __int__(self: *Self) !py.PyLong {
return py.PyLong.create(self.num);
}

pub fn __float__(self: *Self) !py.PyFloat {
return py.PyFloat.create(@as(f64, @floatFromInt(self.num)));
}

pub fn __index__(self: *Self) !py.PyLong {
return py.PyLong.create(self.num);
}

pub fn __bool__(self: *Self) !bool {
return self.num == 1;
}
});

// --8<-- [start:ops]
pub const Operator = py.class(struct {
const Self = @This();
Expand Down
30 changes: 20 additions & 10 deletions pydust/src/functions.zig
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ pub const Signature = struct {
}
};

pub const UnaryOperators = std.ComptimeStringMap(c_int, .{
.{ "__neg__", ffi.Py_nb_negative },
.{ "__pos__", ffi.Py_nb_positive },
.{ "__abs__", ffi.Py_nb_absolute },
.{ "__invert__", ffi.Py_nb_invert },
.{ "__int__", ffi.Py_nb_int },
.{ "__float__", ffi.Py_nb_float },
.{ "__index__", ffi.Py_nb_index },
});

pub const BinaryOperators = std.ComptimeStringMap(c_int, .{
.{ "__add__", ffi.Py_nb_add },
.{ "__iadd__", ffi.Py_nb_inplace_add },
Expand Down Expand Up @@ -72,7 +82,6 @@ pub const BinaryOperators = std.ComptimeStringMap(c_int, .{
.{ "__imatmul__", ffi.Py_nb_inplace_matrix_multiply },
.{ "__getitem__", ffi.Py_mp_subscript },
});
pub const NBinaryOperators = BinaryOperators.kvs.len;

// TODO(marko): Move this somewhere.
fn keys(comptime stringMap: type) [stringMap.kvs.len][]const u8 {
Expand All @@ -93,19 +102,20 @@ pub const compareFuncs = .{
};

const reservedNames = .{
"__new__",
"__init__",
"__len__",
"__del__",
"__bool__",
"__buffer__",
"__str__",
"__repr__",
"__release_buffer__",
"__del__",
"__hash__",
"__init__",
"__iter__",
"__len__",
"__new__",
"__next__",
"__hash__",
"__release_buffer__",
"__repr__",
"__richcompare__",
} ++ compareFuncs ++ keys(BinaryOperators);
"__str__",
} ++ compareFuncs ++ keys(BinaryOperators) ++ keys(UnaryOperators);

/// Parse the arguments of a Zig function into a Pydust function siganture.
pub fn parseSignature(comptime name: []const u8, comptime func: Type.Fn, comptime SelfTypes: []const type) Signature {
Expand Down
43 changes: 43 additions & 0 deletions pydust/src/pytypes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
}};
}

if (@hasDecl(definition, "__bool__")) {
slots_ = slots_ ++ .{ffi.PyType_Slot{
.slot = ffi.Py_nb_bool,
.pfunc = @ptrCast(@constCast(&nb_bool)),
}};
}

if (richcmp.hasCompare) {
slots_ = slots_ ++ .{ffi.PyType_Slot{
.slot = ffi.Py_tp_richcompare,
Expand All @@ -232,6 +239,16 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
}
}

for (funcs.UnaryOperators.kvs) |kv| {
if (@hasDecl(definition, kv.key)) {
const op = UnaryOperator(definition, kv.key);
slots_ = slots_ ++ .{ffi.PyType_Slot{
.slot = kv.value,
.pfunc = @ptrCast(@constCast(&op.call)),
}};
}
}

slots_ = slots_ ++ .{ffi.PyType_Slot{
.slot = ffi.Py_tp_methods,
.pfunc = @ptrCast(@constCast(&methods.pydefs)),
Expand Down Expand Up @@ -369,6 +386,12 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
const result = tramp.coerceError(definition.__call__(self, call_args.argsStruct)) catch return null;
return (py.createOwned(result) catch return null).py;
}

fn nb_bool(pyself: *ffi.PyObject) callconv(.C) c_int {
const self: *PyTypeStruct(definition) = @ptrCast(pyself);
const result = tramp.coerceError(definition.__bool__(&self.state)) catch return -1;
return @intCast(@intFromBool(result));
}
};
}

Expand Down Expand Up @@ -593,6 +616,26 @@ fn BinaryOperator(
};
}

fn UnaryOperator(
comptime definition: type,
comptime op: []const u8,
) type {
return struct {
fn call(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject {
const func = @field(definition, op);
const typeInfo = @typeInfo(@TypeOf(func)).Fn;

if (typeInfo.params.len != 1) @compileError(op ++ " must take exactly one parameter");

// TODO(ngates): do we want to trampoline the self argument?
const self: *PyTypeStruct(definition) = @ptrCast(pyself);

const result = tramp.coerceError(func(&self.state)) catch return null;
return (py.createOwned(result) catch return null).py;
}
};
}

fn EqualsOperator(
comptime definition: type,
comptime op: []const u8,
Expand Down
23 changes: 23 additions & 0 deletions test/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ def test_iops(iop, expected):
assert ops.num() == expected


@pytest.mark.parametrize(
"op,expected",
[
(operator.pos, -3),
(operator.neg, 3),
(operator.invert, 2),
(operator.index, -3),
(operator.abs, 3),
(bool, False),
(int, -3),
(float, -3.0),
],
)
def test_unaryops(op, expected):
ops = operators.UnaryOps(-3)
res = op(ops)

if isinstance(res, operators.UnaryOps):
assert res.num() == expected
else:
assert res == expected


def test_divmod():
ops = operators.Ops(3)
other = operators.Ops(2)
Expand Down

0 comments on commit e6842e5

Please sign in to comment.