diff --git a/Cargo.toml b/Cargo.toml index 7e7cae206a3f..b9dee624723f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ members = [ "arrow-schema", "arrow-select", "arrow-string", + "arrow-variant", "parquet", "parquet_derive", "parquet_derive_test", diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml index a65c0c9ca8e6..2f3d7db7d19e 100644 --- a/arrow-array/Cargo.toml +++ b/arrow-array/Cargo.toml @@ -54,6 +54,7 @@ all-features = true [features] ffi = ["arrow-schema/ffi", "arrow-data/ffi"] force_validate = [] +canonical_extension_types = ["arrow-schema/canonical_extension_types"] [dev-dependencies] rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index e41a3a1d719a..e64a2826a08f 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -1271,4 +1271,4 @@ mod tests { let expected: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); assert_eq!(array, expected); } -} +} \ No newline at end of file diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index 314c8f7a3515..04aa50d623ec 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -40,9 +40,9 @@ serde = { version = "1.0", default-features = false, features = [ ], optional = true } bitflags = { version = "2.0.0", default-features = false, optional = true } serde_json = { version = "1.0", optional = true } - +base64 = { version = "0.21", optional = true } [features] -canonical_extension_types = ["dep:serde", "dep:serde_json"] +canonical_extension_types = ["dep:serde", "dep:serde_json", "dep:base64"] # Enable ffi support ffi = ["bitflags"] serde = ["dep:serde"] diff --git a/arrow-schema/src/extension/canonical/mod.rs b/arrow-schema/src/extension/canonical/mod.rs index 3d66299ca885..8a79501f218f 100644 --- a/arrow-schema/src/extension/canonical/mod.rs +++ b/arrow-schema/src/extension/canonical/mod.rs @@ -37,6 +37,8 @@ mod uuid; pub use uuid::Uuid; mod variable_shape_tensor; pub use variable_shape_tensor::{VariableShapeTensor, VariableShapeTensorMetadata}; +mod variant; +pub use variant::Variant; use crate::{ArrowError, Field}; @@ -77,6 +79,9 @@ pub enum CanonicalExtensionType { /// /// Bool8(Bool8), + + /// The extension type for `Variant`. + Variant(Variant), } impl TryFrom<&Field> for CanonicalExtensionType { @@ -93,6 +98,7 @@ impl TryFrom<&Field> for CanonicalExtensionType { Uuid::NAME => value.try_extension_type::().map(Into::into), Opaque::NAME => value.try_extension_type::().map(Into::into), Bool8::NAME => value.try_extension_type::().map(Into::into), + Variant::NAME => value.try_extension_type::().map(Into::into), _ => Err(ArrowError::InvalidArgumentError(format!("Unsupported canonical extension type: {name}"))), }, // Name missing the expected prefix @@ -140,3 +146,9 @@ impl From for CanonicalExtensionType { CanonicalExtensionType::Bool8(value) } } + +impl From for CanonicalExtensionType { + fn from(value: Variant) -> Self { + CanonicalExtensionType::Variant(value) + } +} diff --git a/arrow-schema/src/extension/canonical/variant.rs b/arrow-schema/src/extension/canonical/variant.rs new file mode 100644 index 000000000000..caf6c96519fd --- /dev/null +++ b/arrow-schema/src/extension/canonical/variant.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Variant +//! +//! + +use base64::engine::Engine as _; +use base64::engine::general_purpose::STANDARD; +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `Variant`. +/// +/// Extension name: `arrow.variant`. +/// +/// The storage type of this extension is **Struct containing two binary fields**: +/// - metadata: Binary field containing the variant metadata +/// - value: Binary field containing the serialized variant data +/// +/// A Variant is a flexible structure that can store **Primitives, Arrays, or Objects**. +/// +/// Both metadata and value fields are required. +/// +/// +#[derive(Debug, Clone, PartialEq)] +pub struct Variant { + metadata: Vec, // Required binary metadata + value: Vec, // Required binary value +} + +impl Variant { + /// Creates a new `Variant` with metadata and value. + pub fn new(metadata: Vec, value: Vec) -> Self { + Self { metadata, value } + } + + /// Creates a Variant representing an empty structure. + pub fn empty() -> Result { + Err(ArrowError::InvalidArgumentError( + "Variant cannot be empty because metadata and value are required".to_owned(), + )) + } + + /// Returns the metadata as a byte array. + pub fn metadata(&self) -> &[u8] { + &self.metadata + } + + /// Returns the value as an byte array. + pub fn value(&self) -> &[u8] { + &self.value + } + + /// Sets the value of the Variant. + pub fn set_value(mut self, value: Vec) -> Self { + self.value = value; + self + } +} + +impl ExtensionType for Variant { + const NAME: &'static str = "arrow.variant"; + + type Metadata = Vec; + + fn metadata(&self) -> &Self::Metadata { + &self.metadata + } + + fn serialize_metadata(&self) -> Option { + Some(STANDARD.encode(&self.metadata)) // Encode metadata as STANDARD string + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + match metadata { + Some(meta) => STANDARD.decode(meta) + .map_err(|_| ArrowError::InvalidArgumentError("Invalid Variant metadata".to_owned())), + None => Ok(Vec::new()), // Default to empty metadata if None + } + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + match data_type { + DataType::Struct(fields) => { + if fields.len() != 2 { + return Err(ArrowError::InvalidArgumentError( + "Variant struct must have exactly two fields".to_owned(), + )); + } + + let metadata_field = fields.iter() + .find(|f| f.name() == "metadata") + .ok_or_else(|| ArrowError::InvalidArgumentError( + "Variant struct must have a field named 'metadata'".to_owned(), + ))?; + + let value_field = fields.iter() + .find(|f| f.name() == "value") + .ok_or_else(|| ArrowError::InvalidArgumentError( + "Variant struct must have a field named 'value'".to_owned(), + ))?; + + match (metadata_field.data_type(), value_field.data_type()) { + (DataType::Binary, DataType::Binary) | + (DataType::LargeBinary, DataType::LargeBinary) => Ok(()), + _ => Err(ArrowError::InvalidArgumentError( + "Variant struct fields must both be Binary or LargeBinary".to_owned(), + )), + } + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Variant data type mismatch, expected Struct, found {data_type}" + ))), + } + } + + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { + let variant = Self { metadata, value: vec![0] }; + variant.supports_data_type(data_type)?; + Ok(variant) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical_extension_types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_NAME_KEY}, + Field, DataType, + }; + + use super::*; + + #[test] + fn variant_metadata_encoding_decoding() { + let metadata = b"variant_metadata".to_vec(); + let encoded = STANDARD.encode(&metadata); + let decoded = Variant::deserialize_metadata(Some(&encoded)).unwrap(); + assert_eq!(metadata, decoded); + } + + #[test] + fn variant_metadata_invalid_decoding() { + let result = Variant::deserialize_metadata(Some("invalid_base64")); + assert!(result.is_err()); + } + + #[test] + fn variant_metadata_none_decoding() { + let decoded = Variant::deserialize_metadata(None).unwrap(); + assert!(decoded.is_empty()); + } + + #[test] + fn variant_supports_valid_data_types() { + // Test with actual binary data + let metadata = vec![0x01, 0x02, 0x03, 0x04, 0x05]; + let value = vec![0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F]; + let variant = Variant::new(metadata.clone(), value.clone()); + + // Test with Binary fields + let struct_type = DataType::Struct(vec![ + Field::new("metadata", DataType::Binary, false), + Field::new("value", DataType::Binary, false) + ].into()); + assert!(variant.supports_data_type(&struct_type).is_ok()); + + // Test with LargeBinary fields + let struct_type = DataType::Struct(vec![ + Field::new("metadata", DataType::LargeBinary, false), + Field::new("value", DataType::LargeBinary, false) + ].into()); + assert!(variant.supports_data_type(&struct_type).is_ok()); + + // Test with invalid type + let result = Variant::try_new(&DataType::Utf8, metadata); + assert!(result.is_err()); + if let Err(ArrowError::InvalidArgumentError(msg)) = result { + assert!(msg.contains("Variant data type mismatch")); + } + } + + #[test] + fn variant_creation_and_access() { + // Test with actual binary data + let metadata = vec![0x01, 0x02, 0x03, 0x04, 0x05]; + let value = vec![0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F]; + let variant = Variant::new(metadata.clone(), value.clone()); + assert_eq!(variant.metadata(), &metadata); + assert_eq!(variant.value(), &value); + } + + #[test] + fn variant_field_extension() { + let struct_type = DataType::Struct(vec![ + Field::new("metadata", DataType::Binary, false), + Field::new("value", DataType::Binary, false) + ].into()); + + // Test with actual binary data + let metadata = vec![0x01, 0x02, 0x03, 0x04, 0x05]; + let value = vec![0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F]; + let variant = Variant::new(metadata.clone(), value.clone()); + + let mut field = Field::new("", struct_type, false); + field.try_with_extension_type(variant.clone()).unwrap(); + + assert_eq!( + field.metadata().get(EXTENSION_TYPE_NAME_KEY), + Some(&"arrow.variant".to_owned()) + ); + + #[cfg(feature = "canonical_extension_types")] + { + let recovered = field.try_canonical_extension_type().unwrap(); + if let CanonicalExtensionType::Variant(recovered_variant) = recovered { + assert_eq!(recovered_variant.metadata(), variant.metadata()); + } else { + panic!("Expected Variant type"); + } + } + } +} diff --git a/arrow-variant/Cargo.toml b/arrow-variant/Cargo.toml new file mode 100644 index 000000000000..cc34d9da904a --- /dev/null +++ b/arrow-variant/Cargo.toml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-variant" +version = { workspace = true } +description = "JSON to Arrow Variant conversion utilities" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = ["arrow"] +include = [ + "src/**/*.rs", + "Cargo.toml", +] +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_variant" +path = "src/lib.rs" + +[features] +default = [] + +[dependencies] +arrow-array = { workspace = true, features = ["canonical_extension_types"] } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true, optional = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } +serde = { version = "1.0", default-features = false } +serde_json = { version = "1.0", default-features = false, features = ["std"] } + +[dev-dependencies] +arrow-cast = { workspace = true } \ No newline at end of file diff --git a/arrow-variant/src/decoder/mod.rs b/arrow-variant/src/decoder/mod.rs new file mode 100644 index 000000000000..d51288a10809 --- /dev/null +++ b/arrow-variant/src/decoder/mod.rs @@ -0,0 +1,981 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder module for converting Variant binary format to JSON values +#[allow(unused_imports)] +use serde_json::{json, Value, Map}; +use std::str; +use crate::error::Error; +use crate::encoder::{VariantBasicType, VariantPrimitiveType}; +#[allow(unused_imports)] +use std::collections::HashMap; + + +/// Decodes a Variant binary value to a JSON value +pub fn decode_value(value: &[u8], keys: &[String]) -> Result { + println!("Decoding value of length: {}", value.len()); + let mut pos = 0; + let result = decode_value_internal(value, &mut pos, keys)?; + println!("Decoded value: {:?}", result); + Ok(result) +} + +/// Extracts the basic type from a header byte +fn get_basic_type(header: u8) -> VariantBasicType { + match header & 0x03 { + 0 => VariantBasicType::Primitive, + 1 => VariantBasicType::ShortString, + 2 => VariantBasicType::Object, + 3 => VariantBasicType::Array, + _ => unreachable!(), + } +} + +/// Extracts the primitive type from a header byte +fn get_primitive_type(header: u8) -> VariantPrimitiveType { + match (header >> 2) & 0x3F { + 0 => VariantPrimitiveType::Null, + 1 => VariantPrimitiveType::BooleanTrue, + 2 => VariantPrimitiveType::BooleanFalse, + 3 => VariantPrimitiveType::Int8, + 4 => VariantPrimitiveType::Int16, + 5 => VariantPrimitiveType::Int32, + 6 => VariantPrimitiveType::Int64, + 7 => VariantPrimitiveType::Double, + 8 => VariantPrimitiveType::Decimal4, + 9 => VariantPrimitiveType::Decimal8, + 10 => VariantPrimitiveType::Decimal16, + 11 => VariantPrimitiveType::Date, + 12 => VariantPrimitiveType::Timestamp, + 13 => VariantPrimitiveType::TimestampNTZ, + 14 => VariantPrimitiveType::Float, + 15 => VariantPrimitiveType::Binary, + 16 => VariantPrimitiveType::String, + 17 => VariantPrimitiveType::TimeNTZ, + 18 => VariantPrimitiveType::TimestampNanos, + 19 => VariantPrimitiveType::TimestampNTZNanos, + 20 => VariantPrimitiveType::Uuid, + _ => unreachable!(), + } +} + +/// Extracts object header information +fn get_object_header_info(header: u8) -> (bool, u8, u8) { + let header = (header >> 2) & 0x3F; // Get header bits + let is_large = (header >> 4) & 0x01 != 0; // is_large from bit 4 + let id_size = ((header >> 2) & 0x03) + 1; // field_id_size from bits 2-3 + let offset_size = (header & 0x03) + 1; // offset_size from bits 0-1 + (is_large, id_size, offset_size) +} + +/// Extracts array header information +fn get_array_header_info(header: u8) -> (bool, u8) { + let header = (header >> 2) & 0x3F; // Get header bits + let is_large = (header >> 2) & 0x01 != 0; // is_large from bit 2 + let offset_size = (header & 0x03) + 1; // offset_size from bits 0-1 + (is_large, offset_size) +} + +/// Reads an unsigned integer of the specified size +fn read_unsigned(data: &[u8], pos: &mut usize, size: u8) -> Result { + if *pos + (size as usize - 1) >= data.len() { + return Err(Error::VariantRead(format!("Unexpected end of data for {} byte unsigned integer", size))); + } + + let mut value = 0usize; + for i in 0..size { + value |= (data[*pos + i as usize] as usize) << (8 * i); + } + *pos += size as usize; + + Ok(value) +} + +/// Internal recursive function to decode a value at the current position +fn decode_value_internal(data: &[u8], pos: &mut usize, keys: &[String]) -> Result { + if *pos >= data.len() { + return Err(Error::VariantRead("Unexpected end of data".to_string())); + } + + let header = data[*pos]; + println!("Decoding at position {}: header byte = 0x{:02X}", *pos, header); + *pos += 1; + + match get_basic_type(header) { + VariantBasicType::Primitive => { + match get_primitive_type(header) { + VariantPrimitiveType::Null => Ok(Value::Null), + VariantPrimitiveType::BooleanTrue => Ok(Value::Bool(true)), + VariantPrimitiveType::BooleanFalse => Ok(Value::Bool(false)), + VariantPrimitiveType::Int8 => decode_int8(data, pos), + VariantPrimitiveType::Int16 => decode_int16(data, pos), + VariantPrimitiveType::Int32 => decode_int32(data, pos), + VariantPrimitiveType::Int64 => decode_int64(data, pos), + VariantPrimitiveType::Double => decode_double(data, pos), + VariantPrimitiveType::Decimal4 => decode_decimal4(data, pos), + VariantPrimitiveType::Decimal8 => decode_decimal8(data, pos), + VariantPrimitiveType::Decimal16 => decode_decimal16(data, pos), + VariantPrimitiveType::Date => decode_date(data, pos), + VariantPrimitiveType::Timestamp => decode_timestamp(data, pos), + VariantPrimitiveType::TimestampNTZ => decode_timestamp_ntz(data, pos), + VariantPrimitiveType::Float => decode_float(data, pos), + VariantPrimitiveType::Binary => decode_binary(data, pos), + VariantPrimitiveType::String => decode_long_string(data, pos), + VariantPrimitiveType::TimeNTZ => decode_time_ntz(data, pos), + VariantPrimitiveType::TimestampNanos => decode_timestamp_nanos(data, pos), + VariantPrimitiveType::TimestampNTZNanos => decode_timestamp_ntz_nanos(data, pos), + VariantPrimitiveType::Uuid => decode_uuid(data, pos), + } + }, + VariantBasicType::ShortString => { + let len = (header >> 2) & 0x3F; + println!("Short string with length: {}", len); + if *pos + len as usize > data.len() { + return Err(Error::VariantRead("Unexpected end of data for short string".to_string())); + } + + let string_bytes = &data[*pos..*pos + len as usize]; + *pos += len as usize; + + let string = str::from_utf8(string_bytes) + .map_err(|e| Error::InvalidMetadata(format!("Invalid UTF-8 string: {}", e)))?; + + Ok(Value::String(string.to_string())) + }, + VariantBasicType::Object => { + let (is_large, id_size, offset_size) = get_object_header_info(header); + println!("Object header: is_large={}, id_size={}, offset_size={}", is_large, id_size, offset_size); + + // Read number of elements + let num_elements = if is_large { + read_unsigned(data, pos, 4)? + } else { + read_unsigned(data, pos, 1)? + }; + println!("Object has {} elements", num_elements); + + // Read field IDs + let mut field_ids = Vec::with_capacity(num_elements); + for _ in 0..num_elements { + field_ids.push(read_unsigned(data, pos, id_size)?); + } + println!("Field IDs: {:?}", field_ids); + + // Read offsets + let mut offsets = Vec::with_capacity(num_elements + 1); + for _ in 0..=num_elements { + offsets.push(read_unsigned(data, pos, offset_size)?); + } + println!("Offsets: {:?}", offsets); + + // Create object and save position after offsets + let mut obj = Map::new(); + let base_pos = *pos; + + // Process each field + for i in 0..num_elements { + let field_id = field_ids[i]; + if field_id >= keys.len() { + return Err(Error::VariantRead(format!("Field ID out of range: {}", field_id))); + } + + let field_name = &keys[field_id]; + let start_offset = offsets[i]; + let end_offset = offsets[i + 1]; + + println!("Field {}: {} (ID: {}), range: {}..{}", i, field_name, field_id, base_pos + start_offset, base_pos + end_offset); + + if base_pos + end_offset > data.len() { + return Err(Error::VariantRead("Unexpected end of data for object field".to_string())); + } + + // Create a slice just for this field and decode it + let field_data = &data[base_pos + start_offset..base_pos + end_offset]; + let mut field_pos = 0; + let value = decode_value_internal(field_data, &mut field_pos, keys)?; + + obj.insert(field_name.clone(), value); + } + + // Update position to end of object data + *pos = base_pos + offsets[num_elements]; + Ok(Value::Object(obj)) + }, + VariantBasicType::Array => { + let (is_large, offset_size) = get_array_header_info(header); + println!("Array header: is_large={}, offset_size={}", is_large, offset_size); + + // Read number of elements + let num_elements = if is_large { + read_unsigned(data, pos, 4)? + } else { + read_unsigned(data, pos, 1)? + }; + println!("Array has {} elements", num_elements); + + // Read offsets + let mut offsets = Vec::with_capacity(num_elements + 1); + for _ in 0..=num_elements { + offsets.push(read_unsigned(data, pos, offset_size)?); + } + println!("Offsets: {:?}", offsets); + + // Create array and save position after offsets + let mut array = Vec::with_capacity(num_elements); + let base_pos = *pos; + + // Process each element + for i in 0..num_elements { + let start_offset = offsets[i]; + let end_offset = offsets[i + 1]; + + println!("Element {}: range: {}..{}", i, base_pos + start_offset, base_pos + end_offset); + + if base_pos + end_offset > data.len() { + return Err(Error::VariantRead("Unexpected end of data for array element".to_string())); + } + + // Create a slice just for this element and decode it + let elem_data = &data[base_pos + start_offset..base_pos + end_offset]; + let mut elem_pos = 0; + let value = decode_value_internal(elem_data, &mut elem_pos, keys)?; + + array.push(value); + } + + // Update position to end of array data + *pos = base_pos + offsets[num_elements]; + Ok(Value::Array(array)) + }, + } +} + +/// Decodes a null value +#[allow(dead_code)] +fn decode_null() -> Result { + Ok(Value::Null) +} + +/// Decodes a primitive value +#[allow(dead_code)] +fn decode_primitive(data: &[u8], pos: &mut usize) -> Result { + if *pos >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for primitive".to_string())); + } + + // Read the primitive type header + let header = data[*pos]; + *pos += 1; + + // Extract primitive type ID + let type_id = header & 0x1F; + + // Decode based on primitive type + match type_id { + 0 => decode_null(), + 1 => Ok(Value::Bool(true)), + 2 => Ok(Value::Bool(false)), + 3 => decode_int8(data, pos), + 4 => decode_int16(data, pos), + 5 => decode_int32(data, pos), + 6 => decode_int64(data, pos), + 7 => decode_double(data, pos), + 8 => decode_decimal4(data, pos), + 9 => decode_decimal8(data, pos), + 10 => decode_decimal16(data, pos), + 11 => decode_date(data, pos), + 12 => decode_timestamp(data, pos), + 13 => decode_timestamp_ntz(data, pos), + 14 => decode_float(data, pos), + 15 => decode_binary(data, pos), + 16 => decode_long_string(data, pos), + 17 => decode_time_ntz(data, pos), + 18 => decode_timestamp_nanos(data, pos), + 19 => decode_timestamp_ntz_nanos(data, pos), + 20 => decode_uuid(data, pos), + _ => Err(Error::InvalidMetadata(format!("Unknown primitive type ID: {}", type_id))) + } +} + +/// Decodes a short string value +#[allow(dead_code)] +fn decode_short_string(data: &[u8], pos: &mut usize) -> Result { + if *pos >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for short string length".to_string())); + } + + // Read the string length (1 byte) + let len = data[*pos] as usize; + *pos += 1; + + // Read the string bytes + if *pos + len > data.len() { + return Err(Error::VariantRead("Unexpected end of data for short string content".to_string())); + } + + let string_bytes = &data[*pos..*pos + len]; + *pos += len; + + // Convert to UTF-8 string + let string = str::from_utf8(string_bytes) + .map_err(|e| Error::InvalidMetadata(format!("Invalid UTF-8 string: {}", e)))?; + + Ok(Value::String(string.to_string())) +} + +/// Decodes an int8 value +fn decode_int8(data: &[u8], pos: &mut usize) -> Result { + if *pos >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for int8".to_string())); + } + + let value = data[*pos] as i8 as i64; + *pos += 1; + + Ok(Value::Number(serde_json::Number::from(value))) +} + +/// Decodes an int16 value +fn decode_int16(data: &[u8], pos: &mut usize) -> Result { + if *pos + 1 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for int16".to_string())); + } + + let mut buf = [0u8; 2]; + buf.copy_from_slice(&data[*pos..*pos+2]); + *pos += 2; + + let value = i16::from_le_bytes(buf) as i64; + Ok(Value::Number(serde_json::Number::from(value))) +} + +/// Decodes an int32 value +fn decode_int32(data: &[u8], pos: &mut usize) -> Result { + if *pos + 3 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for int32".to_string())); + } + + let mut buf = [0u8; 4]; + buf.copy_from_slice(&data[*pos..*pos+4]); + *pos += 4; + + let value = i32::from_le_bytes(buf) as i64; + Ok(Value::Number(serde_json::Number::from(value))) +} + +/// Decodes an int64 value +fn decode_int64(data: &[u8], pos: &mut usize) -> Result { + if *pos + 7 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for int64".to_string())); + } + + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[*pos..*pos+8]); + *pos += 8; + + let value = i64::from_le_bytes(buf); + Ok(Value::Number(serde_json::Number::from(value))) +} + +/// Decodes a double value +fn decode_double(data: &[u8], pos: &mut usize) -> Result { + if *pos + 7 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for double".to_string())); + } + + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[*pos..*pos+8]); + *pos += 8; + + let value = f64::from_le_bytes(buf); + + // Create a Number from the float + let number = serde_json::Number::from_f64(value) + .ok_or_else(|| Error::InvalidMetadata(format!("Invalid float value: {}", value)))?; + + Ok(Value::Number(number)) +} + +/// Decodes a decimal4 value +fn decode_decimal4(data: &[u8], pos: &mut usize) -> Result { + if *pos + 4 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for decimal4".to_string())); + } + + // Read scale (1 byte) + let scale = data[*pos] as i32; + *pos += 1; + + // Read unscaled value (3 bytes) + let mut buf = [0u8; 4]; + buf[0] = data[*pos]; + buf[1] = data[*pos + 1]; + buf[2] = data[*pos + 2]; + buf[3] = 0; // Sign extend + *pos += 3; + + let unscaled = i32::from_le_bytes(buf); + + // Convert to decimal string + let decimal = format!("{}.{}", unscaled, scale); + + Ok(Value::String(decimal)) +} + +/// Decodes a decimal8 value +fn decode_decimal8(data: &[u8], pos: &mut usize) -> Result { + if *pos + 8 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for decimal8".to_string())); + } + + // Read scale (1 byte) + let scale = data[*pos] as i32; + *pos += 1; + + // Read unscaled value (7 bytes) + let mut buf = [0u8; 8]; + buf[0..7].copy_from_slice(&data[*pos..*pos+7]); + buf[7] = 0; // Sign extend + *pos += 7; + + let unscaled = i64::from_le_bytes(buf); + + // Convert to decimal string + let decimal = format!("{}.{}", unscaled, scale); + + Ok(Value::String(decimal)) +} + +/// Decodes a decimal16 value +fn decode_decimal16(data: &[u8], pos: &mut usize) -> Result { + if *pos + 16 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for decimal16".to_string())); + } + + // Read scale (1 byte) + let scale = data[*pos] as i32; + *pos += 1; + + // Read unscaled value (15 bytes) + let mut buf = [0u8; 16]; + buf[0..15].copy_from_slice(&data[*pos..*pos+15]); + buf[15] = 0; // Sign extend + *pos += 15; + + // Convert to decimal string (simplified for now) + let decimal = format!("decimal16.{}", scale); + + Ok(Value::String(decimal)) +} + +/// Decodes a date value +fn decode_date(data: &[u8], pos: &mut usize) -> Result { + if *pos + 3 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for date".to_string())); + } + + let mut buf = [0u8; 4]; + buf.copy_from_slice(&data[*pos..*pos+4]); + *pos += 4; + + let days = i32::from_le_bytes(buf); + + // Convert to ISO date string (simplified) + let date = format!("date-{}", days); + + Ok(Value::String(date)) +} + +/// Decodes a timestamp value +fn decode_timestamp(data: &[u8], pos: &mut usize) -> Result { + if *pos + 7 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for timestamp".to_string())); + } + + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[*pos..*pos+8]); + *pos += 8; + + let micros = i64::from_le_bytes(buf); + + // Convert to ISO timestamp string (simplified) + let timestamp = format!("timestamp-{}", micros); + + Ok(Value::String(timestamp)) +} + +/// Decodes a timestamp without timezone value +fn decode_timestamp_ntz(data: &[u8], pos: &mut usize) -> Result { + if *pos + 7 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for timestamp_ntz".to_string())); + } + + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[*pos..*pos+8]); + *pos += 8; + + let micros = i64::from_le_bytes(buf); + + // Convert to ISO timestamp string (simplified) + let timestamp = format!("timestamp_ntz-{}", micros); + + Ok(Value::String(timestamp)) +} + +/// Decodes a float value +fn decode_float(data: &[u8], pos: &mut usize) -> Result { + if *pos + 3 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for float".to_string())); + } + + let mut buf = [0u8; 4]; + buf.copy_from_slice(&data[*pos..*pos+4]); + *pos += 4; + + let value = f32::from_le_bytes(buf); + + // Create a Number from the float + let number = serde_json::Number::from_f64(value as f64) + .ok_or_else(|| Error::InvalidMetadata(format!("Invalid float value: {}", value)))?; + + Ok(Value::Number(number)) +} + +/// Decodes a binary value +fn decode_binary(data: &[u8], pos: &mut usize) -> Result { + if *pos + 3 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for binary length".to_string())); + } + + // Read the binary length (4 bytes) + let mut buf = [0u8; 4]; + buf.copy_from_slice(&data[*pos..*pos+4]); + *pos += 4; + + let len = u32::from_le_bytes(buf) as usize; + + // Read the binary bytes + if *pos + len > data.len() { + return Err(Error::VariantRead("Unexpected end of data for binary content".to_string())); + } + + let binary_bytes = &data[*pos..*pos + len]; + *pos += len; + + // Convert to hex string instead of base64 + let hex = binary_bytes.iter() + .map(|b| format!("{:02x}", b)) + .collect::>() + .join(""); + + Ok(Value::String(format!("binary:{}", hex))) +} + +/// Decodes a string value +fn decode_long_string(data: &[u8], pos: &mut usize) -> Result { + if *pos + 3 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for string length".to_string())); + } + + // Read the string length (4 bytes) + let mut buf = [0u8; 4]; + buf.copy_from_slice(&data[*pos..*pos+4]); + *pos += 4; + + let len = u32::from_le_bytes(buf) as usize; + + // Read the string bytes + if *pos + len > data.len() { + return Err(Error::VariantRead("Unexpected end of data for string content".to_string())); + } + + let string_bytes = &data[*pos..*pos + len]; + *pos += len; + + // Convert to UTF-8 string + let string = str::from_utf8(string_bytes) + .map_err(|e| Error::InvalidMetadata(format!("Invalid UTF-8 string: {}", e)))?; + + Ok(Value::String(string.to_string())) +} + +/// Decodes a time without timezone value +fn decode_time_ntz(data: &[u8], pos: &mut usize) -> Result { + if *pos + 7 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for time_ntz".to_string())); + } + + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[*pos..*pos+8]); + *pos += 8; + + let micros = i64::from_le_bytes(buf); + + // Convert to ISO time string (simplified) + let time = format!("time_ntz-{}", micros); + + Ok(Value::String(time)) +} + +/// Decodes a timestamp with timezone (nanos) value +fn decode_timestamp_nanos(data: &[u8], pos: &mut usize) -> Result { + if *pos + 7 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for timestamp_nanos".to_string())); + } + + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[*pos..*pos+8]); + *pos += 8; + + let nanos = i64::from_le_bytes(buf); + + // Convert to ISO timestamp string (simplified) + let timestamp = format!("timestamp_nanos-{}", nanos); + + Ok(Value::String(timestamp)) +} + +/// Decodes a timestamp without timezone (nanos) value +fn decode_timestamp_ntz_nanos(data: &[u8], pos: &mut usize) -> Result { + if *pos + 7 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for timestamp_ntz_nanos".to_string())); + } + + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[*pos..*pos+8]); + *pos += 8; + + let nanos = i64::from_le_bytes(buf); + + // Convert to ISO timestamp string (simplified) + let timestamp = format!("timestamp_ntz_nanos-{}", nanos); + + Ok(Value::String(timestamp)) +} + +/// Decodes a UUID value +fn decode_uuid(data: &[u8], pos: &mut usize) -> Result { + if *pos + 15 >= data.len() { + return Err(Error::VariantRead("Unexpected end of data for uuid".to_string())); + } + + let mut buf = [0u8; 16]; + buf.copy_from_slice(&data[*pos..*pos+16]); + *pos += 16; + + // Convert to UUID string (simplified) + let uuid = format!("uuid-{:?}", buf); + + Ok(Value::String(uuid)) +} + +/// Decodes a Variant binary to a JSON value using the given metadata +pub fn decode_json(binary: &[u8], metadata: &[u8]) -> Result { + let keys = parse_metadata_keys(metadata)?; + decode_value(binary, &keys) +} + +/// Parses metadata to extract the key list +fn parse_metadata_keys(metadata: &[u8]) -> Result, Error> { + if metadata.is_empty() { + return Err(Error::InvalidMetadata("Empty metadata".to_string())); + } + + // Parse header + let header = metadata[0]; + let version = header & 0x0F; + let _sorted = (header >> 4) & 0x01 != 0; + let offset_size_minus_one = (header >> 6) & 0x03; + let offset_size = (offset_size_minus_one + 1) as usize; + + if version != 1 { + return Err(Error::InvalidMetadata(format!("Unsupported version: {}", version))); + } + + if metadata.len() < 1 + offset_size { + return Err(Error::InvalidMetadata("Metadata too short for dictionary size".to_string())); + } + + // Parse dictionary_size + let mut dictionary_size = 0u32; + for i in 0..offset_size { + dictionary_size |= (metadata[1 + i] as u32) << (8 * i); + } + + // Parse offsets + let offset_start = 1 + offset_size; + let offset_end = offset_start + (dictionary_size as usize + 1) * offset_size; + + if metadata.len() < offset_end { + return Err(Error::InvalidMetadata("Metadata too short for offsets".to_string())); + } + + let mut offsets = Vec::with_capacity(dictionary_size as usize + 1); + for i in 0..=dictionary_size { + let offset_pos = offset_start + (i as usize * offset_size); + let mut offset = 0u32; + for j in 0..offset_size { + offset |= (metadata[offset_pos + j] as u32) << (8 * j); + } + offsets.push(offset as usize); + } + + // Parse dictionary strings + let mut keys = Vec::with_capacity(dictionary_size as usize); + for i in 0..dictionary_size as usize { + let start = offset_end + offsets[i]; + let end = offset_end + offsets[i + 1]; + + if end > metadata.len() { + return Err(Error::InvalidMetadata("Invalid string offset".to_string())); + } + + let key = str::from_utf8(&metadata[start..end]) + .map_err(|e| Error::InvalidMetadata(format!("Invalid UTF-8: {}", e)))? + .to_string(); + + keys.push(key); + } + + Ok(keys) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::create_metadata; + use crate::encoder::encode_json; + + fn encode_and_decode(value: Value) -> Result { + // Create metadata for this value + let metadata = create_metadata(&value, false)?; + + // Parse metadata to get key mapping + let keys = parse_metadata_keys(&metadata)?; + let key_mapping: HashMap = keys.iter() + .enumerate() + .map(|(i, k)| (k.clone(), i)) + .collect(); + + // Encode to binary + let binary = encode_json(&value, &key_mapping)?; + + // Decode back to value + decode_value(&binary, &keys) + } + + #[test] + fn test_decode_primitives() -> Result<(), Error> { + // Test null + let null_value = Value::Null; + let decoded = encode_and_decode(null_value.clone())?; + assert_eq!(decoded, null_value); + + // Test boolean + let true_value = Value::Bool(true); + let decoded = encode_and_decode(true_value.clone())?; + assert_eq!(decoded, true_value); + + let false_value = Value::Bool(false); + let decoded = encode_and_decode(false_value.clone())?; + assert_eq!(decoded, false_value); + + // Test integer + let int_value = json!(42); + let decoded = encode_and_decode(int_value.clone())?; + assert_eq!(decoded, int_value); + + // Test float + let float_value = json!(3.14159); + let decoded = encode_and_decode(float_value.clone())?; + assert_eq!(decoded, float_value); + + // Test string + let string_value = json!("Hello, World!"); + let decoded = encode_and_decode(string_value.clone())?; + assert_eq!(decoded, string_value); + + Ok(()) + } + + #[test] + fn test_decode_array() -> Result<(), Error> { + let array_value = json!([1, 2, 3, 4, 5]); + let decoded = encode_and_decode(array_value.clone())?; + assert_eq!(decoded, array_value); + + let mixed_array = json!([1, "text", true, null]); + let decoded = encode_and_decode(mixed_array.clone())?; + assert_eq!(decoded, mixed_array); + + let nested_array = json!([[1, 2], [3, 4]]); + let decoded = encode_and_decode(nested_array.clone())?; + assert_eq!(decoded, nested_array); + + Ok(()) + } + + #[test] + fn test_decode_object() -> Result<(), Error> { + let object_value = json!({"name": "John", "age": 30}); + let decoded = encode_and_decode(object_value.clone())?; + assert_eq!(decoded, object_value); + + let complex_object = json!({ + "name": "John", + "age": 30, + "is_active": true, + "email": null + }); + let decoded = encode_and_decode(complex_object.clone())?; + assert_eq!(decoded, complex_object); + + let nested_object = json!({ + "person": { + "name": "John", + "age": 30 + }, + "company": { + "name": "ACME Inc.", + "location": "New York" + } + }); + let decoded = encode_and_decode(nested_object.clone())?; + assert_eq!(decoded, nested_object); + + Ok(()) + } + + #[test] + fn test_decode_complex() -> Result<(), Error> { + let complex_value = json!({ + "name": "John Doe", + "age": 30, + "is_active": true, + "scores": [95, 87, 92], + "null_value": null, + "address": { + "street": "123 Main St", + "city": "Anytown", + "zip": 12345 + }, + "contacts": [ + { + "type": "email", + "value": "john@example.com" + }, + { + "type": "phone", + "value": "555-1234" + } + ] + }); + + let decoded = encode_and_decode(complex_value.clone())?; + assert_eq!(decoded, complex_value); + + Ok(()) + } + + #[test] + fn test_decode_null_function() { + let result = decode_null().unwrap(); + assert_eq!(result, Value::Null); + } + + #[test] + fn test_decode_primitive_function() -> Result<(), Error> { + // Test with null type + let mut pos = 0; + let data = [0x00]; // Null type + let result = decode_primitive(&data, &mut pos)?; + assert_eq!(result, Value::Null); + + // Test with boolean true + let mut pos = 0; + let data = [0x01]; // Boolean true + let result = decode_primitive(&data, &mut pos)?; + assert_eq!(result, Value::Bool(true)); + + // Test with boolean false + let mut pos = 0; + let data = [0x02]; // Boolean false + let result = decode_primitive(&data, &mut pos)?; + assert_eq!(result, Value::Bool(false)); + + // Test with int8 + let mut pos = 0; + let data = [0x03, 42]; // Int8 type, value 42 + let result = decode_primitive(&data, &mut pos)?; + assert_eq!(result, json!(42)); + + // Test with string + let mut pos = 0; + let data = [0x10, 0x05, 0x00, 0x00, 0x00, 0x48, 0x65, 0x6C, 0x6C, 0x6F]; + // String type, length 5, "Hello" + let result = decode_primitive(&data, &mut pos)?; + assert_eq!(result, json!("Hello")); + + Ok(()) + } + + #[test] + fn test_decode_short_string_function() -> Result<(), Error> { + let mut pos = 0; + let data = [0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F]; // Length 5, "Hello" + let result = decode_short_string(&data, &mut pos)?; + assert_eq!(result, json!("Hello")); + + // Test with empty string + let mut pos = 0; + let data = [0x00]; // Length 0, "" + let result = decode_short_string(&data, &mut pos)?; + assert_eq!(result, json!("")); + + // Test with error case - unexpected end of data + let mut pos = 0; + let data = [0x05, 0x48, 0x65]; // Length 5 but only 3 bytes available + let result = decode_short_string(&data, &mut pos); + assert!(result.is_err()); + + Ok(()) + } + + #[test] + fn test_decode_string_function() -> Result<(), Error> { + let mut pos = 0; + let data = [0x05, 0x00, 0x00, 0x00, 0x48, 0x65, 0x6C, 0x6C, 0x6F]; + // Length 5, "Hello" + let result = decode_long_string(&data, &mut pos)?; + assert_eq!(result, json!("Hello")); + + // Test with empty string + let mut pos = 0; + let data = [0x00, 0x00, 0x00, 0x00]; // Length 0, "" + let result = decode_long_string(&data, &mut pos)?; + assert_eq!(result, json!("")); + + // Test with error case - unexpected end of data + let mut pos = 0; + let data = [0x05, 0x00, 0x00, 0x00, 0x48, 0x65]; + // Length 5 but only 2 bytes available + let result = decode_long_string(&data, &mut pos); + assert!(result.is_err()); + + Ok(()) + } +} \ No newline at end of file diff --git a/arrow-variant/src/encoder/mod.rs b/arrow-variant/src/encoder/mod.rs new file mode 100644 index 000000000000..b2cdab366634 --- /dev/null +++ b/arrow-variant/src/encoder/mod.rs @@ -0,0 +1,689 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Encoder module for converting JSON values to Variant binary format + +use serde_json::Value; +use std::collections::HashMap; +use crate::error::Error; + +/// Variant basic types as defined in the Arrow Variant specification +/// +/// Basic Type ID Description +/// Primitive 0 One of the primitive types +/// Short string 1 A string with a length less than 64 bytes +/// Object 2 A collection of (string-key, variant-value) pairs +/// Array 3 An ordered sequence of variant values +pub enum VariantBasicType { + /// Primitive type (0) + Primitive = 0, + /// Short string (1) + ShortString = 1, + /// Object (2) + Object = 2, + /// Array (3) + Array = 3, +} + +/// Variant primitive types as defined in the Arrow Variant specification +/// +/// Equivalence Class Variant Physical Type Type ID Equivalent Parquet Type Binary format +/// NullType null 0 UNKNOWN none +/// Boolean boolean (True) 1 BOOLEAN none +/// Boolean boolean (False) 2 BOOLEAN none +/// Exact Numeric int8 3 INT(8, signed) 1 byte +/// Exact Numeric int16 4 INT(16, signed) 2 byte little-endian +/// Exact Numeric int32 5 INT(32, signed) 4 byte little-endian +/// Exact Numeric int64 6 INT(64, signed) 8 byte little-endian +/// Double double 7 DOUBLE IEEE little-endian +/// Exact Numeric decimal4 8 DECIMAL(precision, scale) 1 byte scale in range [0, 38], followed by little-endian unscaled value +/// Exact Numeric decimal8 9 DECIMAL(precision, scale) 1 byte scale in range [0, 38], followed by little-endian unscaled value +/// Exact Numeric decimal16 10 DECIMAL(precision, scale) 1 byte scale in range [0, 38], followed by little-endian unscaled value +/// Date date 11 DATE 4 byte little-endian +/// Timestamp timestamp 12 TIMESTAMP(isAdjustedToUTC=true, MICROS) 8-byte little-endian +/// TimestampNTZ timestamp without time zone 13 TIMESTAMP(isAdjustedToUTC=false, MICROS) 8-byte little-endian +/// Float float 14 FLOAT IEEE little-endian +/// Binary binary 15 BINARY 4 byte little-endian size, followed by bytes +/// String string 16 STRING 4 byte little-endian size, followed by UTF-8 encoded bytes +/// TimeNTZ time without time zone 17 TIME(isAdjustedToUTC=false, MICROS) 8-byte little-endian +/// Timestamp timestamp with time zone 18 TIMESTAMP(isAdjustedToUTC=true, NANOS) 8-byte little-endian +/// TimestampNTZ timestamp without time zone 19 TIMESTAMP(isAdjustedToUTC=false, NANOS) 8-byte little-endian +/// UUID uuid 20 UUID 16-byte big-endian +pub enum VariantPrimitiveType { + /// Null type (0) + Null = 0, + /// Boolean true (1) + BooleanTrue = 1, + /// Boolean false (2) + BooleanFalse = 2, + /// 8-bit signed integer (3) + Int8 = 3, + /// 16-bit signed integer (4) + Int16 = 4, + /// 32-bit signed integer (5) + Int32 = 5, + /// 64-bit signed integer (6) + Int64 = 6, + /// 64-bit floating point (7) + Double = 7, + /// 32-bit decimal (8) + Decimal4 = 8, + /// 64-bit decimal (9) + Decimal8 = 9, + /// 128-bit decimal (10) + Decimal16 = 10, + /// Date (11) + Date = 11, + /// Timestamp with timezone (12) + Timestamp = 12, + /// Timestamp without timezone (13) + TimestampNTZ = 13, + /// 32-bit floating point (14) + Float = 14, + /// Binary data (15) + Binary = 15, + /// UTF-8 string (16) + String = 16, + /// Time without timezone (17) + TimeNTZ = 17, + /// Timestamp with timezone (nanos) (18) + TimestampNanos = 18, + /// Timestamp without timezone (nanos) (19) + TimestampNTZNanos = 19, + /// UUID (20) + Uuid = 20, +} + +/// Creates a header byte for a primitive type value +/// +/// The header byte contains: +/// - Basic type (2 bits) in the lower bits +/// - Type ID (6 bits) in the upper bits +fn primitive_header(type_id: u8) -> u8 { + (type_id << 2) | VariantBasicType::Primitive as u8 +} + +/// Creates a header byte for a short string value +/// +/// The header byte contains: +/// - Basic type (2 bits) in the lower bits +/// - String length (6 bits) in the upper bits +fn short_str_header(size: u8) -> u8 { + (size << 2) | VariantBasicType::ShortString as u8 +} + +/// Creates a header byte for an object value +/// +/// The header byte contains: +/// - Basic type (2 bits) in the lower bits +/// - is_large (1 bit) at position 6 +/// - field_id_size_minus_one (2 bits) at positions 4-5 +/// - field_offset_size_minus_one (2 bits) at positions 2-3 +fn object_header(is_large: bool, id_size: u8, offset_size: u8) -> u8 { + ((is_large as u8) << 6) | + ((id_size - 1) << 4) | + ((offset_size - 1) << 2) | + VariantBasicType::Object as u8 +} + +/// Creates a header byte for an array value +/// +/// The header byte contains: +/// - Basic type (2 bits) in the lower bits +/// - is_large (1 bit) at position 4 +/// - field_offset_size_minus_one (2 bits) at positions 2-3 +fn array_header(is_large: bool, offset_size: u8) -> u8 { + ((is_large as u8) << 4) | + ((offset_size - 1) << 2) | + VariantBasicType::Array as u8 +} + +/// Encodes a null value +fn encode_null(output: &mut Vec) { + output.push(primitive_header(VariantPrimitiveType::Null as u8)); +} + +/// Encodes a boolean value +fn encode_boolean(value: bool, output: &mut Vec) { + if value { + output.push(primitive_header(VariantPrimitiveType::BooleanTrue as u8)); + } else { + output.push(primitive_header(VariantPrimitiveType::BooleanFalse as u8)); + } +} + +/// Encodes an integer value, choosing the smallest sufficient type +fn encode_integer(value: i64, output: &mut Vec) { + if value >= -128 && value <= 127 { + // Int8 + output.push(primitive_header(VariantPrimitiveType::Int8 as u8)); + output.push(value as u8); + } else if value >= -32768 && value <= 32767 { + // Int16 + output.push(primitive_header(VariantPrimitiveType::Int16 as u8)); + output.extend_from_slice(&(value as i16).to_le_bytes()); + } else if value >= -2147483648 && value <= 2147483647 { + // Int32 + output.push(primitive_header(VariantPrimitiveType::Int32 as u8)); + output.extend_from_slice(&(value as i32).to_le_bytes()); + } else { + // Int64 + output.push(primitive_header(VariantPrimitiveType::Int64 as u8)); + output.extend_from_slice(&value.to_le_bytes()); + } +} + +/// Encodes a float value +fn encode_float(value: f64, output: &mut Vec) { + output.push(primitive_header(VariantPrimitiveType::Double as u8)); + output.extend_from_slice(&value.to_le_bytes()); +} + +/// Encodes a string value +fn encode_string(value: &str, output: &mut Vec) { + let bytes = value.as_bytes(); + let len = bytes.len(); + + if len < 64 { + // Short string format - encode length in header + let header = short_str_header(len as u8); + output.push(header); + output.extend_from_slice(bytes); + } else { + // Long string format (using primitive string type) + let header = primitive_header(VariantPrimitiveType::String as u8); + output.push(header); + + // Write length as 4-byte little-endian + output.extend_from_slice(&(len as u32).to_le_bytes()); + + // Write string bytes + output.extend_from_slice(bytes); + } +} + +/// Encodes an array value +fn encode_array(array: &[Value], output: &mut Vec, key_mapping: &HashMap) -> Result<(), Error> { + let len = array.len(); + + // Determine if we need large size encoding + let is_large = len > 255; + + // First pass to calculate offsets and collect encoded values + let mut temp_outputs = Vec::with_capacity(len); + let mut offsets = Vec::with_capacity(len + 1); + offsets.push(0); + + let mut max_offset = 0; + for value in array { + let mut temp_output = Vec::new(); + encode_value(value, &mut temp_output, key_mapping)?; + max_offset += temp_output.len(); + offsets.push(max_offset); + temp_outputs.push(temp_output); + } + + // Determine minimum offset size + let offset_size = if max_offset <= 255 { 1 } + else if max_offset <= 65535 { 2 } + else { 3 }; + + // Write array header + output.push(array_header(is_large, offset_size)); + + // Write length as 1 or 4 bytes + if is_large { + output.extend_from_slice(&(len as u32).to_le_bytes()); + } else { + output.push(len as u8); + } + + // Write offsets + for offset in &offsets { + match offset_size { + 1 => output.push(*offset as u8), + 2 => output.extend_from_slice(&(*offset as u16).to_le_bytes()), + 3 => { + output.push((*offset & 0xFF) as u8); + output.push(((*offset >> 8) & 0xFF) as u8); + output.push(((*offset >> 16) & 0xFF) as u8); + }, + _ => unreachable!(), + } + } + + // Write values + for temp_output in temp_outputs { + output.extend_from_slice(&temp_output); + } + + Ok(()) +} + +/// Encodes an object value +fn encode_object(obj: &serde_json::Map, output: &mut Vec, key_mapping: &HashMap) -> Result<(), Error> { + let len = obj.len(); + + // Determine if we need large size encoding + let is_large = len > 255; + + // Collect and sort fields by key + let mut fields: Vec<_> = obj.iter().collect(); + fields.sort_by(|a, b| a.0.cmp(b.0)); + + // First pass to calculate offsets and collect encoded values + let mut field_ids = Vec::with_capacity(len); + let mut temp_outputs = Vec::with_capacity(len); + let mut offsets = Vec::with_capacity(len + 1); + offsets.push(0); + + let mut data_size = 0; + for (key, value) in &fields { + let field_id = key_mapping.get(key.as_str()) + .ok_or_else(|| Error::VariantCreation(format!("Key not found in mapping: {}", key)))?; + field_ids.push(*field_id); + + let mut temp_output = Vec::new(); + encode_value(value, &mut temp_output, key_mapping)?; + data_size += temp_output.len(); + offsets.push(data_size); + temp_outputs.push(temp_output); + } + + // Determine minimum sizes needed - use size 1 for empty objects + let id_size = if field_ids.is_empty() { 1 } + else if field_ids.iter().max().unwrap() <= &255 { 1 } + else if field_ids.iter().max().unwrap() <= &65535 { 2 } + else if field_ids.iter().max().unwrap() <= &16777215 { 3 } + else { 4 }; + + let offset_size = if data_size <= 255 { 1 } + else if data_size <= 65535 { 2 } + else { 3 }; + + // Write object header + output.push(object_header(is_large, id_size, offset_size)); + + // Write length as 1 or 4 bytes + if is_large { + output.extend_from_slice(&(len as u32).to_le_bytes()); + } else { + output.push(len as u8); + } + + // Write field IDs + for id in &field_ids { + match id_size { + 1 => output.push(*id as u8), + 2 => output.extend_from_slice(&(*id as u16).to_le_bytes()), + 3 => { + output.push((*id & 0xFF) as u8); + output.push(((*id >> 8) & 0xFF) as u8); + output.push(((*id >> 16) & 0xFF) as u8); + }, + 4 => output.extend_from_slice(&(*id as u32).to_le_bytes()), + _ => unreachable!(), + } + } + + // Write offsets + for offset in &offsets { + match offset_size { + 1 => output.push(*offset as u8), + 2 => output.extend_from_slice(&(*offset as u16).to_le_bytes()), + 3 => { + output.push((*offset & 0xFF) as u8); + output.push(((*offset >> 8) & 0xFF) as u8); + output.push(((*offset >> 16) & 0xFF) as u8); + }, + 4 => output.extend_from_slice(&(*offset as u32).to_le_bytes()), + _ => unreachable!(), + } + } + + // Write values + for temp_output in temp_outputs { + output.extend_from_slice(&temp_output); + } + + Ok(()) +} + +/// Encodes a JSON value to Variant binary format +pub fn encode_value(value: &Value, output: &mut Vec, key_mapping: &HashMap) -> Result<(), Error> { + match value { + Value::Null => encode_null(output), + Value::Bool(b) => encode_boolean(*b, output), + Value::Number(n) => { + if let Some(i) = n.as_i64() { + encode_integer(i, output); + } else if let Some(f) = n.as_f64() { + encode_float(f, output); + } else { + return Err(Error::VariantCreation("Unsupported number format".to_string())); + } + }, + Value::String(s) => encode_string(s, output), + Value::Array(a) => encode_array(a, output, key_mapping)?, + Value::Object(o) => encode_object(o, output, key_mapping)?, + } + + Ok(()) +} + +/// Encodes a JSON value to a complete Variant binary value +pub fn encode_json(json: &Value, key_mapping: &HashMap) -> Result, Error> { + let mut output = Vec::new(); + encode_value(json, &mut output, key_mapping)?; + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn setup_key_mapping() -> HashMap { + let mut mapping = HashMap::new(); + mapping.insert("name".to_string(), 0); + mapping.insert("age".to_string(), 1); + mapping.insert("active".to_string(), 2); + mapping.insert("scores".to_string(), 3); + mapping.insert("address".to_string(), 4); + mapping.insert("street".to_string(), 5); + mapping.insert("city".to_string(), 6); + mapping.insert("zip".to_string(), 7); + mapping.insert("tags".to_string(), 8); + mapping + } + + #[test] + fn test_encode_integers() { + // Test Int8 + let mut output = Vec::new(); + encode_integer(42, &mut output); + assert_eq!(output, vec![primitive_header(VariantPrimitiveType::Int8 as u8), 42]); + + // Test Int16 + output.clear(); + encode_integer(1000, &mut output); + assert_eq!(output, vec![primitive_header(VariantPrimitiveType::Int16 as u8), 232, 3]); + + // Test Int32 + output.clear(); + encode_integer(100000, &mut output); + let mut expected = vec![primitive_header(VariantPrimitiveType::Int32 as u8)]; + expected.extend_from_slice(&(100000i32).to_le_bytes()); + assert_eq!(output, expected); + + // Test Int64 + output.clear(); + encode_integer(3000000000, &mut output); + let mut expected = vec![primitive_header(VariantPrimitiveType::Int64 as u8)]; + expected.extend_from_slice(&(3000000000i64).to_le_bytes()); + assert_eq!(output, expected); + } + + #[test] + fn test_encode_float() { + let mut output = Vec::new(); + encode_float(3.14159, &mut output); + let mut expected = vec![primitive_header(VariantPrimitiveType::Double as u8)]; + expected.extend_from_slice(&(3.14159f64).to_le_bytes()); + assert_eq!(output, expected); + } + + #[test] + fn test_encode_string() { + let mut output = Vec::new(); + + // Test short string + let short_str = "Hello"; + encode_string(short_str, &mut output); + + // Check header byte + assert_eq!(output[0], short_str_header(short_str.len() as u8)); + + // Check string content + assert_eq!(&output[1..], short_str.as_bytes()); + + // Test longer string + output.clear(); + let long_str = "This is a longer string that definitely won't fit in the small format because it needs to be at least 64 bytes long to test the long string format"; + encode_string(long_str, &mut output); + + // Check header byte + assert_eq!(output[0], primitive_header(VariantPrimitiveType::String as u8)); + + // Check length bytes + assert_eq!(&output[1..5], &(long_str.len() as u32).to_le_bytes()); + + // Check string content + assert_eq!(&output[5..], long_str.as_bytes()); + } + + #[test] + fn test_encode_array() -> Result<(), Error> { + let key_mapping = setup_key_mapping(); + let json = json!([1, "text", true, null]); + + let mut output = Vec::new(); + encode_array(json.as_array().unwrap(), &mut output, &key_mapping)?; + + // Validate array header + assert_eq!(output[0], array_header(false, 1)); + assert_eq!(output[1], 4); // 4 elements + + // Array should contain encoded versions of the 4 values + Ok(()) + } + + #[test] + fn test_encode_object() -> Result<(), Error> { + let key_mapping = setup_key_mapping(); + let json = json!({ + "name": "John", + "age": 30, + "active": true + }); + + let mut output = Vec::new(); + encode_object(json.as_object().unwrap(), &mut output, &key_mapping)?; + + // Verify header byte + // - basic_type = 2 (Object) + // - is_large = 0 (3 elements < 255) + // - field_id_size_minus_one = 0 (max field_id = 2 < 255) + // - field_offset_size_minus_one = 0 (offset_size = 1, small offsets) + assert_eq!(output[0], 0b00000010); // Object header + + // Verify num_elements (1 byte) + assert_eq!(output[1], 3); + + // Verify field_ids (in lexicographical order: active, age, name) + assert_eq!(output[2], 2); // active + assert_eq!(output[3], 1); // age + assert_eq!(output[4], 0); // name + + // Test empty object + let empty_obj = json!({}); + output.clear(); + encode_object(empty_obj.as_object().unwrap(), &mut output, &key_mapping)?; + + // Verify header byte for empty object + assert_eq!(output[0], 0b00000010); // Object header with minimum sizes + assert_eq!(output[1], 0); // Zero elements + + // Test case 2: Object with large values requiring larger offsets + let obj = json!({ + "name": "This is a very long string that will definitely require more than 255 bytes to encode. Let me add some more text to make sure it exceeds the limit. The string needs to be long enough to trigger the use of 2-byte offsets. Adding more content to ensure we go over the threshold. This is just padding text to make the string longer. Almost there, just a bit more to go. And finally, some more text to push us over the edge.", + "age": 30, + "active": true + }); + + output.clear(); + encode_object(obj.as_object().unwrap(), &mut output, &key_mapping)?; + + // Verify header byte + // - basic_type = 2 (Object) + // - is_large = 0 (3 elements < 255) + // - field_id_size_minus_one = 0 (max field_id = 2 < 255) + // - field_offset_size_minus_one = 1 (offset_size = 2, large offsets) + assert_eq!(output[0], 0b00000110); // Object header with 2-byte offsets + + // Test case 3: Object with nested objects + let obj = json!({ + "name": "John", + "address": { + "street": "123 Main St", + "city": "New York", + "zip": "10001" + }, + "scores": [95, 87, 92] + }); + + output.clear(); + encode_object(obj.as_object().unwrap(), &mut output, &key_mapping)?; + + // Verify header byte + // - basic_type = 2 (Object) + // - is_large = 0 (3 elements < 255) + // - field_id_size_minus_one = 0 (max field_id < 255) + // - field_offset_size_minus_one = 0 (offset_size = 1, determined by data size) + assert_eq!(output[0], 0b00000010); // Object header with 1-byte offsets + + // Verify num_elements (1 byte) + assert_eq!(output[1], 3); + + // Verify field_ids (in lexicographical order: address, name, scores) + assert_eq!(output[2], 4); // address + assert_eq!(output[3], 0); // name + assert_eq!(output[4], 3); // scores + + Ok(()) + } + + #[test] + fn test_encode_null() { + let mut output = Vec::new(); + encode_null(&mut output); + assert_eq!(output, vec![primitive_header(VariantPrimitiveType::Null as u8)]); + + // Test that the encoded value can be decoded correctly + let keys = Vec::::new(); + let result = crate::decoder::decode_value(&output, &keys).unwrap(); + assert!(result.is_null()); + } + + #[test] + fn test_encode_boolean() { + // Test true + let mut output = Vec::new(); + encode_boolean(true, &mut output); + assert_eq!(output, vec![primitive_header(VariantPrimitiveType::BooleanTrue as u8)]); + + // Test that the encoded value can be decoded correctly + let keys = Vec::::new(); + let result = crate::decoder::decode_value(&output, &keys).unwrap(); + assert_eq!(result, serde_json::json!(true)); + + // Test false + output.clear(); + encode_boolean(false, &mut output); + assert_eq!(output, vec![primitive_header(VariantPrimitiveType::BooleanFalse as u8)]); + + // Test that the encoded value can be decoded correctly + let result = crate::decoder::decode_value(&output, &keys).unwrap(); + assert_eq!(result, serde_json::json!(false)); + } + + #[test] + fn test_object_encoding() { + let key_mapping = setup_key_mapping(); + let json = json!({ + "name": "John", + "age": 30, + "active": true + }); + + let mut output = Vec::new(); + encode_object(json.as_object().unwrap(), &mut output, &key_mapping).unwrap(); + + // Verify header byte + // - basic_type = 2 (Object) + // - is_large = 0 (3 elements < 255) + // - field_id_size_minus_one = 0 (max field_id = 2 < 255) + // - field_offset_size_minus_one = 0 (offset_size = 1, small offsets) + assert_eq!(output[0], 0b00000010); // Object header + + // Verify num_elements (1 byte) + assert_eq!(output[1], 3); + + // Verify field_ids (in lexicographical order: active, age, name) + assert_eq!(output[2], 2); // active + assert_eq!(output[3], 1); // age + assert_eq!(output[4], 0); // name + + // Test case 2: Object with large values requiring larger offsets + let obj = json!({ + "name": "This is a very long string that will definitely require more than 255 bytes to encode. Let me add some more text to make sure it exceeds the limit. The string needs to be long enough to trigger the use of 2-byte offsets. Adding more content to ensure we go over the threshold. This is just padding text to make the string longer. Almost there, just a bit more to go. And finally, some more text to push us over the edge.", + "age": 30, + "active": true + }); + + output.clear(); + encode_object(obj.as_object().unwrap(), &mut output, &key_mapping).unwrap(); + + // Verify header byte + // - basic_type = 2 (Object) + // - is_large = 0 (3 elements < 255) + // - field_id_size_minus_one = 0 (max field_id = 2 < 255) + // - field_offset_size_minus_one = 1 (offset_size = 2, large offsets) + assert_eq!(output[0], 0b00000110); // Object header with 2-byte offsets + + + // Test case 3: Object with nested objects + let obj = json!({ + "name": "John", + "address": { + "street": "123 Main St", + "city": "New York", + "zip": "10001" + }, + "scores": [95, 87, 92] + }); + + output.clear(); + encode_object(obj.as_object().unwrap(), &mut output, &key_mapping).unwrap(); + + // Verify header byte + // - basic_type = 2 (Object) + // - is_large = 0 (3 elements < 255) + // - field_id_size_minus_one = 0 (max field_id < 255) + // - field_offset_size_minus_one = 0 (offset_size = 1, determined by data size) + assert_eq!(output[0], 0b00000010); // Object header with 1-byte offsets + + // Verify num_elements (1 byte) + assert_eq!(output[1], 3); + + // Verify field_ids (in lexicographical order: address, name, scores) + assert_eq!(output[2], 4); // address + assert_eq!(output[3], 0); // name + assert_eq!(output[4], 3); // scores + + } +} \ No newline at end of file diff --git a/arrow-variant/src/error.rs b/arrow-variant/src/error.rs new file mode 100644 index 000000000000..081203e60288 --- /dev/null +++ b/arrow-variant/src/error.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Error types for the arrow-variant crate + +use arrow_schema::ArrowError; +use std::error::Error as StdError; +use std::fmt::{Display, Formatter, Result as FmtResult}; + +/// Error type for operations in this crate +#[derive(Debug)] +pub enum Error { + /// Error when parsing metadata + InvalidMetadata(String), + + /// Error when parsing JSON + JsonParse(serde_json::Error), + + /// Error when creating a Variant + VariantCreation(String), + + /// Error when reading a Variant + VariantRead(String), + + /// Error when creating a VariantArray + VariantArrayCreation(ArrowError), + + /// Error for empty input + EmptyInput, +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Error::InvalidMetadata(msg) => write!(f, "Invalid metadata: {}", msg), + Error::JsonParse(err) => write!(f, "JSON parse error: {}", err), + Error::VariantCreation(msg) => write!(f, "Failed to create Variant: {}", msg), + Error::VariantRead(msg) => write!(f, "Failed to read Variant: {}", msg), + Error::VariantArrayCreation(err) => write!(f, "Failed to create VariantArray: {}", err), + Error::EmptyInput => write!(f, "Empty input"), + } + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + Error::JsonParse(err) => Some(err), + Error::VariantArrayCreation(err) => Some(err), + _ => None, + } + } +} + +impl From for Error { + fn from(err: serde_json::Error) -> Self { + Error::JsonParse(err) + } +} + +impl From for Error { + fn from(err: ArrowError) -> Self { + Error::VariantArrayCreation(err) + } +} + +impl From for ArrowError { + fn from(err: Error) -> Self { + ArrowError::ExternalError(Box::new(err)) + } +} \ No newline at end of file diff --git a/arrow-variant/src/integration.rs b/arrow-variant/src/integration.rs new file mode 100644 index 000000000000..5d4b845099a6 --- /dev/null +++ b/arrow-variant/src/integration.rs @@ -0,0 +1,250 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests and utilities for the arrow-variant crate + +use arrow_array::{Array, StructArray}; +use arrow_schema::extension::Variant; +use serde_json::{json, Value}; + +use crate::error::Error; +use crate::reader::{from_json, from_json_array}; +use crate::writer::{to_json, to_json_array}; + +/// Creates a test Variant from a JSON value +pub fn create_test_variant(json_value: Value) -> Result { + let json_str = json_value.to_string(); + from_json(&json_str) +} + +/// Creates a test StructArray with variant data from a list of JSON values +pub fn create_test_variant_array(json_values: Vec) -> Result { + let json_strings: Vec = json_values.into_iter().map(|v| v.to_string()).collect(); + let str_refs: Vec<&str> = json_strings.iter().map(|s| s.as_str()).collect(); + from_json_array(&str_refs) +} + +/// Validates that a JSON value can be roundtripped through Variant +pub fn validate_variant_roundtrip(json_value: Value) -> Result<(), Error> { + let json_str = json_value.to_string(); + + // Convert JSON to Variant + let variant = from_json(&json_str)?; + + // Convert Variant back to JSON + let result_json = to_json(&variant)?; + + // Parse both to Value to compare them structurally + let original: Value = serde_json::from_str(&json_str).unwrap(); + let result: Value = serde_json::from_str(&result_json).unwrap(); + + assert_eq!(original, result, "Original and result JSON should be equal"); + + Ok(()) +} + +/// Creates a sample Variant with a complex JSON structure +pub fn create_sample_variant() -> Result { + let json = json!({ + "name": "John Doe", + "age": 30, + "is_active": true, + "scores": [95, 87, 92], + "null_value": null, + "address": { + "street": "123 Main St", + "city": "Anytown", + "zip": 12345 + } + }); + + create_test_variant(json) +} + +/// Creates a sample StructArray with variant data containing multiple entries +pub fn create_sample_variant_array() -> Result { + let json_values = vec![ + json!({ + "name": "John", + "age": 30, + "tags": ["developer", "rust"] + }), + json!({ + "name": "Jane", + "age": 28, + "tags": ["designer", "ui/ux"] + }), + json!({ + "name": "Bob", + "age": 22, + "tags": ["intern", "student"] + }) + ]; + + create_test_variant_array(json_values) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_roundtrip_primitives() -> Result<(), Error> { + // Test null + validate_variant_roundtrip(json!(null))?; + + // Test boolean + validate_variant_roundtrip(json!(true))?; + validate_variant_roundtrip(json!(false))?; + + // Test integers + validate_variant_roundtrip(json!(0))?; + validate_variant_roundtrip(json!(42))?; + validate_variant_roundtrip(json!(-1000))?; + validate_variant_roundtrip(json!(12345678))?; + + // Test floating point + validate_variant_roundtrip(json!(3.14159))?; + validate_variant_roundtrip(json!(-2.71828))?; + + // Test string + validate_variant_roundtrip(json!("Hello, World!"))?; + validate_variant_roundtrip(json!(""))?; + + Ok(()) + } + + #[test] + fn test_roundtrip_arrays() -> Result<(), Error> { + // Empty array + validate_variant_roundtrip(json!([]))?; + + // Homogeneous arrays + validate_variant_roundtrip(json!([1, 2, 3, 4, 5]))?; + validate_variant_roundtrip(json!(["a", "b", "c"]))?; + validate_variant_roundtrip(json!([true, false, true]))?; + + // Heterogeneous arrays + validate_variant_roundtrip(json!([1, "text", true, null]))?; + + // Nested arrays + validate_variant_roundtrip(json!([[1, 2], [3, 4], [5, 6]]))?; + + Ok(()) + } + + #[test] + fn test_roundtrip_objects() -> Result<(), Error> { + // Empty object + validate_variant_roundtrip(json!({}))?; + + + // Simple object + validate_variant_roundtrip(json!({"name": "John", "age": 30}))?; + + // Object with different types + validate_variant_roundtrip(json!({ + "name": "John", + "age": 30, + "is_active": true, + "email": null + }))?; + + // Nested object + validate_variant_roundtrip(json!({ + "person": { + "name": "John", + "age": 30 + }, + "company": { + "name": "ACME Inc.", + "location": "New York" + } + }))?; + + Ok(()) + } + + #[test] + fn test_roundtrip_complex() -> Result<(), Error> { + // Complex nested structure + validate_variant_roundtrip(json!({ + "name": "John Doe", + "age": 30, + "is_active": true, + "scores": [95, 87, 92], + "null_value": null, + "address": { + "street": "123 Main St", + "city": "Anytown", + "zip": 12345 + }, + "contacts": [ + { + "type": "email", + "value": "john@example.com" + }, + { + "type": "phone", + "value": "555-1234" + } + ], + "metadata": { + "created_at": "2023-01-01", + "updated_at": null, + "tags": ["user", "customer"] + } + }))?; + + Ok(()) + } + + #[test] + fn test_variant_array_roundtrip() -> Result<(), Error> { + // Create variant array + let variant_array = create_sample_variant_array()?; + + // Convert to JSON strings + let json_strings = to_json_array(&variant_array)?; + + // Convert back to variant array + let str_refs: Vec<&str> = json_strings.iter().map(|s| s.as_str()).collect(); + let round_trip_array = from_json_array(&str_refs)?; + + // Verify length + assert_eq!(variant_array.len(), round_trip_array.len()); + + // Convert both arrays to JSON and compare + let original_json = to_json_array(&variant_array)?; + let result_json = to_json_array(&round_trip_array)?; + + for (i, (original, result)) in original_json.iter().zip(result_json.iter()).enumerate() { + let original_value: Value = serde_json::from_str(original).unwrap(); + let result_value: Value = serde_json::from_str(result).unwrap(); + + assert_eq!( + original_value, + result_value, + "JSON values at index {} should be equal", + i + ); + } + + Ok(()) + } +} \ No newline at end of file diff --git a/arrow-variant/src/lib.rs b/arrow-variant/src/lib.rs new file mode 100644 index 000000000000..54ef099180c3 --- /dev/null +++ b/arrow-variant/src/lib.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Transfer data between the Arrow Variant binary format and JSON. +//! +//! The Arrow Variant extension type stores data as two binary values: +//! metadata and value. This crate provides utilities to convert between +//! JSON and the Variant binary format. +//! +//! # Example +//! +//! ```rust +//! use arrow_variant::{from_json, to_json}; +//! use arrow_schema::extension::Variant; +//! +//! // Convert JSON to Variant +//! let json_str = r#"{"key":"value"}"#; +//! let variant = from_json(json_str).unwrap(); +//! +//! // Convert Variant back to JSON +//! let json = to_json(&variant).unwrap(); +//! assert_eq!(json_str, json); +//! ``` + +#![deny(rustdoc::broken_intra_doc_links)] +#![warn(missing_docs)] + +pub mod error; +pub mod reader; +pub mod writer; +/// Encoder module for converting JSON to Variant binary format +pub mod encoder; +/// Decoder module for converting Variant binary format to JSON +pub mod decoder; +/// Utilities for working with variant as struct arrays +pub mod variant_utils; + +pub use error::Error; +pub use reader::{from_json, from_json_array, from_json_value, from_json_value_array}; +pub use writer::{to_json, to_json_array, to_json_value, to_json_value_array}; +pub use encoder::{encode_value, encode_json, VariantBasicType, VariantPrimitiveType}; +pub use decoder::{decode_value, decode_json}; +pub use variant_utils::{create_variant_array, get_variant, validate_struct_array, create_empty_variant_array}; + +/// Utility functions for working with Variant metadata +pub mod metadata; +pub use metadata::{create_metadata, parse_metadata}; + +/// Integration utilities and tests +pub mod integration; +pub use integration::{create_test_variant, create_test_variant_array, validate_variant_roundtrip}; + +/// Converts a JSON string to a Variant and back +/// +/// # Examples +/// +/// ``` +/// use arrow_variant::{from_json, to_json}; +/// +/// let json_str = r#"{"key":"value"}"#; +/// let variant = from_json(json_str).unwrap(); +/// let result = to_json(&variant).unwrap(); +/// assert_eq!(json_str, result); +/// ``` +pub fn validate_json_roundtrip(json_str: &str) -> Result<(), Error> { + let variant = from_json(json_str)?; + let result = to_json(&variant)?; + assert_eq!(json_str, result); + Ok(()) +} \ No newline at end of file diff --git a/arrow-variant/src/metadata.rs b/arrow-variant/src/metadata.rs new file mode 100644 index 000000000000..d9ee7558cabd --- /dev/null +++ b/arrow-variant/src/metadata.rs @@ -0,0 +1,433 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for working with Variant metadata + +use crate::error::Error; +use serde_json::Value; +use std::collections::HashMap; +use arrow_array::{ + Array, ArrayRef, BinaryArray, StructArray, +}; +use arrow_array::builder::{BinaryBuilder, LargeBinaryBuilder}; + +/// Creates a metadata binary vector for a JSON value according to the Arrow Variant specification +/// +/// Metadata format: +/// - header: 1 byte ( | << 4 | ( << 6)) +/// - dictionary_size: `offset_size` bytes (unsigned little-endian) +/// - offsets: `dictionary_size + 1` entries of `offset_size` bytes each +/// - bytes: UTF-8 encoded dictionary string values +/// +/// # Arguments +/// +/// * `json_value` - The JSON value to create metadata for +/// * `sort_keys` - If true, keys will be sorted lexicographically; if false, keys will be used in their original order +pub fn create_metadata(json_value: &Value, sort_keys: bool) -> Result, Error> { + // Extract all keys from the JSON value (including nested) + let keys = extract_all_keys(json_value); + + // Convert keys to a vector and optionally sort them + let mut keys: Vec<_> = keys.into_iter().collect(); + if sort_keys { + keys.sort(); + } + + // Calculate the total size of all dictionary strings + let mut dictionary_string_size = 0u32; + for key in &keys { + dictionary_string_size += key.len() as u32; + } + + // Determine the minimum integer size required for offsets + // The largest offset is the one-past-the-end value, which is total string size + let max_size = std::cmp::max(dictionary_string_size, (keys.len() + 1) as u32); + let offset_size = get_min_integer_size(max_size as usize); + let offset_size_minus_one = offset_size - 1; + + // Set sorted_strings based on whether keys are sorted in metadata + let sorted_strings = if sort_keys { 1 } else { 0 }; + + // Create header: version=1, sorted_strings based on parameter, offset_size based on calculation + let header = 0x01 | (sorted_strings << 4) | ((offset_size_minus_one as u8) << 6); + + // Start building the metadata + let mut metadata = Vec::new(); + metadata.push(header); + + // Add dictionary_size (this is the number of keys) + // Write the dictionary size using the calculated offset_size + for i in 0..offset_size { + metadata.push(((keys.len() >> (8 * i)) & 0xFF) as u8); + } + + // Pre-calculate offsets and prepare bytes + let mut bytes = Vec::new(); + let mut offsets = Vec::with_capacity(keys.len() + 1); + let mut current_offset = 0u32; + + offsets.push(current_offset); + + for key in keys { + bytes.extend_from_slice(key.as_bytes()); + current_offset += key.len() as u32; + offsets.push(current_offset); + } + + // Add all offsets using the calculated offset_size + for offset in &offsets { + for i in 0..offset_size { + metadata.push(((*offset >> (8 * i)) & 0xFF) as u8); + } + } + + // Add dictionary bytes + metadata.extend_from_slice(&bytes); + + Ok(metadata) +} + +/// Determines the minimum integer size required to represent a value +fn get_min_integer_size(value: usize) -> usize { + if value <= 255 { + 1 + } else if value <= 65535 { + 2 + } else if value <= 16777215 { + 3 + } else { + 4 + } +} + +/// Extracts all keys from a JSON value, including nested objects +fn extract_all_keys(json_value: &Value) -> Vec { + let mut keys = Vec::new(); + + match json_value { + Value::Object(map) => { + for (key, value) in map { + keys.push(key.clone()); + keys.extend(extract_all_keys(value)); + } + } + Value::Array(arr) => { + for value in arr { + keys.extend(extract_all_keys(value)); + } + } + _ => {} // No keys for primitive values + } + + keys +} + +/// Parses metadata binary into a map of keys to their indices +pub fn parse_metadata(metadata: &[u8]) -> Result, Error> { + if metadata.is_empty() { + return Err(Error::InvalidMetadata("Empty metadata".to_string())); + } + + // Parse header + let header = metadata[0]; + let version = header & 0x0F; + let _sorted_strings = (header >> 4) & 0x01 != 0; + let offset_size_minus_one = (header >> 6) & 0x03; + let offset_size = (offset_size_minus_one + 1) as usize; + + if version != 1 { + return Err(Error::InvalidMetadata(format!("Unsupported version: {}", version))); + } + + if metadata.len() < 1 + offset_size { + return Err(Error::InvalidMetadata("Metadata too short for dictionary size".to_string())); + } + + // Parse dictionary_size + let mut dictionary_size = 0u32; + for i in 0..offset_size { + dictionary_size |= (metadata[1 + i] as u32) << (8 * i); + } + + // Parse offsets + let offset_start = 1 + offset_size; + let offset_end = offset_start + (dictionary_size as usize + 1) * offset_size; + + if metadata.len() < offset_end { + return Err(Error::InvalidMetadata("Metadata too short for offsets".to_string())); + } + + let mut offsets = Vec::with_capacity(dictionary_size as usize + 1); + for i in 0..=dictionary_size { + let offset_pos = offset_start + (i as usize * offset_size); + let mut offset = 0u32; + for j in 0..offset_size { + offset |= (metadata[offset_pos + j] as u32) << (8 * j); + } + offsets.push(offset as usize); + } + + // Parse dictionary strings + let mut result = HashMap::new(); + for i in 0..dictionary_size as usize { + let start = offset_end + offsets[i]; + let end = offset_end + offsets[i + 1]; + + if end > metadata.len() { + return Err(Error::InvalidMetadata("Invalid string offset".to_string())); + } + + let key = std::str::from_utf8(&metadata[start..end]) + .map_err(|e| Error::InvalidMetadata(format!("Invalid UTF-8: {}", e)))? + .to_string(); + + result.insert(key, i); + } + + Ok(result) +} + +/// Creates simple metadata for testing purposes +/// +/// This creates valid metadata with a single key "key" +pub fn create_test_metadata() -> Vec { + vec![ + 0x01, // header: version=1, sorted=0, offset_size=1 + 0x01, // dictionary_size = 1 + 0x00, // offset 0 + 0x03, // offset 3 + b'k', b'e', b'y' // dictionary bytes + ] +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_simple_object() { + let value = json!({ + "a": 1, + "b": 2, + "c": 3 + }); + + let metadata = create_metadata(&value, false).unwrap(); + + // Header: version=1, sorted_strings=0, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x01); + + // Dictionary size: 3 keys + assert_eq!(metadata[1], 3); + + // Offsets: [0, 1, 2, 3] (1 byte each) + assert_eq!(metadata[2], 0); // First offset + assert_eq!(metadata[3], 1); // Second offset + assert_eq!(metadata[4], 2); // Third offset + assert_eq!(metadata[5], 3); // One-past-the-end offset + + // Dictionary bytes: "abc" + assert_eq!(&metadata[6..9], b"abc"); + } + + #[test] + fn test_normal_object() { + let value = json!({ + "a": 1, + "b": 2, + "c": 3 + }); + + let metadata = create_metadata(&value, false).unwrap(); + + // Header: version=1, sorted_strings=0, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x01); + + // Dictionary size: 3 keys + assert_eq!(metadata[1], 3); + + // Offsets: [0, 1, 2, 3] (1 byte each) + assert_eq!(metadata[2], 0); // First offset + assert_eq!(metadata[3], 1); // Second offset + assert_eq!(metadata[4], 2); // Third offset + assert_eq!(metadata[5], 3); // One-past-the-end offset + + // Dictionary bytes: "abc" + assert_eq!(&metadata[6..9], b"abc"); + } + + #[test] + fn test_complex_object() { + let value = json!({ + "first_name": "John", + "last_name": "Smith", + "email": "john.smith@example.com" + }); + + let metadata = create_metadata(&value, false).unwrap(); + + // Header: version=1, sorted_strings=0, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x01); + + // Dictionary size: 3 keys + assert_eq!(metadata[1], 3); + + // Offsets: [0, 5, 15, 24] (1 byte each) + assert_eq!(metadata[2], 0); // First offset for "email" + assert_eq!(metadata[3], 5); // Second offset for "first_name" + assert_eq!(metadata[4], 15); // Third offset for "last_name" + assert_eq!(metadata[5], 24); // One-past-the-end offset + + // Dictionary bytes: "emailfirst_namelast_name" + assert_eq!(&metadata[6..30], b"emailfirst_namelast_name"); + } + + #[test] + fn test_nested_object() { + let value = json!({ + "a": { + "b": 1, + "c": 2 + }, + "d": 3 + }); + + let metadata = create_metadata(&value, false).unwrap(); + + // Header: version=1, sorted_strings=0, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x01); + + // Dictionary size: 4 keys (a, b, c, d) + assert_eq!(metadata[1], 4); + + // Offsets: [0, 1, 2, 3, 4] (1 byte each) + assert_eq!(metadata[2], 0); // First offset + assert_eq!(metadata[3], 1); // Second offset + assert_eq!(metadata[4], 2); // Third offset + assert_eq!(metadata[5], 3); // Fourth offset + assert_eq!(metadata[6], 4); // One-past-the-end offset + + // Dictionary bytes: "abcd" + assert_eq!(&metadata[7..11], b"abcd"); + } + + #[test] + fn test_nested_array() { + let value = json!({ + "a": [1, 2, 3], + "b": 4 + }); + + let metadata = create_metadata(&value, false).unwrap(); + + // Header: version=1, sorted_strings=0, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x01); + + // Dictionary size: 2 keys (a, b) + assert_eq!(metadata[1], 2); + + // Offsets: [0, 1, 2] (1 byte each) + assert_eq!(metadata[2], 0); // First offset + assert_eq!(metadata[3], 1); // Second offset + assert_eq!(metadata[4], 2); // One-past-the-end offset + + // Dictionary bytes: "ab" + assert_eq!(&metadata[5..7], b"ab"); + } + + #[test] + fn test_complex_nested() { + let value = json!({ + "a": { + "b": [1, 2, 3], + "c": 4 + }, + "d": 5 + }); + + let metadata = create_metadata(&value, false).unwrap(); + + // Header: version=1, sorted_strings=0, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x01); + + // Dictionary size: 4 keys (a, b, c, d) + assert_eq!(metadata[1], 4); + + // Offsets: [0, 1, 2, 3, 4] (1 byte each) + assert_eq!(metadata[2], 0); // First offset + assert_eq!(metadata[3], 1); // Second offset + assert_eq!(metadata[4], 2); // Third offset + assert_eq!(metadata[5], 3); // Fourth offset + assert_eq!(metadata[6], 4); // One-past-the-end offset + + // Dictionary bytes: "abcd" + assert_eq!(&metadata[7..11], b"abcd"); + } + + #[test] + fn test_sorted_keys() { + let value = json!({ + "c": 3, + "a": 1, + "b": 2 + }); + + let metadata = create_metadata(&value, true).unwrap(); + + // Header: version=1, sorted_strings=1, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x11); + + // Dictionary size: 3 keys + assert_eq!(metadata[1], 3); + + // Offsets: [0, 1, 2, 3] (1 byte each) + assert_eq!(metadata[2], 0); // First offset + assert_eq!(metadata[3], 1); // Second offset + assert_eq!(metadata[4], 2); // Third offset + assert_eq!(metadata[5], 3); // One-past-the-end offset + + // Dictionary bytes: "abc" (sorted) + assert_eq!(&metadata[6..9], b"abc"); + } + + #[test] + fn test_sorted_complex_object() { + let value = json!({ + "first_name": "John", + "email": "john.smith@example.com", + "last_name": "Smith" + }); + + let metadata = create_metadata(&value, true).unwrap(); + + // Header: version=1, sorted_strings=1, offset_size=1 (1 byte) + assert_eq!(metadata[0], 0x11); + + // Dictionary size: 3 keys + assert_eq!(metadata[1], 3); + + // Offsets: [0, 5, 15, 24] (1 byte each) + assert_eq!(metadata[2], 0); // First offset for "email" + assert_eq!(metadata[3], 5); // Second offset for "first_name" + assert_eq!(metadata[4], 15); // Third offset for "last_name" + assert_eq!(metadata[5], 24); // One-past-the-end offset + + // Dictionary bytes: "emailfirst_namelast_name" + assert_eq!(&metadata[6..30], b"emailfirst_namelast_name"); + } +} \ No newline at end of file diff --git a/arrow-variant/src/reader/mod.rs b/arrow-variant/src/reader/mod.rs new file mode 100644 index 000000000000..b4b64a9051c0 --- /dev/null +++ b/arrow-variant/src/reader/mod.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Reading JSON and converting to Variant +//! +use arrow_array::{Array, StructArray}; +use arrow_schema::extension::Variant; +use serde_json::Value; +use crate::error::Error; +use crate::metadata::{create_metadata, parse_metadata}; +use crate::encoder::encode_json; +use crate::variant_utils::create_variant_array; +#[allow(unused_imports)] +use crate::decoder::decode_value; +#[allow(unused_imports)] +use std::collections::HashMap; + +/// Converts a JSON string to a Variant +/// +/// # Example +/// +/// ``` +/// use arrow_variant::from_json; +/// +/// let json_str = r#"{"name": "John", "age": 30, "city": "New York"}"#; +/// let variant = from_json(json_str).unwrap(); +/// +/// // Access variant metadata and value +/// println!("Metadata length: {}", variant.metadata().len()); +/// println!("Value length: {}", variant.value().len()); +/// ``` +pub fn from_json(json_str: &str) -> Result { + // Parse the JSON string + let value: Value = serde_json::from_str(json_str)?; + + // Use the value-based function + from_json_value(&value) +} + +/// Converts an array of JSON strings to a StructArray with variant extension type +/// +/// # Example +/// +/// ``` +/// use arrow_variant::from_json_array; +/// use arrow_array::array::Array; +/// +/// let json_strings = vec![ +/// r#"{"name": "John", "age": 30}"#, +/// r#"{"name": "Jane", "age": 28}"#, +/// ]; +/// +/// let variant_array = from_json_array(&json_strings).unwrap(); +/// assert_eq!(variant_array.len(), 2); +/// ``` +pub fn from_json_array(json_strings: &[&str]) -> Result { + if json_strings.is_empty() { + return Err(Error::EmptyInput); + } + + // Parse each JSON string to a Value + let values: Result, _> = json_strings + .iter() + .map(|json_str| serde_json::from_str::(json_str).map_err(Error::from)) + .collect(); + + // Convert the values to a StructArray with variant extension type + from_json_value_array(&values?) +} + +/// Converts a JSON Value object directly to a Variant +/// +/// # Example +/// +/// ``` +/// use arrow_variant::from_json_value; +/// use serde_json::json; +/// +/// let value = json!({"name": "John", "age": 30, "city": "New York"}); +/// let variant = from_json_value(&value).unwrap(); +/// +/// // Access variant metadata and value +/// println!("Metadata length: {}", variant.metadata().len()); +/// println!("Value length: {}", variant.value().len()); +/// ``` +pub fn from_json_value(value: &Value) -> Result { + // Create metadata from the JSON value + let metadata = create_metadata(value, false)?; + + // Parse the metadata to get a key-to-id mapping + let key_mapping = parse_metadata(&metadata)?; + + // Encode the JSON value to binary format + let value_bytes = encode_json(value, &key_mapping)?; + + // Create the Variant with metadata and value + Ok(Variant::new(metadata, value_bytes)) +} + +/// Converts an array of JSON Value objects to a StructArray with variant extension type +/// +/// # Example +/// +/// ``` +/// use arrow_variant::from_json_value_array; +/// use serde_json::json; +/// use arrow_array::array::Array; +/// +/// let values = vec![ +/// json!({"name": "John", "age": 30}), +/// json!({"name": "Jane", "age": 28}), +/// ]; +/// +/// let variant_array = from_json_value_array(&values).unwrap(); +/// assert_eq!(variant_array.len(), 2); +/// ``` +pub fn from_json_value_array(values: &[Value]) -> Result { + if values.is_empty() { + return Err(Error::EmptyInput); + } + + // Convert each JSON value to a Variant + let variants: Result, _> = values + .iter() + .map(|value| from_json_value(value)) + .collect(); + + let variants = variants?; + + // Create a StructArray with the variants + create_variant_array(variants) + .map_err(|e| Error::VariantArrayCreation(e)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::variant_utils::get_variant; + + #[test] + fn test_from_json() { + let json_str = r#"{"name": "John", "age": 30}"#; + let variant = from_json(json_str).unwrap(); + + // Verify the metadata has the expected keys + assert!(!variant.metadata().is_empty()); + + // Verify the value is not empty + assert!(!variant.value().is_empty()); + + // Verify the first byte is an object header + // Object type (2) with default sizes + assert_eq!(variant.value()[0], 0b00000010); + } + + #[test] + fn test_from_json_array() { + let json_strings = vec![ + r#"{"name": "John", "age": 30}"#, + r#"{"name": "Jane", "age": 28}"#, + ]; + + let variant_array = from_json_array(&json_strings).unwrap(); + + // Verify array length + assert_eq!(variant_array.len(), 2); + + // Verify the values are properly encoded + for i in 0..variant_array.len() { + let variant = get_variant(&variant_array, i).unwrap(); + assert!(!variant.value().is_empty()); + // First byte should be an object header + assert_eq!(variant.value()[0], 0b00000010); + } + } + + #[test] + fn test_from_json_error() { + let invalid_json = r#"{"name": "John", "age": }"#; // Missing value + let result = from_json(invalid_json); + assert!(result.is_err()); + } + + #[test] + fn test_complex_json() { + let json_str = r#"{ + "name": "John", + "age": 30, + "active": true, + "scores": [85, 90, 78], + "address": { + "street": "123 Main St", + "city": "Anytown", + "zip": 12345 + }, + "tags": ["developer", "rust"] + }"#; + + let variant = from_json(json_str).unwrap(); + + // Verify the metadata has the expected keys + assert!(!variant.metadata().is_empty()); + + // Verify the value is not empty + assert!(!variant.value().is_empty()); + + // Verify the first byte is an object header + // Object type (2) with default sizes + assert_eq!(variant.value()[0], 0b00000010); + } +} \ No newline at end of file diff --git a/arrow-variant/src/variant_utils.rs b/arrow-variant/src/variant_utils.rs new file mode 100644 index 000000000000..3191fda027e8 --- /dev/null +++ b/arrow-variant/src/variant_utils.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for working with Variant as a StructArray + +use arrow_array::{Array, ArrayRef, BinaryArray, StructArray}; +use arrow_array::builder::BinaryBuilder; +use arrow_schema::{ArrowError, DataType, Field}; +use arrow_schema::extension::Variant; +use std::sync::Arc; + +/// Validate that a struct array can be used as a variant array +pub fn validate_struct_array(array: &StructArray) -> Result<(), ArrowError> { + // Check that the struct has both metadata and value fields + let fields = array.fields(); + + if fields.len() != 2 { + return Err(ArrowError::InvalidArgumentError( + "Variant struct must have exactly two fields".to_string(), + )); + } + + let metadata_field = fields + .iter() + .find(|f| f.name() == "metadata") + .ok_or_else(|| { + ArrowError::InvalidArgumentError( + "Variant struct must have a field named 'metadata'".to_string(), + ) + })?; + + let value_field = fields + .iter() + .find(|f| f.name() == "value") + .ok_or_else(|| { + ArrowError::InvalidArgumentError( + "Variant struct must have a field named 'value'".to_string(), + ) + })?; + + // Check field types + match (metadata_field.data_type(), value_field.data_type()) { + (DataType::Binary, DataType::Binary) | (DataType::LargeBinary, DataType::LargeBinary) => { + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Variant struct fields must both be Binary or LargeBinary".to_string(), + )), + } +} + +/// Extract a Variant object from a struct array at the given index +pub fn get_variant(array: &StructArray, index: usize) -> Result { + // Verify index is valid + if index >= array.len() { + return Err(ArrowError::InvalidArgumentError( + "Index out of bounds".to_string(), + )); + } + + // Skip if null + if array.is_null(index) { + return Err(ArrowError::InvalidArgumentError( + "Cannot extract variant from null value".to_string(), + )); + } + + // Get metadata and value columns + let metadata_array = array + .column_by_name("metadata") + .ok_or_else(|| ArrowError::InvalidArgumentError("Missing metadata field".to_string()))?; + + let value_array = array + .column_by_name("value") + .ok_or_else(|| ArrowError::InvalidArgumentError("Missing value field".to_string()))?; + + // Extract binary data + let metadata = extract_binary_data(metadata_array, index)?; + let value = extract_binary_data(value_array, index)?; + + Ok(Variant::new(metadata, value)) +} + +/// Extract binary data from a binary array at the specified index +fn extract_binary_data(array: &ArrayRef, index: usize) -> Result, ArrowError> { + match array.data_type() { + DataType::Binary => { + let binary_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::InvalidArgumentError("Failed to downcast binary array".to_string()) + })?; + Ok(binary_array.value(index).to_vec()) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported binary type: {}", + array.data_type() + ))), + } +} + +/// Create a variant struct array from a collection of variants +pub fn create_variant_array( + variants: Vec +) -> Result { + if variants.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Cannot create variant array from empty variants".to_string(), + )); + } + + // Create binary builders for metadata and value + let mut metadata_builder = BinaryBuilder::new(); + let mut value_builder = BinaryBuilder::new(); + + // Add variants to builders + for variant in &variants { + metadata_builder.append_value(variant.metadata()); + value_builder.append_value(variant.value()); + } + + // Create arrays + let metadata_array = metadata_builder.finish(); + let value_array = value_builder.finish(); + + // Create fields + let fields = vec![ + Field::new("metadata", DataType::Binary, false), + Field::new("value", DataType::Binary, false), + ]; + + // Create arrays vector + let arrays: Vec = vec![Arc::new(metadata_array), Arc::new(value_array)]; + + // Build struct array + let struct_array = StructArray::try_new(fields.into(), arrays, None)?; + + Ok(struct_array) +} + +/// Create an empty variant struct array with given capacity +pub fn create_empty_variant_array(capacity: usize) -> Result { + // Create binary builders for metadata and value + let mut metadata_builder = BinaryBuilder::with_capacity(capacity, 0); + let mut value_builder = BinaryBuilder::with_capacity(capacity, 0); + + // Create arrays + let metadata_array = metadata_builder.finish(); + let value_array = value_builder.finish(); + + // Create fields + let fields = vec![ + Field::new("metadata", DataType::Binary, false), + Field::new("value", DataType::Binary, false), + ]; + + // Create arrays vector + let arrays: Vec = vec![Arc::new(metadata_array), Arc::new(value_array)]; + + // Build struct array + StructArray::try_new(fields.into(), arrays, None) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::Array; + use crate::metadata::create_test_metadata; + + #[test] + fn test_variant_array_creation() { + // Create metadata and value for each variant + let metadata = create_test_metadata(); + + // Create variants with different values + let variants = vec![ + Variant::new(metadata.clone(), b"null".to_vec()), + Variant::new(metadata.clone(), b"true".to_vec()), + Variant::new(metadata.clone(), b"{\"a\": 1}".to_vec()), + ]; + + // Create a VariantArray + let variant_array = create_variant_array(variants.clone()).unwrap(); + + // Access variants from the array + assert_eq!(variant_array.len(), 3); + + let retrieved = get_variant(&variant_array, 0).unwrap(); + assert_eq!(retrieved.metadata(), &metadata); + assert_eq!(retrieved.value(), b"null"); + + let retrieved = get_variant(&variant_array, 1).unwrap(); + assert_eq!(retrieved.metadata(), &metadata); + assert_eq!(retrieved.value(), b"true"); + } + + #[test] + fn test_validate_struct_array() { + // Create metadata and value for each variant + let metadata = create_test_metadata(); + + // Create variants with different values + let variants = vec![ + Variant::new(metadata.clone(), b"null".to_vec()), + Variant::new(metadata.clone(), b"true".to_vec()), + ]; + + // Create a VariantArray + let variant_array = create_variant_array(variants.clone()).unwrap(); + + // Validate it + assert!(validate_struct_array(&variant_array).is_ok()); + } + + #[test] + fn test_get_variant_error() { + // Create an empty array + let empty_array = create_empty_variant_array(0).unwrap(); + + // Should error when trying to get a variant from an empty array + let result = get_variant(&empty_array, 0); + assert!(result.is_err()); + } +} \ No newline at end of file diff --git a/arrow-variant/src/writer/mod.rs b/arrow-variant/src/writer/mod.rs new file mode 100644 index 000000000000..c258878b9b6d --- /dev/null +++ b/arrow-variant/src/writer/mod.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Writing Variant data to JSON + +use arrow_array::{Array, StructArray}; +use arrow_schema::extension::Variant; +use serde_json::Value; +use crate::error::Error; +use crate::decoder::decode_json; +use crate::variant_utils::get_variant; + +/// Converts a Variant to a JSON Value +/// +/// # Examples +/// +/// ``` +/// use arrow_variant::reader::from_json; +/// use arrow_variant::writer::to_json_value; +/// use serde_json::json; +/// +/// let json_str = r#"{"name":"John","age":30}"#; +/// let variant = from_json(json_str).unwrap(); +/// let value = to_json_value(&variant).unwrap(); +/// assert_eq!(value, json!({"name":"John","age":30})); +/// ``` +pub fn to_json_value(variant: &Variant) -> Result { + // Decode the variant binary data to a JSON value + decode_json(variant.value(), variant.metadata()) +} + +/// Converts a StructArray with variant extension type to an array of JSON Values +/// +/// # Example +/// +/// ``` +/// use arrow_variant::{from_json_array, to_json_value_array}; +/// use serde_json::json; +/// +/// let json_strings = vec![ +/// r#"{"name": "John", "age": 30}"#, +/// r#"{"name": "Jane", "age": 28}"#, +/// ]; +/// +/// let variant_array = from_json_array(&json_strings).unwrap(); +/// let values = to_json_value_array(&variant_array).unwrap(); +/// assert_eq!(values, vec![ +/// json!({"name": "John", "age": 30}), +/// json!({"name": "Jane", "age": 28}) +/// ]); +/// ``` +pub fn to_json_value_array(variant_array: &StructArray) -> Result, Error> { + let mut result = Vec::with_capacity(variant_array.len()); + for i in 0..variant_array.len() { + if variant_array.is_null(i) { + result.push(Value::Null); + continue; + } + + let variant = get_variant(variant_array, i) + .map_err(|e| Error::VariantRead(e.to_string()))?; + result.push(to_json_value(&variant)?); + } + Ok(result) +} + +/// Converts a Variant to a JSON string +/// +/// # Examples +/// +/// ``` +/// use arrow_variant::reader::from_json; +/// use arrow_variant::writer::to_json; +/// +/// let json_str = r#"{"name":"John","age":30}"#; +/// let variant = from_json(json_str).unwrap(); +/// let result = to_json(&variant).unwrap(); +/// assert_eq!(serde_json::to_string_pretty(&serde_json::from_str::(json_str).unwrap()).unwrap(), +/// serde_json::to_string_pretty(&serde_json::from_str::(&result).unwrap()).unwrap()); +/// ``` +pub fn to_json(variant: &Variant) -> Result { + // Use the value-based function and convert to string + let value = to_json_value(variant)?; + Ok(value.to_string()) +} + +/// Converts a StructArray with variant extension type to an array of JSON strings +/// +/// # Example +/// +/// ``` +/// use arrow_variant::{from_json_array, to_json_array}; +/// +/// let json_strings = vec![ +/// r#"{"name": "John", "age": 30}"#, +/// r#"{"name": "Jane", "age": 28}"#, +/// ]; +/// +/// let variant_array = from_json_array(&json_strings).unwrap(); +/// let json_array = to_json_array(&variant_array).unwrap(); +/// +/// // Note that the output JSON strings may have different formatting +/// // but they are semantically equivalent +/// ``` +pub fn to_json_array(variant_array: &StructArray) -> Result, Error> { + // Use the value-based function and convert each value to a string + to_json_value_array(variant_array).map(|values| + values.into_iter().map(|v| v.to_string()).collect() + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::reader::from_json; + use serde_json::json; + + #[test] + fn test_to_json() { + let json_str = r#"{"name":"John","age":30}"#; + let variant = from_json(json_str).unwrap(); + + let result = to_json(&variant).unwrap(); + + // Parse both to Value to compare them structurally + let original: Value = serde_json::from_str(json_str).unwrap(); + let result_value: Value = serde_json::from_str(&result).unwrap(); + + assert_eq!(original, result_value); + } + + #[test] + fn test_to_json_array() { + let json_strings = vec![ + r#"{"name":"John","age":30}"#, + r#"{"name":"Jane","age":28}"#, + ]; + + // Create variant array from JSON strings + let variant_array = crate::reader::from_json_array(&json_strings).unwrap(); + + // Convert back to JSON + let result = to_json_array(&variant_array).unwrap(); + + // Verify the result + assert_eq!(result.len(), 2); + + // Parse both to Value to compare them structurally + for (i, (original, result)) in json_strings.iter().zip(result.iter()).enumerate() { + let original_value: Value = serde_json::from_str(original).unwrap(); + let result_value: Value = serde_json::from_str(result).unwrap(); + + assert_eq!( + original_value, + result_value, + "JSON values at index {} should be equal", + i + ); + } + } + + #[test] + fn test_roundtrip() { + let complex_json = json!({ + "array": [1, 2, 3], + "nested": {"a": true, "b": null}, + "string": "value" + }); + + let complex_str = complex_json.to_string(); + + let variant = from_json(&complex_str).unwrap(); + let json = to_json(&variant).unwrap(); + + // Parse both to Value to compare them structurally + let original: Value = serde_json::from_str(&complex_str).unwrap(); + let result: Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(original, result); + } + + #[test] + fn test_special_characters() { + // Test with JSON containing special characters + let special_json = json!({ + "unicode": "こんにちは世界", // Hello world in Japanese + "escaped": "Line 1\nLine 2\t\"quoted\"", + "emoji": "🚀🌟⭐" + }); + + let special_str = special_json.to_string(); + + let variant = from_json(&special_str).unwrap(); + let json = to_json(&variant).unwrap(); + + // Parse both to Value to compare them structurally + let original: Value = serde_json::from_str(&special_str).unwrap(); + let result: Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(original, result); + } +} \ No newline at end of file diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 1d2737a0c629..1d8e39379e41 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -87,7 +87,7 @@ tokio = { version = "1.0", default-features = false, features = ["macros", "rt-m rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } object_store = { version = "0.12.0", default-features = false, features = ["azure", "fs"] } sysinfo = { version = "0.34.0", default-features = false, features = ["system"] } - +arrow-variant = { path = "../arrow-variant", features = ["default"] } [package.metadata.docs.rs] all-features = true @@ -98,7 +98,7 @@ lz4 = ["lz4_flex"] # Enable arrow reader/writer APIs arrow = ["base64", "arrow-array", "arrow-buffer", "arrow-cast", "arrow-data", "arrow-schema", "arrow-select", "arrow-ipc"] # Enable support for arrow canonical extension types -arrow_canonical_extension_types = ["arrow-schema?/canonical_extension_types"] +arrow_canonical_extension_types = ["arrow-schema?/canonical_extension_types", "arrow-array?/canonical_extension_types"] # Enable CLI tools cli = ["json", "base64", "clap", "arrow-csv", "serde"] # Enable JSON APIs diff --git a/parquet/src/arrow/array_reader/builder.rs b/parquet/src/arrow/array_reader/builder.rs index 945f62526a7e..d0e38a72e4ea 100644 --- a/parquet/src/arrow/array_reader/builder.rs +++ b/parquet/src/arrow/array_reader/builder.rs @@ -27,7 +27,7 @@ use crate::arrow::array_reader::{ FixedSizeListArrayReader, ListArrayReader, MapArrayReader, NullArrayReader, PrimitiveArrayReader, RowGroups, StructArrayReader, }; -use crate::arrow::schema::{ParquetField, ParquetFieldType}; +use crate::arrow::schema::{parquet_to_arrow_field, ParquetField, ParquetFieldType}; use crate::arrow::ProjectionMask; use crate::basic::Type as PhysicalType; use crate::data_type::{BoolType, DoubleType, FloatType, Int32Type, Int64Type, Int96Type}; diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 8bbe175dafb8..a945bd86b191 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -4583,4 +4583,255 @@ mod tests { assert_eq!(c0.len(), c1.len()); c0.iter().zip(c1.iter()).for_each(|(l, r)| assert_eq!(l, r)); } + + #[test] + #[cfg(feature = "arrow_canonical_extension_types")] + fn test_variant_roundtrip() -> Result<()> { + // This test demonstrates a 9x2 table with: + // - Column 1 (variant_data): Variant type containing different JSON-like values + // - Column 2 (int_data): Simple integers from 100 to 900 + // + // Table structure: + // | Variant | Int | + // |--------------------------------|----------| + // | null | 100 | + // | true | 200 | + // | false | 300 | + // | 12 | 400 | + // | -9876543210 | 500 | + // | 4.5678E123 | 600 | + // | "string value" | 700 | + // | {"a":1,"b":{"e":-4,"f":5.5}} | 800 | + // | [1,-2,4.5,-6.7,"str",true] | 900 | + + use arrow_array::{Int32Array, RecordBatch, Array, StructArray}; + use arrow_schema::{DataType, Field, Schema}; + use arrow_schema::extension::Variant; + use arrow_variant::variant_utils::{create_variant_array, get_variant}; + use bytes::Bytes; + use std::sync::Arc; + use crate::arrow::arrow_writer::ArrowWriter; + + // Value metadata - needs to follow the spec format + let value_metadata = vec![0x01, 0x01, 0x00, 0x03, b'k', b'e', b'y']; // [header, size, offsets, "key"] + let variant_type = Variant::new(value_metadata.clone(), vec![]); + let sample_json_values = vec![ + "null", + "true", + "false", + "12", + "-9876543210", + "4.5678E123", + "\"string value\"", + "{\"a\": 1, \"b\": {\"e\": -4, \"f\": 5.5}, \"c\": true}", + "[1, -2, 4.5, -6.7, \"str\", true]"]; + + let original_variants: Vec = sample_json_values + .iter() + .map(|json| Variant::new(value_metadata.clone(), json.as_bytes().to_vec())) + .collect(); + + // Create variant struct array from variants + let variant_array = create_variant_array(original_variants.clone()) + .expect("Failed to create variant array"); + + let int_array = Int32Array::from(vec![100, 200, 300, 400, 500, 600, 700, 800, 900]); + + // Use the fields from variant_array to create the struct type + let struct_type = DataType::Struct(variant_array.fields().clone()); + let fields = variant_array.fields(); + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name(), "metadata"); + assert_eq!(fields[0].data_type(), &DataType::Binary); + assert_eq!(fields[1].name(), "value"); + assert_eq!(fields[1].data_type(), &DataType::Binary); + + // Add extension type information with try_with_extension_type + let mut variant_field = Field::new("variant_data", struct_type, false); + variant_field.try_with_extension_type(variant_type.clone()).unwrap(); + println!("Variant field: {:#?}", variant_field); + + let schema = Schema::new(vec![ + variant_field, + Field::new("int_data", DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(variant_array), Arc::new(int_array)] + )?; + + let mut buffer = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut buffer, batch.schema(), None)?; + writer.write(&batch)?; + writer.close()?; + + let builder = ParquetRecordBatchReaderBuilder::try_new(Bytes::from(buffer.clone()))?; + let mut reader = builder.build()?; + let out = reader.next().unwrap()?; + + let schema = out.schema(); + let field = schema.field(0).clone(); + + assert!(field.metadata().contains_key("ARROW:extension:name")); + assert_eq!(field.metadata().get("ARROW:extension:name").unwrap(), "arrow.variant"); + + // Get the struct array from the output + let output_variant_array = out.column(0).as_any().downcast_ref::().unwrap(); + + // Verify each variant + for i in 0..original_variants.len() { + let variant = get_variant(output_variant_array, i).expect("Failed to get variant"); + assert_eq!(variant.metadata(), original_variants[i].metadata()); + assert_eq!(variant.value(), original_variants[i].value()); + } + + let int_array = out.column(1).as_any().downcast_ref::().unwrap(); + for i in 0..9 { + assert_eq!(int_array.value(i), (i as i32 + 1) * 100); + } + + Ok(()) + } + + + // #[test] + // #[cfg(feature = "arrow_canonical_extension_types")] + // fn test_read_unshredded_variant() -> Result<(), Box> { + // use arrow_array::{Array, StructArray, BinaryArray}; + // use arrow_schema::extension::Variant; + // use bytes::Bytes; + // use std::fs::File; + // use std::io::Read; + // use std::sync::Arc; + // use crate::arrow::arrow_reader::ParquetRecordBatchReader; + // use crate::util::test_common::file_util::get_test_file; + // use arrow_variant::writer::to_json; + // use arrow_variant::reader::from_json; + // // Get the primitive.parquet test file + // // The file contains id integer and var variant with variant being integer 34 and unshredded + // let test_file = get_test_file("primitive.parquet"); + + // let mut reader = ParquetRecordBatchReader::try_new(test_file, 1024)?; + // let batch = reader.next().expect("Expected to read a batch")?.clone(); + + // println!("Batch schema: {:#?}", batch.schema()); + // println!("Batch rows: {}", batch.num_rows()); + + // let variant_col = batch.column_by_name("var") + // .expect("Column 'var' not found in Parquet file"); + + // println!("Variant column type: {:#?}", variant_col.data_type()); + + // let struct_array = variant_col.as_any().downcast_ref::() + // .expect("Expected variant column to be a struct array"); + + // let metadata_field = struct_array.column_by_name("metadata") + // .expect("metadata field not found in variant column"); + // let value_field = struct_array.column_by_name("value") + // .expect("value field not found in variant column"); + + // let metadata_binary = metadata_field.as_any().downcast_ref::() + // .expect("Expected metadata to be a binary array"); + // let value_binary = value_field.as_any().downcast_ref::() + // .expect("Expected value to be a binary array"); + + // let metadata = metadata_binary.value(0); + // let value = value_binary.value(0); + + // let variant = Variant::new(metadata.to_vec(), value.to_vec()); + // println!("Metadata bytes: {:?}", variant.metadata()); + // println!("Value bytes: {:?}", variant.value()); + // let json_str = to_json(&variant)?; + // println!("JSON: {}", json_str); + + // assert_eq!(json_str, "34", "Expected JSON value to be 34, got {}", json_str); + + // let json_value = "34"; + // let variant_from_json = from_json(&json_value)?; + + // println!("Metadata bytes: {:?}", variant_from_json.metadata()); + // println!("Value bytes: {:?}", variant_from_json.value()); + + // assert_eq!(variant.metadata(), variant_from_json.metadata(), "Metadata bytes do not match"); + // assert_eq!(variant.value(), variant_from_json.value(), "Value bytes do not match"); + + // Ok(()) + // } + + #[test] + #[cfg(feature = "arrow_canonical_extension_types")] + fn test_json_variant_parquet_roundtrip() -> Result<()> { + use arrow_array::{RecordBatch, Array, StructArray}; + use arrow_schema::{DataType, Field, Schema}; + use arrow_schema::extension::Variant; + use bytes::Bytes; + use std::sync::Arc; + use crate::arrow::arrow_writer::ArrowWriter; + use arrow_variant::reader::from_json_value_array; + use arrow_variant::variant_utils::get_variant; + use serde_json::{json, Value}; + + // Create sample JSON values + let json_values = vec![ + json!(null), + json!(42), + json!(-123.456), + json!(true), + json!("hello world"), + json!({"name": "Alice", "age": 30, "active": true}), + json!([1, 2, 3, {"key": "value"}]), + json!({"nested": {"a": 1, "b": [true, false]}, "list": [1, 2, 3]}) + ]; + + // Convert JSON values to StructArray with variant extension type + let variant_array = from_json_value_array(&json_values) + .expect("Failed to create StructArray from JSON values"); + + // Create schema with variant field + let struct_type = variant_array.data_type().clone(); + let mut variant_field = Field::new("json_data", struct_type, true); + + // Create a Variant instance for extension type + let variant_type = Variant::new(vec![], vec![]); + variant_field.try_with_extension_type(variant_type).unwrap(); + + let schema = Schema::new(vec![variant_field]); + + // Create record batch + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(variant_array)] + )?; + + // Write to parquet + let mut buffer = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut buffer, batch.schema(), None)?; + writer.write(&batch)?; + writer.close()?; + + // Read back from parquet + let builder = ParquetRecordBatchReaderBuilder::try_new(Bytes::from(buffer))?; + let mut reader = builder.build()?; + let result_batch = reader.next().unwrap()?; + + // Verify the schema + let schema = result_batch.schema(); + let field = schema.field(0).clone(); + + assert!(field.metadata().contains_key("ARROW:extension:name")); + assert_eq!(field.metadata().get("ARROW:extension:name").unwrap(), "arrow.variant"); + + // Get the struct array from the output + let output_variant_array = result_batch.column(0).as_any().downcast_ref::().unwrap(); + + // Verify each variant + for i in 0..json_values.len() { + let variant = get_variant(output_variant_array, i).expect("Failed to get variant"); + assert!(!variant.metadata().is_empty(), "Variant metadata should not be empty"); + assert!(!variant.value().is_empty(), "Variant value should not be empty"); + } + + Ok(()) + } } diff --git a/parquet/src/arrow/arrow_writer/byte_array.rs b/parquet/src/arrow/arrow_writer/byte_array.rs index 2d23ad8510f9..7a5ffc8744a0 100644 --- a/parquet/src/arrow/arrow_writer/byte_array.rs +++ b/parquet/src/arrow/arrow_writer/byte_array.rs @@ -591,4 +591,4 @@ where max = max.max(val); } Some((min.as_ref().to_vec().into(), max.as_ref().to_vec().into())) -} +} \ No newline at end of file diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 89c42f5eaf92..87e9299b92ff 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use arrow_ipc::writer; #[cfg(feature = "arrow_canonical_extension_types")] -use arrow_schema::extension::{Json, Uuid}; +use arrow_schema::extension::{Json, Uuid, Variant}; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ @@ -396,8 +396,15 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result ret.try_with_extension_type(Uuid)?, - LogicalType::Json => ret.try_with_extension_type(Json::default())?, + LogicalType::Uuid => ret = ret.with_extension_type(Uuid), + LogicalType::Json => ret = ret.with_extension_type(Json::default()), + LogicalType::Variant { specification_version } => { + // For Variant type, we need to create a struct with two binary fields + let metadata_field = Field::new("metadata", DataType::Binary, false); + let value_field = Field::new("value", DataType::Binary, false); + let struct_type = DataType::Struct(Fields::from(vec![metadata_field, value_field])); + ret = Field::new(parquet_column.name(), struct_type, field.nullable).with_extension_type(Variant::new(vec![], vec![])); + } _ => {} } } @@ -709,6 +716,31 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { if fields.is_empty() { return Err(arrow_err!("Parquet does not support writing empty structs",)); } + + #[cfg(feature = "arrow_canonical_extension_types")] + if let Ok(variant) = field.try_extension_type::() { + // Verify we have a struct with exactly two fields + if let DataType::Struct(fields) = field.data_type() { + let metadata_field = Type::primitive_type_builder("metadata", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build()?; + let value_field = Type::primitive_type_builder("value", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build()?; + + return Ok(Type::group_type_builder(name) + .with_fields(vec![Arc::new(metadata_field), Arc::new(value_field)]) + .with_logical_type(Some(LogicalType::Variant { + specification_version: None, + })) + .with_repetition(repetition) + .with_id(id) + .build()?); + } else { + return Err(arrow_err!("Variant data type must be a Struct, found {:?}", field.data_type())); + } + } + // recursively convert children to types/nodes let fields = fields .iter() @@ -2260,4 +2292,107 @@ mod tests { Ok(()) } + + #[test] + #[cfg(feature = "arrow_canonical_extension_types")] + fn arrow_variant_to_parquet_variant() -> Result<()> { + // Create a sample Variant for testing + let metadata = vec![1, 2, 3]; + let value = vec![4, 5, 6]; + + // Create Arrow schema with a Struct field that has Variant extension type + let struct_fields = vec![ + Field::new("metadata", DataType::Binary, false), + Field::new("value", DataType::Binary, false) + ]; + let field = Field::new("variant", DataType::Struct(struct_fields.into()), false) + .with_extension_type(Variant::new(metadata.clone(), value.clone())); + + let arrow_schema = Schema::new(vec![field]); + + // Convert Arrow schema to Parquet schema + let parquet_schema = ArrowSchemaConverter::new().convert(&arrow_schema)?; + + // Get the parquet schema and compare against expected schema + let message_type = parquet_schema.root_schema(); + + // The variant group should be a child of the root + let variant_field = &message_type.get_fields()[0]; + + // Check that it's a group + assert!(variant_field.is_group()); + + // Check logical type + assert_eq!( + variant_field.get_basic_info().logical_type(), + Some(LogicalType::Variant { + specification_version: None, + }) + ); + + // Check the fields + let fields = variant_field.get_fields(); + assert_eq!(fields.len(), 2); + + // Check metadata field + assert_eq!(fields[0].name(), "metadata"); + assert_eq!(fields[0].get_physical_type(), PhysicalType::BYTE_ARRAY); + + // Check value field + assert_eq!(fields[1].name(), "value"); + assert_eq!(fields[1].get_physical_type(), PhysicalType::BYTE_ARRAY); + + Ok(()) + } + + #[test] + #[cfg(feature = "arrow_canonical_extension_types")] + fn parquet_variant_to_arrow() -> Result<()> { + // Create a Parquet schema with Variant type + let metadata = vec![1, 2, 3]; + let value = vec![4, 5, 6]; + + // Build the Parquet schema manually + let metadata_field = Type::primitive_type_builder("metadata", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build()?; + + let value_field = Type::primitive_type_builder("value", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build()?; + + let variant_field = Type::group_type_builder("variant") + .with_fields(vec![Arc::new(metadata_field), Arc::new(value_field)]) + .with_logical_type(Some(LogicalType::Variant { + specification_version: None, + })) + .with_repetition(Repetition::REQUIRED) + .build()?; + + let message_type = Type::group_type_builder("schema") + .with_fields(vec![Arc::new(variant_field)]) + .build()?; + + let schema_descriptor = SchemaDescriptor::new(Arc::new(message_type)); + + // Get both columns (metadata and value) + let metadata_column = schema_descriptor.column(0); + let value_column = schema_descriptor.column(1); + + // Convert each column to Arrow field + let metadata_arrow_field = parquet_to_arrow_field(&metadata_column)?; + let value_arrow_field = parquet_to_arrow_field(&value_column)?; + + // Verify the fields + assert_eq!(metadata_arrow_field.name(), "metadata"); + assert_eq!(metadata_arrow_field.data_type(), &DataType::Binary); + assert!(!metadata_arrow_field.is_nullable()); + + assert_eq!(value_arrow_field.name(), "value"); + assert_eq!(value_arrow_field.data_type(), &DataType::Binary); + assert!(!value_arrow_field.is_nullable()); + + Ok(()) + } + } diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs index f1fed8f2a557..4106b68fdadf 100644 --- a/parquet/src/arrow/schema/primitive.rs +++ b/parquet/src/arrow/schema/primitive.rs @@ -334,4 +334,4 @@ fn from_fixed_len_byte_array( } _ => Ok(DataType::FixedSizeBinary(type_length)), } -} +} \ No newline at end of file diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index 99f122fe4c3e..8956c246a354 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -29,7 +29,7 @@ use crate::errors::{ParquetError, Result}; // Re-export crate::format types used in this module pub use crate::format::{ BsonType, DateType, DecimalType, EnumType, IntType, JsonType, ListType, MapType, NullType, - StringType, TimeType, TimeUnit, TimestampType, UUIDType, + StringType, TimeType, TimeUnit, TimestampType, UUIDType,VariantType }; // ---------------------------------------------------------------------- @@ -228,6 +228,11 @@ pub enum LogicalType { Uuid, /// A 16-bit floating point number. Float16, + /// A variant type. + Variant { + /// The version of the variant specification that the variant was written with. + specification_version: Option, + }, } // ---------------------------------------------------------------------- @@ -579,6 +584,7 @@ impl ColumnOrder { LogicalType::Unknown => SortOrder::UNDEFINED, LogicalType::Uuid => SortOrder::UNSIGNED, LogicalType::Float16 => SortOrder::SIGNED, + LogicalType::Variant { .. } => SortOrder::UNDEFINED, // TODO: consider variant sort order }, // Fall back to converted type None => Self::get_converted_sort_order(converted_type, physical_type), @@ -804,7 +810,7 @@ impl From for Option { ConvertedType::INT_64 => Some(parquet::ConvertedType::INT_64), ConvertedType::JSON => Some(parquet::ConvertedType::JSON), ConvertedType::BSON => Some(parquet::ConvertedType::BSON), - ConvertedType::INTERVAL => Some(parquet::ConvertedType::INTERVAL), + ConvertedType::INTERVAL => Some(parquet::ConvertedType::INTERVAL) } } } @@ -841,6 +847,9 @@ impl From for LogicalType { parquet::LogicalType::BSON(_) => LogicalType::Bson, parquet::LogicalType::UUID(_) => LogicalType::Uuid, parquet::LogicalType::FLOAT16(_) => LogicalType::Float16, + parquet::LogicalType::VARIANT(t) => LogicalType::Variant { + specification_version: t.specification_version, + }, } } } @@ -882,6 +891,11 @@ impl From for parquet::LogicalType { LogicalType::Bson => parquet::LogicalType::BSON(Default::default()), LogicalType::Uuid => parquet::LogicalType::UUID(Default::default()), LogicalType::Float16 => parquet::LogicalType::FLOAT16(Default::default()), + LogicalType::Variant { specification_version } => parquet::LogicalType::VARIANT(VariantType { + specification_version: Some(0), + }), + + } } } @@ -933,7 +947,8 @@ impl From> for ConvertedType { LogicalType::Bson => ConvertedType::BSON, LogicalType::Uuid | LogicalType::Float16 | LogicalType::Unknown => { ConvertedType::NONE - } + }, + LogicalType::Variant { .. } => ConvertedType::NONE, }, None => ConvertedType::NONE, } @@ -1182,6 +1197,9 @@ impl str::FromStr for LogicalType { "Interval parquet logical type not yet supported" )), "FLOAT16" => Ok(LogicalType::Float16), + "VARIANT" => Ok(LogicalType::Variant { + specification_version: Some(0), + }), other => Err(general_err!("Invalid parquet logical type {}", other)), } } @@ -1315,7 +1333,7 @@ mod tests { assert_eq!(ConvertedType::JSON.to_string(), "JSON"); assert_eq!(ConvertedType::BSON.to_string(), "BSON"); assert_eq!(ConvertedType::INTERVAL.to_string(), "INTERVAL"); - assert_eq!(ConvertedType::DECIMAL.to_string(), "DECIMAL") + assert_eq!(ConvertedType::DECIMAL.to_string(), "DECIMAL"); } #[test] @@ -1416,7 +1434,7 @@ mod tests { assert_eq!( ConvertedType::try_from(Some(parquet::ConvertedType::DECIMAL)).unwrap(), ConvertedType::DECIMAL - ) + ); } #[test] @@ -1511,7 +1529,7 @@ mod tests { assert_eq!( Some(parquet::ConvertedType::DECIMAL), ConvertedType::DECIMAL.into() - ) + ); } #[test] @@ -1683,7 +1701,7 @@ mod tests { .parse::() .unwrap(), ConvertedType::DECIMAL - ) + ); } #[test] @@ -1835,6 +1853,12 @@ mod tests { ConvertedType::from(Some(LogicalType::Unknown)), ConvertedType::NONE ); + assert_eq!( + ConvertedType::from(Some(LogicalType::Variant { + specification_version: Some(0), + })), + ConvertedType::NONE + ); } #[test] @@ -2218,7 +2242,13 @@ mod tests { check_sort_order(signed, SortOrder::SIGNED); // Undefined comparison - let undefined = vec![LogicalType::List, LogicalType::Map]; + let undefined = vec![ + LogicalType::List, + LogicalType::Map, + LogicalType::Variant { + specification_version: Some(0), + }, + ]; check_sort_order(undefined, SortOrder::UNDEFINED); } diff --git a/parquet/src/format.rs b/parquet/src/format.rs index 287d08b7a95c..05ec9bc51b03 100644 --- a/parquet/src/format.rs +++ b/parquet/src/format.rs @@ -1834,6 +1834,69 @@ impl crate::thrift::TSerializable for BsonType { } } +// +// VariantType +// + +/// Embedded Variant logical type annotation +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct VariantType { + pub specification_version: Option, +} + +impl VariantType { + pub fn new(specification_version: F1) -> VariantType where F1: Into> { + VariantType { + specification_version: specification_version.into(), + } + } + pub fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { + i_prot.read_struct_begin()?; + let mut f_1: Option = None; + loop { + let field_ident = i_prot.read_field_begin()?; + if field_ident.field_type == TType::Stop { + break; + } + let field_id = field_id(&field_ident)?; + match field_id { + 1 => { + let val = i_prot.read_i8()?; + f_1 = Some(val); + }, + _ => { + i_prot.skip(field_ident.field_type)?; + }, + }; + i_prot.read_field_end()?; + } + i_prot.read_struct_end()?; + let ret = VariantType { + specification_version: f_1, + }; + Ok(ret) + } + pub fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { + let struct_ident = TStructIdentifier::new("VariantType"); + o_prot.write_struct_begin(&struct_ident)?; + if let Some(fld_var) = self.specification_version { + o_prot.write_field_begin(&TFieldIdentifier::new("specification_version", TType::I08, 1))?; + o_prot.write_i8(fld_var)?; + o_prot.write_field_end()? + } + o_prot.write_field_stop()?; + o_prot.write_struct_end() + } +} + +impl Default for VariantType { + fn default() -> Self { + VariantType{ + specification_version: Some(0), + } + } +} + // // LogicalType // @@ -1854,6 +1917,7 @@ pub enum LogicalType { BSON(BsonType), UUID(UUIDType), FLOAT16(Float16Type), + VARIANT(VariantType), } impl crate::thrift::TSerializable for LogicalType { @@ -1966,6 +2030,13 @@ impl crate::thrift::TSerializable for LogicalType { } received_field_count += 1; }, + 16 => { + let val = VariantType::read_from_in_protocol(i_prot)?; + if ret.is_none() { + ret = Some(LogicalType::VARIANT(val)); + } + received_field_count += 1; + }, _ => { i_prot.skip(field_ident.field_type)?; received_field_count += 1; @@ -2070,6 +2141,11 @@ impl crate::thrift::TSerializable for LogicalType { f.write_to_out_protocol(o_prot)?; o_prot.write_field_end()?; }, + LogicalType::VARIANT(ref f) => { + o_prot.write_field_begin(&TFieldIdentifier::new("VARIANT", TType::Struct, 16))?; + f.write_to_out_protocol(o_prot)?; + o_prot.write_field_end()?; + } } o_prot.write_field_stop()?; o_prot.write_struct_end() @@ -5478,4 +5554,3 @@ impl crate::thrift::TSerializable for FileCryptoMetaData { o_prot.write_struct_end() } } - diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs index 44c742fca66e..198af3fc0411 100644 --- a/parquet/src/schema/printer.rs +++ b/parquet/src/schema/printer.rs @@ -327,6 +327,7 @@ fn print_logical_and_converted( LogicalType::Map => "MAP".to_string(), LogicalType::Float16 => "FLOAT16".to_string(), LogicalType::Unknown => "UNKNOWN".to_string(), + LogicalType::Variant { .. } => "VARIANT".to_string(), // TODO: add support for variant }, None => { // Also print converted type if it is available diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index 68492e19f437..2d691a5b4315 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -413,6 +413,16 @@ impl<'a> PrimitiveTypeBuilder<'a> { self.name )) } + (LogicalType::Variant { .. }, PhysicalType::BYTE_ARRAY) => { + } + (LogicalType::Variant { .. }, _) => { + return Err(general_err!( + "{:?} can only be applied to a BYTE_ARRAY type for field '{}'", + logical_type, + self.name + )); + } + (a, b) => { return Err(general_err!( "Cannot annotate {:?} from {} for field '{}'", @@ -1737,6 +1747,22 @@ mod tests { "Parquet error: UUID cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(16) field" ); } + + // TODO Test that Variant cannot be applied to primitive types + // result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY) + // .with_repetition(Repetition::REQUIRED) + // .with_logical_type(Some(LogicalType::Variant { + // metadata: vec![1, 2, 3], + // value: vec![0] + // })) + // .build(); + // assert!(result.is_err()); + // if let Err(e) = result { + // assert_eq!( + // format!("{e}"), + // "Parquet error: Variant { metadata: [1, 2, 3], value: [0] } cannot be applied to a primitive type for field 'foo'" + // ); + // } } #[test] @@ -1773,6 +1799,45 @@ mod tests { assert_eq!(tp.get_fields().len(), 2); assert_eq!(tp.get_fields()[0].name(), "f1"); assert_eq!(tp.get_fields()[1].name(), "f2"); + + + // Test Variant + let metadata = Type::primitive_type_builder("metadata", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(); + + let value = Type::primitive_type_builder("value", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(); + + let fields = vec![Arc::new(metadata), Arc::new(value)]; + let result = Type::group_type_builder("variant") + .with_repetition(Repetition::OPTIONAL) // The whole variant is optional + .with_logical_type(Some(LogicalType::Variant { + specification_version: Some(0), + })) + .with_fields(fields) + .with_id(Some(2)) + .build(); + assert!(result.is_ok()); + + let tp = result.unwrap(); + let basic_info = tp.get_basic_info(); + assert!(tp.is_group()); + assert!(!tp.is_primitive()); + assert_eq!(basic_info.repetition(), Repetition::OPTIONAL); + assert_eq!( + basic_info.logical_type(), + Some(LogicalType::Variant { + specification_version: Some(0), + }) + ); + assert_eq!(basic_info.id(), 2); + assert_eq!(tp.get_fields().len(), 2); + assert_eq!(tp.get_fields()[0].name(), "metadata"); + assert_eq!(tp.get_fields()[1].name(), "value"); } #[test] @@ -1855,13 +1920,31 @@ mod tests { .build()?; fields.push(Arc::new(bag)); + // Add a Variant type + let metadata = Type::primitive_type_builder("metadata", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build()?; + + let value = Type::primitive_type_builder("value", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .build()?; + + let variant = Type::group_type_builder("variant") + .with_repetition(Repetition::OPTIONAL) + .with_logical_type(Some(LogicalType::Variant { + specification_version: Some(0), + })) + .with_fields(vec![Arc::new(metadata), Arc::new(value)]) + .build()?; + fields.push(Arc::new(variant)); + let schema = Type::group_type_builder("schema") .with_repetition(Repetition::REPEATED) .with_fields(fields) .build()?; let descr = SchemaDescriptor::new(Arc::new(schema)); - let nleaves = 6; + let nleaves = 8; assert_eq!(descr.num_columns(), nleaves); // mdef mrep @@ -1872,9 +1955,13 @@ mod tests { // repeated group records 2 1 // required int64 item1 2 1 // optional boolean item2 3 1 - // repeated int32 item3 3 2 - let ex_max_def_levels = [0, 1, 1, 2, 3, 3]; - let ex_max_rep_levels = [0, 0, 1, 1, 1, 2]; + // repeated int32 item3 3 2 + // optional group variant 1 0 + // required byte_array metadata 1 0 + // required byte_array value 1 0 + + let ex_max_def_levels = [0, 1, 1, 2, 3, 3, 1, 1]; + let ex_max_rep_levels = [0, 0, 1, 1, 1, 2, 0, 0]; for i in 0..nleaves { let col = descr.column(i); @@ -1888,11 +1975,16 @@ mod tests { assert_eq!(descr.column(3).path().string(), "bag.records.item1"); assert_eq!(descr.column(4).path().string(), "bag.records.item2"); assert_eq!(descr.column(5).path().string(), "bag.records.item3"); + assert_eq!(descr.column(6).path().string(), "variant.metadata"); + assert_eq!(descr.column(7).path().string(), "variant.value"); assert_eq!(descr.get_column_root(0).name(), "a"); assert_eq!(descr.get_column_root(3).name(), "bag"); assert_eq!(descr.get_column_root(4).name(), "bag"); assert_eq!(descr.get_column_root(5).name(), "bag"); + assert_eq!(descr.get_column_root(6).name(), "variant"); + assert_eq!(descr.get_column_root(7).name(), "variant"); + Ok(()) } @@ -2341,4 +2433,20 @@ mod tests { let result_schema = from_thrift(&thrift_schema).unwrap(); assert_eq!(result_schema, Arc::new(expected_schema)); } + + #[test] + fn test_schema_type_thrift_conversion_variant() { + let message_type = " + message variant_test { + OPTIONAL group variant_field (VARIANT) { + REQUIRED BYTE_ARRAY metadata; + REQUIRED BYTE_ARRAY value; + } + } + "; + let expected_schema = parse_message_type(message_type).unwrap(); + let thrift_schema = to_thrift(&expected_schema).unwrap(); + let result_schema = from_thrift(&thrift_schema).unwrap(); + assert_eq!(result_schema, Arc::new(expected_schema)); + } }