Skip to content

Add Map support to arrow-avro #7451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions arrow-avro/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName};
use arrow_schema::{
ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit,
ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit,
};
use std::borrow::Cow;
use std::collections::HashMap;
Expand Down Expand Up @@ -45,6 +45,19 @@ pub struct AvroDataType {
}

impl AvroDataType {
/// Create a new [`AvroDataType`] with the given parts.
pub fn new(
codec: Codec,
metadata: HashMap<String, String>,
nullability: Option<Nullability>,
) -> Self {
AvroDataType {
codec,
metadata,
nullability,
}
}

/// Returns an arrow [`Field`] with the given name
pub fn field_with_name(&self, name: &str) -> Field {
let d = self.codec.data_type();
Expand Down Expand Up @@ -162,6 +175,8 @@ pub enum Codec {
List(Arc<AvroDataType>),
/// Represents Avro record type, maps to Arrow's Struct data type
Struct(Arc<[AvroField]>),
/// Represents Avro map type, maps to Arrow's Map data type
Map(Arc<AvroDataType>),
/// Represents Avro duration logical type, maps to Arrow's Interval(IntervalUnit::MonthDayNano) data type
Interval,
}
Expand Down Expand Up @@ -192,6 +207,22 @@ impl Codec {
DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
}
Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()),
Self::Map(value_type) => {
let val_dt = value_type.codec.data_type();
let val_field = Field::new("value", val_dt, value_type.nullability.is_some())
.with_metadata(value_type.metadata.clone());
DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(Fields::from(vec![
Field::new("key", DataType::Utf8, false),
val_field,
])),
false,
)),
false,
)
}
}
}
}
Expand Down Expand Up @@ -321,9 +352,14 @@ fn make_data_type<'a>(
ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!(
"Enum of {e:?} not currently supported"
))),
ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!(
"Map of {m:?} not currently supported"
))),
ComplexType::Map(m) => {
let val = make_data_type(&m.values, namespace, resolver)?;
Ok(AvroDataType {
nullability: None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set the nullability to val.nullability?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@klion26

I was attempting to follow the same pattern used in the other types such as ComplexType::Record(r), ComplexType::Array(a), and ComplexType::Fixed(f).

It seemed from reading the code that nullability would only be set in the Schema::Union(f) branch of make_data_type.

metadata: m.attributes.field_metadata(),
codec: Codec::Map(Arc::new(val)),
})
}
},
Schema::Type(t) => {
let mut field =
Expand Down
198 changes: 197 additions & 1 deletion arrow-avro/src/reader/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow_buffer::*;
use arrow_schema::{
ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef,
};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::io::Read;
use std::sync::Arc;
Expand Down Expand Up @@ -94,6 +95,13 @@ enum Decoder {
String(OffsetBufferBuilder<i32>, Vec<u8>),
List(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
Record(Fields, Vec<Decoder>),
Map(
FieldRef,
OffsetBufferBuilder<i32>,
OffsetBufferBuilder<i32>,
Vec<u8>,
Box<Decoder>,
),
Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
}

Expand Down Expand Up @@ -145,6 +153,25 @@ impl Decoder {
}
Self::Record(arrow_fields.into(), encodings)
}
Codec::Map(child) => {
let val_field = child.field_with_name("value").with_nullable(true);
let map_field = Arc::new(ArrowField::new(
"entries",
DataType::Struct(Fields::from(vec![
ArrowField::new("key", DataType::Utf8, false),
val_field,
])),
false,
));
let val_dec = Self::try_new(child)?;
Self::Map(
map_field,
OffsetBufferBuilder::new(DEFAULT_CAPACITY),
OffsetBufferBuilder::new(DEFAULT_CAPACITY),
Vec::with_capacity(DEFAULT_CAPACITY),
Box::new(val_dec),
)
}
};

Ok(match data_type.nullability() {
Expand Down Expand Up @@ -175,6 +202,9 @@ impl Decoder {
e.append_null();
}
Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
Self::Map(_, _koff, moff, _, _) => {
moff.push_length(0);
}
Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
}
}
Expand Down Expand Up @@ -208,6 +238,15 @@ impl Decoder {
encoding.decode(buf)?;
}
}
Self::Map(_, koff, moff, kdata, valdec) => {
let newly_added = read_map_blocks(buf, |cur| {
let kb = cur.get_bytes()?;
koff.push_length(kb.len());
kdata.extend_from_slice(kb);
valdec.decode(cur)
})?;
moff.push_length(newly_added);
}
Self::Nullable(nullability, nulls, e) => {
let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst);
nulls.append(is_valid);
Expand Down Expand Up @@ -245,7 +284,6 @@ impl Decoder {
),
Self::Float32(values) => Arc::new(flush_primitive::<Float32Type>(values, nulls)),
Self::Float64(values) => Arc::new(flush_primitive::<Float64Type>(values, nulls)),

