From e6842e52971922aace48795d6e4133dc9d6f6cd3 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 31 Oct 2023 17:35:22 +0000 Subject: [PATCH] Add support for unary functions (#235) --- docs/guide/classes.md | 11 +++++++++- example/operators.pyi | 45 ++++++++++++++++++++++++++++++++++++++ example/operators.zig | 47 ++++++++++++++++++++++++++++++++++++++++ pydust/src/functions.zig | 30 ++++++++++++++++--------- pydust/src/pytypes.zig | 43 ++++++++++++++++++++++++++++++++++++ test/test_operators.py | 23 ++++++++++++++++++++ 6 files changed, 188 insertions(+), 11 deletions(-) diff --git a/docs/guide/classes.md b/docs/guide/classes.md index 982de63c..ee422db2 100644 --- a/docs/guide/classes.md +++ b/docs/guide/classes.md @@ -167,6 +167,8 @@ Also note the shorthand signatures: ```zig const binaryfunc = fn(*Self, object) !object; +const unaryfunc = fn(*Self) !object; +const inquiry = fn(*Self) !bool; ``` ### Type Methods @@ -210,7 +212,6 @@ The remaining mapping methods are yet to be implemented. | `__gt__` | `#!zig fn(*Self, object) !bool` | | `__ge__` | `#!zig fn(*Self, object) !bool` | - !!! note By default, `__ne__` will delegate to the negation of `__eq__` if it is defined. @@ -258,6 +259,14 @@ to implement the full comparison logic in a single `__richcompare__` function. | `__ifloordiv__` | `binaryfunc` | | `__matmul__` | `binaryfunc` | | `__imatmul__` | `binaryfunc` | +| `__neg__` | `unaryfunc` | +| `__pos__` | `unaryfunc` | +| `__abs__` | `unaryfunc` | +| `__invert__` | `unaryfunc` | +| `__int__` | `unaryfunc` | +| `__float__` | `unaryfunc` | +| `__index__` | `unaryfunc` | +| `__bool__` | `inquiry` | !!! note diff --git a/example/operators.pyi b/example/operators.pyi index cf1c5fe6..9b2b1fb0 100644 --- a/example/operators.pyi +++ b/example/operators.pyi @@ -326,3 +326,48 @@ class Ops: """ ... def num(self, /): ... + +class UnaryOps: + def __init__(self, num, /): + pass + def __neg__(self, /): + """ + -self + """ + ... + def __pos__(self, /): + """ + +self + """ + ... + def __abs__(self, /): + """ + abs(self) + """ + ... + def __bool__(self, /): + """ + True if self else False + """ + ... + def __invert__(self, /): + """ + ~self + """ + ... + def __int__(self, /): + """ + int(self) + """ + ... + def __float__(self, /): + """ + float(self) + """ + ... + def __index__(self, /): + """ + Return self converted to an integer, if self is suitable for use as an index into a list. + """ + ... + def num(self, /): ... diff --git a/example/operators.zig b/example/operators.zig index 4d3d4648..09d6f12c 100644 --- a/example/operators.zig +++ b/example/operators.zig @@ -165,6 +165,53 @@ pub const Ops = py.class(struct { }); // --8<-- [end:all] +pub const UnaryOps = py.class(struct { + const Self = @This(); + + num: i64, + + pub fn __init__(self: *Self, args: struct { num: i64 }) !void { + self.num = args.num; + } + + pub fn num(self: *const Self) i64 { + return self.num; + } + + pub fn __neg__(self: *Self) !py.PyLong { + return py.PyLong.create(-self.num); + } + + pub fn __pos__(self: *Self) !*Self { + py.incref(self); + return self; + } + + pub fn __abs__(self: *Self) !*Self { + return py.init(Self, .{ .num = @as(i64, @intCast(std.math.absCast(self.num))) }); + } + + pub fn __invert__(self: *Self) !*Self { + return py.init(Self, .{ .num = ~self.num }); + } + + pub fn __int__(self: *Self) !py.PyLong { + return py.PyLong.create(self.num); + } + + pub fn __float__(self: *Self) !py.PyFloat { + return py.PyFloat.create(@as(f64, @floatFromInt(self.num))); + } + + pub fn __index__(self: *Self) !py.PyLong { + return py.PyLong.create(self.num); + } + + pub fn __bool__(self: *Self) !bool { + return self.num == 1; + } +}); + // --8<-- [start:ops] pub const Operator = py.class(struct { const Self = @This(); diff --git a/pydust/src/functions.zig b/pydust/src/functions.zig index 6bb3f04d..55ef9166 100644 --- a/pydust/src/functions.zig +++ b/pydust/src/functions.zig @@ -42,6 +42,16 @@ pub const Signature = struct { } }; +pub const UnaryOperators = std.ComptimeStringMap(c_int, .{ + .{ "__neg__", ffi.Py_nb_negative }, + .{ "__pos__", ffi.Py_nb_positive }, + .{ "__abs__", ffi.Py_nb_absolute }, + .{ "__invert__", ffi.Py_nb_invert }, + .{ "__int__", ffi.Py_nb_int }, + .{ "__float__", ffi.Py_nb_float }, + .{ "__index__", ffi.Py_nb_index }, +}); + pub const BinaryOperators = std.ComptimeStringMap(c_int, .{ .{ "__add__", ffi.Py_nb_add }, .{ "__iadd__", ffi.Py_nb_inplace_add }, @@ -72,7 +82,6 @@ pub const BinaryOperators = std.ComptimeStringMap(c_int, .{ .{ "__imatmul__", ffi.Py_nb_inplace_matrix_multiply }, .{ "__getitem__", ffi.Py_mp_subscript }, }); -pub const NBinaryOperators = BinaryOperators.kvs.len; // TODO(marko): Move this somewhere. fn keys(comptime stringMap: type) [stringMap.kvs.len][]const u8 { @@ -93,19 +102,20 @@ pub const compareFuncs = .{ }; const reservedNames = .{ - "__new__", - "__init__", - "__len__", - "__del__", + "__bool__", "__buffer__", - "__str__", - "__repr__", - "__release_buffer__", + "__del__", + "__hash__", + "__init__", "__iter__", + "__len__", + "__new__", "__next__", - "__hash__", + "__release_buffer__", + "__repr__", "__richcompare__", -} ++ compareFuncs ++ keys(BinaryOperators); + "__str__", +} ++ compareFuncs ++ keys(BinaryOperators) ++ keys(UnaryOperators); /// Parse the arguments of a Zig function into a Pydust function siganture. pub fn parseSignature(comptime name: []const u8, comptime func: Type.Fn, comptime SelfTypes: []const type) Signature { diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 1a95b0dd..5ab80428 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -215,6 +215,13 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type { }}; } + if (@hasDecl(definition, "__bool__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_nb_bool, + .pfunc = @ptrCast(@constCast(&nb_bool)), + }}; + } + if (richcmp.hasCompare) { slots_ = slots_ ++ .{ffi.PyType_Slot{ .slot = ffi.Py_tp_richcompare, @@ -232,6 +239,16 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type { } } + for (funcs.UnaryOperators.kvs) |kv| { + if (@hasDecl(definition, kv.key)) { + const op = UnaryOperator(definition, kv.key); + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = kv.value, + .pfunc = @ptrCast(@constCast(&op.call)), + }}; + } + } + slots_ = slots_ ++ .{ffi.PyType_Slot{ .slot = ffi.Py_tp_methods, .pfunc = @ptrCast(@constCast(&methods.pydefs)), @@ -369,6 +386,12 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type { const result = tramp.coerceError(definition.__call__(self, call_args.argsStruct)) catch return null; return (py.createOwned(result) catch return null).py; } + + fn nb_bool(pyself: *ffi.PyObject) callconv(.C) c_int { + const self: *PyTypeStruct(definition) = @ptrCast(pyself); + const result = tramp.coerceError(definition.__bool__(&self.state)) catch return -1; + return @intCast(@intFromBool(result)); + } }; } @@ -593,6 +616,26 @@ fn BinaryOperator( }; } +fn UnaryOperator( + comptime definition: type, + comptime op: []const u8, +) type { + return struct { + fn call(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { + const func = @field(definition, op); + const typeInfo = @typeInfo(@TypeOf(func)).Fn; + + if (typeInfo.params.len != 1) @compileError(op ++ " must take exactly one parameter"); + + // TODO(ngates): do we want to trampoline the self argument? + const self: *PyTypeStruct(definition) = @ptrCast(pyself); + + const result = tramp.coerceError(func(&self.state)) catch return null; + return (py.createOwned(result) catch return null).py; + } + }; +} + fn EqualsOperator( comptime definition: type, comptime op: []const u8, diff --git a/test/test_operators.py b/test/test_operators.py index 941bf8c2..9a3aacaa 100644 --- a/test/test_operators.py +++ b/test/test_operators.py @@ -71,6 +71,29 @@ def test_iops(iop, expected): assert ops.num() == expected +@pytest.mark.parametrize( + "op,expected", + [ + (operator.pos, -3), + (operator.neg, 3), + (operator.invert, 2), + (operator.index, -3), + (operator.abs, 3), + (bool, False), + (int, -3), + (float, -3.0), + ], +) +def test_unaryops(op, expected): + ops = operators.UnaryOps(-3) + res = op(ops) + + if isinstance(res, operators.UnaryOps): + assert res.num() == expected + else: + assert res == expected + + def test_divmod(): ops = operators.Ops(3) other = operators.Ops(2)