diff --git a/.gitignore b/.gitignore index e73c965..3389c86 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -zig-cache/ +.zig-cache/ zig-out/ diff --git a/build.zig b/build.zig index 0a034cc..9f5e053 100644 --- a/build.zig +++ b/build.zig @@ -1,20 +1,41 @@ const std = @import("std"); -pub fn build(b: *std.build.Builder) void { - // Standard target options allows the person running `zig build` to choose - // what target to build for. Here we do not override the defaults, which - // means any target is allowed, and the default is native. Other options - // for restricting supported target set are available. +// Although this function looks imperative, note that its job is to +// declaratively construct a build graph that will be executed by an external +// runner. +pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); - // Standard release options allow the person running `zig build` to select - // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. - const mode = b.standardReleaseOptions(); + const lib = b.addStaticLibrary(.{ + .name = "zigpb", + // In this case the main source file is merely a path, however, in more + // complicated build scripts, this could be a generated file. + .root_source_file = b.path("src/root.zig"), + .target = target, + .optimize = optimize, + }); - const exe_tests = b.addTest("test.zig"); - exe_tests.setTarget(target); - exe_tests.setBuildMode(mode); + const module = b.addModule("zigpb", .{ + .root_source_file = b.path("src/root.zig"), + }); + b.installArtifact(lib); + + // Creates a step for unit testing. This only builds the test executable + // but does not run it. + const lib_unit_tests = b.addTest(.{ + .root_source_file = b.path("src/test.zig"), + .target = target, + .optimize = optimize, + }); + lib_unit_tests.root_module.addImport("zigpb", module); + + const run_lib_unit_tests = b.addRunArtifact(lib_unit_tests); + + // Similar to creating the run step earlier, this exposes a `test` step to + // the `zig build --help` menu, providing a way for the user to request + // running the unit tests. const test_step = b.step("test", "Run unit tests"); - test_step.dependOn(&exe_tests.step); + test_step.dependOn(&run_lib_unit_tests.step); } diff --git a/build.zig.zon b/build.zig.zon new file mode 100644 index 0000000..4f9ae1c --- /dev/null +++ b/build.zig.zon @@ -0,0 +1,10 @@ +.{ + .name = "zigpb", + .version = "0.0.0", + .dependencies = .{}, + .paths = .{ + "src", + "build.zig", + "build.zig.zon", + }, +} diff --git a/protobuf.zig b/src/root.zig similarity index 72% rename from protobuf.zig rename to src/root.zig index c220619..aa0a796 100644 --- a/protobuf.zig +++ b/src/root.zig @@ -25,23 +25,6 @@ const FieldDescriptor = struct { encoding: FieldEncoding, }; -/// Convenience wrapper for constructing protobuf maps. -pub fn Map(comptime K: type, comptime V: type) type { - return std.HashMapUnmanaged(K, V, struct { - pub fn hash(_: @This(), key: K) u64 { - var hasher = std.hash.Wyhash.init(0); - std.hash.autoHashStrat(&hasher, key, .Deep); - return hasher.final(); - } - pub fn eql(_: @This(), a: K, b: K) bool { - return if (comptime std.meta.trait.isSlice(K)) - std.mem.eql(std.meta.Child(K), a, b) - else - a == b; - } - }, std.hash_map.default_max_load_percentage); -} - /// A Protobuf wire type - all data is encoded as one of these. const WireType = enum(u3) { varint, @@ -61,16 +44,16 @@ fn encodeVarInt(w: anytype, val: u64) !void { var x = val; while (x != 0) { - const part: u8 = @truncate(u7, x); + const part: u8 = @truncate(x & 0x7f); x >>= 7; - const next: u8 = @boolToInt(x != 0); + const next: u8 = @intFromBool(x != 0); try w.writeByte(next << 7 | part); } } /// Encode a field tag, composed of a field number and associated wire type. fn encodeTag(w: anytype, field_num: u29, wire_type: WireType) !void { - const wire = @enumToInt(wire_type); + const wire = @intFromEnum(wire_type); const val = @as(u32, wire) | @as(u32, field_num) << 3; return encodeVarInt(w, val); } @@ -82,18 +65,18 @@ fn encodeTag(w: anytype, field_num: u29, wire_type: WireType) !void { fn encodeSingleScalar(w: anytype, val: anytype, comptime desc: FieldDescriptor, comptime encode_default: bool, comptime override_default: ?@TypeOf(val), comptime include_tag: bool) !void { const T = @TypeOf(val); - if (@typeInfo(T) == .Enum) { + if (@typeInfo(T) == .@"enum") { if (desc.encoding != .default) @compileError("Enum types must use FieldEncoding.default"); - const Tag = @typeInfo(T).Enum.tag_type; + const Tag = @typeInfo(T).@"enum".tag_type; if (@bitSizeOf(Tag) > 32) @compileError("Enum types must have a tag type of no more than 32 bits"); - const Tag32 = if (@typeInfo(Tag).Int.signedness == .signed) i32 else u32; - const ival: Tag32 = @enumToInt(val); + const Tag32 = if (@typeInfo(Tag).int.signedness == .signed) i32 else u32; + const ival: Tag32 = @intFromEnum(val); return encodeSingleScalar( w, ival, .{ .field_num = desc.field_num, .encoding = .varint }, encode_default, - if (override_default) |x| @enumToInt(x) else null, + if (override_default) |x| @intFromEnum(x) else null, include_tag, ); } @@ -103,7 +86,7 @@ fn encodeSingleScalar(w: anytype, val: anytype, comptime desc: FieldDescriptor, if (desc.encoding != .default) @compileError("Boolean types must use FieldEncoding.default"); if (!encode_default and val == (override_default orelse false)) return; if (include_tag) try encodeTag(w, desc.field_num, .varint); - try w.writeByte(@boolToInt(val)); + try w.writeByte(@intFromBool(val)); }, u32, u64, i32, i64 => { @@ -115,25 +98,25 @@ fn encodeSingleScalar(w: anytype, val: anytype, comptime desc: FieldDescriptor, u64, i64 => .i64, else => unreachable, }); - try w.writeIntLittle(T, val); + try w.writeInt(T, val, .little); }, .varint => { if (include_tag) try encodeTag(w, desc.field_num, .varint); const val64: u64 = switch (T) { u32, u64 => val, - i32 => @bitCast(u64, @as(i64, val)), // sign-extend - i64 => @bitCast(u64, val), + i32 => @bitCast(@as(i64, val)), // sign-extend + i64 => @bitCast(val), else => unreachable, }; try encodeVarInt(w, val64); }, .zigzag => { - if (@typeInfo(T).Int.signedness != .signed) @compileError("Only signed integral types can use FieldEncoding.zigzag"); + if (@typeInfo(T).int.signedness != .signed) @compileError("Only signed integral types can use FieldEncoding.zigzag"); if (include_tag) try encodeTag(w, desc.field_num, .varint); if (val >= 0) { - try encodeVarInt(w, @intCast(u64, val) * 2); + try encodeVarInt(w, @as(u64, @intCast(val)) * 2); } else { - try encodeVarInt(w, @intCast(u64, -val - 1) * 2 + 1); + try encodeVarInt(w, @as(u64, @intCast(-val - 1)) * 2 + 1); } }, else => @compileError("Integral types must use FieldEncoding.fixed, FieldEncoding.varint, or FieldEncoding.zigzag"), @@ -146,15 +129,14 @@ fn encodeSingleScalar(w: anytype, val: anytype, comptime desc: FieldDescriptor, if (!encode_default and val == (override_default orelse 0)) return; if (T == f32) { if (include_tag) try encodeTag(w, desc.field_num, .i32); - try w.writeIntLittle(u32, @bitCast(u32, val)); + try w.writeInt(u32, @as(u32, @bitCast(val)), .little); } else { if (include_tag) try encodeTag(w, desc.field_num, .i64); - try w.writeIntLittle(u64, @bitCast(u64, val)); + try w.writeInt(u64, @as(u64, @bitCast(val)), .little); } }, []u8, []const u8 => { - if (override_default != null) @compileError("Cannot override default for []u8"); if (!encode_default and val.len == 0) return; switch (desc.encoding) { .string, .bytes => { @@ -173,16 +155,16 @@ fn encodeSingleScalar(w: anytype, val: anytype, comptime desc: FieldDescriptor, /// Encode a single value of scalar or submessage type. 'map's are not included here since /// they're sugar for a 'repeated' submessage (and cannot themselves be repeated), meaning they are /// really multiple values. -fn encodeSingleValue(w: anytype, ally: std.mem.Allocator, val: anytype, comptime desc: FieldDescriptor, comptime encode_default: bool, comptime override_default: ?@TypeOf(val)) !void { +fn encodeSingleValue(w: anytype, allocator: std.mem.Allocator, val: anytype, comptime desc: FieldDescriptor, comptime encode_default: bool, comptime override_default: ?@TypeOf(val)) !void { const T = @TypeOf(val); - if (@typeInfo(T) == .Struct) { + if (@typeInfo(T) == .@"struct") { if (desc.encoding != .default) @compileError("Sub-messages must use FieldEncoding.default"); - var buf = std.ArrayList(u8).init(ally); + var buf = std.ArrayList(u8).init(allocator); defer buf.deinit(); - try encodeMessage(buf.writer(), ally, val); + try encode(allocator, buf.writer(), val); try encodeTag(w, desc.field_num, .len); try encodeVarInt(w, buf.items.len); @@ -196,7 +178,7 @@ fn encodeSingleValue(w: anytype, ally: std.mem.Allocator, val: anytype, comptime /// writer. 'field_name' is used only for error messages. fn encodeAnyField( w: anytype, - ally: std.mem.Allocator, + allocator: std.mem.Allocator, val: anytype, comptime desc_opt: ?FieldDescriptor, comptime field_name: []const u8, @@ -205,12 +187,12 @@ fn encodeAnyField( const T = @TypeOf(val); // Nicer error message if you forgot to make your union optional - if (@typeInfo(T) == .Union) { + if (@typeInfo(T) == .@"union") { @compileError("Only optional unions can be encoded"); } - if (@typeInfo(T) == .Optional and - @typeInfo(std.meta.Child(T)) == .Union) + if (@typeInfo(T) == .optional and + @typeInfo(std.meta.Child(T)) == .@"union") { // oneof const U = std.meta.Child(T); @@ -222,7 +204,7 @@ fn encodeAnyField( const sub_desc = comptime pb_desc.getField(@tagName(tag)) orelse @compileError("Mising descriptor for field '" ++ @typeName(U) ++ "." ++ @tagName(tag) ++ "'"); - try encodeSingleValue(w, ally, payload, sub_desc, true, null); + try encodeSingleValue(w, allocator, payload, sub_desc, true, null); }, } } @@ -233,14 +215,16 @@ fn encodeAnyField( const desc = desc_opt orelse @compileError("Missing descriptor for field '" ++ field_name ++ "'"); if (desc.encoding == .repeat) { - for (val) |x| { - try encodeSingleValue(w, ally, x, .{ + for (val.items) |x| { + try encodeSingleValue(w, allocator, x, .{ .field_num = desc.field_num, .encoding = desc.encoding.repeat.*, }, true, null); } } else if (desc.encoding == .repeat_pack) { - var buf = std.ArrayList(u8).init(ally); + if (val.items.len == 0) return; + + var buf = std.ArrayList(u8).init(allocator); defer buf.deinit(); for (val.items) |x| { @@ -256,7 +240,7 @@ fn encodeAnyField( } else if (desc.encoding == .map) { var it = val.iterator(); while (it.next()) |pair| { - try encodeSingleValue(w, ally, struct { + try encodeSingleValue(w, allocator, struct { k: std.meta.FieldType(T.KV, .key), v: std.meta.FieldType(T.KV, .value), const pb_desc = .{ @@ -268,33 +252,33 @@ fn encodeAnyField( .encoding = .default, }, true, null); } - } else if (@typeInfo(T) == .Optional) { + } else if (@typeInfo(T) == .optional) { if (val) |x| { - try encodeSingleValue(w, ally, x, desc, true, null); + try encodeSingleValue(w, allocator, x, desc, true, null); } } else { - try encodeSingleValue(w, ally, val, desc, false, field_default); + try encodeSingleValue(w, allocator, val, desc, false, field_default); } } /// Encode an entire Protobuf message 'msg' into the given writer. Only temporary allocations are /// performed, all of which are cleaned up before this function returns. -pub fn encodeMessage(w: anytype, ally: std.mem.Allocator, msg: anytype) !void { +pub fn encode(allocator: std.mem.Allocator, writer: anytype, msg: anytype) !void { const Msg = @TypeOf(msg); const pb_desc = comptime getPbDesc(Msg) orelse @compileError("Message type '" ++ @typeName(Msg) ++ "' must have a pb_desc decl"); validateDescriptors(Msg); - inline for (@typeInfo(Msg).Struct.fields) |field| { + inline for (@typeInfo(Msg).@"struct".fields) |field| { const desc: ?FieldDescriptor = comptime pb_desc.getField(field.name); const default: ?field.type = if (field.default_value) |ptr| - @ptrCast(*const field.type, ptr).* + @as(*const field.type, @alignCast(@ptrCast(ptr))).* else null; - try encodeAnyField(w, ally, @field(msg, field.name), desc, @typeName(Msg) ++ "." ++ field.name, default); + try encodeAnyField(writer, allocator, @field(msg, field.name), desc, @typeName(Msg) ++ "." ++ field.name, default); } } @@ -304,7 +288,7 @@ fn validateDescriptors(comptime Msg: type) void { comptime { var seen_field_nums: []const u29 = &.{}; validateDescriptorsInner(Msg, &seen_field_nums); - for (seen_field_nums) |x, i| { + for (seen_field_nums, 0..) |x, i| { for (seen_field_nums[i + 1 ..]) |y| { if (x == y) { @compileError(std.fmt.comptimePrint("Duplicate field number {} in type '{s}'", .{ x, @typeName(Msg) })); @@ -325,10 +309,10 @@ fn validateDescriptorsInner(comptime Msg: type, comptime seen_field_nums: *[]con } for (std.meta.fields(Msg)) |field| { - if (@typeInfo(field.type) == .Struct and comptime getPbDesc(field.type) != null) { + if (@typeInfo(field.type) == .@"struct" and comptime getPbDesc(field.type) != null) { validateDescriptors(field.type); - } else if (@typeInfo(field.type) == .Optional and - @typeInfo(std.meta.Child(field.type)) == .Union and + } else if (@typeInfo(field.type) == .optional and + @typeInfo(std.meta.Child(field.type)) == .@"union" and comptime getPbDesc(std.meta.Child(field.type)) != null) { validateDescriptorsInner(std.meta.Child(field.type), seen_field_nums); @@ -336,46 +320,34 @@ fn validateDescriptorsInner(comptime Msg: type, comptime seen_field_nums: *[]con } } -/// A small wrapper around a decoded message. You must call 'deinit' once you're done with the -/// message to free all its allocated memory. -pub fn Decoded(comptime Msg: type) type { - return struct { - msg: Msg, - arena: std.heap.ArenaAllocator, - - const Self = @This(); - - pub fn deinit(self: Self) void { - self.arena.deinit(); - } - }; -} - -fn initDefault(comptime Msg: type, arena: std.mem.Allocator) Msg { +fn initDefault(comptime Msg: type) Msg { var result: Msg = undefined; inline for (comptime std.meta.fields(Msg)) |field| { - if (comptime std.meta.trait.isSlice(field.type)) { - @field(result, field.name) = &.{}; - continue; + switch (@typeInfo(field.type)) { + .pointer => |info| if (info.size == .Slice) { + @field(result, field.name) = &.{}; + continue; + }, + else => {}, } const default: ?field.type = if (field.default_value) |ptr| - @ptrCast(*const field.type, ptr).* + @as(*const field.type, @alignCast(@ptrCast(ptr))).* else null; @field(result, field.name) = switch (@typeInfo(field.type)) { - .Optional => default orelse null, - .Int, .Float => default orelse 0, - .Enum => |e| default orelse if (e.is_exhaustive) + .optional => default orelse null, + .int, .float => default orelse 0, + .@"enum" => |e| default orelse if (e.is_exhaustive) comptime std.meta.intToEnum(field.type, 0) catch @compileError("Enum '" ++ @typeName(field.type) ++ "' has no 0 default") else - @intToEnum(field.type, 0), - .Bool => default orelse false, - .Struct => if (comptime getPbDesc(field.type) != null) - initDefault(field.type, arena) + @enumFromInt(0), + .bool => default orelse false, + .@"struct" => if (comptime getPbDesc(field.type) != null) + initDefault(field.type) else field.type{}, else => @compileError("Type '" ++ @typeName(field.type) ++ "' cannot be deserialized"), @@ -390,7 +362,7 @@ fn decodeVarInt(r: anytype) !u64 { var x: u64 = 0; while (true) { const b = try r.readByte(); - x |= @as(u64, @truncate(u7, b)) << shift; + x |= @as(u64, b & 0x7f) << shift; if (b >> 7 == 0) break; shift += 7; } @@ -400,7 +372,7 @@ fn decodeVarInt(r: anytype) !u64 { fn skipField(r: anytype, wire_type: WireType, field_num: u29) !void { switch (wire_type) { .varint => _ = try decodeVarInt(r), - .i64 => _ = try r.readIntLittle(u64), + .i64 => _ = try r.readInt(u64, .little), .len => { const len = try decodeVarInt(r); try r.skipBytes(len, .{}); @@ -408,7 +380,7 @@ fn skipField(r: anytype, wire_type: WireType, field_num: u29) !void { .sgroup => { while (true) { const tag = try decodeVarInt(r); - const sub_wire = std.meta.intToEnum(WireType, @truncate(u3, tag)) catch return error.MalformedInput; + const sub_wire = std.meta.intToEnum(WireType, tag & 7) catch return error.MalformedInput; const sub_num = std.math.cast(u29, tag >> 3) orelse return error.MalformedInput; if (sub_wire == .egroup and sub_num == field_num) { break; @@ -417,22 +389,18 @@ fn skipField(r: anytype, wire_type: WireType, field_num: u29) !void { } }, .egroup => return error.MalformedInput, - .i32 => _ = try r.readIntLittle(u32), + .i32 => _ = try r.readInt(u32, .little), } } -fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: anytype, arena: std.mem.Allocator, wire_type: WireType) !T { - if (@typeInfo(T) == .Enum) { +fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: anytype, allocator: std.mem.Allocator, wire_type: WireType) !T { + if (@typeInfo(T) == .@"enum") { if (encoding != .default) @compileError("Enum types must use FieldEncoding.default"); - const Tag = @typeInfo(T).Enum.tag_type; + const Tag = @typeInfo(T).@"enum".tag_type; if (@bitSizeOf(Tag) > 32) @compileError("Enum types must have a tag type of no more than 32 bits"); - const Tag32 = if (@typeInfo(Tag).Int.signedness == .signed) i32 else u32; - const ival = try decodeSingleScalar(Tag32, .varint, r, arena, wire_type); - if (@typeInfo(T).Enum.is_exhaustive) { - return std.meta.intToEnum(T, ival) catch return error.UnknownEnumTag; - } else { - return @intToEnum(T, ival); - } + const Tag32 = if (@typeInfo(Tag).int.signedness == .signed) i32 else u32; + const ival = try decodeSingleScalar(Tag32, .varint, r, allocator, wire_type); + return try std.meta.intToEnum(T, ival); } switch (T) { @@ -440,7 +408,7 @@ fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: any if (encoding != .default) @compileError("Boolean types must use FieldEncoding.default"); if (wire_type != .varint) return error.MalformedInput; const x = try decodeVarInt(r); - return @truncate(u32, x) != 0; + return @as(u32, @truncate(x)) != 0; }, u32, u64, i32, i64 => { @@ -451,7 +419,7 @@ fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: any u64, i64 => if (wire_type != .i64) return error.MalformedInput, else => unreachable, } - return r.readIntLittle(T); + return r.readInt(T, .little); }, .varint => { if (wire_type != .varint) return error.MalformedInput; @@ -460,17 +428,17 @@ fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: any u64, i64 => u64, else => unreachable, }; - return @bitCast(T, @truncate(Unsigned, try decodeVarInt(r))); + return @bitCast(@as(Unsigned, @truncate(try decodeVarInt(r)))); }, .zigzag => { - if (@typeInfo(T).Int.signedness != .signed) @compileError("Only signed integral types can use FieldEncoding.zigzag"); + if (@typeInfo(T).int.signedness != .signed) @compileError("Only signed integral types can use FieldEncoding.zigzag"); if (wire_type != .varint) return error.MalformedInput; const raw = try decodeVarInt(r); const val = if (raw % 2 == 1) - -@intCast(i64, raw / 2) - 1 + -@as(i64, @intCast(raw / 2)) - 1 else - @intCast(i64, raw / 2); - return @truncate(T, val); + @as(i64, @intCast(raw / 2)); + return @truncate(val); }, else => @compileError("Integral types must use FieldEncoding.fixed, FieldEncoding.varint, or FieldEncoding.zigzag"), } @@ -480,10 +448,10 @@ fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: any if (encoding != .default) @compileError("Floating types must use FieldEncoding.default"); if (T == f32) { if (wire_type != .i32) return error.MalformedInput; - return @bitCast(f32, try r.readIntLittle(u32)); + return @bitCast(try r.readInt(u32, .little)); } else { if (wire_type != .i64) return error.MalformedInput; - return @bitCast(f64, try r.readIntLittle(u64)); + return @bitCast(try r.readInt(u64, .little)); } }, @@ -491,7 +459,7 @@ fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: any if (encoding != .string and encoding != .bytes) @compileError("[]u8 must use FieldEncoding.string or FieldEncoding.bytes"); if (wire_type != .len) return error.MalformedInput; const len = try decodeVarInt(r); - const buf = try arena.alloc(u8, len); + const buf = try allocator.alloc(u8, len); try r.readNoEof(buf); return buf; }, @@ -501,30 +469,30 @@ fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: any } /// Decodes a value of scalar or submessage type, returning the result. -fn decodeSingleValue(comptime T: type, comptime encoding: FieldEncoding, r: anytype, arena: std.mem.Allocator, wire_type: WireType) !T { - if (@typeInfo(T) == .Struct) { +fn decodeSingleValue(comptime T: type, comptime encoding: FieldEncoding, r: anytype, allocator: std.mem.Allocator, wire_type: WireType) !T { + if (@typeInfo(T) == .@"struct") { if (encoding != .default) @compileError("Sub-messages must use FieldEncoding.default"); if (wire_type != .len) return error.MalformedInput; const len = try decodeVarInt(r); var lr = std.io.limitedReader(r, len); - return decodeMessageInner(T, lr.reader(), arena); + return decode(T, allocator, lr.reader()); } else { - return decodeSingleScalar(T, encoding, r, arena, wire_type); + return decodeSingleScalar(T, encoding, r, allocator, wire_type); } } /// Attempts to decode a field of any type, modifying the result location as necessary (either /// overwriting the value or appending data). Returns true if this message corresponded to the given /// field (and was decoded). -fn maybeDecodeAnyField(comptime T: type, comptime desc_opt: ?FieldDescriptor, comptime field_name: []const u8, r: anytype, arena: std.mem.Allocator, wire_type: WireType, field_num: u29, result: *T) !bool { +fn maybeDecodeAnyField(comptime T: type, comptime desc_opt: ?FieldDescriptor, comptime field_name: []const u8, r: anytype, allocator: std.mem.Allocator, wire_type: WireType, field_num: u29, result: *T) !bool { // Nicer error message if you forgot to make your union optional - if (@typeInfo(T) == .Union) { + if (@typeInfo(T) == .@"union") { @compileError("Only optional unions can be decoded"); } - if (@typeInfo(T) == .Optional and @typeInfo(std.meta.Child(T)) == .Union) { + if (@typeInfo(T) == .optional and @typeInfo(std.meta.Child(T)) == .@"union") { if (desc_opt != null) @compileError("Union must not have a field descriptor"); - if (try maybeDecodeOneOf(std.meta.Child(T), r, arena, wire_type, field_num)) |val| { + if (try maybeDecodeOneOf(std.meta.Child(T), r, allocator, wire_type, field_num)) |val| { result.* = val; return true; } else { @@ -539,7 +507,7 @@ fn maybeDecodeAnyField(comptime T: type, comptime desc_opt: ?FieldDescriptor, co if (desc.encoding == .repeat or desc.encoding == .repeat_pack) { const Elem = std.meta.Child(T.Slice); const scalar_elem = switch (@typeInfo(Elem)) { - .Int, .Bool, .Float => true, + .int, .bool, .float => true, else => Elem == []u8 or Elem == []const u8, }; if (desc.encoding == .repeat_pack and !scalar_elem) { @@ -573,8 +541,8 @@ fn maybeDecodeAnyField(comptime T: type, comptime desc_opt: ?FieldDescriptor, co else => undefined, }; - while (decodeSingleScalar(Elem, child_enc, lr.reader(), arena, expect_wire)) |elem| { - try result.*.append(arena, elem); + while (decodeSingleScalar(Elem, child_enc, lr.reader(), allocator, expect_wire)) |elem| { + try result.*.append(allocator, elem); } else |err| switch (err) { error.EndOfStream => {}, else => |e| return e, @@ -584,8 +552,8 @@ fn maybeDecodeAnyField(comptime T: type, comptime desc_opt: ?FieldDescriptor, co } } - const elem = try decodeSingleScalar(Elem, child_enc, r, arena, wire_type); - try result.*.append(arena, elem); + const elem = try decodeSingleValue(Elem, child_enc, r, allocator, wire_type); + try result.*.append(allocator, elem); } else if (desc.encoding == .map) { const val = try decodeSingleValue(struct { k: std.meta.FieldType(T.KV, .key), @@ -594,18 +562,18 @@ fn maybeDecodeAnyField(comptime T: type, comptime desc_opt: ?FieldDescriptor, co .k = .{ 1, desc.encoding.map[0] }, .v = .{ 2, desc.encoding.map[1] }, }; - }, .default, r, arena, wire_type); - try result.put(arena, val.k, val.v); - } else if (@typeInfo(T) == .Optional) { - result.* = try decodeSingleValue(std.meta.Child(T), desc.encoding, r, arena, wire_type); + }, .default, r, allocator, wire_type); + try result.put(allocator, val.k, val.v); + } else if (@typeInfo(T) == .optional) { + result.* = try decodeSingleValue(std.meta.Child(T), desc.encoding, r, allocator, wire_type); } else { - result.* = try decodeSingleValue(T, desc.encoding, r, arena, wire_type); + result.* = try decodeSingleValue(T, desc.encoding, r, allocator, wire_type); } return true; } -fn maybeDecodeOneOf(comptime U: type, r: anytype, arena: std.mem.Allocator, wire_type: WireType, field_num: u29) !?U { +fn maybeDecodeOneOf(comptime U: type, r: anytype, allocator: std.mem.Allocator, wire_type: WireType, field_num: u29) !?U { const pb_desc = comptime getPbDesc(U) orelse @compileError("Union '" ++ @typeName(U) ++ "' must have a pb_desc decl"); @@ -614,7 +582,7 @@ fn maybeDecodeOneOf(comptime U: type, r: anytype, arena: std.mem.Allocator, wire @compileError("Missing descriptor for field '" ++ @typeName(U) ++ "." ++ field.name ++ "'"); if (desc.field_num == field_num) { - const payload = try decodeSingleValue(field.type, desc.encoding, r, arena, wire_type); + const payload = try decodeSingleValue(field.type, desc.encoding, r, allocator, wire_type); return @unionInit(U, field.name, payload); } } @@ -622,24 +590,24 @@ fn maybeDecodeOneOf(comptime U: type, r: anytype, arena: std.mem.Allocator, wire return null; } -fn decodeMessageInner(comptime Msg: type, r: anytype, arena: std.mem.Allocator) !Msg { +pub fn decode(comptime Msg: type, allocator: std.mem.Allocator, reader: anytype) !Msg { const pb_desc = comptime getPbDesc(Msg) orelse @compileError("Message type '" ++ @typeName(Msg) ++ "' must have a pb_desc decl"); validateDescriptors(Msg); - var result = initDefault(Msg, arena); + var result = initDefault(Msg); - while (decodeVarInt(r)) |tag| { - const wire_type = std.meta.intToEnum(WireType, @truncate(u3, tag)) catch return error.MalformedInput; + while (decodeVarInt(reader)) |tag| { + const wire_type = std.meta.intToEnum(WireType, tag & 7) catch return error.MalformedInput; const field_num = std.math.cast(u29, tag >> 3) orelse return error.MalformedInput; inline for (std.meta.fields(Msg)) |field| { const desc_opt: ?FieldDescriptor = comptime pb_desc.getField(field.name); - if (try maybeDecodeAnyField(field.type, desc_opt, @typeName(Msg) ++ "." ++ field.name, r, arena, wire_type, field_num, &@field(result, field.name))) { + if (try maybeDecodeAnyField(field.type, desc_opt, @typeName(Msg) ++ "." ++ field.name, reader, allocator, wire_type, field_num, &@field(result, field.name))) { break; } } else { - try skipField(r, wire_type, field_num); + try skipField(reader, wire_type, field_num); } } else |err| switch (err) { error.EndOfStream => {}, @@ -649,21 +617,12 @@ fn decodeMessageInner(comptime Msg: type, r: anytype, arena: std.mem.Allocator) return result; } -pub fn decodeMessage(comptime Msg: type, r: anytype, ally: std.mem.Allocator) !Decoded(Msg) { - var arena = std.heap.ArenaAllocator.init(ally); - errdefer arena.deinit(); - - return .{ - .msg = try decodeMessageInner(Msg, r, arena.allocator()), - .arena = arena, - }; -} - const PbDesc = struct { const Entry = struct { []const u8, FieldDescriptor }; fields: []const Entry, fn getField(self: PbDesc, name: []const u8) ?FieldDescriptor { + @setEvalBranchQuota(100 * self.fields.len); for (self.fields) |f| { if (std.mem.eql(u8, f[0], name)) return f[1]; } @@ -673,7 +632,6 @@ const PbDesc = struct { // Directly making a pb_desc with fields of type FieldDescriptor is quite inconvenient, so instead // we'll take big literals in the same shape and parse them into the real descriptors. - fn getPbDesc(comptime T: type) ?PbDesc { comptime { if (!@hasDecl(T, "pb_desc")) return null; @@ -691,29 +649,30 @@ fn getPbDesc(comptime T: type) ?PbDesc { } fn createFieldDesc(comptime desc: anytype, comptime field_name: []const u8) FieldDescriptor { - if (!std.meta.trait.isTuple(@TypeOf(desc))) { - @compileError("Bad descriptor format for field '" ++ field_name ++ "'"); + switch (@typeInfo(@TypeOf(desc))) { + .@"struct" => |info| if (info.is_tuple) return .{ + .field_num = desc[0], + .encoding = createFieldEncoding(desc[1], field_name), + }, + else => {}, } - return .{ - .field_num = desc[0], - .encoding = createFieldEncoding(desc[1], field_name), - }; + @compileError("Bad descriptor format for field '" ++ field_name ++ "'"); } fn createFieldEncoding(comptime enc: anytype, comptime field_name: []const u8) FieldEncoding { if (@TypeOf(enc) == FieldEncoding) { return enc; - } else if (@TypeOf(enc) == @Type(.EnumLiteral)) { + } else if (@TypeOf(enc) == @Type(.enum_literal)) { // try to match with an encoding type for (std.meta.fields(FieldEncoding)) |field| { if (std.mem.eql(u8, @tagName(enc), field.name)) { return @field(FieldEncoding, field.name); } } - } else if (@typeInfo(@TypeOf(enc)) == .Struct) { + } else if (@typeInfo(@TypeOf(enc)) == .@"struct") { // nested encoding types - const fields = @typeInfo(@TypeOf(enc)).Struct.fields; + const fields = @typeInfo(@TypeOf(enc)).@"struct".fields; if (fields.len == 1) { const tag = fields[0].name; const val = @field(enc, tag); @@ -724,10 +683,13 @@ fn createFieldEncoding(comptime enc: anytype, comptime field_name: []const u8) F const child = createFieldEncoding(val, field_name); return .{ .repeat_pack = &child }; } else if (std.mem.eql(u8, tag, "map")) { - if (std.meta.trait.isTuple(@TypeOf(val)) and val.len == 2) { - const child0 = createFieldEncoding(val[0], field_name); - const child1 = createFieldEncoding(val[1], field_name); - return .{ .map = &[2]FieldEncoding{ child0, child1 } }; + switch (@typeInfo(@TypeOf(val))) { + .@"struct" => |info| if (info.is_tuple and val.len == 2) { + const child0 = createFieldEncoding(val[0], field_name); + const child1 = createFieldEncoding(val[1], field_name); + return .{ .map = &[2]FieldEncoding{ child0, child1 } }; + }, + else => {}, } } } diff --git a/test.zig b/src/test.zig similarity index 56% rename from test.zig rename to src/test.zig index bc11205..a0abf6b 100644 --- a/test.zig +++ b/src/test.zig @@ -1,13 +1,31 @@ const std = @import("std"); -const pb = @import("protobuf.zig"); +const pb = @import("root.zig"); + +fn unhex(comptime hex: []const u8) *const [hex.len / 2]u8 { + if (hex.len & 1 != 0) @compileError("invalid hex"); + + return comptime blk: { + var out: [hex.len / 2]u8 = undefined; + var i: usize = 0; + + while (i < hex.len) : (i += 2) { + const hi = std.fmt.charToDigit(hex[i], 16) catch @compileError("invalid hex"); + const lo = std.fmt.charToDigit(hex[i + 1], 16) catch @compileError("invalid hex"); + out[i / 2] = (hi << 4) | lo; + } + + const final = out; + break :blk &final; + }; +} fn expectEqualMessages(comptime T: type, expected: T, actual: T) !void { - if (@typeInfo(T) == .Optional) { + if (@typeInfo(T) == .optional) { try std.testing.expectEqual(expected == null, actual == null); return expectEqualMessages(std.meta.Child(T), expected.?, actual.?); } - if (@typeInfo(T) == .Union) { + if (@typeInfo(T) == .@"union") { try std.testing.expectEqual(std.meta.activeTag(expected), std.meta.activeTag(actual)); switch (expected) { inline else => |val, tag| { @@ -16,7 +34,7 @@ fn expectEqualMessages(comptime T: type, expected: T, actual: T) !void { } } - if (@typeInfo(T) == .Struct) { + if (@typeInfo(T) == .@"struct") { if (@hasDecl(T, "pb_desc")) { inline for (comptime std.meta.fields(T)) |field| { try expectEqualMessages(field.type, @field(expected, field.name), @field(actual, field.name)); @@ -41,23 +59,23 @@ fn expectEqualMessages(comptime T: type, expected: T, actual: T) !void { } switch (@typeInfo(T)) { - .Int, .Float, .Enum => try std.testing.expectEqual(expected, actual), + .int, .float, .@"enum" => try std.testing.expectEqual(expected, actual), else => @compileError("Cannot test equality of type '" ++ @typeName(T) ++ "'"), } } fn initMessage(comptime T: type, comptime val: anytype, arena: std.mem.Allocator) !T { - if (@typeInfo(T) == .Optional) { - if (@typeInfo(@TypeOf(val)) == .Optional) { + if (@typeInfo(T) == .optional) { + if (@typeInfo(@TypeOf(val)) == .optional) { return if (val) |x| try initMessage(std.meta.Child(T), x, arena) else null; } else { return try initMessage(std.meta.Child(T), val, arena); } } - if (@typeInfo(T) == .Union) { - if (@typeInfo(@TypeOf(val)) != .Struct) @compileError("Expected struct literal to initialize union"); - const fields = @typeInfo(@TypeOf(val)).Struct.fields; + if (@typeInfo(T) == .@"union") { + if (@typeInfo(@TypeOf(val)) != .@"struct") @compileError("Expected struct literal to initialize union"); + const fields = @typeInfo(@TypeOf(val)).@"struct".fields; if (fields.len != 1) @compileError("Expected single-element struct to initialize union"); return @unionInit(T, fields[0].name, try initMessage( std.meta.TagPayload(T, @field(std.meta.Tag(T), fields[0].name)), @@ -66,10 +84,14 @@ fn initMessage(comptime T: type, comptime val: anytype, arena: std.mem.Allocator )); } - if (@typeInfo(T) == .Struct) { + if (@typeInfo(T) == .@"struct") { if (@hasDecl(T, "pb_desc")) { var result: T = undefined; inline for (comptime std.meta.fields(T)) |field| { + if (field.default_value) |ptr| { + @field(result, field.name) = @as(*const field.type, @alignCast(@ptrCast(ptr))).*; + continue; + } @field(result, field.name) = try initMessage(field.type, @field(val, field.name), arena); } return result; @@ -105,13 +127,30 @@ fn testEncodeDecode(comptime Msg: type, comptime val: anytype) !void { var buf = std.ArrayList(u8).init(std.testing.allocator); defer buf.deinit(); - try pb.encodeMessage(buf.writer(), std.testing.allocator, msg); + try pb.encode(std.testing.allocator, buf.writer(), msg); + + var fbs = std.io.fixedBufferStream(buf.items); + const decoded = try pb.decode(Msg, arena.allocator(), fbs.reader()); + + try expectEqualMessages(Msg, msg, decoded); +} + +fn testEncodeDecodeHex(comptime Msg: type, comptime val: anytype, comptime hex: []const u8) !void { + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + + const msg = try initMessage(Msg, val, arena.allocator()); + + var buf = std.ArrayList(u8).init(std.testing.allocator); + defer buf.deinit(); + + try pb.encode(std.testing.allocator, buf.writer(), msg); + try std.testing.expectEqualSlices(u8, unhex(hex), buf.items); var fbs = std.io.fixedBufferStream(buf.items); - const decoded = try pb.decodeMessage(Msg, fbs.reader(), std.testing.allocator); - defer decoded.deinit(); + const decoded = try pb.decode(Msg, arena.allocator(), fbs.reader()); - try expectEqualMessages(Msg, msg, decoded.msg); + try expectEqualMessages(Msg, msg, decoded); } test { @@ -120,7 +159,7 @@ test { single2: u32, opt: ?u64, rep: std.ArrayListUnmanaged(i32), - map: pb.Map([]const u8, f32), + map: std.StringHashMapUnmanaged(f32), options: ?union(enum) { foo: u32, bar: []const u8, @@ -172,3 +211,75 @@ test { .en = .val2, }); } + +test "repeated message" { + const SubMsg = struct { + num: u32, + + pub const pb_desc = .{ + .num = .{ 1, .varint }, + }; + }; + + const Msg = struct { + list: std.ArrayListUnmanaged(SubMsg), + + pub const pb_desc = .{ + .list = .{ 1, .{ .repeat = .default } }, + }; + }; + + try testEncodeDecode(Msg, .{ + .list = .{ .{ .num = 69 }, .{ .num = 0 }, .{ .num = 42 } }, + }); +} + +test "empty sub message" { + const SubMsg = struct { + num: u32, + + pub const pb_desc = .{ + .num = .{ 1, .varint }, + }; + }; + + const Msg = struct { + sub: SubMsg, + + pub const pb_desc = .{ + .sub = .{ 1, .default }, + }; + }; + + try testEncodeDecodeHex(Msg, .{ .sub = .{ .num = 0 } }, "0a00"); +} + +test "packed u32" { + const Msg = struct { + list: std.ArrayListUnmanaged(u32), + + pub const pb_desc = .{ + .list = .{ 1, .{ .repeat_pack = .varint } }, + }; + }; + + try testEncodeDecodeHex(Msg, .{ .list = .{ 1, 2, 3, 4 } }, "0a0401020304"); +} + +test "empty default" { + const Msg = struct { + int: u32 = 0, + str: []const u8 = &[0]u8{}, + list: std.ArrayListUnmanaged(u32) = .empty, + map: std.AutoHashMapUnmanaged(u32, u32) = .empty, + + pub const pb_desc = .{ + .int = .{ 1, .varint }, + .str = .{ 2, .string }, + .list = .{ 3, .{ .repeat_pack = .varint } }, + .map = .{ 4, .{ .map = .{ .varint, .varint } } }, + }; + }; + + try testEncodeDecodeHex(Msg, .{}, ""); +}