Self::Binary(offsets, values) => {
let offsets = flush_offsets(offsets);
let values = flush_values(values).into();
Expand All @@ -268,10 +306,89 @@ impl Decoder {
.collect::<Result<Vec<_>, _>>()?;
Arc::new(StructArray::new(fields.clone(), arrays, nulls))
}
Self::Map(map_field, k_off, m_off, kdata, valdec) => {
let moff = flush_offsets(m_off);
let koff = flush_offsets(k_off);
let kd = flush_values(kdata).into();
let val_arr = valdec.flush(None)?;
let key_arr = StringArray::new(koff, kd, None);
if key_arr.len() != val_arr.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Map keys length ({}) != map values length ({})",
key_arr.len(),
val_arr.len()
)));
}
let final_len = moff.len() - 1;
if let Some(n) = &nulls {
if n.len() != final_len {
return Err(ArrowError::InvalidArgumentError(format!(
"Map array null buffer length {} != final map length {final_len}",
n.len()
)));
}
}
let entries_struct = StructArray::new(
Fields::from(vec![
Arc::new(ArrowField::new("key", DataType::Utf8, false)),
Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)),
]),
vec![Arc::new(key_arr), val_arr],
None,
);
let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false);
Arc::new(map_arr)
}
})
}
}

fn read_map_blocks(
buf: &mut AvroCursor,
decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
) -> Result<usize, ArrowError> {
read_blockwise_items(buf, true, decode_entry)
}

fn read_blockwise_items(
buf: &mut AvroCursor,
read_size_after_negative: bool,
mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
) -> Result<usize, ArrowError> {
let mut total = 0usize;
loop {
// Read the block count
// positive = that many items
// negative = that many items + read block size
// See: https://avro.apache.org/docs/1.11.1/specification/#maps
let block_count = buf.get_long()?;
match block_count.cmp(&0) {
Ordering::Equal => break,
Ordering::Less => {
// If block_count is negative, read the absolute value of count,
// then read the block size as a long and discard
let count = (-block_count) as usize;
if read_size_after_negative {
let _size_in_bytes = buf.get_long()?;
}
for _ in 0..count {
decode_fn(buf)?;
}
total += count;
}
Ordering::Greater => {
// If block_count is positive, decode that many items
let count = block_count as usize;
for _i in 0..count {
decode_fn(buf)?;
}
total += count;
}
}
}
Ok(total)
}

#[inline]
fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
Expand All @@ -291,3 +408,82 @@ fn flush_primitive<T: ArrowPrimitiveType>(
}

const DEFAULT_CAPACITY: usize = 1024;

#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{
cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray,
IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray,
};

fn encode_avro_long(value: i64) -> Vec<u8> {
let mut buf = Vec::new();
let mut v = (value << 1) ^ (value >> 63);
while v & !0x7F != 0 {
buf.push(((v & 0x7F) | 0x80) as u8);
v >>= 7;
}
buf.push(v as u8);
buf
}

fn encode_avro_bytes(bytes: &[u8]) -> Vec<u8> {
let mut buf = encode_avro_long(bytes.len() as i64);
buf.extend_from_slice(bytes);
buf
}

fn avro_from_codec(codec: Codec) -> AvroDataType {
AvroDataType::new(codec, Default::default(), None)
}

#[test]
fn test_map_decoding_one_entry() {
let value_type = avro_from_codec(Codec::Utf8);
let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
let mut decoder = Decoder::try_new(&map_type).unwrap();
// Encode a single map with one entry: {"hello": "world"}
let mut data = Vec::new();
data.extend_from_slice(&encode_avro_long(1));
data.extend_from_slice(&encode_avro_bytes(b"hello")); // key
data.extend_from_slice(&encode_avro_bytes(b"world")); // value
data.extend_from_slice(&encode_avro_long(0));
let mut cursor = AvroCursor::new(&data);
decoder.decode(&mut cursor).unwrap();
let array = decoder.flush(None).unwrap();
let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
assert_eq!(map_arr.len(), 1); // one map
assert_eq!(map_arr.value_length(0), 1);
let entries = map_arr.value(0);
let struct_entries = entries.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_entries.len(), 1);
let key_arr = struct_entries
.column_by_name("key")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let val_arr = struct_entries
.column_by_name("value")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(key_arr.value(0), "hello");
assert_eq!(val_arr.value(0), "world");
}

#[test]
fn test_map_decoding_empty() {
let value_type = avro_from_codec(Codec::Utf8);
let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
let mut decoder = Decoder::try_new(&map_type).unwrap();
let data = encode_avro_long(0);
decoder.decode(&mut AvroCursor::new(&data)).unwrap();
let array = decoder.flush(None).unwrap();
let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
assert_eq!(map_arr.len(), 1);
assert_eq!(map_arr.value_length(0), 0);
}
}