Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CallArgs struct #232

Merged
merged 5 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 3 additions & 24 deletions pydust/src/functions.zig
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,9 @@ pub fn unwrapArgs(comptime Args: type, pyargs: py.Args, pykwargs: py.Kwargs) !Ar
var args: Args = undefined;

const s = @typeInfo(Args).Struct;
var varargsFieldIdx: usize = undefined;
var varkwargsFieldIdx: usize = undefined;
var argIdx: usize = 0;
inline for (s.fields, 0..) |field, fieldIdx| {
if (field.type == py.Args) {
// Variadic args
varargsFieldIdx = fieldIdx;
} else if (field.type == py.Kwargs) {
// Variadic kwargs
varkwargsFieldIdx = fieldIdx;
} else if (field.default_value) |def_value| {
inline for (s.fields) |field| {
if (field.default_value) |def_value| {
// We have a kwarg.
if (kwargs.fetchRemove(field.name)) |entry| {
@field(args, field.name) = try py.as(field.type, entry.value);
Expand All @@ -334,7 +326,7 @@ pub fn unwrapArgs(comptime Args: type, pyargs: py.Args, pykwargs: py.Kwargs) !Ar
const defaultValue: *field.type = @alignCast(@ptrCast(@constCast(def_value)));
@field(args, field.name) = defaultValue.*;
}
} else {
} else if (field.type != py.Args and field.type != py.Kwargs) {
// Otherwise, we have a regular argument.
if (argIdx >= pyargs.len) {
return py.TypeError.raiseFmt("Expected {d} arg{s}", .{
Expand Down Expand Up @@ -366,19 +358,6 @@ pub fn unwrapArgs(comptime Args: type, pyargs: py.Args, pykwargs: py.Kwargs) !Ar
return args;
}

pub fn deinitArgs(comptime Args: type, args: Args) void {
const s = @typeInfo(Args).Struct;
inline for (s.fields) |field| {
if (field.type == py.Args) {
py.allocator.free(@field(args, field.name));
}
if (field.type == py.Kwargs) {
var kwargs: py.Kwargs = @field(args, field.name);
kwargs.deinit();
}
}
}

pub fn Methods(comptime definition: type) type {
const empty = ffi.PyMethodDef{ .ml_name = null, .ml_meth = null, .ml_flags = 0, .ml_doc = null };

Expand Down
8 changes: 4 additions & 4 deletions pydust/src/pytypes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
const kwargs = if (pykwargs) |pk| py.PyDict.unchecked(.{ .py = pk }) else null;

const init_args = tramp.Trampoline(Args).unwrapCallArgs(args, kwargs) catch return -1;
defer funcs.deinitArgs(Args, init_args);
defer init_args.deinit();

tramp.coerceError(definition.__init__(self, init_args)) catch return -1;
tramp.coerceError(definition.__init__(self, init_args.argsStruct)) catch return -1;
} else if (sig.selfParam) |_| {
tramp.coerceError(definition.__init__(self)) catch return -1;
} else {
Expand Down Expand Up @@ -364,9 +364,9 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {

const self = tramp.Trampoline(sig.selfParam.?).unwrap(py.PyObject{ .py = pyself }) catch return null;
const call_args = tramp.Trampoline(sig.argsParam.?).unwrapCallArgs(args, kwargs) catch return null;
defer funcs.deinitArgs(sig.argsParam.?, call_args);
defer call_args.deinit();

const result = tramp.coerceError(definition.__call__(self, call_args)) catch return null;
const result = tramp.coerceError(definition.__call__(self, call_args.argsStruct)) catch return null;
return (py.createOwned(result) catch return null).py;
}
};
Expand Down
58 changes: 42 additions & 16 deletions pydust/src/trampoline.zig
Original file line number Diff line number Diff line change
Expand Up @@ -275,27 +275,53 @@ pub fn Trampoline(comptime T: type) type {

// Unwrap the call args into a Pydust argument struct, borrowing references to the Python objects
// but instantiating the args slice and kwargs map containers.
// The caller is responsible for invoking funcs.deinitArgs on the returned struct.
pub inline fn unwrapCallArgs(pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) PyError!T {
var kwargs = py.Kwargs.init(py.allocator);
if (pykwargs) |kw| {
var iter = kw.itemsIterator();
while (iter.next()) |item| {
const key: []const u8 = try (try py.PyString.checked(item.k)).asSlice();
try kwargs.put(key, item.v);
// The caller is responsible for invoking deinit on the returned struct.
pub inline fn unwrapCallArgs(pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) PyError!ZigCallArgs {
return ZigCallArgs.unwrap(pyargs, pykwargs);
}

const ZigCallArgs = struct {
argsStruct: T,
allPosArgs: []py.PyObject,

pub fn unwrap(pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) PyError!@This() {
var kwargs = py.Kwargs.init(py.allocator);
if (pykwargs) |kw| {
var iter = kw.itemsIterator();
while (iter.next()) |item| {
const key: []const u8 = try (try py.PyString.checked(item.k)).asSlice();
try kwargs.put(key, item.v);
}
}
}

const args = try py.allocator.alloc(py.PyObject, if (pyargs) |a| a.length() else 0);
defer py.allocator.free(args);
if (pyargs) |a| {
for (0..a.length()) |i| {
args[i] = try a.getItem(py.PyObject, i);
const args = try py.allocator.alloc(py.PyObject, if (pyargs) |a| a.length() else 0);
if (pyargs) |a| {
for (0..a.length()) |i| {
args[i] = try a.getItem(py.PyObject, i);
}
}

return .{ .argsStruct = try funcs.unwrapArgs(T, args, kwargs), .allPosArgs = args };
}

return funcs.unwrapArgs(T, args, kwargs);
}
pub fn deinit(self: @This()) void {
if (comptime funcs.varArgsIdx(T)) |idx| {
py.allocator.free(self.allPosArgs[0..idx]);
} else {
py.allocator.free(self.allPosArgs);
}

inline for (@typeInfo(T).Struct.fields) |field| {
if (field.type == py.Args) {
py.allocator.free(@field(self.argsStruct, field.name));
}
if (field.type == py.Kwargs) {
var kwargs: py.Kwargs = @field(self.argsStruct, field.name);
kwargs.deinit();
}
}
}
};
};
}

Expand Down