diff --git a/pydust/src/functions.zig b/pydust/src/functions.zig index 58fe4941..9d27c7e9 100644 --- a/pydust/src/functions.zig +++ b/pydust/src/functions.zig @@ -315,17 +315,9 @@ pub fn unwrapArgs(comptime Args: type, pyargs: py.Args, pykwargs: py.Kwargs) !Ar var args: Args = undefined; const s = @typeInfo(Args).Struct; - var varargsFieldIdx: usize = undefined; - var varkwargsFieldIdx: usize = undefined; var argIdx: usize = 0; - inline for (s.fields, 0..) |field, fieldIdx| { - if (field.type == py.Args) { - // Variadic args - varargsFieldIdx = fieldIdx; - } else if (field.type == py.Kwargs) { - // Variadic kwargs - varkwargsFieldIdx = fieldIdx; - } else if (field.default_value) |def_value| { + inline for (s.fields) |field| { + if (field.default_value) |def_value| { // We have a kwarg. if (kwargs.fetchRemove(field.name)) |entry| { @field(args, field.name) = try py.as(field.type, entry.value); @@ -334,7 +326,7 @@ pub fn unwrapArgs(comptime Args: type, pyargs: py.Args, pykwargs: py.Kwargs) !Ar const defaultValue: *field.type = @alignCast(@ptrCast(@constCast(def_value))); @field(args, field.name) = defaultValue.*; } - } else { + } else if (field.type != py.Args and field.type != py.Kwargs) { // Otherwise, we have a regular argument. if (argIdx >= pyargs.len) { return py.TypeError.raiseFmt("Expected {d} arg{s}", .{ @@ -366,19 +358,6 @@ pub fn unwrapArgs(comptime Args: type, pyargs: py.Args, pykwargs: py.Kwargs) !Ar return args; } -pub fn deinitArgs(comptime Args: type, args: Args) void { - const s = @typeInfo(Args).Struct; - inline for (s.fields) |field| { - if (field.type == py.Args) { - py.allocator.free(@field(args, field.name)); - } - if (field.type == py.Kwargs) { - var kwargs: py.Kwargs = @field(args, field.name); - kwargs.deinit(); - } - } -} - pub fn Methods(comptime definition: type) type { const empty = ffi.PyMethodDef{ .ml_name = null, .ml_meth = null, .ml_flags = 0, .ml_doc = null }; diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 421955a3..c60eaa1e 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -272,9 +272,9 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type { const kwargs = if (pykwargs) |pk| py.PyDict.unchecked(.{ .py = pk }) else null; const init_args = tramp.Trampoline(Args).unwrapCallArgs(args, kwargs) catch return -1; - defer funcs.deinitArgs(Args, init_args); + defer init_args.deinit(); - tramp.coerceError(definition.__init__(self, init_args)) catch return -1; + tramp.coerceError(definition.__init__(self, init_args.argsStruct)) catch return -1; } else if (sig.selfParam) |_| { tramp.coerceError(definition.__init__(self)) catch return -1; } else { @@ -364,9 +364,9 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type { const self = tramp.Trampoline(sig.selfParam.?).unwrap(py.PyObject{ .py = pyself }) catch return null; const call_args = tramp.Trampoline(sig.argsParam.?).unwrapCallArgs(args, kwargs) catch return null; - defer funcs.deinitArgs(sig.argsParam.?, call_args); + defer call_args.deinit(); - const result = tramp.coerceError(definition.__call__(self, call_args)) catch return null; + const result = tramp.coerceError(definition.__call__(self, call_args.argsStruct)) catch return null; return (py.createOwned(result) catch return null).py; } }; diff --git a/pydust/src/trampoline.zig b/pydust/src/trampoline.zig index dc308988..acb497bd 100644 --- a/pydust/src/trampoline.zig +++ b/pydust/src/trampoline.zig @@ -275,27 +275,53 @@ pub fn Trampoline(comptime T: type) type { // Unwrap the call args into a Pydust argument struct, borrowing references to the Python objects // but instantiating the args slice and kwargs map containers. - // The caller is responsible for invoking funcs.deinitArgs on the returned struct. - pub inline fn unwrapCallArgs(pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) PyError!T { - var kwargs = py.Kwargs.init(py.allocator); - if (pykwargs) |kw| { - var iter = kw.itemsIterator(); - while (iter.next()) |item| { - const key: []const u8 = try (try py.PyString.checked(item.k)).asSlice(); - try kwargs.put(key, item.v); + // The caller is responsible for invoking deinit on the returned struct. + pub inline fn unwrapCallArgs(pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) PyError!ZigCallArgs { + return ZigCallArgs.unwrap(pyargs, pykwargs); + } + + const ZigCallArgs = struct { + argsStruct: T, + allPosArgs: []py.PyObject, + + pub fn unwrap(pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) PyError!@This() { + var kwargs = py.Kwargs.init(py.allocator); + if (pykwargs) |kw| { + var iter = kw.itemsIterator(); + while (iter.next()) |item| { + const key: []const u8 = try (try py.PyString.checked(item.k)).asSlice(); + try kwargs.put(key, item.v); + } } - } - const args = try py.allocator.alloc(py.PyObject, if (pyargs) |a| a.length() else 0); - defer py.allocator.free(args); - if (pyargs) |a| { - for (0..a.length()) |i| { - args[i] = try a.getItem(py.PyObject, i); + const args = try py.allocator.alloc(py.PyObject, if (pyargs) |a| a.length() else 0); + if (pyargs) |a| { + for (0..a.length()) |i| { + args[i] = try a.getItem(py.PyObject, i); + } } + + return .{ .argsStruct = try funcs.unwrapArgs(T, args, kwargs), .allPosArgs = args }; } - return funcs.unwrapArgs(T, args, kwargs); - } + pub fn deinit(self: @This()) void { + if (comptime funcs.varArgsIdx(T)) |idx| { + py.allocator.free(self.allPosArgs[0..idx]); + } else { + py.allocator.free(self.allPosArgs); + } + + inline for (@typeInfo(T).Struct.fields) |field| { + if (field.type == py.Args) { + py.allocator.free(@field(self.argsStruct, field.name)); + } + if (field.type == py.Kwargs) { + var kwargs: py.Kwargs = @field(self.argsStruct, field.name); + kwargs.deinit(); + } + } + } + }; }; }