Skip to content

Commit 5623bcf

Browse files
authored
Fix panics in array_union (#287)
* Drop rust-toolchain * Fix panics in array_union * Fix the chrono
1 parent 0123a16 commit 5623bcf

File tree

4 files changed

+107
-125
lines changed

4 files changed

+107
-125
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ arrow-string = { version = "52.0.0", default-features = false }
8383
async-trait = "0.1.73"
8484
bigdecimal = "=0.4.1"
8585
bytes = "1.4"
86-
chrono = { version = "0.4.34", default-features = false }
86+
chrono = { version = ">=0.4.34,<0.4.40", default-features = false }
8787
ctor = "0.2.0"
8888
dashmap = "5.5.0"
8989
datafusion = { path = "datafusion/core", version = "39.0.0", default-features = false }

datafusion/functions-array/src/set_ops.rs

Lines changed: 96 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717

1818
//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
1919
20-
use crate::make_array::{empty_array_type, make_array_inner};
2120
use crate::utils::make_scalar_function;
22-
use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait};
21+
use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait};
2322
use arrow::buffer::OffsetBuffer;
2423
use arrow::compute;
2524
use arrow::datatypes::{DataType, Field, FieldRef};
2625
use arrow::row::{RowConverter, SortField};
26+
use arrow_array::{new_null_array, LargeListArray, ListArray};
2727
use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null};
2828
use datafusion_common::cast::{as_large_list_array, as_list_array};
29-
use datafusion_common::{exec_err, internal_err, Result};
29+
use datafusion_common::{exec_err, internal_err, plan_err, Result};
3030
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
3131
use itertools::Itertools;
3232
use std::any::Any;
@@ -89,7 +89,8 @@ impl ScalarUDFImpl for ArrayUnion {
8989

9090
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
9191
match (&arg_types[0], &arg_types[1]) {
92-
(&Null, dt) => Ok(dt.clone()),
92+
(Null, Null) => Ok(DataType::new_list(Null, true)),
93+
(Null, dt) => Ok(dt.clone()),
9394
(dt, Null) => Ok(dt.clone()),
9495
(dt, _) => Ok(dt.clone()),
9596
}
@@ -134,9 +135,10 @@ impl ScalarUDFImpl for ArrayIntersect {
134135

135136
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136137
match (arg_types[0].clone(), arg_types[1].clone()) {
137-
(Null, Null) | (Null, _) => Ok(Null),
138-
(_, Null) => Ok(empty_array_type()),
139-
(dt, _) => Ok(dt),
138+
(Null, Null) => Ok(DataType::new_list(Null, true)),
139+
(Null, dt) => Ok(dt.clone()),
140+
(dt, Null) => Ok(dt.clone()),
141+
(dt, _) => Ok(dt.clone()),
140142
}
141143
}
142144

@@ -179,19 +181,13 @@ impl ScalarUDFImpl for ArrayDistinct {
179181

180182
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
181183
match &arg_types[0] {
182-
List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new(
183-
"item",
184-
field.data_type().clone(),
185-
true,
186-
)))),
187-
LargeList(field) => Ok(LargeList(Arc::new(Field::new(
188-
"item",
189-
field.data_type().clone(),
190-
true,
191-
)))),
192-
_ => exec_err!(
193-
"Not reachable, data_type should be List, LargeList or FixedSizeList"
194-
),
184+
List(field) | FixedSizeList(field, _) => {
185+
Ok(DataType::new_list(field.data_type().clone(), true))
186+
}
187+
LargeList(field) => {
188+
Ok(DataType::new_large_list(field.data_type().clone(), true))
189+
}
190+
arg_type => plan_err!("{} does not support type {arg_type}", self.name()),
195191
}
196192
}
197193

@@ -211,22 +207,18 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
211207
return exec_err!("array_distinct needs one argument");
212208
}
213209

214-
// handle null
215-
if args[0].data_type() == &Null {
216-
return Ok(args[0].clone());
217-
}
218-
219-
// handle for list & largelist
220-
match args[0].data_type() {
210+
let array = &args[0];
211+
match array.data_type() {
212+
Null => Ok(Arc::clone(array)),
221213
List(field) => {
222-
let array = as_list_array(&args[0])?;
214+
let array = as_list_array(array)?;
223215
general_array_distinct(array, field)
224216
}
225217
LargeList(field) => {
226-
let array = as_large_list_array(&args[0])?;
218+
let array = as_large_list_array(array)?;
227219
general_array_distinct(array, field)
228220
}
229-
array_type => exec_err!("array_distinct does not support type '{array_type:?}'"),
221+
arg_type => exec_err!("array_distinct does not support type '{arg_type:?}'"),
230222
}
231223
}
232224

