diff --git a/parquet/src/record/reader.rs b/parquet/src/record/reader.rs index a6b8d2d54cd7..bee49c70d81d 100644 --- a/parquet/src/record/reader.rs +++ b/parquet/src/record/reader.rs @@ -437,11 +437,17 @@ impl Reader { fn read_field(&mut self) -> Result { let field = match *self { Reader::PrimitiveReader(_, ref mut column) => { + if !column.has_next() { + return Err(general_err!("Unexpected end of column data")); + } let value = column.current_value()?; column.read_next()?; value } Reader::OptionReader(def_level, ref mut reader) => { + if !reader.has_next() { + return Err(general_err!("Unexpected end of column data")); + } if reader.current_def_level() > def_level { reader.read_field()? } else { @@ -465,6 +471,9 @@ impl Reader { Field::Group(row) } Reader::RepeatedReader(_, def_level, rep_level, ref mut reader) => { + if !reader.has_next() { + return Err(general_err!("Unexpected end of column data")); + } let mut elements = Vec::new(); loop { if reader.current_def_level() > def_level { @@ -488,6 +497,9 @@ impl Reader { Field::ListInternal(make_list(elements)) } Reader::KeyValueReader(_, def_level, rep_level, ref mut keys, ref mut values) => { + if !keys.has_next() { + return Err(general_err!("Unexpected end of column data")); + } let mut pairs = Vec::new(); loop { if keys.current_def_level() > def_level { @@ -1912,4 +1924,57 @@ mod tests { ),]]; assert_eq!(rows, expected_rows); } + + fn assert_err_on_overcount(file_name: &str, proj_schema: Option) { + let file = get_test_file(file_name); + let file_reader = SerializedFileReader::new(file).unwrap(); + let metadata = file_reader.metadata(); + let row_group_reader = file_reader.get_row_group(0).unwrap(); + let actual_rows = row_group_reader.metadata().num_rows() as usize; + + let descr = match proj_schema { + Some(schema) => Arc::new(SchemaDescriptor::new(Arc::new(schema))), + None => metadata.file_metadata().schema_descr_ptr(), + }; + let reader = TreeBuilder::new().build(descr, &*row_group_reader).unwrap(); + let iter = ReaderIter::new(reader, actual_rows + 1).unwrap(); + + let rows: Vec> = iter.collect(); + assert_eq!(rows.len(), actual_rows + 1); + for row in &rows[..actual_rows] { + assert!(row.is_ok(), "Expected Ok row, got: {:?}", row); + } + let err = rows[actual_rows].as_ref().unwrap_err(); + assert!( + err.to_string().contains("Unexpected end of column data"), + "Unexpected error message: {}", + err + ); + } + + #[test] + fn test_reader_iter_returns_error_when_num_records_exceeds_data() { + assert_err_on_overcount("nulls.snappy.parquet", None); + } + + #[test] + fn test_reader_iter_returns_error_for_repeated_field_when_num_records_exceeds_data() { + assert_err_on_overcount("repeated_primitive_no_list.parquet", None); + } + + #[test] + fn test_reader_iter_returns_error_for_map_field_when_num_records_exceeds_data() { + let schema = parse_message_type( + "message schema { + REQUIRED group my_map (MAP) { + REPEATED group key_value { + REQUIRED INT32 key; + OPTIONAL INT32 value; + } + } + }", + ) + .unwrap(); + assert_err_on_overcount("map_no_value.parquet", Some(schema)); + } } diff --git a/parquet/src/record/triplet.rs b/parquet/src/record/triplet.rs index 8244dfb12823..b4d39bbbd9c9 100644 --- a/parquet/src/record/triplet.rs +++ b/parquet/src/record/triplet.rs @@ -263,6 +263,9 @@ impl TypedTripletIter { /// If field is required, then maximum definition level is returned. #[inline] fn current_def_level(&self) -> i16 { + if !self.has_next { + return 0; + } match self.def_levels { Some(ref vec) => vec[self.curr_triplet_index], None => self.max_def_level, @@ -273,6 +276,9 @@ impl TypedTripletIter { /// If field is required, then maximum repetition level is returned. #[inline] fn current_rep_level(&self) -> i16 { + if !self.has_next { + return 0; + } match self.rep_levels { Some(ref vec) => vec[self.curr_triplet_index], None => self.max_rep_level, @@ -315,6 +321,7 @@ impl TypedTripletIter { // No more values or levels to read if records_read == 0 && values_read == 0 && levels_read == 0 { + self.curr_triplet_index = 0; self.has_next = false; return Ok(false); } @@ -561,4 +568,41 @@ mod tests { assert_eq!(def_levels, expected_def_levels); assert_eq!(rep_levels, expected_rep_levels); } + + fn open_triplet_iter(file_name: &str, path: &[&str], batch_size: usize) -> TripletIter { + let column_path = ColumnPath::from(path.iter().map(|x| x.to_string()).collect::>()); + let file = get_test_file(file_name); + let file_reader = SerializedFileReader::new(file).unwrap(); + let metadata = file_reader.metadata(); + let schema = metadata.file_metadata().schema_descr(); + let row_group_reader = file_reader.get_row_group(0).unwrap(); + for i in 0..schema.num_columns() { + let descr = schema.column(i); + if descr.path() == &column_path { + let reader = row_group_reader.get_column_reader(i).unwrap(); + return TripletIter::new(descr.clone(), reader, batch_size); + } + } + panic!("Column {column_path:?} not found in {file_name}"); + } + + #[test] + fn test_current_def_level_safe_after_exhaustion() { + let mut iter = open_triplet_iter("nulls.snappy.parquet", &["b_struct", "b_c_int"], 256); + while let Ok(true) = iter.read_next() {} + assert!(!iter.has_next()); + assert_eq!(iter.current_def_level(), 0); + } + + #[test] + fn test_current_rep_level_safe_after_exhaustion() { + let mut iter = open_triplet_iter( + "nested_lists.snappy.parquet", + &["a", "list", "element", "list", "element", "list", "element"], + 256, + ); + while let Ok(true) = iter.read_next() {} + assert!(!iter.has_next()); + assert_eq!(iter.current_rep_level(), 0); + } }