diff --git a/pydust/src/conversions.zig b/pydust/src/conversions.zig index 29b5d968..0c8ce03d 100644 --- a/pydust/src/conversions.zig +++ b/pydust/src/conversions.zig @@ -12,6 +12,8 @@ const py = @import("./pydust.zig"); const tramp = @import("./trampoline.zig"); +const pytypes = @import("./pytypes.zig"); +const State = @import("./discovery.zig").State; /// Zig PyObject-like -> ffi.PyObject. Convert a Zig PyObject-like value into a py.PyObject. /// e.g. py.PyObject, py.PyTuple, ffi.PyObject, etc. @@ -38,6 +40,30 @@ pub inline fn as(comptime T: type, obj: anytype) py.PyError!T { return tramp.Trampoline(T).unwrap(object(obj)); } +/// Python -> Pydust. Perform a checked cast from a PyObject to a given PyDust class type. +pub inline fn checked(comptime T: type, obj: py.PyObject) py.PyError!T { + const definition = State.getDefinition(@typeInfo(T).Pointer.child); + if (definition.type != .class) { + @compileError("Can only perform checked cast into a PyDust class type"); + } + + // TODO(ngates): to perform fast type checking, we need to store our PyType on the parent module. + // See how the Python JSON module did this: https://github.com/python/cpython/commit/33f15a16d40cb8010a8c758952cbf88d7912ee2d#diff-efe183ae0b85e5b8d9bbbc588452dd4de80b39fd5c5174ee499ba554217a39edR1814 + // For now, we perform a slow import/isinstance check by using the `as` conversion. + return as(T, obj); +} + +/// Python -> Pydust. Perform an unchecked cast from a PyObject to a given PyDust class type. +pub inline fn unchecked(comptime T: type, obj: py.PyObject) T { + const Definition = @typeInfo(T).Pointer.child; + const definition = State.getDefinition(Definition); + if (definition.type != .class) { + @compileError("Can only perform unchecked cast into a PyDust class type. Found " ++ @typeName(Definition)); + } + const instance: *pytypes.PyTypeStruct(Definition) = @ptrCast(@alignCast(obj.py)); + return &instance.state; +} + const testing = @import("std").testing; const expect = testing.expect; diff --git a/pydust/src/functions.zig b/pydust/src/functions.zig index 58fe4941..090f2a67 100644 --- a/pydust/src/functions.zig +++ b/pydust/src/functions.zig @@ -33,6 +33,13 @@ pub const Signature = struct { pub fn supportsKwargs(comptime self: @This()) bool { return self.nkwargs > 0 or self.varkwargsIdx != null; } + + pub fn isModuleMethod(comptime self: @This()) bool { + if (self.selfParam) |Self| { + return State.getDefinition(@typeInfo(Self).Pointer.child).type == .module; + } + return false; + } }; pub const BinaryOperators = std.ComptimeStringMap(c_int, .{ @@ -257,7 +264,8 @@ pub fn wrap(comptime definition: type, comptime func: anytype, comptime sig: Sig } inline fn internal(pyself: py.PyObject, pyargs: []py.PyObject) PyError!py.PyObject { - const self = if (sig.selfParam) |Self| try py.as(Self, pyself) else null; + const self = if (sig.selfParam) |Self| try castSelf(Self, pyself) else null; + if (sig.argsParam) |Args| { const args = try unwrapArgs(Args, pyargs, py.Kwargs.init(py.allocator)); const result = if (sig.selfParam) |_| func(self, args) else func(args); @@ -302,10 +310,19 @@ pub fn wrap(comptime definition: type, comptime func: anytype, comptime sig: Sig pykwargs: py.Kwargs, ) PyError!py.PyObject { const args = try unwrapArgs(sig.argsParam.?, pyargs, pykwargs); - const self = if (sig.selfParam) |Self| try py.as(Self, pyself) else null; + const self = if (sig.selfParam) |Self| try castSelf(Self, pyself) else null; const result = if (sig.selfParam) |_| func(self, args) else func(args); return py.createOwned(tramp.coerceError(result)); } + + inline fn castSelf(comptime Self: type, pyself: py.PyObject) !Self { + if (comptime sig.isModuleMethod()) { + const mod = py.PyModule{ .obj = pyself }; + return try mod.getState(@typeInfo(Self).Pointer.child); + } else { + return py.unchecked(Self, pyself); + } + } }; } diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 421955a3..7175c478 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -666,7 +666,7 @@ fn RichCompare(comptime definition: type) type { const CompareOpArg = typeInfo.params[2].type.?; if (CompareOpArg != py.CompareOp) @compileError("Third parameter of __richcompare__ must be a py.CompareOp"); - const self = py.as(Self, pyself) catch return null; + const self = py.unchecked(Self, .{ .py = pyself }); const otherArg = tramp.Trampoline(Other).unwrap(.{ .py = pyother }) catch return null; const opEnum: py.CompareOp = @enumFromInt(op); diff --git a/pydust/src/types/module.zig b/pydust/src/types/module.zig index 97a2639f..81d4cb65 100644 --- a/pydust/src/types/module.zig +++ b/pydust/src/types/module.zig @@ -31,7 +31,7 @@ pub const PyModule = extern struct { return .{ .obj = .{ .py = ffi.PyImport_ImportModule(name) orelse return PyError.PyRaised } }; } - pub fn getState(self: PyModule, comptime state: type) !*state { + pub fn getState(self: PyModule, comptime ModState: type) !*ModState { const statePtr = ffi.PyModule_GetState(self.obj.py) orelse return PyError.PyRaised; return @ptrCast(@alignCast(statePtr)); }