diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 786cf9212d04..c728b8853a12 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -161,7 +161,7 @@ use crate::reader::run_end_array::RunEndEncodedArrayDecoder; use crate::reader::string_array::StringArrayDecoder; use crate::reader::string_view_array::StringViewArrayDecoder; use crate::reader::struct_array::StructArrayDecoder; -use crate::reader::tape::{Tape, TapeDecoder}; +use crate::reader::tape::{Tape, TapeDecoder, TapeDecoderOptions}; use crate::reader::timestamp_array::TimestampArrayDecoder; mod binary_array; @@ -181,12 +181,14 @@ mod tape; mod timestamp_array; /// A builder for [`Reader`] and [`Decoder`] +#[derive(Clone)] pub struct ReaderBuilder { batch_size: usize, coerce_primitive: bool, strict_mode: bool, is_field: bool, struct_mode: StructMode, + flatten_top_level_arrays: bool, schema: SchemaRef, } @@ -207,6 +209,7 @@ impl ReaderBuilder { strict_mode: false, is_field: false, struct_mode: Default::default(), + flatten_top_level_arrays: false, schema, } } @@ -248,6 +251,7 @@ impl ReaderBuilder { strict_mode: false, is_field: true, struct_mode: Default::default(), + flatten_top_level_arrays: false, schema: Arc::new(Schema::new([field.into()])), } } @@ -288,6 +292,26 @@ impl ReaderBuilder { } } + /// Sets whether to flatten top-level arrays. + /// + /// * When `true`, each element of a top-level array will be treated as its own row. + /// * When `false` (the default), the entire top-level array will be treated as one row. + /// + /// For example, consider this input file: + /// ```text + /// [{ "a": 1 }, { "a": 2 }, { "b": 3 }] + /// [{ "a": 4 }, { "a": 5 }, { "b": 6 }] + /// ``` + /// + /// By default, this would be parsed as two rows, each an array containing three elements. + /// With this option set to `true`, however, this would be parsed as six rows. + pub fn with_flatten(self, flatten_top_level_arrays: bool) -> Self { + Self { + flatten_top_level_arrays, + ..self + } + } + /// Create a [`Reader`] with the provided [`BufRead`] pub fn build(self, reader: R) -> Result, ArrowError> { Ok(Reader { @@ -319,7 +343,11 @@ impl ReaderBuilder { Ok(Decoder { decoder, is_field: self.is_field, - tape_decoder: TapeDecoder::new(self.batch_size, num_fields), + tape_decoder: TapeDecoder::new(TapeDecoderOptions { + batch_size: self.batch_size, + num_fields, + flatten_top_level_arrays: self.flatten_top_level_arrays, + }), batch_size: self.batch_size, schema: self.schema, }) @@ -813,7 +841,9 @@ mod tests { use std::io::{BufReader, Cursor, Seek}; use arrow_array::cast::AsArray; - use arrow_array::{Array, BooleanArray, Float64Array, ListArray, StringArray, StringViewArray}; + use arrow_array::{ + Array, BooleanArray, Float64Array, Int32Array, ListArray, StringArray, StringViewArray, + }; use arrow_buffer::{ArrowNativeType, Buffer}; use arrow_cast::display::{ArrayFormatter, FormatOptions}; use arrow_data::ArrayDataBuilder; @@ -828,13 +858,21 @@ mod tests { strict_mode: bool, schema: SchemaRef, ) -> Vec { + let config = ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .with_strict_mode(strict_mode) + .with_coerce_primitive(coerce_primitive); + do_read_config(buf, config) + } + + fn do_read_config(buf: &str, builder: ReaderBuilder) -> Vec { let mut unbuffered = vec![]; // Test with different batch sizes to test for boundary conditions - for batch_size in [1, 3, 100, batch_size] { - unbuffered = ReaderBuilder::new(schema.clone()) + for batch_size in [1, 3, 100, builder.batch_size] { + unbuffered = builder + .clone() .with_batch_size(batch_size) - .with_coerce_primitive(coerce_primitive) .build(Cursor::new(buf.as_bytes())) .unwrap() .collect::, _>>() @@ -846,10 +884,9 @@ mod tests { // Test with different buffer sizes to test for boundary conditions for b in [1, 3, 5] { - let buffered = ReaderBuilder::new(schema.clone()) + let buffered = builder + .clone() .with_batch_size(batch_size) - .with_coerce_primitive(coerce_primitive) - .with_strict_mode(strict_mode) .build(BufReader::with_capacity(b, Cursor::new(buf.as_bytes()))) .unwrap() .collect::, _>>() @@ -2975,4 +3012,27 @@ mod tests { assert_eq!(run_array.len(), 3); assert_eq!(run_array.run_ends().values(), &[2i16, 3]); } + + #[test] + fn test_flatten_top_level_arrays() { + let buf = r#" + [ + {"a": 1}, + {"a": 2} + ] + {"a": 3} + [{"a": 4}, {"a": 5}, {"a": 6}] + "#; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = do_read_config(buf, ReaderBuilder::new(schema).with_flatten(true)); + assert_eq!(batches.len(), 1); + + let col = batches[0].column(0); + let col = col.as_any().downcast_ref::().unwrap(); + assert_eq!(col.len(), 6); + for i in 0..6 { + assert_eq!(col.value(i), (i as i32) + 1); + } + } } diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 89ee3f778765..2a91e67b028b 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -216,6 +216,13 @@ impl<'a> Tape<'a> { /// States based on #[derive(Debug, Copy, Clone)] enum DecoderState { + /// Decoding a top-level list, where each element is + /// treated as an individual row + /// + /// This can only appear as the first element on the stack, + /// and it is valid to flush a batch when this is the only + /// state on the stack + TopLevelList, /// Decoding an object /// /// Contains index of start [`TapeElement::StartObject`] @@ -242,6 +249,7 @@ enum DecoderState { impl DecoderState { fn as_str(&self) -> &'static str { match self { + DecoderState::TopLevelList => "list", DecoderState::Object(_) => "object", DecoderState::List(_) => "list", DecoderState::String => "string", @@ -295,7 +303,9 @@ macro_rules! next { } /// Implements a state machine for decoding JSON to a tape +#[derive(Debug)] pub struct TapeDecoder { + /// The decoded elements elements: Vec, /// The number of rows decoded, including any in progress if `!stack.is_empty()` @@ -304,6 +314,12 @@ pub struct TapeDecoder { /// Number of rows to read per batch batch_size: usize, + /// Whether to flatten top-level arrays into the stream of JSON rows, + /// meaning that each of their elements will be treated as an individual row + /// + /// When `false` (the default), the entire top-level array will be treated as one row + flatten_top_level_arrays: bool, + /// A buffer of parsed string data /// /// Note: if part way through a record, i.e. `stack` is not empty, @@ -317,10 +333,26 @@ pub struct TapeDecoder { stack: Vec, } +/// Configuration for a `TapeDecoder`. +#[derive(Clone, Copy, Debug)] +pub struct TapeDecoderOptions { + /// The batch size + pub batch_size: usize, + /// The estimated number of fields in each row + pub num_fields: usize, + /// Whether to flatten top-level arrays + pub flatten_top_level_arrays: bool, +} + impl TapeDecoder { - /// Create a new [`TapeDecoder`] with the provided batch size - /// and an estimated number of fields in each row - pub fn new(batch_size: usize, num_fields: usize) -> Self { + /// Create a new [`TapeDecoder`] with the provided options + pub fn new(options: TapeDecoderOptions) -> Self { + let TapeDecoderOptions { + batch_size, + num_fields, + flatten_top_level_arrays, + } = options; + let tokens_per_row = 2 + num_fields * 2; let mut offsets = Vec::with_capacity(batch_size * (num_fields * 2) + 1); offsets.push(0); @@ -332,6 +364,7 @@ impl TapeDecoder { offsets, elements, batch_size, + flatten_top_level_arrays, cur_row: 0, bytes: Vec::with_capacity(num_fields * 2 * 8), stack: Vec::with_capacity(10), @@ -346,18 +379,52 @@ impl TapeDecoder { Some(l) => l, None => { iter.skip_whitespace(); - if iter.is_empty() || self.cur_row >= self.batch_size { + if self.cur_row >= self.batch_size { break; } - // Start of row - self.cur_row += 1; - self.stack.push(DecoderState::Value); + match iter.peek() { + Some(b'[') if self.flatten_top_level_arrays => { + // Consume the `[` without writing it + iter.next(); + self.stack.push(DecoderState::TopLevelList); + } + Some(_) => { + // Start of row + self.cur_row += 1; + self.stack.push(DecoderState::Value); + } + None => break, + } + + // There is now a top-most state to process self.stack.last_mut().unwrap() } }; match state { + // Decoding a top-level list + DecoderState::TopLevelList => { + iter.advance_until(|b| !json_whitespace(b) && b != b','); + if self.cur_row >= self.batch_size { + break; + } + + match iter.peek() { + Some(b']') => { + // Consume the `]` without writing it + iter.next(); + self.stack.pop(); + continue; + } + Some(_) => { + // Start of row + self.cur_row += 1; + self.stack.push(DecoderState::Value); + } + None => break, + } + } // Decoding an object DecoderState::Object(start_idx) => { iter.advance_until(|b| !json_whitespace(b) && b != b','); @@ -554,16 +621,20 @@ impl TapeDecoder { /// True if the decoder is part way through decoding a row. If so, calling [`Self::finish`] /// would return an error. pub fn has_partial_row(&self) -> bool { - !self.stack.is_empty() + !matches!(self.stack.last(), None | Some(DecoderState::TopLevelList)) } /// Finishes the current [`Tape`] pub fn finish(&self) -> Result, ArrowError> { - if let Some(b) = self.stack.last() { - return Err(ArrowError::JsonError(format!( - "Truncated record whilst reading {}", - b.as_str() - ))); + match self.stack.last() { + None => {} + Some(DecoderState::TopLevelList) => {} + Some(state) => { + return Err(ArrowError::JsonError(format!( + "Truncated record whilst reading {}", + state.as_str() + ))); + } } if self.offsets.len() >= u32::MAX as usize { @@ -607,7 +678,7 @@ impl TapeDecoder { /// Clears this [`TapeDecoder`] in preparation to read the next batch pub fn clear(&mut self) { - assert!(self.stack.is_empty()); + assert!(!self.has_partial_row()); self.cur_row = 0; self.bytes.clear(); @@ -744,6 +815,15 @@ fn parse_hex(b: u8) -> Result { mod tests { use super::*; + /// Helper method to create a `TapeDecoder` with sensible defaults + fn tape_decoder() -> TapeDecoder { + TapeDecoder::new(TapeDecoderOptions { + batch_size: 16, + num_fields: 2, + flatten_top_level_arrays: false, + }) + } + #[test] fn test_sizes() { assert_eq!(std::mem::size_of::(), 8); @@ -767,7 +847,7 @@ mod tests { {"a": ["", "foo", ["bar", "c"]], "b": {"1": []}, "c": {"2": [1, 2, 3]} } "#; - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(a.as_bytes()).unwrap(); assert!(!decoder.has_partial_row()); assert_eq!(decoder.num_buffered_rows(), 7); @@ -877,21 +957,21 @@ mod tests { #[test] fn test_invalid() { // Test invalid - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); let err = decoder.decode(b"hello").unwrap_err().to_string(); assert_eq!( err, "Json error: Encountered unexpected 'h' whilst parsing value" ); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); let err = decoder.decode(b"{\"hello\": }").unwrap_err().to_string(); assert_eq!( err, "Json error: Encountered unexpected '}' whilst parsing value" ); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); let err = decoder .decode(b"{\"hello\": [ false, tru ]}") .unwrap_err() @@ -901,7 +981,7 @@ mod tests { "Json error: Encountered unexpected ' ' whilst parsing literal" ); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); let err = decoder .decode(b"{\"hello\": \"\\ud8\"}") .unwrap_err() @@ -912,7 +992,7 @@ mod tests { ); // Missing surrogate pair - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); let err = decoder .decode(b"{\"hello\": \"\\ud83d\"}") .unwrap_err() @@ -923,40 +1003,40 @@ mod tests { ); // Test truncation - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(b"{\"he").unwrap(); assert!(decoder.has_partial_row()); assert_eq!(decoder.num_buffered_rows(), 1); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading string"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(b"{\"hello\" : ").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading value"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(b"{\"hello\" : [").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading list"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(b"{\"hello\" : tru").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading true"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(b"{\"hello\" : nu").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading null"); // Test invalid UTF-8 - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(b"{\"hello\" : \"world\xFF\"}").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Encountered non-UTF-8 data"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); decoder.decode(b"{\"\xe2\" : \"\x96\xa1\"}").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Encountered truncated UTF-8 sequence"); @@ -964,12 +1044,197 @@ mod tests { #[test] fn test_invalid_surrogates() { - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); let res = decoder.decode(b"{\"test\": \"\\ud800\\ud801\"}"); assert!(res.is_err()); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = tape_decoder(); let res = decoder.decode(b"{\"test\": \"\\udc00\\udc01\"}"); assert!(res.is_err()); } + + #[test] + fn test_flatten_top_level_arrays() { + let input = r#" + [ + {"hello": "world", "foo": 2, "bar": 45}, + {"a": true, "b": false, "c": null} + ] + [ + {"a": "b", "object": {"nested": "hello", "foo": 23}}, + {"a": ["", "foo", ["bar", "c"]]}, + {"hello": "world", "foo": 2, "bar": 27} + ]"#; + const TOTAL_ROWS: usize = 5; + + // Check that regular decoding returns two rows + let mut decoder = tape_decoder(); + decoder.decode(input.as_bytes()).unwrap(); + assert!(!decoder.has_partial_row()); + assert_eq!(decoder.num_buffered_rows(), 2); + + let expected = TestCase::new() + // {"hello": "world", "foo": 2, "bar": 45} + .start_object(6) + .string("hello") + .string("world") + .string("foo") + .number("2") + .string("bar") + .number("45") + .end_object(6) + // {"a": true, "b": false, "c": null} + .start_object(6) + .string("a") + .r#true() + .string("b") + .r#false() + .string("c") + .null() + .end_object(6) + // {"a": "b", "object": {"nested": "hello", "foo": 23}} + .start_object(9) + .string("a") + .string("b") + .string("object") + .start_object(4) + .string("nested") + .string("hello") + .string("foo") + .number("23") + .end_object(4) + .end_object(9) + // {"a": ["", "foo", ["bar", "c"]]} + .start_object(9) + .string("a") + .start_list(6) + .string("") + .string("foo") + .start_list(2) + .string("bar") + .string("c") + .end_list(2) + .end_list(6) + .end_object(9) + // {"hello": "world", "foo": 2, "bar": 27} + .start_object(6) + .string("hello") + .string("world") + .string("foo") + .number("2") + .string("bar") + .number("27") + .end_object(6); + + // Check that decoding with `flatten_top_level_arrays` yields rows correctly, + // and respects the configured batch size + for batch_size in [1, 2, 3, 4, 8] { + dbg!(batch_size); + + let mut decoder = TapeDecoder::new(TapeDecoderOptions { + batch_size, + num_fields: 2, + flatten_top_level_arrays: true, + }); + decoder.decode(input.as_bytes()).unwrap(); + assert!(!decoder.has_partial_row()); + assert_eq!(decoder.num_buffered_rows(), batch_size.min(TOTAL_ROWS)); + + let finished = decoder.finish().unwrap(); + assert!(!decoder.has_partial_row()); + assert_eq!(decoder.num_buffered_rows(), batch_size.min(TOTAL_ROWS)); // didn't call clear() yet + assert_eq!( + finished.elements, + &expected.elements[..finished.elements.len()] + ); + assert_eq!( + finished.strings, + &expected.strings[..finished.strings.len()] + ); + assert_eq!( + finished.string_offsets, + &expected.string_offsets[..finished.string_offsets.len()] + ); + + decoder.clear(); + assert!(!decoder.has_partial_row()); + assert_eq!(decoder.num_buffered_rows(), 0); + } + } + + /// The expected elements, strings and string offsets for a test case + struct TestCase { + elements: Vec, + strings: String, + string_offsets: Vec, + } + + impl TestCase { + fn new() -> Self { + Self { + elements: vec![TapeElement::Null], + strings: String::new(), + string_offsets: vec![0], + } + } + + fn start_object(mut self, len: usize) -> Self { + let end_idx = (self.elements.len() + len + 1) as u32; + self.elements.push(TapeElement::StartObject(end_idx)); + self + } + + fn end_object(mut self, len: usize) -> Self { + let start_idx = (self.elements.len() - len - 1) as u32; + self.elements.push(TapeElement::EndObject(start_idx)); + self + } + + fn start_list(mut self, len: usize) -> Self { + let end_idx = (self.elements.len() + len + 1) as u32; + self.elements.push(TapeElement::StartList(end_idx)); + self + } + + fn end_list(mut self, len: usize) -> Self { + let start_idx = (self.elements.len() - len - 1) as u32; + self.elements.push(TapeElement::EndList(start_idx)); + self + } + + fn string(mut self, raw: &str) -> Self { + let idx = (self.string_offsets.len() - 1) as u32; + let start = self.strings.len(); + let end = start + raw.len(); + self.elements.push(TapeElement::String(idx)); + self.strings.push_str(raw); + self.string_offsets.push(end); + self + } + + fn number(mut self, raw: &str) -> Self { + let idx = (self.string_offsets.len() - 1) as u32; + let start = self.strings.len(); + let end = start + raw.len(); + self.elements.push(TapeElement::Number(idx)); + self.strings.push_str(raw); + self.string_offsets.push(end); + self + } + + fn r#true(mut self) -> Self { + self.elements.push(TapeElement::True); + self + } + + fn r#false(mut self) -> Self { + self.elements.push(TapeElement::False); + self + } + + fn null(mut self) -> Self { + self.elements.push(TapeElement::Null); + self + } + } }