Skip to content

Commit ba228e1

Browse files
authored
fix: handle cast to dictionary vector introduced by case when (apache#2044)
1 parent a907d7d commit ba228e1

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
2222
use arrow::array::{DictionaryArray, StringArray, StructArray};
2323
use arrow::compute::can_cast_types;
24-
use arrow::datatypes::{DataType, Schema};
24+
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
2525
use arrow::{
2626
array::{
2727
cast::AsArray,
@@ -41,7 +41,8 @@ use arrow::{
4141
};
4242
use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
4343
use datafusion::common::{
44-
cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue,
44+
cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult,
45+
ScalarValue,
4546
};
4647
use datafusion::physical_expr::PhysicalExpr;
4748
use datafusion::physical_plan::ColumnarValue;
@@ -867,6 +868,40 @@ pub fn spark_cast(
867868
}
868869
}
869870

871+
// copied from datafusion common scalar/mod.rs
872+
fn dict_from_values<K: ArrowDictionaryKeyType>(
873+
values_array: ArrayRef,
874+
) -> datafusion::common::Result<ArrayRef> {
875+
// Create a key array with `size` elements of 0..array_len for all
876+
// non-null value elements
877+
let key_array: PrimitiveArray<K> = (0..values_array.len())
878+
.map(|index| {
879+
if values_array.is_valid(index) {
880+
let native_index = K::Native::from_usize(index).ok_or_else(|| {
881+
DataFusionError::Internal(format!(
882+
"Can not create index of type {} from value {}",
883+
K::DATA_TYPE,
884+
index
885+
))
886+
})?;
887+
Ok(Some(native_index))
888+
} else {
889+
Ok(None)
890+
}
891+
})
892+
.collect::<datafusion::common::Result<Vec<_>>>()?
893+
.into_iter()
894+
.collect();
895+
896+
// create a new DictionaryArray
897+
//
898+
// Note: this path could be made faster by using the ArrayData
899+
// APIs and skipping validation, if it every comes up in
900+
// performance traces.
901+
let dict_array = DictionaryArray::<K>::try_new(key_array, values_array)?;
902+
Ok(Arc::new(dict_array))
903+
}
904+
870905
fn cast_array(
871906
array: ArrayRef,
872907
to_type: &DataType,
@@ -896,18 +931,33 @@ fn cast_array(
896931
.downcast_ref::<DictionaryArray<Int32Type>>()
897932
.expect("Expected a dictionary array");
898933

899-
let casted_dictionary = DictionaryArray::<Int32Type>::new(
900-
dict_array.keys().clone(),
901-
cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
902-
);
903-
904934
let casted_result = match to_type {
905-
Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
906-
_ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?,
935+
Dictionary(_, to_value_type) => {
936+
let casted_dictionary = DictionaryArray::<Int32Type>::new(
937+
dict_array.keys().clone(),
938+
cast_array(Arc::clone(dict_array.values()), to_value_type, cast_options)?,
939+
);
940+
Arc::new(casted_dictionary.clone())
941+
}
942+
_ => {
943+
let casted_dictionary = DictionaryArray::<Int32Type>::new(
944+
dict_array.keys().clone(),
945+
cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
946+
);
947+
take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?
948+
}
907949
};
908950
return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
909951
}
910-
_ => array,
952+
_ => {
953+
if let Dictionary(_, _) = to_type {
954+
let dict_array = dict_from_values::<Int32Type>(array)?;
955+
let casted_result = cast_array(dict_array, to_type, cast_options)?;
956+
return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
957+
} else {
958+
array
959+
}
960+
}
911961
};
912962
let from_type = array.data_type();
913963
let eval_mode = cast_options.eval_mode;

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,16 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10631063
withNulls(gen.generateLongs(dataSize)).toDF("a")
10641064
}
10651065

1066+
// https://github.com/apache/datafusion-comet/issues/2038
1067+
test("test implicit cast to dictionary with case when and dictionary type") {
1068+
withSQLConf("parquet.enable.dictionary" -> "true") {
1069+
withParquetTable((0 until 10000).map(i => (i < 5000, "one")), "tbl") {
1070+
val df = spark.sql("select case when (_1 = true) then _2 else '' end as aaa from tbl")
1071+
checkSparkAnswerAndOperator(df)
1072+
}
1073+
}
1074+
}
1075+
10661076
private def generateDecimalsPrecision10Scale2(): DataFrame = {
10671077
val values = Seq(
10681078
BigDecimal("-99999999.999"),

0 commit comments

Comments
 (0)