Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions parquet/src/record/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,17 @@ impl Reader {
fn read_field(&mut self) -> Result<Field> {
let field = match *self {
Reader::PrimitiveReader(_, ref mut column) => {
if !column.has_next() {
return Err(general_err!("Unexpected end of column data"));
}
let value = column.current_value()?;
column.read_next()?;
value
}
Reader::OptionReader(def_level, ref mut reader) => {
if !reader.has_next() {
return Err(general_err!("Unexpected end of column data"));
}
if reader.current_def_level() > def_level {
reader.read_field()?
} else {
Expand All @@ -465,6 +471,9 @@ impl Reader {
Field::Group(row)
}
Reader::RepeatedReader(_, def_level, rep_level, ref mut reader) => {
if !reader.has_next() {
return Err(general_err!("Unexpected end of column data"));
}
let mut elements = Vec::new();
loop {
if reader.current_def_level() > def_level {
Expand All @@ -488,6 +497,9 @@ impl Reader {
Field::ListInternal(make_list(elements))
}
Reader::KeyValueReader(_, def_level, rep_level, ref mut keys, ref mut values) => {
if !keys.has_next() {
return Err(general_err!("Unexpected end of column data"));
}
let mut pairs = Vec::new();
loop {
if keys.current_def_level() > def_level {
Expand Down Expand Up @@ -1912,4 +1924,57 @@ mod tests {
),]];
assert_eq!(rows, expected_rows);
}

fn assert_err_on_overcount(file_name: &str, proj_schema: Option<Type>) {
let file = get_test_file(file_name);
let file_reader = SerializedFileReader::new(file).unwrap();
let metadata = file_reader.metadata();
let row_group_reader = file_reader.get_row_group(0).unwrap();
let actual_rows = row_group_reader.metadata().num_rows() as usize;

let descr = match proj_schema {
Some(schema) => Arc::new(SchemaDescriptor::new(Arc::new(schema))),
None => metadata.file_metadata().schema_descr_ptr(),
};
let reader = TreeBuilder::new().build(descr, &*row_group_reader).unwrap();
let iter = ReaderIter::new(reader, actual_rows + 1).unwrap();

let rows: Vec<Result<Row>> = iter.collect();
assert_eq!(rows.len(), actual_rows + 1);
for row in &rows[..actual_rows] {
assert!(row.is_ok(), "Expected Ok row, got: {:?}", row);
}
let err = rows[actual_rows].as_ref().unwrap_err();
assert!(
err.to_string().contains("Unexpected end of column data"),
"Unexpected error message: {}",
err
);
}

#[test]
fn test_reader_iter_returns_error_when_num_records_exceeds_data() {
assert_err_on_overcount("nulls.snappy.parquet", None);
}

#[test]
fn test_reader_iter_returns_error_for_repeated_field_when_num_records_exceeds_data() {
assert_err_on_overcount("repeated_primitive_no_list.parquet", None);
}

#[test]
fn test_reader_iter_returns_error_for_map_field_when_num_records_exceeds_data() {
let schema = parse_message_type(
"message schema {
REQUIRED group my_map (MAP) {
REPEATED group key_value {
REQUIRED INT32 key;
OPTIONAL INT32 value;
}
}
}",
)
.unwrap();
assert_err_on_overcount("map_no_value.parquet", Some(schema));
}
}
44 changes: 44 additions & 0 deletions parquet/src/record/triplet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ impl<T: DataType> TypedTripletIter<T> {
/// If field is required, then maximum definition level is returned.
#[inline]
fn current_def_level(&self) -> i16 {
if !self.has_next {
return 0;
}
match self.def_levels {
Some(ref vec) => vec[self.curr_triplet_index],
None => self.max_def_level,
Expand All @@ -273,6 +276,9 @@ impl<T: DataType> TypedTripletIter<T> {
/// If field is required, then maximum repetition level is returned.
#[inline]
fn current_rep_level(&self) -> i16 {
if !self.has_next {
return 0;
}
match self.rep_levels {
Some(ref vec) => vec[self.curr_triplet_index],
None => self.max_rep_level,
Expand Down Expand Up @@ -315,6 +321,7 @@ impl<T: DataType> TypedTripletIter<T> {

// No more values or levels to read
if records_read == 0 && values_read == 0 && levels_read == 0 {
self.curr_triplet_index = 0;
self.has_next = false;
return Ok(false);
}
Expand Down Expand Up @@ -561,4 +568,41 @@ mod tests {
assert_eq!(def_levels, expected_def_levels);
assert_eq!(rep_levels, expected_rep_levels);
}

fn open_triplet_iter(file_name: &str, path: &[&str], batch_size: usize) -> TripletIter {
let column_path = ColumnPath::from(path.iter().map(|x| x.to_string()).collect::<Vec<_>>());
let file = get_test_file(file_name);
let file_reader = SerializedFileReader::new(file).unwrap();
let metadata = file_reader.metadata();
let schema = metadata.file_metadata().schema_descr();
let row_group_reader = file_reader.get_row_group(0).unwrap();
for i in 0..schema.num_columns() {
let descr = schema.column(i);
if descr.path() == &column_path {
let reader = row_group_reader.get_column_reader(i).unwrap();
return TripletIter::new(descr.clone(), reader, batch_size);
}
}
panic!("Column {column_path:?} not found in {file_name}");
}

#[test]
fn test_current_def_level_safe_after_exhaustion() {
let mut iter = open_triplet_iter("nulls.snappy.parquet", &["b_struct", "b_c_int"], 256);
while let Ok(true) = iter.read_next() {}
assert!(!iter.has_next());
assert_eq!(iter.current_def_level(), 0);
}

#[test]
fn test_current_rep_level_safe_after_exhaustion() {
let mut iter = open_triplet_iter(
"nested_lists.snappy.parquet",
&["a", "list", "element", "list", "element", "list", "element"],
256,
);
while let Ok(true) = iter.read_next() {}
assert!(!iter.has_next());
assert_eq!(iter.current_rep_level(), 0);
}
}
Loading