From a8abf89dfaf7056740b5ee80354d0222a69da775 Mon Sep 17 00:00:00 2001 From: Chuck Daniels Date: Thu, 22 Aug 2024 19:28:08 -0400 Subject: [PATCH] Do not encode UNSET values in struct arrays. Fixes #723 --- msgspec/_core.c | 30 +++++++++++++++++++++++++++--- tests/test_common.py | 11 +++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/msgspec/_core.c b/msgspec/_core.c index 0f287814..ff740bfd 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -12702,17 +12702,36 @@ mpack_encode_struct_array( int tagged = tag_value != NULL; PyObject *fields = struct_type->struct_encode_fields; Py_ssize_t nfields = PyTuple_GET_SIZE(fields); - Py_ssize_t len = nfields + tagged; + Py_ssize_t len = nfields + tagged, actual_len = len; if (Py_EnterRecursiveCall(" while serializing an object")) return -1; + Py_ssize_t header_offset = self->output_len; if (mpack_encode_array_header(self, len, "structs") < 0) goto cleanup; if (tagged) { if (mpack_encode(self, tag_value) < 0) goto cleanup; } for (Py_ssize_t i = 0; i < nfields; i++) { PyObject *val = Struct_get_index(obj, i); - if (val == NULL || mpack_encode(self, val) < 0) goto cleanup; + if (val == UNSET) { + actual_len--; + } else if (val == NULL || mpack_encode(self, val) < 0) { + goto cleanup; + } + } + if (MS_UNLIKELY(actual_len != len)) { + /* Fixup the header length after we know how many fields were + * actually written */ + char *header_loc = self->output_buffer_raw + header_offset; + if (len < 16) { + *header_loc = MP_FIXARRAY | actual_len; + } else if (len < (1 << 16)) { + *header_loc++ = MP_ARRAY16; + _msgspec_store16(header_loc, (uint16_t)actual_len); + } else { + *header_loc++ = MP_ARRAY32; + _msgspec_store32(header_loc, (uint32_t)actual_len); + } } status = 0; cleanup: @@ -14100,11 +14119,16 @@ json_encode_struct_array( for (Py_ssize_t i = 0; i < nfields; i++) { PyObject *val = Struct_get_index(obj, i); if (val == NULL) goto cleanup; + if (val == UNSET) continue; if (json_encode(self, val) < 0) goto cleanup; if (ms_write(self, ",", 1) < 0) goto cleanup; } /* Overwrite trailing comma with ] */ - *(self->output_buffer_raw + self->output_len - 1) = ']'; + if (*(self->output_buffer_raw + self->output_len - 1) == ',') { + *(self->output_buffer_raw + self->output_len - 1) = ']'; + } else { + if (ms_write(self, "]", 1) < 0) goto cleanup; + } status = 0; cleanup: Py_LeaveRecursiveCall(); diff --git a/tests/test_common.py b/tests/test_common.py index c5bd92fa..68694883 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -4061,6 +4061,17 @@ class Ex(Struct, omit_defaults=True): sol = proto.encode(y) assert res == sol + def test_unset_encode_struct_array_like(self, proto): + class Ex(Struct, array_like=True): + x: Union[int, UnsetType] = UNSET + y: Union[int, UnsetType] = UNSET + z: int = 0 + + for x, y in [(Ex(), [0]), (Ex(x=1), [1, 0]), (Ex(y=2), [2, 0])]: + res = proto.encode(x) + sol = proto.encode(y) + assert res == sol + class TestOrder: def test_encoder_order_attribute(self, proto):