diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 9679c89b4807..21a1af62c51a 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -335,11 +335,13 @@ mod variable; /// └───────┴───────────────┴───────┴─────────┴───────┘ /// ``` /// -/// ## List Encoding +/// ## List and Map Encoding /// /// Lists are encoded by first encoding all child elements to the row format. /// -/// A list value is then encoded as the concatenation of each of the child elements, +/// the Map encoding is the same with the only difference being that the child elements are key-value pairs +/// +/// A list/map value is then encoded as the concatenation of each of the child elements, /// separately encoded using the variable length encoding described above, followed /// by the variable length encoding of an empty byte array. /// @@ -547,6 +549,8 @@ enum Codec { Struct(RowConverter, OwnedRow), /// A row converter for the child field List(RowConverter), + /// A row converter for the entries in map (keys and values and not the struct) + Map(RowConverter), /// A row converter for the values array of a run-end encoded array RunEndEncoded(RowConverter), /// Row converters for each union field (indexed by field position) @@ -639,6 +643,33 @@ impl Codec { let converter = RowConverter::new(vec![field])?; Ok(Self::List(converter)) } + DataType::Map(f, _) => { + // The encoded contents will be inverted if descending is set to true + // As such we set `descending` to false and negate nulls first if it + // it set to true + let options = SortOptions { + descending: false, + nulls_first: sort_field.options.nulls_first != sort_field.options.descending, + }; + + let DataType::Struct(fields) = f.data_type() else { + return Err(ArrowError::InvalidArgumentError(format!( + "expected struct field in map, got {:?}", + f.data_type() + ))); + }; + + // For Map type we unwrap the intermediate struct type to avoid going through Struct codec to improve performance + let fields = fields + .iter() + .map(|struct_field| { + SortField::new_with_options(struct_field.data_type().clone(), options) + }) + .collect::>(); + assert_eq!(fields.len(), 2); + let converter = RowConverter::new(fields)?; + Ok(Self::Map(converter)) + } DataType::FixedSizeList(f, _) => { let field = SortField::new_with_options(f.data_type().clone(), sort_field.options); let converter = RowConverter::new(vec![field])?; @@ -761,6 +792,22 @@ impl Codec { let rows = converter.convert_columns(&[values])?; Ok(Encoder::List(rows)) } + Codec::Map(converter) => { + let map_array = as_map_array(array); + + let first_offset = map_array.offsets()[0] as usize; + let last_offset = map_array.offsets()[map_array.offsets().len() - 1] as usize; + + // entries can include more data than referenced in the MapArray, only encode + // the referenced entries. + let sliced_entries = map_array + .entries() + .slice(first_offset, last_offset - first_offset); + + // the converter for the map is the keys and values and not the wrapping struct + let rows = converter.convert_columns(sliced_entries.columns())?; + Ok(Encoder::Map(rows)) + } Codec::RunEndEncoded(converter) => { let values = match array.data_type() { DataType::RunEndEncoded(r, _) => match r.data_type() { @@ -807,6 +854,7 @@ impl Codec { Codec::Dictionary(converter, nulls) => converter.size() + nulls.data.len(), Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(), Codec::List(converter) => converter.size(), + Codec::Map(converter) => converter.size(), Codec::RunEndEncoded(converter) => converter.size(), Codec::Union(converters, _, null_rows) => { converters.iter().map(|c| c.size()).sum::() @@ -830,6 +878,8 @@ enum Encoder<'a> { Struct(Rows, Row<'a>), /// The row encoding of the child array List(Rows), + /// The row encoding of the map entries + Map(Rows), /// The row encoding of the values array RunEndEncoded(Rows), /// The row encoding of each union field's child array, type_ids buffer, offsets buffer (for Dense), and mode @@ -897,7 +947,8 @@ impl RowConverter { | DataType::LargeList(f) | DataType::ListView(f) | DataType::LargeListView(f) - | DataType::FixedSizeList(f, _) => Self::supports_datatype(f.data_type()), + | DataType::FixedSizeList(f, _) + | DataType::Map(f, _) => Self::supports_datatype(f.data_type()), DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())), DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()), DataType::Union(fs, _mode) => fs @@ -1751,6 +1802,9 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { ), _ => unreachable!(), }, + Encoder::Map(rows) => { + list::compute_lengths(tracker.materialized(), rows, as_map_array(array)) + } Encoder::RunEndEncoded(rows) => match array.data_type() { DataType::RunEndEncoded(r, _) => match r.data_type() { DataType::Int16 => run::compute_lengths( @@ -1992,6 +2046,7 @@ fn encode_column( } _ => unreachable!(), }, + Encoder::Map(rows) => list::encode(data, offsets, rows, opts, as_map_array(column)), Encoder::RunEndEncoded(rows) => match column.data_type() { DataType::RunEndEncoded(r, _) => match r.data_type() { DataType::Int16 => { @@ -2140,12 +2195,12 @@ unsafe fn decode_column( Arc::new(StructArray::from(unsafe { builder.build_unchecked() })) } Codec::List(converter) => match &field.data_type { - DataType::List(_) => { - Arc::new(unsafe { list::decode::(converter, rows, field, validate_utf8) }?) - } - DataType::LargeList(_) => { - Arc::new(unsafe { list::decode::(converter, rows, field, validate_utf8) }?) - } + DataType::List(_) => Arc::new(unsafe { + list::decode::>(converter, rows, field, validate_utf8) + }?), + DataType::LargeList(_) => Arc::new(unsafe { + list::decode::>(converter, rows, field, validate_utf8) + }?), DataType::ListView(_) => Arc::new(unsafe { list::decode_list_view::(converter, rows, field, validate_utf8) }?), @@ -2163,6 +2218,9 @@ unsafe fn decode_column( }?), _ => unreachable!(), }, + Codec::Map(converter) => { + Arc::new(unsafe { list::decode::(converter, rows, field, validate_utf8) }?) + } Codec::RunEndEncoded(converter) => match &field.data_type { DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { DataType::Int16 => Arc::new(unsafe { @@ -4041,6 +4099,480 @@ mod tests { assert_eq!(&back[1], &second); } + #[test] + fn test_single_map() { + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + // Entry 0: {"hello": 1, "world": 2} + builder.keys().append_value("hello"); + builder.values().append_value(1); + builder.keys().append_value("world"); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Entry 1: {"foo": 3} + builder.keys().append_value("foo"); + builder.values().append_value(3); + builder.append(true).unwrap(); + // Entry 2: {} (empty map) + builder.append(true).unwrap(); + // Entry 3: null (with masked data) + builder.keys().append_value("masked_key"); + builder.values().append_value(999); + builder.append(false).unwrap(); + // Entry 4: {"bar": null} + builder.keys().append_value("bar"); + builder.values().append_null(); + builder.append(true).unwrap(); + // Entry 5: null (with different masked data) + builder.keys().append_value("other_masked"); + builder.values().append_value(0); + builder.append(false).unwrap(); + // Entry 6: {"a": 10, "b": 20, "c": 30} + builder.keys().append_value("a"); + builder.values().append_value(10); + builder.keys().append_value("b"); + builder.values().append_value(20); + builder.keys().append_value("c"); + builder.values().append_value(30); + builder.append(true).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + assert_eq!(rows.row(3), rows.row(5)); // null = null (different masked values) + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + + let sliced_map = map.slice(1, 5); + let rows_on_sliced = converter + .convert_columns(&[Arc::clone(&sliced_map)]) + .unwrap(); + + let back = converter.convert_rows(&rows_on_sliced).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &sliced_map); + } + + #[test] + fn two_maps_with_different_keys_order_should_still_match() { + // { "hello": 1, "world": 2 } + let map_1 = { + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + + builder.keys().append_value("hello"); + builder.values().append_value(1); + + builder.keys().append_value("world"); + builder.values().append_value(2); + + builder.append(true).unwrap(); + + Arc::new(builder.finish()) as ArrayRef + }; + + // { "world": 2, "hello": 1 } + let map_2 = { + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + + builder.keys().append_value("world"); + builder.values().append_value(2); + + builder.keys().append_value("hello"); + builder.values().append_value(1); + + builder.append(true).unwrap(); + + Arc::new(builder.finish()) as ArrayRef + }; + + let converter = RowConverter::new(vec![SortField::new(map_1.data_type().clone())]).unwrap(); + + let map_1_rows = converter.convert_columns(&[Arc::clone(&map_1)]).unwrap(); + let map_2_rows = converter.convert_columns(&[Arc::clone(&map_2)]).unwrap(); + + assert_eq!(map_1_rows.row(0), map_2_rows.row(0)); + + // TODO - what would the expected returned array be? + // if they are the same rows they will produce the same output, TODO - this should be noted if we decide to go that path + } + + #[test] + fn test_nested_map() { + // Map> + let mut builder = MapBuilder::new( + None, + StringBuilder::new(), + MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()), + ); + + // Entry 0: {"outer1": {"inner_a": 1, "inner_b": 2}, "outer2": {"inner_c": 3}} + builder.keys().append_value("outer1"); + builder.values().keys().append_value("inner_a"); + builder.values().values().append_value(1); + builder.values().keys().append_value("inner_b"); + builder.values().values().append_value(2); + builder.values().append(true).unwrap(); + builder.keys().append_value("outer2"); + builder.values().keys().append_value("inner_c"); + builder.values().values().append_value(3); + builder.values().append(true).unwrap(); + builder.append(true).unwrap(); + + // Entry 1: {"x": {}} (inner map is empty) + builder.keys().append_value("x"); + builder.values().append(true).unwrap(); + builder.append(true).unwrap(); + + // Entry 2: {"y": null} (inner map is null) + builder.keys().append_value("y"); + builder.values().keys().append_value("masked"); // MASKED + builder.values().values().append_value(0); // MASKED + builder.values().append(false).unwrap(); + builder.append(true).unwrap(); + + // Entry 3: null (outer map is null) + builder.keys().append_value("masked_outer"); // MASKED + builder.values().keys().append_value("masked_inner"); // MASKED + builder.values().values().append_value(0); // MASKED + builder.values().append(true).unwrap(); // MASKED + builder.append(false).unwrap(); + + // Entry 4: {} (outer map is empty) + builder.append(true).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + + let sliced_map = map.slice(1, 3); + let rows_on_sliced = converter + .convert_columns(&[Arc::clone(&sliced_map)]) + .unwrap(); + + let back = converter.convert_rows(&rows_on_sliced).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &sliced_map); + } + + #[test] + fn test_single_map_with_non_nullable_keys() { + // Use `with_keys_field` on `MapBuilder` to set the keys are not nullable + let key_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()) + .with_keys_field(key_field); + // Entry 0: {"a": 1, "b": 2} + builder.keys().append_value("a"); + builder.values().append_value(1); + builder.keys().append_value("b"); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Entry 1: {"c": null} + builder.keys().append_value("c"); + builder.values().append_null(); + builder.append(true).unwrap(); + // Entry 2: {} + builder.append(true).unwrap(); + // Entry 3: null + builder.keys().append_value("masked"); // MASKED + builder.values().append_value(0); // MASKED + builder.append(false).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_single_with_non_nullable_values() { + // Use `with_values_field` on `MapBuilder` to set the values are not nullable + let value_field = Arc::new(Field::new("values", DataType::Int32, false)); + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()) + .with_values_field(value_field); + // Entry 0: {"a": 1, "b": 2} + builder.keys().append_value("a"); + builder.values().append_value(1); + builder.keys().append_value("b"); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Entry 1: {"c": 3} + builder.keys().append_value("c"); + builder.values().append_value(3); + builder.append(true).unwrap(); + // Entry 2: {} + builder.append(true).unwrap(); + // Entry 3: null + builder.keys().append_value("masked"); // MASKED + builder.values().append_value(0); // MASKED + builder.append(false).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_single_with_non_nullable_map_but_with_nullable_keys() { + let keys = Arc::new(StringArray::from(vec![ + Some("a"), + None, + Some("c"), + Some("d"), + None, + ])) as ArrayRef; + + let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef; + + // Entry 0 = [0..2) -> {"a": 1, null: 2} + // Entry 1 = [2..3) -> {"c": 3} + // Entry 2 = [3..5) -> {"d": 4, null: 5} + let offsets = OffsetBuffer::new(vec![0, 2, 3, 5].into()); + + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, true)), // nullable keys + Arc::new(Field::new("values", DataType::Int32, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, values], None); + + let map = Arc::new(MapArray::new(struct_field, offsets, entries, None, false)) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_single_with_non_nullable_map_but_with_nullable_values() { + // Map column is non-nullable, but values are nullable + let value_field = Arc::new(Field::new("values", DataType::Int32, true)); + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()) + .with_values_field(value_field); + + // Entry 0: {"a": 1, "b": null} + builder.keys().append_value("a"); + builder.values().append_value(1); + builder.keys().append_value("b"); + builder.values().append_null(); + builder.append(true).unwrap(); + // Entry 1: {"c": null, "d": null} + builder.keys().append_value("c"); + builder.values().append_null(); + builder.keys().append_value("d"); + builder.values().append_null(); + builder.append(true).unwrap(); + // Entry 2: {} + builder.append(true).unwrap(); + // Entry 3: {"e": 5} + builder.keys().append_value("e"); + builder.values().append_value(5); + builder.append(true).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_map_all_nulls() { + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + // All entries are null + builder.keys().append_value("m1"); // MASKED + builder.values().append_value(1); // MASKED + builder.append(false).unwrap(); + builder.keys().append_value("m2"); // MASKED + builder.values().append_value(2); // MASKED + builder.append(false).unwrap(); + builder.keys().append_value("m3"); // MASKED + builder.values().append_value(3); // MASKED + builder.append(false).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + // All null rows should be equal + assert_eq!(rows.row(0), rows.row(1)); + assert_eq!(rows.row(1), rows.row(2)); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_map_all_empty() { + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + // All entries are empty maps + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + // All empty maps should be equal + assert_eq!(rows.row(0), rows.row(1)); + assert_eq!(rows.row(1), rows.row(2)); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_map_empty_array() { + // Zero-length map array + let builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + let map = Arc::new(builder.finish_cloned()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_map_nullable_keys_and_nullable_values() { + // Both keys and values are nullable + let keys = Arc::new(StringArray::from(vec![ + Some("a"), + None, + Some("c"), + Some("d"), + ])) as ArrayRef; + + let values = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, None])) as ArrayRef; + + // Entry 0 = [0..2) -> {"a": 1, null: 2} + // Entry 1 = [2..4) -> {"c": null, "d": null} + let offsets = OffsetBuffer::new(vec![0, 2, 4].into()); + + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, true)), + Arc::new(Field::new("values", DataType::Int32, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, values], None); + + let map = Arc::new(MapArray::new(struct_field, offsets, entries, None, false)) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + + #[test] + fn test_map_non_nullable_keys_and_non_nullable_values() { + // Both keys and values are non-nullable + let key_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("values", DataType::Int32, false)); + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()) + .with_keys_field(key_field) + .with_values_field(value_field); + + // Entry 0: {"a": 1, "b": 2} + builder.keys().append_value("a"); + builder.values().append_value(1); + builder.keys().append_value("b"); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Entry 1: {} (empty) + builder.append(true).unwrap(); + // Entry 2: null + builder.keys().append_value("m"); // MASKED + builder.values().append_value(0); // MASKED + builder.append(false).unwrap(); + // Entry 3: {"c": 3} + builder.keys().append_value("c"); + builder.values().append_value(3); + builder.append(true).unwrap(); + + let map = Arc::new(builder.finish()) as ArrayRef; + let d = map.data_type().clone(); + + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &map); + } + fn generate_primitive_array( rng: &mut impl RngCore, len: usize, @@ -4055,6 +4587,51 @@ mod tests { .collect() } + #[derive(Clone, Copy)] + enum GenerateAllUniqueNullBehavior { + AllowUpToSingleNull { valid_percent: f64 }, + AllValid, + } + + fn generate_all_unique_primitive_array( + rng: &mut impl RngCore, + len: usize, + null_behavior: GenerateAllUniqueNullBehavior, + ) -> PrimitiveArray + where + K: ArrowPrimitiveType, + K::Native: Hash + Eq, + StandardUniform: Distribution, + { + let mut seen = std::collections::HashSet::new(); + (0..len) + .map(|_| { + let mut value; + + loop { + match null_behavior { + GenerateAllUniqueNullBehavior::AllValid => { + value = Some(rng.random()); + } + GenerateAllUniqueNullBehavior::AllowUpToSingleNull { valid_percent } => { + value = if !seen.contains(&None) { + rng.random_bool(valid_percent).then(|| rng.random()) + } else { + Some(rng.random()) + }; + } + } + + if seen.insert(value) { + break; + } + } + + value + }) + .collect() + } + fn generate_boolean_array( rng: &mut impl RngCore, len: usize, @@ -4262,6 +4839,76 @@ mod tests { ) } + fn generate_map( + rng: &mut R, + len: usize, + valid_percent: f64, + gen_keys: KeysFn, + gen_values: ValuesFn, + ) -> MapArray + where + KeysFn: FnOnce(&mut R, usize) -> ArrayRef, + ValuesFn: FnOnce(&mut R, usize) -> ArrayRef, + { + let offsets = OffsetBuffer::::from_lengths((0..len).map(|_| rng.random_range(0..10))); + let entries_len = offsets.last().unwrap().to_usize().unwrap(); + let keys = gen_keys(rng, entries_len); + let values = gen_values(rng, entries_len); + let nulls = NullBuffer::from_iter((0..len).map(|_| rng.random_bool(valid_percent))); + let field = Arc::new(Field::new_map( + "", + "entries", + Field::new("keys", keys.data_type().clone(), true), + Field::new("values", values.data_type().clone(), true), + false, + true, + )); + let DataType::Map(struct_field, _) = field.data_type() else { + unreachable!(); + }; + + let DataType::Struct(fields) = struct_field.data_type() else { + unreachable!(); + }; + + let entries = StructArray::new(fields.clone(), vec![keys, values], None); + + let map_array = MapArray::new(struct_field.clone(), offsets, entries, Some(nulls), false); + + assert_valid_map(&map_array); + + map_array + } + + /// Assert that the map is valid; this includes unique map keys + /// + /// Can be removed once https://github.com/apache/arrow-rs/issues/9475 is resolved and `MapArray` + /// will not allow creation of invalid map arrays with duplicate keys. + /// + /// # Unique map keys + /// According to [arrow specification][1], the keys of a map must be unique, so this asserts that + /// + /// [1]: https://github.com/apache/arrow/blob/cbe2618431e413f12aa16aeba88b3a98914f194b/format/Schema.fbs#L124 + fn assert_valid_map(array: &MapArray) { + let keys_arrow_row_converter = + RowConverter::new(vec![SortField::new(array.key_type().clone())]).unwrap(); + + array.iter().enumerate().flat_map(|(index, entry)| entry.map(|entry| (index, Arc::clone(entry.column(0))))).for_each(|(entry_index, keys)| { + let keys_as_rows = keys_arrow_row_converter.convert_columns(&[Arc::clone(&keys)]).expect("should be able to convert keys"); + + for i in 0..keys_as_rows.num_rows() { + for j in (i + 1)..keys_as_rows.num_rows() { + if keys_as_rows.row(i) == keys_as_rows.row(j) { + let key_i = keys.slice(i, 1); + let key_j = keys.slice(j, 1); + + assert_ne!(keys_as_rows.row(i), keys_as_rows.row(j), "map keys should be unique, but key {i} and key {j} are equal in entry {entry_index}. key {i} value is {key_i:?} and key {j} value is {key_j:?}"); + } + } + } + }) + } + fn generate_nulls(rng: &mut impl RngCore, len: usize) -> Option { Some(NullBuffer::from_iter( (0..len).map(|_| rng.random_bool(0.8)), @@ -4360,6 +5007,55 @@ mod tests { GenericListArray::::new(field, new_offsets, new_values, nulls) } + fn change_underline_null_values_for_map_array(array: &MapArray) -> MapArray { + let (field, offsets, entries, nulls, ordered) = array.clone().into_parts(); + assert!( + !ordered, + "can't replace underlying null values for ordered map array as this can violate the ordering" + ); + + let (new_entries, new_offsets) = { + let concat_values = offsets + .windows(2) + .zip(nulls.as_ref().unwrap().iter()) + .map(|(start_and_end, is_valid)| { + let start = start_and_end[0].as_usize(); + let end = start_and_end[1].as_usize(); + if is_valid { + return (start, end - start); + } + + // If reached end, we take one less + if end == entries.len() { + (start, (end - start).saturating_sub(1)) + } else { + // The keys may no longer be unique + (start, end - start + 1) + } + }) + .map(|(start, length)| entries.slice(start, length)) + .collect::>(); + + let new_offsets = OffsetBuffer::from_lengths(concat_values.iter().map(|s| s.len())); + + let new_values = { + let values = concat_values + .iter() + .map(|a| a as &dyn Array) + .collect::>(); + arrow_select::concat::concat(&values).expect("should be able to concat") + }; + + (new_values.as_struct().clone(), new_offsets) + }; + + let new_map = MapArray::new(field, new_offsets, new_entries, nulls, ordered); + + assert_valid_map(&new_map); + + new_map + } + fn change_underline_null_values(array: &ArrayRef) -> ArrayRef { if array.null_count() == 0 { return Arc::clone(array); @@ -4390,6 +5086,9 @@ mod tests { DataType::LargeList(_) => { Arc::new(change_underline_null_values_for_list_array(array.as_list::())) } + DataType::Map(_, _) => { + Arc::new(change_underline_null_values_for_map_array(array.as_map())) + } _ => { Arc::clone(array) } @@ -4397,7 +5096,7 @@ mod tests { } fn generate_column(rng: &mut (impl RngCore + Clone), len: usize) -> ArrayRef { - match rng.random_range(0..23) { + match rng.random_range(0..25) { 0 => Arc::new(generate_primitive_array::(rng, len, 0.8)), 1 => Arc::new(generate_primitive_array::(rng, len, 0.8)), 2 => Arc::new(generate_primitive_array::(rng, len, 0.8)), @@ -4466,6 +5165,35 @@ mod tests { }) .slice(500, len), ), + 23 => Arc::new(generate_map( + rng, + len, + 0.9, + |rng, keys_len| { + // Need to generate all unique keys or make sure between each map every key is unique, + // so we generate up to a single null + Arc::new(generate_all_unique_primitive_array::( + rng, + keys_len, + GenerateAllUniqueNullBehavior::AllowUpToSingleNull { valid_percent: 0.8 }, + )) + }, + |rng, values_len| Arc::new(generate_strings::(rng, values_len, 0.7)), + )), + 24 => Arc::new(generate_map( + rng, + len, + 0.9, + |rng, keys_len| { + // Need to generate all unique keys or make sure between each map everything is unique + Arc::new(generate_all_unique_primitive_array::( + rng, + keys_len, + GenerateAllUniqueNullBehavior::AllValid, + )) + }, + |rng, values_len| Arc::new(generate_strings::(rng, values_len, 0.7)), + )), _ => unreachable!(), } } @@ -4766,51 +5494,6 @@ mod tests { assert_eq!(rows.row(0).cmp(&rows.row(1)), Ordering::Less); } - #[test] - fn map_should_be_marked_as_unsupported() { - let map_data_type = Field::new_map( - "map", - "entries", - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Utf8, true), - false, - true, - ) - .data_type() - .clone(); - - let is_supported = RowConverter::supports_fields(&[SortField::new(map_data_type)]); - - assert!(!is_supported, "Map should not be supported"); - } - - #[test] - fn should_fail_to_create_row_converter_for_unsupported_map_type() { - let map_data_type = Field::new_map( - "map", - "entries", - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Utf8, true), - false, - true, - ) - .data_type() - .clone(); - - let converter = RowConverter::new(vec![SortField::new(map_data_type)]); - - match converter { - Err(ArrowError::NotYetImplemented(message)) => { - assert!( - message.contains("Row format support not yet implemented for"), - "Expected NotYetImplemented error for map data type, got: {message}", - ); - } - Err(e) => panic!("Expected NotYetImplemented error, got: {e}"), - Ok(_) => panic!("Expected NotYetImplemented error for map data type"), - } - } - #[test] fn test_values_buffer_smaller_when_utf8_validation_disabled() { fn get_values_buffer_len(col: ArrayRef) -> (usize, usize) { @@ -5622,33 +6305,337 @@ mod tests { assert_eq!(&outer_struct, &back[0]); } - // Test Map - verify it's not supported (as per current implementation) - // https://github.com/apache/arrow-rs/issues/7879 + // Test Map with various combinations of nulls and empty maps #[test] - fn test_map_null_not_supported() { - // Map with Null values - let map_data_type = Field::new_map( - "map", + fn test_map_null_variations() { + // Map with Null values: [{a: NULL}, {}, {b: NULL, c: NULL}] + let keys = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let null_values = Arc::new(NullArray::new(3)) as ArrayRef; + + let offsets = OffsetBuffer::new(vec![0, 1, 1, 3].into()); + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let struct_field = Arc::new(Field::new( "entries", - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Null, true), + DataType::Struct(entries_fields.clone().into()), false, - true, - ) - .data_type() - .clone(); + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, null_values], None); - // Currently Map is not supported by RowConverter - let result = RowConverter::new(vec![SortField::new(map_data_type)]); - assert!( - result.is_err(), - "Map should not be supported by RowConverter" + let map: ArrayRef = Arc::new(MapArray::new( + struct_field.clone(), + offsets, + entries, + None, + false, + )); + + let converter = RowConverter::new(vec![SortField::new(map.data_type().clone())]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); + + // Map with Null values and null map entries: [{a: NULL}, null, {b: NULL, c: NULL}] + let keys = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let null_values = Arc::new(NullArray::new(3)) as ArrayRef; + + let offsets = OffsetBuffer::new(vec![0, 1, 1, 3].into()); + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, null_values], None); + + let map: ArrayRef = Arc::new(MapArray::new( + struct_field.clone(), + offsets, + entries, + Some(vec![true, false, true].into()), + false, + )); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); + + // Empty map array with Null value type + let keys = Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef; + let null_values = Arc::new(NullArray::new(0)) as ArrayRef; + + let offsets = OffsetBuffer::new(vec![0i32].into()); + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, null_values], None); + + let map: ArrayRef = Arc::new(MapArray::new(struct_field, offsets, entries, None, false)); + + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); + } + + // Test Map with descending order + #[test] + fn test_map_null_descending() { + // [{a: NULL}, {}, {b: NULL, c: NULL}] + let keys = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let null_values = Arc::new(NullArray::new(3)) as ArrayRef; + + let offsets = OffsetBuffer::new(vec![0, 1, 1, 3].into()); + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, null_values], None); + + let map: ArrayRef = Arc::new(MapArray::new(struct_field, offsets, entries, None, false)); + + let options = SortOptions::default().with_descending(true); + let field = SortField::new_with_options(map.data_type().clone(), options); + let converter = RowConverter::new(vec![field]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); + } + + // Test Map - both keys and values are Null type + #[test] + fn test_map_null_keys_and_null_values() { + let null_keys = Arc::new(NullArray::new(3)) as ArrayRef; + let null_values = Arc::new(NullArray::new(3)) as ArrayRef; + + let offsets = OffsetBuffer::new(vec![0, 1, 1, 3].into()); + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Null, true)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![null_keys, null_values], None); + + let map: ArrayRef = Arc::new(MapArray::new(struct_field, offsets, entries, None, false)); + + let converter = RowConverter::new(vec![SortField::new(map.data_type().clone())]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); + } + + // Test Map all empty maps + #[test] + fn test_map_null_all_empty() { + let keys = Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef; + let null_values = Arc::new(NullArray::new(0)) as ArrayRef; + + let offsets = OffsetBuffer::new(vec![0, 0, 0, 0].into()); + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, null_values], None); + + let map: ArrayRef = Arc::new(MapArray::new(struct_field, offsets, entries, None, false)); + + let converter = RowConverter::new(vec![SortField::new(map.data_type().clone())]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + + // All empty maps should be equal + assert_eq!(rows.row(0), rows.row(1)); + assert_eq!(rows.row(1), rows.row(2)); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); + } + + // Test Map> - nested map with Null leaf values + #[test] + fn test_nested_map_null() { + // Inner map entries: {a: NULL, b: NULL, c: NULL} + let inner_keys = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let inner_null_values = Arc::new(NullArray::new(3)) as ArrayRef; + + let inner_entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let inner_struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(inner_entries_fields.clone().into()), + false, + )); + let inner_entries = StructArray::new( + inner_entries_fields.clone().into(), + vec![inner_keys, inner_null_values], + None, ); - assert!( - result - .unwrap_err() - .to_string() - .contains("not yet implemented") + + // Inner maps: [{a: NULL}, {b: NULL, c: NULL}] + let inner_map = Arc::new(MapArray::new( + inner_struct_field.clone(), + OffsetBuffer::new(vec![0, 1, 3].into()), + inner_entries, + None, + false, + )) as ArrayRef; + + // Outer map entries + let outer_keys = Arc::new(StringArray::from(vec!["x", "y"])) as ArrayRef; + + let inner_map_type = DataType::Map(inner_struct_field.clone(), false); + let outer_entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", inner_map_type, true)), + ]; + let outer_struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(outer_entries_fields.clone().into()), + false, + )); + let outer_entries = StructArray::new( + outer_entries_fields.into(), + vec![outer_keys, inner_map], + None, ); + + // Outer map: [{x: {a: NULL}}, {y: {b: NULL, c: NULL}}] + let map: ArrayRef = Arc::new(MapArray::new( + outer_struct_field, + OffsetBuffer::new(vec![0, 1, 2].into()), + outer_entries, + None, + false, + )); + + let converter = RowConverter::new(vec![SortField::new(map.data_type().clone())]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); + } + + // Test List> - list containing maps with Null values + #[test] + fn test_list_of_map_null() { + // Map entries: {a: NULL, b: NULL, c: NULL} + let keys = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let null_values = Arc::new(NullArray::new(3)) as ArrayRef; + + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::Null, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, null_values], None); + + // Maps: [{a: NULL}, {}, {b: NULL, c: NULL}] + let map_array = Arc::new(MapArray::new( + struct_field.clone(), + OffsetBuffer::new(vec![0, 1, 1, 3].into()), + entries, + None, + false, + )) as ArrayRef; + + let map_type = DataType::Map(struct_field, false); + // List of maps: [[{a: NULL}], [{}, {b: NULL, c: NULL}]] + let list: ArrayRef = Arc::new(ListArray::new( + Arc::new(Field::new_list_field(map_type, true)), + OffsetBuffer::new(vec![0, 1, 3].into()), + map_array, + None, + )); + + let converter = RowConverter::new(vec![SortField::new(list.data_type().clone())]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(&list, &back[0]); + } + + // Test Map> - map with list values containing Null + #[test] + fn test_map_of_list_null() { + // Inner list values: [NULL, NULL, NULL] + let null_array = Arc::new(NullArray::new(3)) as ArrayRef; + // Lists: [[NULL], [], [NULL, NULL]] + let list_array = Arc::new(ListArray::new( + Arc::new(Field::new_list_field(DataType::Null, true)), + OffsetBuffer::from_lengths(vec![1, 0, 2]), + null_array, + None, + )) as ArrayRef; + + let keys = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + + let list_type = list_array.data_type().clone(); + let entries_fields = vec![ + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", list_type, true)), + ]; + let struct_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields.clone().into()), + false, + )); + let entries = StructArray::new(entries_fields.into(), vec![keys, list_array], None); + + // Map: [{a: [NULL], b: [], c: [NULL, NULL]}] + let map: ArrayRef = Arc::new(MapArray::new( + struct_field, + OffsetBuffer::new(vec![0, 3].into()), + entries, + None, + false, + )); + + let converter = RowConverter::new(vec![SortField::new(map.data_type().clone())]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&map)]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&map, &back[0]); } } diff --git a/arrow-row/src/list.rs b/arrow-row/src/list.rs index 843101673f05..7be1116f73fa 100644 --- a/arrow-row/src/list.rs +++ b/arrow-row/src/list.rs @@ -17,51 +17,144 @@ use crate::{LengthTracker, RowConverter, Rows, SortField, fixed, null_sentinel}; use arrow_array::{ - Array, FixedSizeListArray, GenericListArray, GenericListViewArray, OffsetSizeTrait, - new_null_array, + Array, ArrayRef, FixedSizeListArray, GenericListArray, GenericListViewArray, MapArray, + OffsetSizeTrait, StructArray, new_null_array, }; use arrow_buffer::{ - ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer, + ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer, }; use arrow_data::ArrayDataBuilder; -use arrow_schema::{ArrowError, DataType, SortOptions}; +use arrow_schema::{ArrowError, DataType, Fields, SortOptions}; use std::{ops::Range, sync::Arc}; -pub fn compute_lengths( +pub(crate) trait GenericListArrayOrMap: Array { + type Offset: OffsetSizeTrait; + + fn offsets(&self) -> &[Self::Offset]; + + unsafe fn from_parts_unchecked( + data_type: DataType, + offsets: Vec, + children: Vec, + null_buffer: Option, + ) -> Self + where + Self: Sized; +} + +impl GenericListArrayOrMap for GenericListArray { + type Offset = O; + + fn offsets(&self) -> &[Self::Offset] { + self.value_offsets() + } + + unsafe fn from_parts_unchecked( + data_type: DataType, + offsets: Vec, + children: Vec, + null_buffer: Option, + ) -> Self + where + Self: Sized, + { + let field = match data_type { + DataType::List(inner_field) | DataType::LargeList(inner_field) => inner_field, + _ => unreachable!(), + }; + + let child = children + .into_iter() + .next() + .expect("List arrays must have exactly one child array"); + + // SAFETY: Caller must ensure offsets are valid and correctly correspond to the children and null buffer + // the benefit here is to avoid validating that the offsets are monotonically increasing + let offset_buffer = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + GenericListArray::::new(field, offset_buffer, child, null_buffer) + } +} + +impl GenericListArrayOrMap for MapArray { + type Offset = i32; + + fn offsets(&self) -> &[Self::Offset] { + self.value_offsets() + } + + unsafe fn from_parts_unchecked( + data_type: DataType, + offsets: Vec, + children: Vec, + null_buffer: Option, + ) -> Self + where + Self: Sized, + { + let DataType::Map(entries_field, ordered) = data_type else { + unreachable!("data type must be Map for MapArray"); + }; + + assert_eq!( + children.len(), + 2, + "Map arrays must have exactly two child arrays for keys and values" + ); + + let DataType::Struct(fields) = entries_field.data_type() else { + unreachable!("Map entry type must be Struct"); + }; + + let entries = StructArray::new( + fields.clone(), + children, + // Entries StructArray cannot have NullBuffer since nulls are represented at the Map level + None, + ); + + // SAFETY: Caller must ensure offsets are valid and correctly correspond to the children and null buffer + // the benefit here is to avoid validating that the offsets are monotonically increasing + let offset_buffer = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + + MapArray::new(entries_field, offset_buffer, entries, null_buffer, ordered) + } +} + +pub(crate) fn compute_lengths( lengths: &mut [usize], rows: &Rows, - array: &GenericListArray, + array: &L, ) { - let shift = array.value_offsets()[0].as_usize(); + let shift = array.offsets()[0].as_usize(); lengths .iter_mut() - .zip(array.value_offsets().windows(2)) + .zip(array.offsets().windows(2)) .enumerate() .for_each(|(idx, (length, offsets))| { let start = offsets[0].as_usize() - shift; let end = offsets[1].as_usize() - shift; let range = array.is_valid(idx).then_some(start..end); - *length += list_element_encoded_len(rows, range); + *length += list_like_element_encoded_len(rows, range); }); } -/// Encodes the provided `GenericListArray` to `out` with the provided `SortOptions` +/// Encodes the provided [`GenericListArrayOrMap`] to `out` with the provided `SortOptions` /// /// `rows` should contain the encoded child elements -pub fn encode( +pub(crate) fn encode( data: &mut [u8], offsets: &mut [usize], rows: &Rows, opts: SortOptions, - array: &GenericListArray, + array: &L, ) { - let shift = array.value_offsets()[0].as_usize(); + let shift = array.offsets()[0].as_usize(); offsets .iter_mut() .skip(1) - .zip(array.value_offsets().windows(2)) + .zip(array.offsets().windows(2)) .enumerate() .for_each(|(idx, (offset, offsets))| { let start = offsets[0].as_usize() - shift; @@ -99,19 +192,19 @@ fn encode_one( /// # Safety /// /// `rows` must contain valid data for the provided `converter` -pub unsafe fn decode( +pub(crate) unsafe fn decode( converter: &RowConverter, rows: &mut [&[u8]], field: &SortField, validate_utf8: bool, -) -> Result, ArrowError> { +) -> Result { let opts = field.options; let mut values_bytes = 0; let mut offset = 0; let mut offsets = Vec::with_capacity(rows.len() + 1); - offsets.push(O::usize_as(0)); + offsets.push(ListLikeImpl::Offset::usize_as(0)); for row in rows.iter_mut() { let mut row_offset = 0; @@ -120,22 +213,30 @@ pub unsafe fn decode( values_bytes += x.len(); }); if decoded <= 1 { - offsets.push(O::usize_as(offset)); + offsets.push(ListLikeImpl::Offset::usize_as(offset)); break; } row_offset += decoded; offset += 1; } } - O::from_usize(offset).expect("overflow"); + ListLikeImpl::Offset::from_usize(offset).expect("overflow"); let mut null_count = 0; - let nulls = MutableBuffer::collect_bool(rows.len(), |x| { + let nulls = BooleanBuffer::collect_bool(rows.len(), |x| { let valid = rows[x][0] != null_sentinel(opts); null_count += !valid as usize; valid }); + let nulls = if null_count > 0 { + // SAFETY: null_count was computed correctly when building the nulls buffer above and the + // Perf benefit: avoid computing the null count again + Some(unsafe { NullBuffer::new_unchecked(nulls, null_count) }) + } else { + None + }; + let mut values_offsets = Vec::with_capacity(offset); let mut values_bytes = Vec::with_capacity(values_bytes); for row in rows.iter_mut() { @@ -167,37 +268,66 @@ pub unsafe fn decode( }) .collect(); - let child = unsafe { converter.convert_raw(&mut child_rows, validate_utf8) }?; - assert_eq!(child.len(), 1); - - let child_data = child[0].to_data(); + let children = unsafe { converter.convert_raw(&mut child_rows, validate_utf8) }?; // Since RowConverter flattens certain data types (i.e. Dictionary), // we need to use updated data type instead of original field let corrected_type = match &field.data_type { - DataType::List(inner_field) => DataType::List(Arc::new( - inner_field + DataType::List(inner_field) => { + assert_eq!(children.len(), 1); + DataType::List(Arc::new( + inner_field + .as_ref() + .clone() + .with_data_type(children[0].data_type().clone()), + )) + } + DataType::LargeList(inner_field) => { + assert_eq!(children.len(), 1); + DataType::LargeList(Arc::new( + inner_field + .as_ref() + .clone() + .with_data_type(children[0].data_type().clone()), + )) + } + DataType::Map(inner_field, ordered) => { + let DataType::Struct(entries_field) = inner_field.data_type() else { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected Map entry type to be Struct, found: {}", + inner_field.data_type() + ))); + }; + assert_eq!( + children.len(), + 2, + "Map arrays must have exactly two child arrays for keys and values" + ); + let key_field = entries_field[0] .as_ref() .clone() - .with_data_type(child_data.data_type().clone()), - )), - DataType::LargeList(inner_field) => DataType::LargeList(Arc::new( - inner_field + .with_data_type(children[0].data_type().clone()); + let value_field = entries_field[1] .as_ref() .clone() - .with_data_type(child_data.data_type().clone()), - )), + .with_data_type(children[1].data_type().clone()); + + let entries_fields = Fields::from(vec![key_field, value_field]); + + DataType::Map( + Arc::new( + inner_field + .as_ref() + .clone() + .with_data_type(DataType::Struct(entries_fields)), + ), + *ordered, + ) + } _ => unreachable!(), }; - let builder = ArrayDataBuilder::new(corrected_type) - .len(rows.len()) - .null_count(null_count) - .null_bit_buffer(Some(nulls.into())) - .add_buffer(Buffer::from_vec(offsets)) - .add_child_data(child_data); - - Ok(GenericListArray::from(unsafe { builder.build_unchecked() })) + Ok(unsafe { ListLikeImpl::from_parts_unchecked(corrected_type, offsets, children, nulls) }) } pub fn compute_lengths_fixed_size_list( @@ -317,13 +447,13 @@ pub unsafe fn decode_fixed_size_list( })) } -/// Computes the encoded length for a single list element given its child rows. +/// Computes the encoded length for a single list/map element given its child rows. /// -/// This is used by list types (List, LargeList, ListView, LargeListView) to determine -/// the encoded length of a list element. For null elements, returns 1 (null sentinel only). +/// This is used by list types (List, LargeList, ListView, LargeListView) and by Map to determine +/// the encoded length of a list element/map entry. For null elements, returns 1 (null sentinel only). /// For valid elements, returns 1 + the sum of padded lengths for each child row. #[inline] -fn list_element_encoded_len(rows: &Rows, range: Option>) -> usize { +fn list_like_element_encoded_len(rows: &Rows, range: Option>) -> usize { match range { None => 1, Some(range) => { @@ -358,7 +488,7 @@ pub fn compute_lengths_list_view( }; start..start + size }); - *length += list_element_encoded_len(rows, range); + *length += list_like_element_encoded_len(rows, range); }); }