Skip to content

Commit

Permalink
Add py.unchecked for casting PyObject to Pydust class (#230)
Browse files Browse the repository at this point in the history
TODO(ngates): we should store PyType objects on module state and then
auto-traverse them. See #229

Fixes #226, #227, #228
  • Loading branch information
gatesn authored Oct 30, 2023
1 parent 42632fb commit 1ae977b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 4 deletions.
26 changes: 26 additions & 0 deletions pydust/src/conversions.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

const py = @import("./pydust.zig");
const tramp = @import("./trampoline.zig");
const pytypes = @import("./pytypes.zig");
const State = @import("./discovery.zig").State;

/// Zig PyObject-like -> ffi.PyObject. Convert a Zig PyObject-like value into a py.PyObject.
/// e.g. py.PyObject, py.PyTuple, ffi.PyObject, etc.
Expand All @@ -38,6 +40,30 @@ pub inline fn as(comptime T: type, obj: anytype) py.PyError!T {
return tramp.Trampoline(T).unwrap(object(obj));
}

/// Python -> Pydust. Perform a checked cast from a PyObject to a given PyDust class type.
pub inline fn checked(comptime T: type, obj: py.PyObject) py.PyError!T {
const definition = State.getDefinition(@typeInfo(T).Pointer.child);
if (definition.type != .class) {
@compileError("Can only perform checked cast into a PyDust class type");
}

// TODO(ngates): to perform fast type checking, we need to store our PyType on the parent module.
// See how the Python JSON module did this: https://github.com/python/cpython/commit/33f15a16d40cb8010a8c758952cbf88d7912ee2d#diff-efe183ae0b85e5b8d9bbbc588452dd4de80b39fd5c5174ee499ba554217a39edR1814
// For now, we perform a slow import/isinstance check by using the `as` conversion.
return as(T, obj);
}

/// Python -> Pydust. Perform an unchecked cast from a PyObject to a given PyDust class type.
pub inline fn unchecked(comptime T: type, obj: py.PyObject) T {
const Definition = @typeInfo(T).Pointer.child;
const definition = State.getDefinition(Definition);
if (definition.type != .class) {
@compileError("Can only perform unchecked cast into a PyDust class type. Found " ++ @typeName(Definition));
}
const instance: *pytypes.PyTypeStruct(Definition) = @ptrCast(@alignCast(obj.py));
return &instance.state;
}

const testing = @import("std").testing;
const expect = testing.expect;

Expand Down
21 changes: 19 additions & 2 deletions pydust/src/functions.zig
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ pub const Signature = struct {
pub fn supportsKwargs(comptime self: @This()) bool {
return self.nkwargs > 0 or self.varkwargsIdx != null;
}

pub fn isModuleMethod(comptime self: @This()) bool {
if (self.selfParam) |Self| {
return State.getDefinition(@typeInfo(Self).Pointer.child).type == .module;
}
return false;
}
};

pub const BinaryOperators = std.ComptimeStringMap(c_int, .{
Expand Down Expand Up @@ -257,7 +264,8 @@ pub fn wrap(comptime definition: type, comptime func: anytype, comptime sig: Sig
}

inline fn internal(pyself: py.PyObject, pyargs: []py.PyObject) PyError!py.PyObject {
const self = if (sig.selfParam) |Self| try py.as(Self, pyself) else null;
const self = if (sig.selfParam) |Self| try castSelf(Self, pyself) else null;

if (sig.argsParam) |Args| {
const args = try unwrapArgs(Args, pyargs, py.Kwargs.init(py.allocator));
const result = if (sig.selfParam) |_| func(self, args) else func(args);
Expand Down Expand Up @@ -302,10 +310,19 @@ pub fn wrap(comptime definition: type, comptime func: anytype, comptime sig: Sig
pykwargs: py.Kwargs,
) PyError!py.PyObject {
const args = try unwrapArgs(sig.argsParam.?, pyargs, pykwargs);
const self = if (sig.selfParam) |Self| try py.as(Self, pyself) else null;
const self = if (sig.selfParam) |Self| try castSelf(Self, pyself) else null;
const result = if (sig.selfParam) |_| func(self, args) else func(args);
return py.createOwned(tramp.coerceError(result));
}

inline fn castSelf(comptime Self: type, pyself: py.PyObject) !Self {
if (comptime sig.isModuleMethod()) {
const mod = py.PyModule{ .obj = pyself };
return try mod.getState(@typeInfo(Self).Pointer.child);
} else {
return py.unchecked(Self, pyself);
}
}
};
}

Expand Down
2 changes: 1 addition & 1 deletion pydust/src/pytypes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ fn RichCompare(comptime definition: type) type {
const CompareOpArg = typeInfo.params[2].type.?;
if (CompareOpArg != py.CompareOp) @compileError("Third parameter of __richcompare__ must be a py.CompareOp");

const self = py.as(Self, pyself) catch return null;
const self = py.unchecked(Self, .{ .py = pyself });
const otherArg = tramp.Trampoline(Other).unwrap(.{ .py = pyother }) catch return null;
const opEnum: py.CompareOp = @enumFromInt(op);

Expand Down
2 changes: 1 addition & 1 deletion pydust/src/types/module.zig
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub const PyModule = extern struct {
return .{ .obj = .{ .py = ffi.PyImport_ImportModule(name) orelse return PyError.PyRaised } };
}

pub fn getState(self: PyModule, comptime state: type) !*state {
pub fn getState(self: PyModule, comptime ModState: type) !*ModState {
const statePtr = ffi.PyModule_GetState(self.obj.py) orelse return PyError.PyRaised;
return @ptrCast(@alignCast(statePtr));
}
Expand Down

0 comments on commit 1ae977b

Please sign in to comment.