@@ -251,80 +243,69 @@ fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
251243
field: Arc<Field>,
252244
set_op: SetOp,
253245
) -> Result<ArrayRef> {
254-
if matches!(l.value_type(), Null) {
246+
if l.is_empty() || l.value_type().is_null() {
255247
let field = Arc::new(Field::new("item", r.value_type(), true));
256248
return general_array_distinct::<OffsetSize>(r, &field);
257-
} else if matches!(r.value_type(), Null) {
249+
} else if r.is_empty() || r.value_type().is_null() {
258250
let field = Arc::new(Field::new("item", l.value_type(), true));
259251
return general_array_distinct::<OffsetSize>(l, &field);
260252
}
261253

262-
// Handle empty array at rhs case
263-
// array_union(arr, []) -> arr;
264-
// array_intersect(arr, []) -> [];
265-
if r.value_length(0).is_zero() {
266-
if set_op == SetOp::Union {
267-
return Ok(Arc::new(l.clone()) as ArrayRef);
268-
} else {
269-
return Ok(Arc::new(r.clone()) as ArrayRef);
270-
}
271-
}
272-
273254
if l.value_type() != r.value_type() {
274-
return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'");
255+
return internal_err!(
256+
"{set_op} is not implemented for {} and {}",
257+
l.data_type(),
258+
r.data_type()
259+
);
275260
}
276261

277-
let dt = l.value_type();
278-
279262
let mut offsets = vec![OffsetSize::usize_as(0)];
280263
let mut new_arrays = vec![];
281-
282-
let converter = RowConverter::new(vec![SortField::new(dt)])?;
264+
let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;
283265
for (first_arr, second_arr) in l.iter().zip(r.iter()) {
284-
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
285-
let l_values = converter.convert_columns(&[first_arr])?;
286-
let r_values = converter.convert_columns(&[second_arr])?;
287-
288-
let l_iter = l_values.iter().sorted().dedup();
289-
let values_set: HashSet<_> = l_iter.clone().collect();
290-
let mut rows = if set_op == SetOp::Union {
291-
l_iter.collect::<Vec<_>>()
292-
} else {
293-
vec![]
294-
};
295-
for r_val in r_values.iter().sorted().dedup() {
296-
match set_op {
297-
SetOp::Union => {
298-
if !values_set.contains(&r_val) {
299-
rows.push(r_val);
300-
}
301-
}
302-
SetOp::Intersect => {
303-
if values_set.contains(&r_val) {
304-
rows.push(r_val);
305-
}
306-
}
307-
}
308-
}
266+
let l_values = if let Some(first_arr) = first_arr {
267+
converter.convert_columns(&[first_arr])?
268+
} else {
269+
converter.convert_columns(&[])?
270+
};
271+
272+
let r_values = if let Some(second_arr) = second_arr {
273+
converter.convert_columns(&[second_arr])?
274+
} else {
275+
converter.convert_columns(&[])?
276+
};
277+
278+
let l_iter = l_values.iter().sorted().dedup();
279+
let values_set: HashSet<_> = l_iter.clone().collect();
280+
let mut rows = if set_op == SetOp::Union {
281+
l_iter.collect()
282+
} else {
283+
vec![]
284+
};
309285

310-
let last_offset = match offsets.last().copied() {
311-
Some(offset) => offset,
312-
None => return internal_err!("offsets should not be empty"),
313-
};
314-
offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
315-
let arrays = converter.convert_rows(rows)?;
316-
let array = match arrays.first() {
317-
Some(array) => array.clone(),
318-
None => {
319-
return internal_err!("{set_op}: failed to get array from rows");
320-
}
321-
};
322-
new_arrays.push(array);
286+
for r_val in r_values.iter().sorted().dedup() {
287+
match set_op {
288+
SetOp::Union if !values_set.contains(&r_val) => rows.push(r_val),
289+
SetOp::Intersect if values_set.contains(&r_val) => rows.push(r_val),
290+
_ => (),
291+
}
323292
}
293+
294+
let last_offset = match offsets.last() {
295+
Some(offset) => *offset,
296+
None => return internal_err!("offsets should not be empty"),
297+
};
298+
299+
offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
300+
let arrays = converter.convert_rows(rows)?;
301+
new_arrays.push(match arrays.first() {
302+
Some(array) => Arc::clone(array),
303+
None => return internal_err!("{set_op}: failed to get array from rows"),
304+
});
324305
}
325306

326307
let offsets = OffsetBuffer::new(offsets.into());
327-
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
308+
let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect();
328309
let values = compute::concat(&new_arrays_ref)?;
329310
let arr = GenericListArray::<OffsetSize>::try_new(field, offsets, values, None)?;
330311
Ok(Arc::new(arr))
@@ -335,38 +316,60 @@ fn general_set_op(
335316
array2: &ArrayRef,
336317
set_op: SetOp,
337318
) -> Result<ArrayRef> {
319+
fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result<ArrayRef> {
320+
let field = Arc::new(Field::new_list_field(data_type.clone(), true));
321+
let values = new_null_array(data_type, len);
322+
if large {
323+
Ok(Arc::new(LargeListArray::try_new(
324+
field,
325+
OffsetBuffer::new_zeroed(len),
326+
values,
327+
None,
328+
)?))
329+
} else {
330+
Ok(Arc::new(ListArray::try_new(
331+
field,
332+
OffsetBuffer::new_zeroed(len),
333+
values,
334+
None,
335+
)?))
336+
}
337+
}
338+
338339
match (array1.data_type(), array2.data_type()) {
340+
(Null, Null) => Ok(Arc::new(ListArray::new_null(
341+
Arc::new(Field::new_list_field(Null, true)),
342+
array1.len(),
343+
))),
339344
(Null, List(field)) => {
340345
if set_op == SetOp::Intersect {
341-
return Ok(new_empty_array(&Null));
346+
return empty_array(field.data_type(), array1.len(), false);
342347
}
343348
let array = as_list_array(&array2)?;
344349
general_array_distinct::<i32>(array, field)
345350
}
346351

347352
(List(field), Null) => {
348353
if set_op == SetOp::Intersect {
349-
return make_array_inner(&[]);
354+
return empty_array(field.data_type(), array1.len(), false);
350355
}
351356
let array = as_list_array(&array1)?;
352357
general_array_distinct::<i32>(array, field)
353358
}
354359
(Null, LargeList(field)) => {
355360
if set_op == SetOp::Intersect {
356-
return Ok(new_empty_array(&Null));
361+
return empty_array(field.data_type(), array1.len(), true);
357362
}
358363
let array = as_large_list_array(&array2)?;
359364
general_array_distinct::<i64>(array, field)
360365
}
361366
(LargeList(field), Null) => {
362367
if set_op == SetOp::Intersect {
363-
return make_array_inner(&[]);
368+
return empty_array(field.data_type(), array1.len(), true);
364369
}
365370
let array = as_large_list_array(&array1)?;
366371
general_array_distinct::<i64>(array, field)
367372
}
368-
(Null, Null) => Ok(new_empty_array(&Null)),
369-
370373
(List(field), List(_)) => {
371374
let array1 = as_list_array(&array1)?;
372375
let array2 = as_list_array(&array2)?;

datafusion/sqllogictest/test_files/array.slt

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3820,21 +3820,24 @@ select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5,
38203820
statement ok
38213821
CREATE TABLE arrays_with_repeating_elements_for_union
38223822
AS VALUES
3823-
([1], [2]),
3823+
([0, 1, 1], []),
3824+
([1, 1], [2]),
38243825
([2, 3], [3]),
38253826
([3], [3, 4])
38263827
;
38273828

38283829
query ?
38293830
select array_union(column1, column2) from arrays_with_repeating_elements_for_union;
38303831
----
3832+
[0, 1]
38313833
[1, 2]
38323834
[2, 3]
38333835
[3, 4]
38343836

38353837
query ?
38363838
select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union;
38373839
----
3840+
[0, 1]
38383841
[1, 2]
38393842
[2, 3]
38403843
[3, 4]
@@ -3854,15 +3857,11 @@ select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList
38543857
[]
38553858

38563859
# array_union scalar function #7
3857-
query ?
3860+
query error DataFusion error: Internal error: array_union is not implemented for
38583861
select array_union([[null]], []);
3859-
----
3860-
[[]]
38613862

3862-
query ?
3863+
query error DataFusion error: Internal error: array_union is not implemented for
38633864
select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)'));
3864-
----
3865-
[[]]
38663865

38673866
# array_union scalar function #8
38683867
query ?
@@ -5530,12 +5529,12 @@ select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null)
55305529
query ?
55315530
select array_intersect(null, [1, 1, 2, 2, 3, 3]);
55325531
----
5533-
NULL
5532+
[]
55345533

55355534
query ?
55365535
select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'));
55375536
----
5538-
NULL
5537+
[]
55395538

55405539
query ?
55415540
select array_intersect([], null);
@@ -5560,12 +5559,12 @@ select array_intersect(arrow_cast([], 'LargeList(Int64)'), null);
55605559
query ?
55615560
select array_intersect(null, []);
55625561
----
5563-
NULL
5562+
[]
55645563

55655564
query ?
55665565
select array_intersect(null, arrow_cast([], 'LargeList(Int64)'));
55675566
----
5568-
NULL
5567+
[]
55695568

55705569
query ?
55715570
select array_intersect(null, null);

rust-toolchain.toml

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)