From e980fd0867de5022f8761f96f29e848d8656d5ab Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Fri, 5 Jun 2026 14:10:35 -0700 Subject: [PATCH] GH-45946: [C++][Parquet] Variant decoding --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 3 +- cpp/src/arrow/extension/meson.build | 5 +- cpp/src/arrow/extension/variant_internal.cc | 1020 ++++++++ cpp/src/arrow/extension/variant_internal.h | 347 +++ .../arrow/extension/variant_internal_test.cc | 2128 +++++++++++++++++ cpp/src/arrow/extension/variant_test_util.h | 137 ++ cpp/src/arrow/meson.build | 1 + 8 files changed, 3640 insertions(+), 2 deletions(-) create mode 100644 cpp/src/arrow/extension/variant_internal.cc create mode 100644 cpp/src/arrow/extension/variant_internal.h create mode 100644 cpp/src/arrow/extension/variant_internal_test.cc create mode 100644 cpp/src/arrow/extension/variant_test_util.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 45cd7e838121..530d3e5ff3b8 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -391,6 +391,7 @@ set(ARROW_SRCS extension/bool8.cc extension/json.cc extension/parquet_variant.cc + extension/variant_internal.cc extension/uuid.cc pretty_print.cc record_batch.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index ae52bc32a998..582825027c74 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc) +set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc + variant_internal_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/meson.build b/cpp/src/arrow/extension/meson.build index 84dafe4bbe32..6c6d3a7b67a8 100644 --- a/cpp/src/arrow/extension/meson.build +++ b/cpp/src/arrow/extension/meson.build @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc'] +canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', 'variant_internal_test.cc'] if needs_json canonical_extension_tests += [ @@ -40,5 +40,8 @@ install_headers( 'parquet_variant.h', 'uuid.h', 'variable_shape_tensor.h', + # variant_internal.h: public API for variant binary encoding/decoding. + # "internal" refers to the binary encoding internals, not visibility. + 'variant_internal.h', ], ) diff --git a/cpp/src/arrow/extension/variant_internal.cc b/cpp/src/arrow/extension/variant_internal.cc new file mode 100644 index 000000000000..2ee3fd09ba4a --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal.cc @@ -0,0 +1,1020 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_internal.h" + +#include + +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" + +namespace arrow::extension::variant_internal { + +namespace { + +// --------------------------------------------------------------------------- +// Helpers for reading little-endian integers of variable size (1-4 bytes) +// --------------------------------------------------------------------------- + +/// \brief Read an unsigned integer of 1-4 bytes in little-endian order. +/// +/// On big-endian platforms, FromLittleEndian byte-swaps the full 32-bit +/// word after memcpy; the mask then discards any bytes beyond num_bytes. +/// +/// \param[in] data Pointer to the bytes (must have at least num_bytes valid) +/// \param[in] num_bytes Number of bytes to read (1, 2, 3, or 4) +/// \return The decoded unsigned integer value +inline uint32_t ReadUnsignedLE(const uint8_t* data, int32_t num_bytes) { + uint32_t result = 0; + std::memcpy(&result, data, num_bytes); + result = bit_util::FromLittleEndian(result); + if (num_bytes < 4) { + result &= (static_cast(1) << (num_bytes * 8)) - 1; + } + return result; +} + +/// \brief Validate that an offset array is monotonically non-decreasing +/// and within the buffer bounds. +Status ValidateOffsets(const std::vector& offsets, int64_t data_length) { + for (size_t i = 1; i < offsets.size(); ++i) { + if (offsets[i] < offsets[i - 1]) { + return Status::Invalid( + "Variant metadata: string offsets are not monotonically " + "non-decreasing at index ", + i); + } + } + if (!offsets.empty() && offsets.back() > static_cast(data_length)) { + return Status::Invalid("Variant metadata: last string offset ", offsets.back(), + " exceeds data length ", data_length); + } + return Status::OK(); +} + +// --------------------------------------------------------------------------- +// Value decoding helpers +// --------------------------------------------------------------------------- + +/// \brief Decode a single variant value at the given offset and invoke +/// the visitor. Returns the number of bytes consumed. +/// +/// This is the core recursive function. +Status DecodeValueAt(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, VariantVisitor* visitor, int64_t* bytes_consumed, + int32_t depth); + +/// \brief Decode a primitive value at data[offset]. +Status DecodePrimitive(const uint8_t* data, int64_t length, int64_t offset, + uint8_t header, VariantVisitor* visitor, int64_t* bytes_consumed) { + auto primitive_type = GetPrimitiveType(header); + int64_t pos = offset + 1; // skip header byte + + auto check_remaining = [&](int64_t needed) -> Status { + if (pos + needed > length) { + return Status::Invalid("Variant value: truncated primitive at offset ", offset, + ", need ", needed, " bytes but only ", length - pos, + " remaining"); + } + return Status::OK(); + }; + + switch (primitive_type) { + case PrimitiveType::kNull: + ARROW_RETURN_NOT_OK(visitor->Null()); + *bytes_consumed = 1; + return Status::OK(); + + case PrimitiveType::kTrue: + ARROW_RETURN_NOT_OK(visitor->Bool(true)); + *bytes_consumed = 1; + return Status::OK(); + + case PrimitiveType::kFalse: + ARROW_RETURN_NOT_OK(visitor->Bool(false)); + *bytes_consumed = 1; + return Status::OK(); + + case PrimitiveType::kInt8: { + ARROW_RETURN_NOT_OK(check_remaining(1)); + auto value = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Int8(value)); + *bytes_consumed = 2; + return Status::OK(); + } + + case PrimitiveType::kInt16: { + ARROW_RETURN_NOT_OK(check_remaining(2)); + int16_t value; + std::memcpy(&value, data + pos, 2); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int16(value)); + *bytes_consumed = 3; + return Status::OK(); + } + + case PrimitiveType::kInt32: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + int32_t value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int32(value)); + *bytes_consumed = 5; + return Status::OK(); + } + + case PrimitiveType::kInt64: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int64(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kFloat: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + float value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Float(value)); + *bytes_consumed = 5; + return Status::OK(); + } + + case PrimitiveType::kDouble: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + double value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Double(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kDecimal4: { + // Spec: 1 byte scale in range [0, 38], followed by 4 bytes LE unscaled value. + // Note: scale is not validated during decode to remain lenient with + // forward-compatible data. The encoder validates scale <= 38. + ARROW_RETURN_NOT_OK(check_remaining(5)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal4(data + pos + 1, scale)); + *bytes_consumed = 6; + return Status::OK(); + } + + case PrimitiveType::kDecimal8: { + // Spec: 1 byte scale, followed by 8 bytes LE unscaled value + ARROW_RETURN_NOT_OK(check_remaining(9)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal8(data + pos + 1, scale)); + *bytes_consumed = 10; + return Status::OK(); + } + + case PrimitiveType::kDecimal16: { + // Spec: 1 byte scale, followed by 16 bytes LE unscaled value + ARROW_RETURN_NOT_OK(check_remaining(17)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal16(data + pos + 1, scale)); + *bytes_consumed = 18; + return Status::OK(); + } + + case PrimitiveType::kDate: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + int32_t value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Date(value)); + *bytes_consumed = 5; + return Status::OK(); + } + + case PrimitiveType::kTimestampMicros: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampMicros(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kTimestampMicrosNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampMicrosNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kBinary: { + // 4-byte length prefix + data + ARROW_RETURN_NOT_OK(check_remaining(4)); + uint32_t bin_length; + std::memcpy(&bin_length, data + pos, 4); + bin_length = bit_util::FromLittleEndian(bin_length); + ARROW_RETURN_NOT_OK(check_remaining(4 + static_cast(bin_length))); + auto view = + std::string_view(reinterpret_cast(data + pos + 4), bin_length); + ARROW_RETURN_NOT_OK(visitor->Binary(view)); + *bytes_consumed = 1 + 4 + static_cast(bin_length); + return Status::OK(); + } + + case PrimitiveType::kString: { + // 4-byte length prefix + data + ARROW_RETURN_NOT_OK(check_remaining(4)); + uint32_t str_length; + std::memcpy(&str_length, data + pos, 4); + str_length = bit_util::FromLittleEndian(str_length); + ARROW_RETURN_NOT_OK(check_remaining(4 + static_cast(str_length))); + auto view = + std::string_view(reinterpret_cast(data + pos + 4), str_length); + ARROW_RETURN_NOT_OK(visitor->String(view)); + *bytes_consumed = 1 + 4 + static_cast(str_length); + return Status::OK(); + } + + case PrimitiveType::kTimeNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimeNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kTimestampNanos: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampNanos(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kTimestampNanosNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampNanosNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kUUID: { + // UUID is 16 bytes in big-endian order + ARROW_RETURN_NOT_OK(check_remaining(16)); + ARROW_RETURN_NOT_OK(visitor->UUID(data + pos)); + *bytes_consumed = 17; + return Status::OK(); + } + + default: + return Status::Invalid("Variant value: unknown primitive type ", + static_cast(primitive_type)); + } +} + +/// \brief Decode a short string (basic_type == 1). The length is encoded +/// in bits 2-7 of the header byte (max 63 bytes). +Status DecodeShortString(const uint8_t* data, int64_t length, int64_t offset, + uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed) { + int32_t str_len = (header >> 2) & 0x3F; + int64_t pos = offset + 1; + if (pos + str_len > length) { + return Status::Invalid("Variant value: truncated short string at offset ", offset, + ", need ", str_len, " bytes but only ", length - pos, + " remaining"); + } + auto view = std::string_view(reinterpret_cast(data + pos), str_len); + ARROW_RETURN_NOT_OK(visitor->String(view)); + *bytes_consumed = 1 + str_len; + return Status::OK(); +} + +/// \brief Decode an object value (basic_type == 2). +/// +/// Object layout per spec: +/// header (1 byte): +/// bits 0-1: basic_type = 2 +/// bits 2-3: field_offset_size_minus_one +/// bits 4-5: field_id_size_minus_one +/// bit 6: is_large (0 → 1-byte num_elements, 1 → 4-byte) +/// num_elements: 1 or 4 bytes (unsigned LE) +/// field_ids: num_elements × field_id_size bytes +/// field_offsets: (num_elements + 1) × field_offset_size bytes +/// field values: concatenated variant values +Status DecodeObject(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed, int32_t depth) { + // Variant Encoding Spec: object value_header layout (bits 2-7 of full byte): + // bits 2-3 (type_info bits 0-1): field_offset_size_minus_one + // bits 4-5 (type_info bits 2-3): field_id_size_minus_one + // bit 6 (type_info bit 4): is_large (0 = 1-byte num_elements, 1 = 4-byte) + // bit 7 (type_info bit 5): unused + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + int32_t field_id_size = ((type_info >> 2) & 0x03) + 1; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + int64_t pos = offset + 1; // skip header + + // Read num_fields + if (pos + num_fields_size > length) { + return Status::Invalid("Variant value: truncated object num_fields at offset ", + offset); + } + auto num_fields = static_cast(ReadUnsignedLE(data + pos, num_fields_size)); + pos += num_fields_size; + + // Read field IDs + int64_t field_ids_size = static_cast(num_fields) * field_id_size; + if (pos + field_ids_size > length) { + return Status::Invalid("Variant value: truncated object field_ids at offset ", + offset); + } + // TODO: Consider using a stack-allocated small_vector (e.g. arrow::internal:: + // SmallVector) for field_ids and value_offsets to avoid heap allocation for + // the common case of small objects (< 16 fields). Acceptable for a + // correctness-first implementation; optimize if profiling shows pressure. + std::vector field_ids(num_fields); + // NOTE: Per spec, field IDs must be in lexicographic order of corresponding + // key names. We do not validate this ordering here for performance; see + // FindObjectField which relies on this invariant for binary search. + for (int32_t i = 0; i < num_fields; ++i) { + field_ids[i] = ReadUnsignedLE(data + pos, field_id_size); + pos += field_id_size; + } + + // Read value offsets (num_fields + 1 entries) + int64_t offsets_size = (static_cast(num_fields) + 1) * field_offset_size; + if (pos + offsets_size > length) { + return Status::Invalid("Variant value: truncated object offsets at offset ", offset); + } + std::vector value_offsets(num_fields + 1); + for (int32_t i = 0; i <= num_fields; ++i) { + value_offsets[i] = ReadUnsignedLE(data + pos, field_offset_size); + pos += field_offset_size; + } + + // Note: per spec, object field offsets are NOT required to be + // monotonically increasing because field values may be stored + // in a different order than field IDs. + + // The field data starts at pos + int64_t data_start = pos; + int64_t total_data_size = static_cast(value_offsets[num_fields]); + + if (data_start + total_data_size > length) { + return Status::Invalid("Variant value: object data exceeds buffer at offset ", + offset); + } + + // Validate each field offset is within the data region. + // Unlike arrays, object offsets need not be monotonic, but each must + // point within the valid data area. + for (int32_t i = 0; i < num_fields; ++i) { + if (value_offsets[i] > static_cast(total_data_size)) { + return Status::Invalid("Variant value: object field offset ", value_offsets[i], + " at index ", i, " exceeds data size ", total_data_size); + } + } + + ARROW_RETURN_NOT_OK(visitor->StartObject(num_fields)); + + for (int32_t i = 0; i < num_fields; ++i) { + // Resolve field name from metadata dictionary + auto field_id = field_ids[i]; + if (field_id >= metadata.strings.size()) { + return Status::Invalid("Variant value: field_id ", field_id, + " exceeds metadata dictionary size ", + metadata.strings.size()); + } + ARROW_RETURN_NOT_OK(visitor->FieldName(metadata.strings[field_id])); + + // Decode the field value. Pass data_start + total_data_size as the effective + // length to restrict field value decoding within this object's data region. + // NOTE: We do not validate that consumed bytes match the expected field size + // (value_offsets[i+1] - value_offsets[i]) because object offsets are not + // required to be monotonic, making per-field size inference unreliable. + // TODO: Consider optional strict validation for untrusted input. + int64_t field_offset = data_start + value_offsets[i]; + int64_t consumed = 0; + ARROW_RETURN_NOT_OK(DecodeValueAt(metadata, data, data_start + total_data_size, + field_offset, visitor, &consumed, depth)); + } + + ARROW_RETURN_NOT_OK(visitor->EndObject()); + + *bytes_consumed = (data_start - offset) + total_data_size; + return Status::OK(); +} + +/// \brief Decode an array value (basic_type == 3). +/// +/// Array layout per spec: +/// header (1 byte): +/// bits 0-1: basic_type = 3 +/// bits 2-3: field_offset_size_minus_one +/// bit 4: is_large (0 → 1-byte num_elements, 1 → 4-byte) +/// bits 5-7: unused +/// num_elements: 1 or 4 bytes (unsigned LE) +/// field_offsets: (num_elements + 1) × field_offset_size bytes +/// element values: concatenated variant values +Status DecodeArray(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed, int32_t depth) { + // Variant Encoding Spec: array value_header layout (bits 2-7 of full byte): + // bits 2-3 (type_info bits 0-1): field_offset_size_minus_one + // bit 4 (type_info bit 2): is_large (0 = 1-byte num_elements, 1 = 4-byte) + // bits 5-7 (type_info bits 3-5): unused + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + + int64_t pos = offset + 1; // skip header + + // Read num_elements + if (pos + num_elements_size > length) { + return Status::Invalid("Variant value: truncated array num_elements at offset ", + offset); + } + auto num_elements = static_cast(ReadUnsignedLE(data + pos, num_elements_size)); + pos += num_elements_size; + + // Read offsets (num_elements + 1 entries) + int64_t offsets_size = (static_cast(num_elements) + 1) * field_offset_size; + if (pos + offsets_size > length) { + return Status::Invalid("Variant value: truncated array offsets at offset ", offset); + } + // TODO: Consider stack-allocated small_vector for the common case of small arrays. + std::vector value_offsets(num_elements + 1); + for (int32_t i = 0; i <= num_elements; ++i) { + value_offsets[i] = ReadUnsignedLE(data + pos, field_offset_size); + pos += field_offset_size; + } + + // Validate value offsets are monotonically non-decreasing + for (int32_t i = 1; i <= num_elements; ++i) { + if (value_offsets[i] < value_offsets[i - 1]) { + return Status::Invalid( + "Variant value: array value offsets are not monotonically " + "non-decreasing at index ", + i); + } + } + + // The element data starts at pos + int64_t data_start = pos; + int64_t total_data_size = static_cast(value_offsets[num_elements]); + + if (data_start + total_data_size > length) { + return Status::Invalid("Variant value: array data exceeds buffer at offset ", offset); + } + + ARROW_RETURN_NOT_OK(visitor->StartArray(num_elements)); + + for (int32_t i = 0; i < num_elements; ++i) { + // Pass data_start + total_data_size as the effective length to restrict + // element value decoding within this array's data region. + // NOTE: consumed bytes are not validated against expected element size + // (value_offsets[i+1] - value_offsets[i]). Monotonicity of offsets is + // already validated above, but we do not check that each element exactly + // fills its allocated slot. TODO: Consider optional strict validation. + int64_t elem_offset = data_start + value_offsets[i]; + int64_t consumed = 0; + ARROW_RETURN_NOT_OK(DecodeValueAt(metadata, data, data_start + total_data_size, + elem_offset, visitor, &consumed, depth)); + } + + ARROW_RETURN_NOT_OK(visitor->EndArray()); + + *bytes_consumed = (data_start - offset) + total_data_size; + return Status::OK(); +} + +Status DecodeValueAt(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, VariantVisitor* visitor, int64_t* bytes_consumed, + int32_t depth) { + if (offset >= length) { + return Status::Invalid("Variant value: offset ", offset, + " is at or beyond buffer length ", length); + } + if (depth > kMaxNestingDepth) { + return Status::Invalid("Variant value: nesting depth exceeds maximum of ", + kMaxNestingDepth); + } + + uint8_t header = data[offset]; + auto basic_type = GetBasicType(header); + + switch (basic_type) { + case BasicType::kPrimitive: + return DecodePrimitive(data, length, offset, header, visitor, bytes_consumed); + case BasicType::kShortString: + return DecodeShortString(data, length, offset, header, visitor, bytes_consumed); + case BasicType::kObject: + return DecodeObject(metadata, data, length, offset, header, visitor, bytes_consumed, + depth + 1); + case BasicType::kArray: + return DecodeArray(metadata, data, length, offset, header, visitor, bytes_consumed, + depth + 1); + default: + return Status::Invalid("Variant value: unknown basic type ", + static_cast(basic_type)); + } +} + +} // namespace + +// --------------------------------------------------------------------------- +// Public API implementations +// --------------------------------------------------------------------------- + +int32_t PrimitiveValueSize(PrimitiveType primitive_type) { + switch (primitive_type) { + case PrimitiveType::kNull: + case PrimitiveType::kTrue: + case PrimitiveType::kFalse: + return 0; + case PrimitiveType::kInt8: + return 1; + case PrimitiveType::kInt16: + return 2; + case PrimitiveType::kInt32: + case PrimitiveType::kFloat: + case PrimitiveType::kDate: + return 4; + case PrimitiveType::kInt64: + case PrimitiveType::kDouble: + case PrimitiveType::kTimestampMicros: + case PrimitiveType::kTimestampMicrosNTZ: + case PrimitiveType::kTimeNTZ: + case PrimitiveType::kTimestampNanos: + case PrimitiveType::kTimestampNanosNTZ: + return 8; + case PrimitiveType::kDecimal4: + return 5; // 1 byte scale + 4 bytes value + case PrimitiveType::kDecimal8: + return 9; // 1 byte scale + 8 bytes value + case PrimitiveType::kDecimal16: + return 17; // 1 byte scale + 16 bytes value + case PrimitiveType::kUUID: + return 16; + case PrimitiveType::kBinary: + case PrimitiveType::kString: + return -1; // variable length + default: + return -1; + } +} + +Result DecodeMetadata(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant metadata: buffer is null or empty"); + } + + // Variant Encoding Spec §2: Metadata encoding + // Header byte: bits 0-3 = version, bit 4 = sorted, bit 5 = reserved, + // bits 6-7 = offset_size-1 + uint8_t header = data[0]; + uint8_t version = header & 0x0F; + if (version != kVariantVersion) { + return Status::Invalid("Variant metadata: unsupported version ", + static_cast(version), ", expected ", + static_cast(kVariantVersion)); + } + + // Bit 5 is reserved and must be zero in version 1 + if ((header >> 5) & 0x01) { + return Status::Invalid("Variant metadata: reserved bit 5 is set in header"); + } + + bool is_sorted = ((header >> 4) & 0x01) != 0; + int32_t offset_size = ((header >> 6) & 0x03) + 1; + + int64_t pos = 1; + + // Read dictionary size + if (pos + offset_size > length) { + return Status::Invalid("Variant metadata: truncated dictionary size at byte ", pos); + } + auto dict_size = static_cast(ReadUnsignedLE(data + pos, offset_size)); + pos += offset_size; + + // Read string offsets: (dict_size + 1) offsets + int64_t offsets_bytes = static_cast(dict_size + 1) * offset_size; + if (pos + offsets_bytes > length) { + return Status::Invalid("Variant metadata: truncated string offsets, need ", + offsets_bytes, " bytes at position ", pos, + " but buffer length is ", length); + } + + std::vector offsets(dict_size + 1); + for (int32_t i = 0; i <= dict_size; ++i) { + offsets[i] = ReadUnsignedLE(data + pos, offset_size); + pos += offset_size; + } + + // Validate offsets + int64_t string_data_length = length - pos; + ARROW_RETURN_NOT_OK(ValidateOffsets(offsets, string_data_length)); + + // Extract string views + std::vector strings(dict_size); + for (int32_t i = 0; i < dict_size; ++i) { + auto start = static_cast(offsets[i]); + auto end = static_cast(offsets[i + 1]); + strings[i] = + std::string_view(reinterpret_cast(data + pos + start), end - start); + } + + VariantMetadata result; + result.version = version; + result.is_sorted = is_sorted; + result.offset_size = offset_size; + result.strings = std::move(strings); + return result; +} + +Status DecodeVariantValue(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, VariantVisitor* visitor) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + DCHECK_NE(visitor, nullptr); + int64_t bytes_consumed = 0; + return DecodeValueAt(metadata, data, length, 0, visitor, &bytes_consumed, /*depth=*/0); +} + +Result GetValueBasicType(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + return GetBasicType(data[0]); +} + +Result GetObjectFieldCount(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kObject) { + return Status::Invalid("Variant value: expected object but got basic type ", + static_cast(GetBasicType(header))); + } + // type_info bit 4 = is_large (bit 6 of full byte) + uint8_t type_info = (header >> 2) & 0x3F; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + if (1 + num_fields_size > length) { + return Status::Invalid("Variant value: truncated object header"); + } + return static_cast(ReadUnsignedLE(data + 1, num_fields_size)); +} + +Result GetArrayElementCount(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kArray) { + return Status::Invalid("Variant value: expected array but got basic type ", + static_cast(GetBasicType(header))); + } + // type_info bit 2 = is_large (bit 4 of full byte) + uint8_t type_info = (header >> 2) & 0x3F; + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + if (1 + num_elements_size > length) { + return Status::Invalid("Variant value: truncated array header"); + } + return static_cast(ReadUnsignedLE(data + 1, num_elements_size)); +} + +Result ValueSize(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("ValueSize: buffer is null or empty"); + } + + uint8_t header = data[0]; + auto basic_type = GetBasicType(header); + uint8_t type_info = (header >> 2) & 0x3F; + + switch (basic_type) { + case BasicType::kShortString: + return 1 + static_cast(type_info); + + case BasicType::kObject: { + // type_info bit 4 = is_large (bit 6 of full byte) + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t sz_bytes = is_large ? 4 : 1; + if (1 + sz_bytes > length) { + return Status::Invalid("ValueSize: truncated object header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, sz_bytes)); + int32_t id_size = ((type_info >> 2) & 0x03) + 1; + int32_t offset_size = (type_info & 0x03) + 1; + int64_t id_start = 1 + sz_bytes; + int64_t offset_start = id_start + num_elements * id_size; + int64_t data_start = offset_start + (num_elements + 1) * offset_size; + // Last offset = total data size + int64_t last_offset_pos = offset_start + num_elements * offset_size; + if (last_offset_pos + offset_size > length) { + return Status::Invalid("ValueSize: truncated object offsets"); + } + auto total_data = + static_cast(ReadUnsignedLE(data + last_offset_pos, offset_size)); + return data_start + total_data; + } + + case BasicType::kArray: { + // type_info bit 2 = is_large (bit 4 of full byte) + // Note: Go's valueSize() in arrow-go (prior to fix PR) incorrectly + // used (typeInfo >> 4) for arrays, which reads bit 6 — the object's + // is_large position. The spec places array is_large at bit 4 of the + // full header byte. See: apache/arrow-go#839. + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t sz_bytes = is_large ? 4 : 1; + if (1 + sz_bytes > length) { + return Status::Invalid("ValueSize: truncated array header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, sz_bytes)); + int32_t offset_size = (type_info & 0x03) + 1; + int64_t offset_start = 1 + sz_bytes; + int64_t data_start = offset_start + (num_elements + 1) * offset_size; + // Last offset = total data size + int64_t last_offset_pos = offset_start + num_elements * offset_size; + if (last_offset_pos + offset_size > length) { + return Status::Invalid("ValueSize: truncated array offsets"); + } + auto total_data = + static_cast(ReadUnsignedLE(data + last_offset_pos, offset_size)); + return data_start + total_data; + } + + case BasicType::kPrimitive: { + auto ptype = static_cast(type_info); + int32_t payload_size = PrimitiveValueSize(ptype); + if (payload_size >= 0) { + return 1 + static_cast(payload_size); + } + // Variable-length: Binary or String (4-byte length prefix) + if (1 + 4 > length) { + return Status::Invalid("ValueSize: truncated variable-length header"); + } + uint32_t var_len; + std::memcpy(&var_len, data + 1, 4); + var_len = bit_util::FromLittleEndian(var_len); + return 1 + 4 + static_cast(var_len); + } + + default: + return Status::Invalid("ValueSize: unknown basic type"); + } +} + +Status FindObjectField(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, std::string_view field_name, int64_t* field_offset, + int64_t* field_size) { + *field_offset = -1; + *field_size = 0; + + if (data == nullptr || length < 1) { + return Status::Invalid("FindObjectField: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kObject) { + return Status::Invalid("FindObjectField: not an object"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + int32_t field_id_size = ((type_info >> 2) & 0x03) + 1; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + if (1 + num_fields_size > length) { + return Status::Invalid("FindObjectField: truncated header"); + } + auto num_fields = static_cast(ReadUnsignedLE(data + 1, num_fields_size)); + + int64_t id_start = 1 + num_fields_size; + int64_t offset_start = id_start + static_cast(num_fields) * field_id_size; + int64_t data_start = + offset_start + (static_cast(num_fields) + 1) * field_offset_size; + + if (data_start > length) { + return Status::Invalid("FindObjectField: truncated object"); + } + + // Per spec, field IDs are in lexicographic order of corresponding keys. + // Use binary search for large objects, linear scan for small ones. + // NOTE: If the input violates this ordering invariant (malformed data), + // binary search may return incorrect results. We do not validate sorting + // here for performance; callers should use DecodeVariantValue() for full + // validation of untrusted input. + constexpr int32_t kBinarySearchThreshold = 32; + + // Note: get_key_at returns an empty string_view for out-of-range field IDs. + // For the binary search path, this could theoretically misorder comparisons, + // but out-of-range IDs indicate a malformed variant. The function will simply + // not find the requested key and return field_offset=-1 (not found), which is + // a safe degradation for corrupted data. + auto get_key_at = [&](int32_t i) -> std::string_view { + auto id = ReadUnsignedLE(data + id_start + i * field_id_size, field_id_size); + if (id < metadata.strings.size()) { + return metadata.strings[id]; + } + return {}; + }; + + auto get_value_offset = [&](int32_t i) -> int64_t { + return data_start + + static_cast(ReadUnsignedLE( + data + offset_start + i * field_offset_size, field_offset_size)); + }; + + int32_t found_index = -1; + + if (num_fields < kBinarySearchThreshold) { + // Linear scan for small objects + for (int32_t i = 0; i < num_fields; ++i) { + if (get_key_at(i) == field_name) { + found_index = i; + break; + } + } + } else { + // Binary search for large objects (keys are in lex order). + // Note: int32_t is used deliberately for lo/hi to avoid unsigned + // underflow when hi = mid - 1 and mid == 0. The Go implementation + // (ObjectValue.ValueByKey) uses uint32 which wraps to MaxUint32. + int32_t lo = 0, hi = num_fields - 1; + while (lo <= hi) { + int32_t mid = lo + (hi - lo) / 2; + auto key = get_key_at(mid); + if (key == field_name) { + found_index = mid; + break; + } else if (key < field_name) { + lo = mid + 1; + } else { + hi = mid - 1; + } + } + } + + if (found_index >= 0) { + *field_offset = get_value_offset(found_index); + ARROW_ASSIGN_OR_RAISE(auto size, + ValueSize(data + *field_offset, length - *field_offset)); + *field_size = size; + } + + return Status::OK(); +} + +Status GetArrayElement(const uint8_t* data, int64_t length, int32_t index, + int64_t* element_offset, int64_t* element_size) { + if (data == nullptr || length < 1) { + return Status::Invalid("GetArrayElement: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kArray) { + return Status::Invalid("GetArrayElement: not an array"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + + if (1 + num_elements_size > length) { + return Status::Invalid("GetArrayElement: truncated header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, num_elements_size)); + + if (index < 0 || index >= num_elements) { + return Status::Invalid("GetArrayElement: index ", index, " out of range [0, ", + num_elements, ")"); + } + + int64_t offset_start = 1 + num_elements_size; + int64_t data_start = + offset_start + (static_cast(num_elements) + 1) * field_offset_size; + + auto elem_offset = static_cast( + ReadUnsignedLE(data + offset_start + index * field_offset_size, field_offset_size)); + *element_offset = data_start + elem_offset; + ARROW_ASSIGN_OR_RAISE(auto size, + ValueSize(data + *element_offset, length - *element_offset)); + *element_size = size; + return Status::OK(); +} + +Status GetObjectFieldAt(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, int32_t index, std::string_view* field_name, + int64_t* field_offset, int64_t* field_size) { + if (data == nullptr || length < 1) { + return Status::Invalid("GetObjectFieldAt: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kObject) { + return Status::Invalid("GetObjectFieldAt: not an object"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int32_t obj_offset_size = (type_info & 0x03) + 1; + int32_t field_id_size = ((type_info >> 2) & 0x03) + 1; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + if (1 + num_fields_size > length) { + return Status::Invalid("GetObjectFieldAt: truncated header"); + } + auto num_fields = static_cast(ReadUnsignedLE(data + 1, num_fields_size)); + + if (index < 0 || index >= num_fields) { + return Status::Invalid("GetObjectFieldAt: index ", index, " out of range [0, ", + num_fields, ")"); + } + + int64_t id_start = 1 + num_fields_size; + int64_t offset_start = id_start + static_cast(num_fields) * field_id_size; + int64_t data_start = + offset_start + (static_cast(num_fields) + 1) * obj_offset_size; + + // Get field name from dictionary + auto field_id = ReadUnsignedLE(data + id_start + index * field_id_size, field_id_size); + if (field_id >= metadata.strings.size()) { + return Status::Invalid("GetObjectFieldAt: field_id ", field_id, + " exceeds dictionary size ", metadata.strings.size()); + } + *field_name = metadata.strings[field_id]; + + // Get field value offset + auto value_offset = static_cast( + ReadUnsignedLE(data + offset_start + index * obj_offset_size, obj_offset_size)); + *field_offset = data_start + value_offset; + ARROW_ASSIGN_OR_RAISE(auto size, + ValueSize(data + *field_offset, length - *field_offset)); + *field_size = size; + return Status::OK(); +} + +int32_t FindMetadataKey(const VariantMetadata& metadata, std::string_view key) { + if (metadata.is_sorted) { + // Binary search on sorted dictionary + int32_t lo = 0; + int32_t hi = static_cast(metadata.strings.size()) - 1; + while (lo <= hi) { + int32_t mid = lo + (hi - lo) / 2; + int cmp = metadata.strings[mid].compare(key); + if (cmp == 0) { + return mid; + } else if (cmp < 0) { + lo = mid + 1; + } else { + hi = mid - 1; + } + } + return -1; + } + + // Linear scan for unsorted dictionary + for (int32_t i = 0; i < static_cast(metadata.strings.size()); ++i) { + if (metadata.strings[i] == key) { + return i; + } + } + return -1; +} + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_internal.h b/cpp/src/arrow/extension/variant_internal.h new file mode 100644 index 000000000000..6f19e9fe6769 --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal.h @@ -0,0 +1,347 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow::extension::variant_internal { + +/// \file variant_internal.h +/// \brief Utilities for Variant binary encoding/decoding. +/// +/// Implements parsing logic per the Variant Encoding Spec: +/// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md +/// +/// The "internal" in the filename refers to the binary encoding internals +/// of the Variant type, not the visibility of this header. This header is +/// installed and provides the public C++ API for working with Variant +/// binary data (independent of the VariantExtensionType in parquet_variant.h). + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// Variant encoding spec version 1. +constexpr uint8_t kVariantVersion = 1; + +/// Maximum nesting depth for recursive value decoding. +/// Prevents stack overflow on deeply nested (possibly malicious) input. +constexpr int32_t kMaxNestingDepth = 128; + +// --------------------------------------------------------------------------- +// Enumerations +// --------------------------------------------------------------------------- + +/// \brief Basic type codes from bits 0-1 of the value header byte. +/// +/// Variant Encoding Spec §3: "Value encoding" +enum class BasicType : uint8_t { + kPrimitive = 0, + kShortString = 1, + kObject = 2, + kArray = 3, +}; + +/// \brief Primitive type codes from bits 2-7 when basic_type == kPrimitive. +/// +/// Variant Encoding Spec §3.1: "Primitive types" +enum class PrimitiveType : uint8_t { + kNull = 0, + kTrue = 1, + kFalse = 2, + kInt8 = 3, + kInt16 = 4, + kInt32 = 5, + kInt64 = 6, + kDouble = 7, + kDecimal4 = 8, + kDecimal8 = 9, + kDecimal16 = 10, + kDate = 11, + kTimestampMicros = 12, + kTimestampMicrosNTZ = 13, + kFloat = 14, + kBinary = 15, + kString = 16, + kTimeNTZ = 17, + kTimestampNanos = 18, + kTimestampNanosNTZ = 19, + kUUID = 20, +}; + +// --------------------------------------------------------------------------- +// Metadata +// --------------------------------------------------------------------------- + +/// \brief Parsed variant metadata (string dictionary). +/// +/// The metadata buffer contains a header byte followed by a dictionary of +/// interned strings. String views reference the raw buffer and are valid +/// only as long as the underlying buffer is alive. +struct ARROW_EXPORT VariantMetadata { + /// Spec version (must be kVariantVersion). + uint8_t version = 0; + + /// Whether the dictionary strings are sorted lexicographically. + bool is_sorted = false; + + /// Number of bytes used for each offset (1, 2, 3, or 4). + int32_t offset_size = 0; + + /// Dictionary of interned strings. Views into the raw metadata buffer. + std::vector strings; +}; + +/// \brief Decode a variant metadata buffer. +/// +/// Parses the header byte and string dictionary from the raw metadata +/// buffer. The returned VariantMetadata contains string_views that +/// reference the input buffer directly (zero-copy). +/// +/// \param[in] data Pointer to the metadata buffer (must not be null) +/// \param[in] length Length of the metadata buffer in bytes +/// \return Parsed VariantMetadata on success, Status::Invalid on +/// malformed input +/// +/// \note The input buffer must outlive the returned VariantMetadata. +ARROW_EXPORT Result DecodeMetadata(const uint8_t* data, int64_t length); + +// --------------------------------------------------------------------------- +// Value header utilities +// --------------------------------------------------------------------------- + +/// \brief Extract the basic type from a value header byte. +/// +/// \param[in] header The first byte of a variant value +/// \return The BasicType (bits 0-1) +inline BasicType GetBasicType(uint8_t header) { + return static_cast(header & 0x03); +} + +/// \brief Extract the primitive type from a value header byte. +/// +/// Only valid when GetBasicType(header) == BasicType::kPrimitive. +/// +/// \param[in] header The first byte of a variant value +/// \return The PrimitiveType (bits 2-7) +inline PrimitiveType GetPrimitiveType(uint8_t header) { + return static_cast((header >> 2) & 0x3F); +} + +/// \brief Get the byte size of a primitive value (excluding header). +/// +/// \param[in] primitive_type The primitive type code +/// \return Number of bytes for the value payload, or -1 for +/// variable-length types (Binary, String) +ARROW_EXPORT int32_t PrimitiveValueSize(PrimitiveType primitive_type); + +// --------------------------------------------------------------------------- +// Value decoding +// --------------------------------------------------------------------------- + +/// \brief Visitor interface for variant value decoding. +/// +/// Implement this interface to receive callbacks during variant value +/// traversal. The visitor pattern avoids materializing a tree of objects, +/// which is important when scanning millions of rows. +/// +/// All methods return Status::OK() to continue traversal, or any error +/// Status to abort. +/// +/// \note String values passed to String() and FieldName() are raw bytes from +/// the variant buffer without UTF-8 validation. Per spec, all strings +/// must be valid UTF-8, but validation is the responsibility of a +/// higher-level consumer (e.g., when materializing to Arrow StringArray). +class ARROW_EXPORT VariantVisitor { + public: + virtual ~VariantVisitor() = default; + + /// @name Primitive value callbacks + /// @{ + virtual Status Null() = 0; + virtual Status Bool(bool value) = 0; + virtual Status Int8(int8_t value) = 0; + virtual Status Int16(int16_t value) = 0; + virtual Status Int32(int32_t value) = 0; + virtual Status Int64(int64_t value) = 0; + virtual Status Float(float value) = 0; + virtual Status Double(double value) = 0; + virtual Status Decimal4(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Decimal8(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Decimal16(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Date(int32_t days_since_epoch) = 0; + virtual Status TimestampMicros(int64_t micros_since_epoch) = 0; + virtual Status TimestampMicrosNTZ(int64_t micros_since_epoch) = 0; + virtual Status String(std::string_view value) = 0; + virtual Status Binary(std::string_view value) = 0; + virtual Status TimeNTZ(int64_t micros_since_midnight) = 0; + virtual Status TimestampNanos(int64_t nanos_since_epoch) = 0; + virtual Status TimestampNanosNTZ(int64_t nanos_since_epoch) = 0; + virtual Status UUID(const uint8_t* bytes) = 0; + /// @} + + /// @name Container callbacks + /// @{ + + /// \brief Called at the start of an object with the number of fields. + virtual Status StartObject(int32_t num_fields) = 0; + + /// \brief Called for each object field name, before the field value. + virtual Status FieldName(std::string_view name) = 0; + + /// \brief Called after all fields of an object have been visited. + virtual Status EndObject() = 0; + + /// \brief Called at the start of an array with the number of elements. + virtual Status StartArray(int32_t num_elements) = 0; + + /// \brief Called after all elements of an array have been visited. + virtual Status EndArray() = 0; + /// @} +}; + +/// \brief Decode a variant value buffer using a visitor. +/// +/// Recursively traverses the variant value, calling the appropriate +/// visitor methods for each element. Objects and arrays trigger +/// Start/End pairs with nested visits for their contents. +/// +/// \param[in] metadata Parsed metadata (for resolving string dictionary) +/// \param[in] data Pointer to the value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] visitor Callback interface for decoded values +/// \return Status::OK on success, Status::Invalid on malformed input +/// +/// \note The data buffer must remain valid for the duration of the call. +ARROW_EXPORT Status DecodeVariantValue(const VariantMetadata& metadata, + const uint8_t* data, int64_t length, + VariantVisitor* visitor); + +/// \brief Get the basic type of a variant value without full decoding. +/// +/// \param[in] data Pointer to the value buffer +/// \param[in] length Length of the value buffer in bytes +/// \return The BasicType of the value, or Status::Invalid if the +/// buffer is empty +ARROW_EXPORT Result GetValueBasicType(const uint8_t* data, int64_t length); + +/// \brief Get the number of fields in a variant object. +/// +/// \param[in] data Pointer to the value buffer (must start with an object) +/// \param[in] length Length of the value buffer in bytes +/// \return The number of fields, or Status::Invalid if not an object +ARROW_EXPORT Result GetObjectFieldCount(const uint8_t* data, int64_t length); + +/// \brief Get the number of elements in a variant array. +/// +/// \param[in] data Pointer to the value buffer (must start with an array) +/// \param[in] length Length of the value buffer in bytes +/// \return The number of elements, or Status::Invalid if not an array +ARROW_EXPORT Result GetArrayElementCount(const uint8_t* data, int64_t length); + +// --------------------------------------------------------------------------- +// Value size computation +// --------------------------------------------------------------------------- + +/// \brief Compute the total byte size of a variant value (header + data). +/// +/// Determines how many bytes a variant value occupies by examining +/// its header and (for containers/variable-length types) reading +/// size information. Does NOT recursively validate the contents. +/// +/// \param[in] data Pointer to the start of a variant value +/// \param[in] length Maximum bytes available +/// \return Total byte count of the value, or Status::Invalid if truncated +ARROW_EXPORT Result ValueSize(const uint8_t* data, int64_t length); + +// --------------------------------------------------------------------------- +// Random access utilities +// --------------------------------------------------------------------------- + +/// \brief Find an object field by name and return the offset/size of its value. +/// +/// Searches the field IDs in the object, resolving each against the +/// metadata dictionary. Per spec, field IDs are in lexicographic order +/// of their corresponding key names, enabling binary search for large +/// objects (>=32 fields). For smaller objects, linear scan is used. +/// +/// \param[in] metadata Parsed metadata (for resolving field IDs to names) +/// \param[in] data Pointer to the object value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] field_name The field name to search for +/// \param[out] field_offset Set to the byte offset of the field's value +/// within data, or -1 if not found +/// \param[out] field_size Set to the byte size of the field's value, +/// or 0 if not found +/// \return Status::OK if search completed (field may or may not exist), +/// Status::Invalid if the buffer is malformed +ARROW_EXPORT Status FindObjectField(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, std::string_view field_name, + int64_t* field_offset, int64_t* field_size); + +/// \brief Get the i-th element of a variant array by index (O(1) access). +/// +/// Uses the offset table for random access without traversing preceding +/// elements. +/// +/// \param[in] data Pointer to the array value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] index Zero-based element index +/// \param[out] element_offset Set to the byte offset of the element within data +/// \param[out] element_size Set to the byte size of the element +/// \return Status::OK on success, Status::Invalid if not an array or +/// index is out of range +ARROW_EXPORT Status GetArrayElement(const uint8_t* data, int64_t length, int32_t index, + int64_t* element_offset, int64_t* element_size); + +/// \brief Get the i-th field of a variant object by position. +/// +/// Returns both the field name (resolved from metadata) and a pointer +/// to the field's value. +/// +/// \param[in] metadata Parsed metadata +/// \param[in] data Pointer to the object value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] index Zero-based field index +/// \param[out] field_name Set to the field's key name +/// \param[out] field_offset Set to the byte offset of the field's value +/// \param[out] field_size Set to the byte size of the field's value +/// \return Status::OK on success, Status::Invalid if not an object or +/// index is out of range +ARROW_EXPORT Status GetObjectFieldAt(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, int32_t index, + std::string_view* field_name, int64_t* field_offset, + int64_t* field_size); + +/// \brief Find the dictionary ID for a given key name. +/// +/// Uses binary search if the metadata is sorted, otherwise linear scan. +/// +/// \param[in] metadata Parsed metadata +/// \param[in] key The key to search for +/// \return The dictionary ID if found, or -1 if not present +ARROW_EXPORT int32_t FindMetadataKey(const VariantMetadata& metadata, + std::string_view key); + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_internal_test.cc b/cpp/src/arrow/extension/variant_internal_test.cc new file mode 100644 index 000000000000..9cfdb665aa82 --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal_test.cc @@ -0,0 +1,2128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_internal.h" +#include "arrow/extension/variant_test_util.h" + +#include +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" + +namespace arrow::extension::variant_internal { + +// =========================================================================== +// Test helpers +// =========================================================================== + +/// \brief Build a metadata buffer from a list of strings. +/// +/// Uses offset_size=1, version=1, sorted flag as specified. +std::vector BuildMetadataBuffer(const std::vector& strings, + bool sorted = false, int32_t offset_size = 1) { + std::vector buffer; + + // Header byte: version=1, sorted flag, offset_size + uint8_t header = kVariantVersion; + if (sorted) { + header |= (1 << 4); + } + header |= static_cast((offset_size - 1) << 6); + buffer.push_back(header); + + // Dictionary size + auto dict_size = static_cast(strings.size()); + for (int32_t b = 0; b < offset_size; ++b) { + buffer.push_back(static_cast((dict_size >> (b * 8)) & 0xFF)); + } + + // Compute string offsets + std::vector offsets(dict_size + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < dict_size; ++i) { + offsets[i + 1] = offsets[i] + static_cast(strings[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= dict_size; ++i) { + for (int32_t b = 0; b < offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write string data + for (const auto& s : strings) { + buffer.insert(buffer.end(), s.begin(), s.end()); + } + + return buffer; +} + +/// \brief Build a primitive value header byte. +uint8_t PrimitiveHeader(PrimitiveType type) { + return static_cast(BasicType::kPrimitive) | (static_cast(type) << 2); +} + +/// \brief Build a short string value buffer. +std::vector BuildShortString(const std::string& s) { + std::vector buffer; + auto len = static_cast(s.size()); + uint8_t header = static_cast(BasicType::kShortString) | (len << 2); + buffer.push_back(header); + buffer.insert(buffer.end(), s.begin(), s.end()); + return buffer; +} + +/// \brief Build an object value buffer. +/// +/// \param field_ids Dictionary indices for each field name +/// \param field_values Serialized variant values for each field +/// \param field_id_size Bytes per field ID (1-4) +/// \param field_offset_size Bytes per offset (1-4) +std::vector BuildObject(const std::vector& field_ids, + const std::vector>& field_values, + int32_t field_id_size = 1, + int32_t field_offset_size = 1) { + auto num_fields = static_cast(field_ids.size()); + bool is_large = (num_fields > 255); + + std::vector buffer; + + // Header per spec: basic_type=2 in bits 0-1, + // bits 2-3: field_offset_size-1 + // bits 4-5: field_id_size-1 + // bit 6: is_large + uint8_t header = static_cast(BasicType::kObject); + header |= static_cast((field_offset_size - 1) << 2); + header |= static_cast((field_id_size - 1) << 4); + if (is_large) { + header |= (1 << 6); + } + buffer.push_back(header); + + // num_fields: 1 byte or 4 bytes depending on is_large + int32_t num_fields_size = is_large ? 4 : 1; + for (int32_t b = 0; b < num_fields_size; ++b) { + buffer.push_back(static_cast((num_fields >> (b * 8)) & 0xFF)); + } + + // field_ids + for (auto fid : field_ids) { + for (int32_t b = 0; b < field_id_size; ++b) { + buffer.push_back(static_cast((fid >> (b * 8)) & 0xFF)); + } + } + + // Compute offsets + std::vector offsets(num_fields + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < num_fields; ++i) { + offsets[i + 1] = offsets[i] + static_cast(field_values[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= num_fields; ++i) { + for (int32_t b = 0; b < field_offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write field value data + for (const auto& fv : field_values) { + buffer.insert(buffer.end(), fv.begin(), fv.end()); + } + + return buffer; +} + +/// \brief Build an array value buffer. +/// +/// \param elements Serialized variant values for each element +/// \param field_offset_size Bytes per offset (1-4) +std::vector BuildArray(const std::vector>& elements, + int32_t field_offset_size = 1) { + auto num_elements = static_cast(elements.size()); + bool is_large = (num_elements > 255); + + std::vector buffer; + + // Header per spec: basic_type=3 in bits 0-1, + // bits 2-3: field_offset_size-1 + // bit 4: is_large + uint8_t header = static_cast(BasicType::kArray); + header |= static_cast((field_offset_size - 1) << 2); + if (is_large) { + header |= (1 << 4); + } + buffer.push_back(header); + + // num_elements: 1 byte or 4 bytes depending on is_large + int32_t num_elements_size = is_large ? 4 : 1; + for (int32_t b = 0; b < num_elements_size; ++b) { + buffer.push_back(static_cast((num_elements >> (b * 8)) & 0xFF)); + } + + // Compute offsets + std::vector offsets(num_elements + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < num_elements; ++i) { + offsets[i + 1] = offsets[i] + static_cast(elements[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= num_elements; ++i) { + for (int32_t b = 0; b < field_offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write element data + for (const auto& elem : elements) { + buffer.insert(buffer.end(), elem.begin(), elem.end()); + } + + return buffer; +} + +// =========================================================================== +// Metadata decoding tests +// =========================================================================== + +class VariantMetadataTest : public ::testing::Test {}; + +TEST_F(VariantMetadataTest, EmptyDictionary) { + auto buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.version, 1); + ASSERT_FALSE(metadata.is_sorted); + ASSERT_EQ(metadata.offset_size, 1); + ASSERT_EQ(metadata.strings.size(), 0); +} + +TEST_F(VariantMetadataTest, SingleString) { + auto buf = BuildMetadataBuffer({"hello"}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 1); + ASSERT_EQ(metadata.strings[0], "hello"); +} + +TEST_F(VariantMetadataTest, MultipleStrings) { + auto buf = BuildMetadataBuffer({"name", "age", "scores"}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], "name"); + ASSERT_EQ(metadata.strings[1], "age"); + ASSERT_EQ(metadata.strings[2], "scores"); +} + +TEST_F(VariantMetadataTest, SortedFlag) { + auto buf = BuildMetadataBuffer({"age", "name", "score"}, true); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_TRUE(metadata.is_sorted); +} + +TEST_F(VariantMetadataTest, OffsetSize2) { + auto buf = BuildMetadataBuffer({"key1", "key2"}, false, 2); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 2); + ASSERT_EQ(metadata.strings.size(), 2); + ASSERT_EQ(metadata.strings[0], "key1"); + ASSERT_EQ(metadata.strings[1], "key2"); +} + +TEST_F(VariantMetadataTest, OffsetSize4) { + auto buf = BuildMetadataBuffer({"a", "bb", "ccc"}, false, 4); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 4); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], "a"); + ASSERT_EQ(metadata.strings[1], "bb"); + ASSERT_EQ(metadata.strings[2], "ccc"); +} + +TEST_F(VariantMetadataTest, EmptyStrings) { + auto buf = BuildMetadataBuffer({"", "nonempty", ""}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], ""); + ASSERT_EQ(metadata.strings[1], "nonempty"); + ASSERT_EQ(metadata.strings[2], ""); +} + +// Error cases + +TEST_F(VariantMetadataTest, NullBuffer) { + ASSERT_RAISES(Invalid, DecodeMetadata(nullptr, 0)); +} + +TEST_F(VariantMetadataTest, EmptyBuffer) { + uint8_t data = 0; + ASSERT_RAISES(Invalid, DecodeMetadata(&data, 0)); +} + +TEST_F(VariantMetadataTest, UnsupportedVersion) { + // Version 2 (unsupported) + uint8_t data[] = {0x02, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, TruncatedDictionarySize) { + // Header says offset_size=2 (bits 6-7 = 01), but only 1 byte follows + uint8_t data[] = {0x41, 0x00}; // version=1, offset_size=2 + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, TruncatedStringOffsets) { + // Claims dict_size=5 but buffer is too short for offsets + uint8_t data[] = {0x01, 0x05, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, OffsetSize3) { + auto buf = BuildMetadataBuffer({"foo", "bar"}, false, 3); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 3); + ASSERT_EQ(metadata.strings.size(), 2); + ASSERT_EQ(metadata.strings[0], "foo"); + ASSERT_EQ(metadata.strings[1], "bar"); +} + +TEST_F(VariantMetadataTest, ReservedBit5Set) { + // Header with bit 5 set: 0x21 = version=1, bit5=1 + uint8_t data[] = {0x21, 0x00, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, NonMonotonicStringOffsets) { + // Manually construct metadata where string offsets are NOT monotonically + // non-decreasing. ValidateOffsets should reject this. + // Header: version=1, offset_size=1 + // dict_size=2, offsets=[0, 5, 3] — 3 < 5, non-monotonic + // String data: "helloabc" (8 bytes, but offsets claim 3 as last) + uint8_t data[] = { + 0x01, // header: version=1, offset_size=1 + 0x02, // dict_size = 2 + 0x00, 0x05, 0x03, // offsets: [0, 5, 3] — non-monotonic + 'h', 'e', 'l', 'l', 'o', 'a', 'b', 'c'}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +// =========================================================================== +// Primitive value decoding tests +// =========================================================================== + +class VariantPrimitiveTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantPrimitiveTest, DecodeNull) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Null"); +} + +TEST_F(VariantPrimitiveTest, DecodeTrue) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kTrue)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Bool(true)"); +} + +TEST_F(VariantPrimitiveTest, DecodeFalse) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kFalse)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Bool(false)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt8) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x2A}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(42)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt8Negative) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0xD6}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(-42)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt16) { + // 300 = 0x012C in little-endian: 0x2C, 0x01 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt16), 0x2C, 0x01}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(300)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt32) { + // 100000 = 0x000186A0 in LE: A0 86 01 00 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0xA0, 0x86, 0x01, 0x00}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(100000)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt32Max) { + int32_t val = std::numeric_limits::max(); + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kInt32); + std::memcpy(data + 1, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt64) { + int64_t val = 1234567890123LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kInt64); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int64(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeFloat) { + float val = 3.14f; + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kFloat); + std::memcpy(data + 1, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + // Float string representation may vary; just check it starts with Float( + ASSERT_TRUE(visitor.events[0].find("Float(") == 0); +} + +TEST_F(VariantPrimitiveTest, DecodeDouble) { + double val = 2.718281828459045; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kDouble); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +TEST_F(VariantPrimitiveTest, DecodeDate) { + // Days since epoch: 19000 (approximately 2022-01-01) + int32_t days = 19000; + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kDate); + std::memcpy(data + 1, &days, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Date(19000)"); +} + +TEST_F(VariantPrimitiveTest, DecodeTimestampMicros) { + int64_t micros = 1654041600000000LL; // some timestamp + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampMicros); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampMicros(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeTimestampMicrosNTZ) { + int64_t micros = 1654041600000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampMicrosNTZ); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampMicrosNTZ(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal4) { + // Spec layout: 1 byte scale, then 4 bytes LE unscaled value + uint8_t data[6]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal4); + data[1] = 2; // scale = 2 + int32_t val = 12345; + std::memcpy(data + 2, &val, 4); // unscaled value + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal4(scale=2)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal4MaxScale) { + // Scale at maximum per spec: 38 + uint8_t data[6]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal4); + data[1] = 38; // scale = 38 (maximum per spec) + int32_t val = 12345; + std::memcpy(data + 2, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal4(scale=38)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal8) { + // Spec layout: 1 byte scale, then 8 bytes LE unscaled value + uint8_t data[10]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal8); + data[1] = 5; // scale = 5 + int64_t val = 123456789012345LL; + std::memcpy(data + 2, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal8(scale=5)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal16) { + // Spec layout: 1 byte scale, then 16 bytes LE unscaled value + uint8_t data[18]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal16); + data[1] = 10; // scale = 10 + std::memset(data + 2, 0, 16); + data[2] = 0x01; // low byte = 1 + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal16(scale=10)"); +} + +TEST_F(VariantPrimitiveTest, DecodeLongString) { + // Long string: primitive type kString with 4-byte length prefix + std::string test_str = "hello world, this is a long string"; + auto str_len = static_cast(test_str.size()); + + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + // 4-byte little-endian length + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((str_len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), test_str.begin(), test_str.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hello world, this is a long string\")"); +} + +TEST_F(VariantPrimitiveTest, DecodeBinary) { + std::vector bin_bytes = {0x00, 0x01, 0x02, 0x03}; + auto bin_len = static_cast(bin_bytes.size()); + + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((bin_len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), bin_bytes.begin(), bin_bytes.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "Binary(len=4)"); +} + +// Truncation errors + +TEST_F(VariantPrimitiveTest, TruncatedInt32) { + // Only 2 bytes after header, but Int32 needs 4 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0x00, 0x00}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +TEST_F(VariantPrimitiveTest, EmptyValueBuffer) { + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeVariantValue(empty_metadata_, nullptr, 0, &visitor)); +} + +// =========================================================================== +// Short string tests +// =========================================================================== + +class VariantShortStringTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantShortStringTest, EmptyShortString) { + auto data = BuildShortString(""); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"\")"); +} + +TEST_F(VariantShortStringTest, SimpleShortString) { + auto data = BuildShortString("hi"); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hi\")"); +} + +TEST_F(VariantShortStringTest, MaxLengthShortString) { + // Maximum short string is 63 bytes + std::string max_str(63, 'x'); + auto data = BuildShortString(max_str); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"" + max_str + "\")"); +} + +TEST_F(VariantShortStringTest, TruncatedShortString) { + // Header says length=10 but buffer only has 3 bytes total + uint8_t data[] = {static_cast(BasicType::kShortString) | (10 << 2), 'a', 'b'}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Object decoding tests +// =========================================================================== + +class VariantObjectTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "age", "scores"}; + } +}; + +TEST_F(VariantObjectTest, EmptyObject) { + auto data = BuildObject({}, {}); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 2); + ASSERT_EQ(visitor.events[0], "StartObject(0)"); + ASSERT_EQ(visitor.events[1], "EndObject"); +} + +TEST_F(VariantObjectTest, SingleField) { + // Object with one field: name -> "Alice" (short string) + auto value = BuildShortString("Alice"); + auto data = BuildObject({0}, {value}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 4); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "EndObject"); +} + +TEST_F(VariantObjectTest, MultipleFields) { + // Object: {name: "Bob", age: 30} + auto name_val = BuildShortString("Bob"); + // age: Int32(30) + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + + auto data = BuildObject({0, 1}, {name_val, age_val}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 6); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Bob\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(30)"); + ASSERT_EQ(visitor.events[5], "EndObject"); +} + +TEST_F(VariantObjectTest, InvalidFieldId) { + // field_id=99 exceeds dictionary size of 3 + auto value = BuildShortString("oops"); + auto data = BuildObject({99}, {value}); + + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeVariantValue(metadata_, data.data(), + static_cast(data.size()), &visitor)); +} + +TEST_F(VariantObjectTest, ThreeByteOffsetSize) { + // Exercises value decoding with 3-byte field_offset_size and field_id_size. + // Object with 2 fields: {name: "test", age: 42} + auto name_val = BuildShortString("test"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto data = BuildObject({0, 1}, {name_val, age_val}, + /*field_id_size=*/3, /*field_offset_size=*/3); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 6); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"test\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(42)"); + ASSERT_EQ(visitor.events[5], "EndObject"); +} + +// =========================================================================== +// Array decoding tests +// =========================================================================== + +class VariantArrayTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantArrayTest, EmptyArray) { + auto data = BuildArray({}); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 2); + ASSERT_EQ(visitor.events[0], "StartArray(0)"); + ASSERT_EQ(visitor.events[1], "EndArray"); +} + +TEST_F(VariantArrayTest, SingleElement) { + std::vector elem = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto data = BuildArray({elem}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 3); + ASSERT_EQ(visitor.events[0], "StartArray(1)"); + ASSERT_EQ(visitor.events[1], "Int32(42)"); + ASSERT_EQ(visitor.events[2], "EndArray"); +} + +TEST_F(VariantArrayTest, HeterogeneousElements) { + // Array with mixed types: [42, "hello", true] + std::vector int_elem = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto str_elem = BuildShortString("hello"); + std::vector bool_elem = {PrimitiveHeader(PrimitiveType::kTrue)}; + + auto data = BuildArray({int_elem, str_elem, bool_elem}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 5); + ASSERT_EQ(visitor.events[0], "StartArray(3)"); + ASSERT_EQ(visitor.events[1], "Int32(42)"); + ASSERT_EQ(visitor.events[2], "String(\"hello\")"); + ASSERT_EQ(visitor.events[3], "Bool(true)"); + ASSERT_EQ(visitor.events[4], "EndArray"); +} + +TEST_F(VariantArrayTest, LargeArrayIsLargeFlag) { + // Build an array with 256 elements to exercise is_large=true (4-byte + // num_elements). Each element is a Null primitive (1 byte each). + // Use field_offset_size=2 since total data (256 bytes) exceeds 1-byte max. + std::vector> elements; + elements.reserve(256); + for (int i = 0; i < 256; ++i) { + elements.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildArray(elements, /*field_offset_size=*/2); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + // StartArray(256) + 256 Nulls + EndArray = 258 events + ASSERT_EQ(visitor.events.size(), 258); + ASSERT_EQ(visitor.events[0], "StartArray(256)"); + ASSERT_EQ(visitor.events[1], "Null"); + ASSERT_EQ(visitor.events[256], "Null"); + ASSERT_EQ(visitor.events[257], "EndArray"); +} + +// =========================================================================== +// Nested structure tests +// =========================================================================== + +class VariantNestedTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "scores", "inner"}; + } +}; + +TEST_F(VariantNestedTest, ObjectWithNestedArray) { + // {name: "Alice", scores: [95, 87]} + auto name_val = BuildShortString("Alice"); + + // scores array: [Int32(95), Int32(87)] + std::vector score1 = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + std::vector score2 = {PrimitiveHeader(PrimitiveType::kInt32), 87, 0, 0, 0}; + auto scores_val = BuildArray({score1, score2}); + + auto data = BuildObject({0, 1}, {name_val, scores_val}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + // Expected events: + // StartObject(2), FieldName("name"), String("Alice"), + // FieldName("scores"), StartArray(2), Int32(95), Int32(87), EndArray, + // EndObject + ASSERT_EQ(visitor.events.size(), 9); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"scores\")"); + ASSERT_EQ(visitor.events[4], "StartArray(2)"); + ASSERT_EQ(visitor.events[5], "Int32(95)"); + ASSERT_EQ(visitor.events[6], "Int32(87)"); + ASSERT_EQ(visitor.events[7], "EndArray"); + ASSERT_EQ(visitor.events[8], "EndObject"); +} + +TEST_F(VariantNestedTest, NestedObjects) { + // {inner: {name: "deep"}} + auto deep_name = BuildShortString("deep"); + auto inner_obj = BuildObject({0}, {deep_name}); + auto data = BuildObject({2}, {inner_obj}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + ASSERT_EQ(visitor.events.size(), 7); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"inner\")"); + ASSERT_EQ(visitor.events[2], "StartObject(1)"); + ASSERT_EQ(visitor.events[3], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[4], "String(\"deep\")"); + ASSERT_EQ(visitor.events[5], "EndObject"); + ASSERT_EQ(visitor.events[6], "EndObject"); +} + +TEST_F(VariantNestedTest, ArrayOfObjects) { + // [{name: "a"}, {name: "b"}] + auto val_a = BuildShortString("a"); + auto obj_a = BuildObject({0}, {val_a}); + + auto val_b = BuildShortString("b"); + auto obj_b = BuildObject({0}, {val_b}); + + auto data = BuildArray({obj_a, obj_b}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + ASSERT_EQ(visitor.events.size(), 10); + ASSERT_EQ(visitor.events[0], "StartArray(2)"); + ASSERT_EQ(visitor.events[1], "StartObject(1)"); + ASSERT_EQ(visitor.events[2], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[3], "String(\"a\")"); + ASSERT_EQ(visitor.events[4], "EndObject"); + ASSERT_EQ(visitor.events[5], "StartObject(1)"); + ASSERT_EQ(visitor.events[6], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[7], "String(\"b\")"); + ASSERT_EQ(visitor.events[8], "EndObject"); + ASSERT_EQ(visitor.events[9], "EndArray"); +} + +// =========================================================================== +// Recursion depth limit test +// =========================================================================== + +class VariantDepthTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"x"}; + } +}; + +TEST_F(VariantDepthTest, ExceedsMaxNestingDepth) { + // Build a deeply nested array: [[[[...]]]] + // Each level wraps the inner in a 1-element array with offset_size=2 + // to allow buffers larger than 255 bytes. + std::vector inner = {PrimitiveHeader(PrimitiveType::kNull)}; + + // Wrap 130 times (exceeds kMaxNestingDepth=128) + for (int i = 0; i < 130; ++i) { + inner = BuildArray({inner}, /*field_offset_size=*/2); + } + + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(metadata_, inner.data(), + static_cast(inner.size()), &visitor)); +} + +TEST_F(VariantDepthTest, AtMaxNestingDepthSucceeds) { + // Build 50 levels of nesting — well within kMaxNestingDepth=128 + // and within offset_size=1 limits (each level adds ~4 bytes). + std::vector inner = {PrimitiveHeader(PrimitiveType::kNull)}; + + for (int i = 0; i < 50; ++i) { + inner = BuildArray({inner}); + } + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, inner.data(), + static_cast(inner.size()), &visitor)); +} + +// =========================================================================== +// Utility function tests +// =========================================================================== + +class VariantUtilTest : public ::testing::Test {}; + +TEST_F(VariantUtilTest, GetValueBasicTypePrimitive) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0, 0, 0, 0}; + ASSERT_OK_AND_ASSIGN(auto bt, GetValueBasicType(data, sizeof(data))); + ASSERT_EQ(bt, BasicType::kPrimitive); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeShortString) { + auto data = BuildShortString("test"); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kShortString); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeObject) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"key"}; + auto val = BuildShortString("val"); + auto data = BuildObject({0}, {val}); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kObject); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeArray) { + auto data = BuildArray({}); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kArray); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeEmptyBuffer) { + ASSERT_RAISES(Invalid, GetValueBasicType(nullptr, 0)); +} + +TEST_F(VariantUtilTest, GetObjectFieldCount) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"a", "b", "c"}; + auto v1 = BuildShortString("x"); + auto v2 = BuildShortString("y"); + auto data = BuildObject({0, 1}, {v1, v2}); + ASSERT_OK_AND_ASSIGN( + auto count, GetObjectFieldCount(data.data(), static_cast(data.size()))); + ASSERT_EQ(count, 2); +} + +TEST_F(VariantUtilTest, GetArrayElementCount) { + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + std::vector e3 = {PrimitiveHeader(PrimitiveType::kFalse)}; + auto data = BuildArray({e1, e2, e3}); + ASSERT_OK_AND_ASSIGN( + auto count, GetArrayElementCount(data.data(), static_cast(data.size()))); + ASSERT_EQ(count, 3); +} + +TEST_F(VariantUtilTest, PrimitiveValueSizes) { + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kNull), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTrue), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kFalse), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt8), 1); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt16), 2); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt32), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt64), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kFloat), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDouble), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDate), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampMicros), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampMicrosNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimeNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampNanos), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampNanosNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kUUID), 16); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal4), 5); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal8), 9); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal16), 17); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kBinary), -1); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kString), -1); +} + +// =========================================================================== +// Integration: Metadata + Value decoding together +// =========================================================================== + +class VariantIntegrationTest : public ::testing::Test {}; + +TEST_F(VariantIntegrationTest, FullRoundTrip) { + // Build a complete variant: {name: "Alice", age: 30, scores: [95, 87]} + auto meta_buf = BuildMetadataBuffer({"name", "age", "scores"}); + + auto name_val = BuildShortString("Alice"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + std::vector s1 = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + std::vector s2 = {PrimitiveHeader(PrimitiveType::kInt32), 87, 0, 0, 0}; + auto scores_val = BuildArray({s1, s2}); + + auto value_buf = BuildObject({0, 1, 2}, {name_val, age_val, scores_val}); + + // Decode metadata + ASSERT_OK_AND_ASSIGN( + auto metadata, + DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + ASSERT_EQ(metadata.strings.size(), 3); + + // Decode value + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_buf.data(), + static_cast(value_buf.size()), &visitor)); + + // Verify full event sequence + ASSERT_EQ(visitor.events.size(), 11); + ASSERT_EQ(visitor.events[0], "StartObject(3)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(30)"); + ASSERT_EQ(visitor.events[5], "FieldName(\"scores\")"); + ASSERT_EQ(visitor.events[6], "StartArray(2)"); + ASSERT_EQ(visitor.events[7], "Int32(95)"); + ASSERT_EQ(visitor.events[8], "Int32(87)"); + ASSERT_EQ(visitor.events[9], "EndArray"); + ASSERT_EQ(visitor.events[10], "EndObject"); +} + +// =========================================================================== +// Visitor early abort test +// =========================================================================== + +/// \brief A visitor that aborts after receiving a specific number of events. +class AbortingVisitor : public VariantVisitor { + public: + int32_t abort_after; + int32_t count = 0; + + explicit AbortingVisitor(int32_t abort_after) : abort_after(abort_after) {} + + Status MaybeAbort() { + ++count; + if (count >= abort_after) { + return Status::Cancelled("Visitor aborted after ", count, " events"); + } + return Status::OK(); + } + + Status Null() override { return MaybeAbort(); } + Status Bool(bool /*value*/) override { return MaybeAbort(); } + Status Int8(int8_t /*value*/) override { return MaybeAbort(); } + Status Int16(int16_t /*value*/) override { return MaybeAbort(); } + Status Int32(int32_t /*value*/) override { return MaybeAbort(); } + Status Int64(int64_t /*value*/) override { return MaybeAbort(); } + Status Float(float /*value*/) override { return MaybeAbort(); } + Status Double(double /*value*/) override { return MaybeAbort(); } + Status Decimal4(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Decimal8(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Decimal16(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Date(int32_t /*days*/) override { return MaybeAbort(); } + Status TimestampMicros(int64_t /*micros*/) override { return MaybeAbort(); } + Status TimestampMicrosNTZ(int64_t /*micros*/) override { return MaybeAbort(); } + Status String(std::string_view /*value*/) override { return MaybeAbort(); } + Status Binary(std::string_view /*value*/) override { return MaybeAbort(); } + Status TimeNTZ(int64_t /*micros*/) override { return MaybeAbort(); } + Status TimestampNanos(int64_t /*nanos*/) override { return MaybeAbort(); } + Status TimestampNanosNTZ(int64_t /*nanos*/) override { return MaybeAbort(); } + Status UUID(const uint8_t* /*bytes*/) override { return MaybeAbort(); } + Status StartObject(int32_t /*num_fields*/) override { return MaybeAbort(); } + Status FieldName(std::string_view /*name*/) override { return MaybeAbort(); } + Status EndObject() override { return MaybeAbort(); } + Status StartArray(int32_t /*num_elements*/) override { return MaybeAbort(); } + Status EndArray() override { return MaybeAbort(); } +}; + +class VariantAbortTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "age"}; + } +}; + +TEST_F(VariantAbortTest, VisitorAbortsEarly) { + // Object: {name: "Alice", age: 30} + auto name_val = BuildShortString("Alice"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto data = BuildObject({0, 1}, {name_val, age_val}); + + // Abort after 3 events (StartObject, FieldName, String) + // Should NOT reach the second field + AbortingVisitor visitor(3); + auto status = DecodeVariantValue(metadata_, data.data(), + static_cast(data.size()), &visitor); + ASSERT_TRUE(status.IsCancelled()); + ASSERT_EQ(visitor.count, 3); +} + +TEST_F(VariantAbortTest, VisitorAbortsOnFirstEvent) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + AbortingVisitor visitor(1); + auto status = DecodeVariantValue(metadata_, data, sizeof(data), &visitor); + ASSERT_TRUE(status.IsCancelled()); +} + +// =========================================================================== +// Spec-conformance test with hardcoded byte sequences +// =========================================================================== + +class VariantSpecTest : public ::testing::Test {}; + +TEST_F(VariantSpecTest, HandcraftedNullValue) { + // Variant Encoding Spec: Null is basic_type=0, primitive_type=0 + // Header byte: 0x00 (bits 0-1=00 for primitive, bits 2-7=000000 for null) + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; // v1, 0 strings, offset[0]=0 + uint8_t value_bytes[] = {0x00}; // null + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Null"); +} + +TEST_F(VariantSpecTest, HandcraftedInt32Value) { + // Int32(42): basic_type=0, primitive_type=5 + // Header: (5 << 2) | 0 = 0x14 + // Value: 42 as LE int32 = 2A 00 00 00 + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[] = {0x14, 0x2A, 0x00, 0x00, 0x00}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(42)"); +} + +TEST_F(VariantSpecTest, HandcraftedShortString) { + // Short string "hello": basic_type=1, length=5 + // Header: (5 << 2) | 1 = 0x15 + // Followed by 5 bytes of UTF-8 "hello" + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[] = {0x15, 'h', 'e', 'l', 'l', 'o'}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hello\")"); +} + +TEST_F(VariantSpecTest, HandcraftedSimpleObject) { + // Object {"a": 1} with metadata dictionary ["a"] + // + // Metadata: version=1, sorted=false, offset_size=1 + // header=0x01, dict_size=0x01, offsets=[0x00, 0x01], data="a" + uint8_t metadata_bytes[] = {0x01, 0x01, 0x00, 0x01, 'a'}; + // + // Value: object with 1 field + // header: basic_type=2, field_id_size=1(bits2-3=00), + // offset_size=1(bits4-5=00), num_fields_size=1(bits6-7=00) + // = 0x02 + // num_fields: 0x01 + // field_ids: [0x00] (index into metadata for "a") + // offsets: [0x00, 0x05] (field 0 at offset 0, total size 5) + // field value: Int32(1) = header 0x14 + LE bytes 01 00 00 00 + uint8_t value_bytes[] = { + 0x02, // object header + 0x01, // num_fields = 1 + 0x00, // field_id[0] = 0 + 0x00, 0x05, // offsets: [0, 5] + 0x14, 0x01, 0x00, 0x00, 0x00 // Int32(1) + }; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events.size(), 4); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"a\")"); + ASSERT_EQ(visitor.events[2], "Int32(1)"); + ASSERT_EQ(visitor.events[3], "EndObject"); +} + +TEST_F(VariantSpecTest, HandcraftedTrueAndFalse) { + // True: basic_type=0, primitive_type=1 → header = (1<<2)|0 = 0x04 + // False: basic_type=0, primitive_type=2 → header = (2<<2)|0 = 0x08 + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + + uint8_t true_bytes[] = {0x04}; + uint8_t false_bytes[] = {0x08}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + + RecordingVisitor v1; + ASSERT_OK(DecodeVariantValue(metadata, true_bytes, sizeof(true_bytes), &v1)); + ASSERT_EQ(v1.events[0], "Bool(true)"); + + RecordingVisitor v2; + ASSERT_OK(DecodeVariantValue(metadata, false_bytes, sizeof(false_bytes), &v2)); + ASSERT_EQ(v2.events[0], "Bool(false)"); +} + +TEST_F(VariantSpecTest, HandcraftedDouble) { + // Double: basic_type=0, primitive_type=7 → header = (7<<2)|0 = 0x1C + // Value: 3.14 as IEEE 754 double LE + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[9]; + value_bytes[0] = 0x1C; + double val = 3.14; + std::memcpy(value_bytes + 1, &val, 8); + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +// =========================================================================== +// ValueSize tests +// =========================================================================== + +class VariantValueSizeTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeTest, NullSize) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 1); +} + +TEST_F(VariantValueSizeTest, Int32Size) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0, 0, 0, 0}; + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 5); +} + +TEST_F(VariantValueSizeTest, ShortStringSize) { + auto data = BuildShortString("hello"); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 6); // 1 header + 5 chars +} + +TEST_F(VariantValueSizeTest, ObjectSize) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"key"}; + auto val = BuildShortString("val"); + auto data = BuildObject({0}, {val}); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeTest, ArraySize) { + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + auto data = BuildArray({e1, e2}); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeTest, UUIDSize) { + uint8_t data[17]; + data[0] = PrimitiveHeader(PrimitiveType::kUUID); + std::memset(data + 1, 0, 16); + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 17); +} + +// =========================================================================== +// Random access tests +// =========================================================================== + +class VariantRandomAccessTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // Sorted lexicographically for binary search + metadata_.strings = {"age", "name", "score"}; + } +}; + +TEST_F(VariantRandomAccessTest, FindObjectFieldExists) { + // Object: {age: 30, name: "Alice", score: 95} + // field_ids must be in lex order of keys: age=0, name=1, score=2 + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto name_val = BuildShortString("Alice"); + std::vector score_val = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + auto data = BuildObject({0, 1, 2}, {age_val, name_val, score_val}); + + int64_t offset = -1, size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "name", &offset, &size)); + ASSERT_GT(offset, 0); + ASSERT_EQ(size, 6); // short string "Alice" = 1 + 5 + + // Verify we can decode just that field + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data() + offset, size, &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"Alice\")"); +} + +TEST_F(VariantRandomAccessTest, FindObjectFieldNotFound) { + auto val = BuildShortString("x"); + auto data = BuildObject({0}, {val}); + + int64_t offset = -1, size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "nonexistent", &offset, &size)); + ASSERT_EQ(offset, -1); + ASSERT_EQ(size, 0); +} + +TEST_F(VariantRandomAccessTest, GetArrayElementFirst) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0, e1}); + + int64_t offset = 0, size = 0; + ASSERT_OK( + GetArrayElement(data.data(), static_cast(data.size()), 0, &offset, &size)); + ASSERT_EQ(size, 5); // Int32 = 5 bytes + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data() + offset, size, &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(42)"); +} + +TEST_F(VariantRandomAccessTest, GetArrayElementLast) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0, e1}); + + int64_t offset = 0, size = 0; + ASSERT_OK( + GetArrayElement(data.data(), static_cast(data.size()), 1, &offset, &size)); + ASSERT_EQ(size, 1); // Null = 1 byte +} + +TEST_F(VariantRandomAccessTest, GetArrayElementOutOfRange) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0}); + + int64_t offset = 0, size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + 5, &offset, &size)); +} + +TEST_F(VariantRandomAccessTest, GetObjectFieldAtByIndex) { + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto name_val = BuildShortString("Bob"); + auto data = BuildObject({0, 1}, {age_val, name_val}); + + std::string_view name; + int64_t offset = 0, size = 0; + ASSERT_OK(GetObjectFieldAt(metadata_, data.data(), static_cast(data.size()), 1, + &name, &offset, &size)); + ASSERT_EQ(name, "name"); + ASSERT_EQ(size, 4); // short string "Bob" = 1 + 3 +} + +TEST_F(VariantRandomAccessTest, GetObjectFieldAtOutOfRange) { + auto val = BuildShortString("x"); + auto data = BuildObject({0}, {val}); + + std::string_view name; + int64_t offset = 0, size = 0; + ASSERT_RAISES( + Invalid, GetObjectFieldAt(metadata_, data.data(), static_cast(data.size()), + 99, &name, &offset, &size)); +} + +// =========================================================================== +// FindMetadataKey tests +// =========================================================================== + +class VariantFindMetadataKeyTest : public ::testing::Test {}; + +TEST_F(VariantFindMetadataKeyTest, SortedFound) { + VariantMetadata meta; + meta.is_sorted = true; + meta.strings = {"age", "name", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "name"), 1); + ASSERT_EQ(FindMetadataKey(meta, "age"), 0); + ASSERT_EQ(FindMetadataKey(meta, "score"), 2); +} + +TEST_F(VariantFindMetadataKeyTest, SortedNotFound) { + VariantMetadata meta; + meta.is_sorted = true; + meta.strings = {"age", "name", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "missing"), -1); +} + +TEST_F(VariantFindMetadataKeyTest, UnsortedFound) { + VariantMetadata meta; + meta.is_sorted = false; + meta.strings = {"name", "age", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "age"), 1); +} + +TEST_F(VariantFindMetadataKeyTest, UnsortedNotFound) { + VariantMetadata meta; + meta.is_sorted = false; + meta.strings = {"name", "age"}; + ASSERT_EQ(FindMetadataKey(meta, "missing"), -1); +} + +// =========================================================================== +// ValueSize regression tests (Go bug: array is_large bit position) +// =========================================================================== + +class VariantValueSizeRegressionTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeRegressionTest, LargeArrayIsLargeBit) { + // Build a large array with 300 elements (>255) to trigger is_large=true. + // This verifies the is_large bit is read at bit 2 of type_info (bit 4 of + // full byte), NOT bit 4 of type_info (bit 6 of full byte) which was the + // Go bug (apache/arrow-go#839). + std::vector> elements; + elements.reserve(300); + for (int i = 0; i < 300; ++i) { + elements.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildArray(elements, /*field_offset_size=*/2); + + // Verify the header byte is correctly structured + uint8_t header = data[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kArray); + // is_large should be set at bit 4 of the full header byte + ASSERT_TRUE(((header >> 4) & 0x01) != 0); + + // ValueSize must return the total size of the buffer + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeRegressionTest, SmallArrayIsLargeFalse) { + // Array with 3 elements — is_large=false + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + std::vector e3 = {PrimitiveHeader(PrimitiveType::kFalse)}; + auto data = BuildArray({e1, e2, e3}); + + // Verify is_large is NOT set + uint8_t header = data[0]; + ASSERT_FALSE(((header >> 4) & 0x01) != 0); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeRegressionTest, LargeObjectIsLargeBit) { + // Object with 300 fields to trigger is_large=true (bit 6 of full byte) + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 300; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = + BuildObject(field_ids, values, /*field_id_size=*/2, /*field_offset_size=*/2); + + // Verify is_large is set at bit 6 of the full header byte + uint8_t header = data[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kObject); + ASSERT_TRUE(((header >> 6) & 0x01) != 0); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +// =========================================================================== +// Additional primitive decoding tests +// =========================================================================== + +class VariantPrimitiveExtraTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantPrimitiveExtraTest, DecodeTimeNTZ) { + int64_t micros = 43200000000LL; // 12:00:00 in microseconds + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimeNTZ); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimeNTZ(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeTimestampNanos) { + int64_t nanos = 1654041600000000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampNanos); + std::memcpy(data + 1, &nanos, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampNanos(" + std::to_string(nanos) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeTimestampNanosNTZ) { + int64_t nanos = 1654041600000000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampNanosNTZ); + std::memcpy(data + 1, &nanos, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampNanosNTZ(" + std::to_string(nanos) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeUUID) { + uint8_t data[17]; + data[0] = PrimitiveHeader(PrimitiveType::kUUID); + // Fill UUID with recognizable pattern (big-endian per spec) + for (int i = 0; i < 16; ++i) { + data[1 + i] = static_cast(i + 1); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "UUID"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt8Boundaries) { + // INT8_MIN = -128 + { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x80}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(-128)"); + } + // INT8_MAX = 127 + { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x7F}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(127)"); + } +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt16Boundaries) { + // INT16_MIN = -32768 + { + int16_t val = std::numeric_limits::min(); + uint8_t data[3]; + data[0] = PrimitiveHeader(PrimitiveType::kInt16); + std::memcpy(data + 1, &val, 2); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(-32768)"); + } + // INT16_MAX = 32767 + { + int16_t val = std::numeric_limits::max(); + uint8_t data[3]; + data[0] = PrimitiveHeader(PrimitiveType::kInt16); + std::memcpy(data + 1, &val, 2); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(32767)"); + } +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt64Min) { + int64_t val = std::numeric_limits::min(); + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kInt64); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int64(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeEmptyBinary) { + // Binary with zero length + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + uint32_t len = 0; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "Binary(len=0)"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeEmptyLongString) { + // Long string with zero length + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + uint32_t len = 0; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"\")"); +} + +// =========================================================================== +// Object with non-monotonic offsets (spec-compliant) +// =========================================================================== + +class VariantObjectNonMonotonicTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // Sorted lexicographically + metadata_.strings = {"a", "b", "c"}; + } +}; + +TEST_F(VariantObjectNonMonotonicTest, NonMonotonicObjectOffsets) { + // Per spec: "field IDs and offsets must be listed in the order of the + // corresponding field names, sorted lexicographically" but "the actual + // value entries do not need to be in any particular order" and "the + // field_offset values may not be monotonically increasing." + // + // Construct: {a: 1, b: 2, c: 3} where values are stored as [3, 1, 2] + // in the data area but offsets point to them in key-sorted order. + std::vector val_a = {PrimitiveHeader(PrimitiveType::kInt8), 1}; + std::vector val_b = {PrimitiveHeader(PrimitiveType::kInt8), 2}; + std::vector val_c = {PrimitiveHeader(PrimitiveType::kInt8), 3}; + + // Data area stores: val_c (2 bytes) | val_a (2 bytes) | val_b (2 bytes) + // Offsets: a->2, b->4, c->0, end->6 + uint8_t header = static_cast(BasicType::kObject); // offset_size=1, id_size=1 + std::vector data; + data.push_back(header); + data.push_back(3); // num_fields = 3 + data.push_back(0); // field_id[0] = 0 ("a") + data.push_back(1); // field_id[1] = 1 ("b") + data.push_back(2); // field_id[2] = 2 ("c") + data.push_back(2); // offset[0] = 2 (val_a starts at byte 2) + data.push_back(4); // offset[1] = 4 (val_b starts at byte 4) + data.push_back(0); // offset[2] = 0 (val_c starts at byte 0) + data.push_back(6); // offset[3] = 6 (total data size) + // Data area: val_c, val_a, val_b + data.insert(data.end(), val_c.begin(), val_c.end()); + data.insert(data.end(), val_a.begin(), val_a.end()); + data.insert(data.end(), val_b.begin(), val_b.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + // Field iteration order follows field_ids (sorted by key): a, b, c + ASSERT_EQ(visitor.events.size(), 8); + ASSERT_EQ(visitor.events[0], "StartObject(3)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"a\")"); + ASSERT_EQ(visitor.events[2], "Int8(1)"); + ASSERT_EQ(visitor.events[3], "FieldName(\"b\")"); + ASSERT_EQ(visitor.events[4], "Int8(2)"); + ASSERT_EQ(visitor.events[5], "FieldName(\"c\")"); + ASSERT_EQ(visitor.events[6], "Int8(3)"); + ASSERT_EQ(visitor.events[7], "EndObject"); +} + +TEST_F(VariantObjectNonMonotonicTest, FindFieldWithNonMonotonicOffsets) { + // Same layout as above: values stored out-of-order + uint8_t header = static_cast(BasicType::kObject); + std::vector data; + data.push_back(header); + data.push_back(3); + data.push_back(0); + data.push_back(1); + data.push_back(2); + data.push_back(2); // a -> offset 2 + data.push_back(4); // b -> offset 4 + data.push_back(0); // c -> offset 0 + data.push_back(6); // end = 6 + // Data: [Int8(3), Int8(1), Int8(2)] + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(3); + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(1); + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(2); + + // FindObjectField should find "c" at offset 0 of data area + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "c", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_EQ(field_size, 2); // Int8 = 2 bytes + + // Decode the value at that offset and verify it's 3 (val_c) + RecordingVisitor v; + ASSERT_OK(DecodeVariantValue(metadata_, data.data() + field_offset, field_size, &v)); + ASSERT_EQ(v.events[0], "Int8(3)"); +} + +// =========================================================================== +// ValueSize for variable-length primitives +// =========================================================================== + +class VariantValueSizeVarLenTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeVarLenTest, LongStringSize) { + // Long string "hello" (5 chars): header(1) + length(4) + data(5) = 10 + std::string s = "hello"; + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + auto len = static_cast(s.size()); + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), s.begin(), s.end()); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 10); +} + +TEST_F(VariantValueSizeVarLenTest, BinarySize) { + // Binary with 4 bytes: header(1) + length(4) + data(4) = 9 + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + uint32_t len = 4; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + data.push_back(0x00); + data.push_back(0x01); + data.push_back(0x02); + data.push_back(0x03); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 9); +} + +TEST_F(VariantValueSizeVarLenTest, TruncatedLongString) { + // Only header byte, no length field + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kString)}; + ASSERT_RAISES(Invalid, ValueSize(data, sizeof(data))); +} + +// =========================================================================== +// Unknown/invalid type tests +// =========================================================================== + +class VariantUnknownTypeTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantUnknownTypeTest, UnknownPrimitiveType) { + // Primitive type ID 25 (beyond kUUID=20) should produce an error. + // Header: (25 << 2) | 0 = 0x64 + uint8_t data[] = {0x64}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +TEST_F(VariantUnknownTypeTest, UnknownPrimitiveTypeValueSize) { + // ValueSize on an unknown primitive type should still return a value + // (PrimitiveValueSize returns -1, triggering variable-length path). + // With only 1 byte, variable-length path requires 5 bytes → truncated. + uint8_t data[] = {0x64}; + ASSERT_RAISES(Invalid, ValueSize(data, sizeof(data))); +} + +// =========================================================================== +// Array non-monotonic offset rejection test +// =========================================================================== + +class VariantArrayNonMonotonicTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantArrayNonMonotonicTest, RejectsNonMonotonicOffsets) { + // Manually craft an array with 2 elements where offsets go [0, 3, 1] + // (non-monotonic: 1 < 3). This should be rejected. + // header: basic_type=3, offset_size=1, is_large=false → 0x03 + // num_elements: 2 + // offsets: [0, 3, 1] — non-monotonic + // data: 3 bytes of nulls + uint8_t data[] = { + 0x03, // array header: basic_type=3, offset_size=1, is_large=false + 0x02, // num_elements = 2 + 0x00, 0x03, 0x01, // offsets: [0, 3, 1] — non-monotonic! + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + }; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Object field offset out-of-bounds test +// =========================================================================== + +class VariantObjectOffsetBoundsTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"key"}; + } +}; + +TEST_F(VariantObjectOffsetBoundsTest, FieldOffsetExceedsDataSize) { + // Object with 1 field where field_offset[0] = 99 (beyond total_data_size). + // header: basic_type=2, offset_size=1, id_size=1, is_large=false → 0x02 + // num_fields: 1 + // field_ids: [0] + // offsets: [99, 2] — field 0 at offset 99, total=2 + // data: 2 bytes (Null) + uint8_t data[] = { + 0x02, // object header + 0x01, // num_fields = 1 + 0x00, // field_id[0] = 0 + 0x63, 0x02, // offsets: [99, 2] — 99 > total_data_size(2) + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + }; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Empty metadata with various offset sizes +// =========================================================================== + +class VariantMetadataOffsetSizeTest : public ::testing::Test {}; + +TEST_F(VariantMetadataOffsetSizeTest, EmptyDictionaryOffsetSize4) { + // Valid metadata with 0 strings but offset_size=4. + auto buf = BuildMetadataBuffer({}, false, 4); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.version, 1); + ASSERT_EQ(metadata.offset_size, 4); + ASSERT_EQ(metadata.strings.size(), 0); +} + +// =========================================================================== +// FindObjectField with binary search (large object >= 32 fields) +// =========================================================================== + +class VariantFindFieldBinarySearchTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + // Backing storage for string_views in metadata (must outlive metadata_). + // Do NOT modify key_storage_ after SetUp(); reallocation invalidates + // the string_views stored in metadata_.strings. + std::vector key_storage_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // 40 keys in sorted order to trigger binary search path + key_storage_.reserve(40); + for (int i = 0; i < 40; ++i) { + std::string key = "k" + std::string(i < 10 ? "0" : "") + std::to_string(i); + key_storage_.emplace_back(key); + } + for (const auto& k : key_storage_) { + metadata_.strings.push_back(k); + } + } +}; + +TEST_F(VariantFindFieldBinarySearchTest, FindMiddleField) { + // Build object with 40 fields, all null values + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + // Search for "k20" (middle of the sorted range) + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k20", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_EQ(field_size, 1); // Null = 1 byte +} + +TEST_F(VariantFindFieldBinarySearchTest, FindFirstField) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k00", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); +} + +TEST_F(VariantFindFieldBinarySearchTest, FindLastField) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k39", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); +} + +TEST_F(VariantFindFieldBinarySearchTest, NotFoundInLargeObject) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "zzz", &field_offset, &field_size)); + ASSERT_EQ(field_offset, -1); +} + +// =========================================================================== +// GetArrayElement middle index +// =========================================================================== + +class VariantGetArrayElementExtraTest : public ::testing::Test {}; + +TEST_F(VariantGetArrayElementExtraTest, MiddleElement) { + // Array of [Int32(10), Int32(20), Int32(30)] + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 10, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kInt32), 20, 0, 0, 0}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto data = BuildArray({e0, e1, e2}); + + int64_t elem_offset = 0, elem_size = 0; + ASSERT_OK(GetArrayElement(data.data(), static_cast(data.size()), 1, + &elem_offset, &elem_size)); + ASSERT_EQ(elem_size, 5); // Int32 = 5 bytes + + // Decode the middle element + VariantMetadata meta; + meta.version = 1; + RecordingVisitor v; + ASSERT_OK(DecodeVariantValue(meta, data.data() + elem_offset, elem_size, &v)); + ASSERT_EQ(v.events[0], "Int32(20)"); +} + +TEST_F(VariantGetArrayElementExtraTest, EmptyArrayOutOfRange) { + auto data = BuildArray({}); + int64_t elem_offset = 0, elem_size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + 0, &elem_offset, &elem_size)); +} + +// =========================================================================== +// Additional error case tests (missing coverage) +// =========================================================================== + +class VariantErrorCaseTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantErrorCaseTest, MetadataVersionZero) { + // Version 0 is not supported (only version 1 is valid per spec) + uint8_t data[] = {0x00, 0x00, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetObjectFieldCountOnArray) { + // Calling GetObjectFieldCount on an array value should produce an error + auto data = BuildArray({}); + ASSERT_RAISES(Invalid, + GetObjectFieldCount(data.data(), static_cast(data.size()))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementCountOnObject) { + // Calling GetArrayElementCount on an object value should produce an error + auto data = BuildObject({}, {}); + ASSERT_RAISES(Invalid, + GetArrayElementCount(data.data(), static_cast(data.size()))); +} + +TEST_F(VariantErrorCaseTest, GetObjectFieldCountOnPrimitive) { + // Calling GetObjectFieldCount on a primitive should produce an error + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_RAISES(Invalid, GetObjectFieldCount(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementCountOnPrimitive) { + // Calling GetArrayElementCount on a primitive should produce an error + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_RAISES(Invalid, GetArrayElementCount(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, MetadataStringOffsetExceedsBuffer) { + // Metadata where the last string offset claims more data than the buffer + // contains. This exercises the ValidateOffsets check for offsets.back() > + // data_length. + // Header: version=1, offset_size=1 + // dict_size=1, offsets=[0, 100] — but only 3 bytes of string data + uint8_t data[] = { + 0x01, // header: version=1, offset_size=1 + 0x01, // dict_size = 1 + 0x00, 0x64, // offsets: [0, 100] — 100 exceeds available string data + 'a', 'b', 'c'}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementNegativeIndex) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0}); + int64_t elem_offset = 0, elem_size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + -1, &elem_offset, &elem_size)); +} + +TEST_F(VariantErrorCaseTest, FindObjectFieldOnNonObject) { + // Calling FindObjectField on an array should produce an error + auto data = BuildArray({}); + int64_t field_offset = -1, field_size = 0; + ASSERT_RAISES(Invalid, + FindObjectField(empty_metadata_, data.data(), + static_cast(data.size()), "key", + &field_offset, &field_size)); +} + +// TODO: Add fuzz targets for DecodeMetadata and DecodeVariantValue to exercise +// adversarial/malformed input. Fuzz tests in Arrow are typically registered as +// separate executables under cpp/src/arrow/testing/fuzzing/ — see GH-45948. + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_test_util.h b/cpp/src/arrow/extension/variant_test_util.h new file mode 100644 index 000000000000..9e20947697d7 --- /dev/null +++ b/cpp/src/arrow/extension/variant_test_util.h @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +// This file is for tests only and is not installed as a public header. + +#include +#include + +#include "arrow/extension/variant_internal.h" + +namespace arrow::extension::variant_internal { + +/// \brief A visitor that records all callbacks as a vector of strings +/// for easy assertion in tests. +class RecordingVisitor : public VariantVisitor { + public: + std::vector events; + + Status Null() override { + events.push_back("Null"); + return Status::OK(); + } + Status Bool(bool value) override { + events.push_back(std::string("Bool(") + (value ? "true" : "false") + ")"); + return Status::OK(); + } + Status Int8(int8_t value) override { + events.push_back("Int8(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int16(int16_t value) override { + events.push_back("Int16(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int32(int32_t value) override { + events.push_back("Int32(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int64(int64_t value) override { + events.push_back("Int64(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Float(float value) override { + events.push_back("Float(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Double(double value) override { + events.push_back("Double(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Decimal4(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal4(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Decimal8(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal8(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Decimal16(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal16(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Date(int32_t days) override { + events.push_back("Date(" + std::to_string(days) + ")"); + return Status::OK(); + } + Status TimestampMicros(int64_t micros) override { + events.push_back("TimestampMicros(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status TimestampMicrosNTZ(int64_t micros) override { + events.push_back("TimestampMicrosNTZ(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status String(std::string_view value) override { + events.push_back("String(\"" + std::string(value) + "\")"); + return Status::OK(); + } + Status Binary(std::string_view value) override { + events.push_back("Binary(len=" + std::to_string(value.size()) + ")"); + return Status::OK(); + } + Status TimeNTZ(int64_t micros) override { + events.push_back("TimeNTZ(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status TimestampNanos(int64_t nanos) override { + events.push_back("TimestampNanos(" + std::to_string(nanos) + ")"); + return Status::OK(); + } + Status TimestampNanosNTZ(int64_t nanos) override { + events.push_back("TimestampNanosNTZ(" + std::to_string(nanos) + ")"); + return Status::OK(); + } + Status UUID(const uint8_t* /*bytes*/) override { + events.push_back("UUID"); + return Status::OK(); + } + Status StartObject(int32_t num_fields) override { + events.push_back("StartObject(" + std::to_string(num_fields) + ")"); + return Status::OK(); + } + Status FieldName(std::string_view name) override { + events.push_back("FieldName(\"" + std::string(name) + "\")"); + return Status::OK(); + } + Status EndObject() override { + events.push_back("EndObject"); + return Status::OK(); + } + Status StartArray(int32_t num_elements) override { + events.push_back("StartArray(" + std::to_string(num_elements) + ")"); + return Status::OK(); + } + Status EndArray() override { + events.push_back("EndArray"); + return Status::OK(); + } +}; + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index 4b8faebecfd7..dc16985255c2 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -142,6 +142,7 @@ arrow_components = { 'extension/bool8.cc', 'extension/json.cc', 'extension/parquet_variant.cc', + 'extension/variant_internal.cc', 'extension/uuid.cc', 'pretty_print.cc', 'record_batch.cc',