@@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult};
2121use arrow:: array:: builder:: StringBuilder ;
2222use arrow:: array:: { DictionaryArray , StringArray , StructArray } ;
2323use arrow:: compute:: can_cast_types;
24- use arrow:: datatypes:: { DataType , Schema } ;
24+ use arrow:: datatypes:: { ArrowDictionaryKeyType , ArrowNativeType , DataType , Schema } ;
2525use arrow:: {
2626 array:: {
2727 cast:: AsArray ,
@@ -41,7 +41,8 @@ use arrow::{
4141} ;
4242use chrono:: { DateTime , NaiveDate , TimeZone , Timelike } ;
4343use datafusion:: common:: {
44- cast:: as_generic_string_array, internal_err, Result as DataFusionResult , ScalarValue ,
44+ cast:: as_generic_string_array, internal_err, DataFusionError , Result as DataFusionResult ,
45+ ScalarValue ,
4546} ;
4647use datafusion:: physical_expr:: PhysicalExpr ;
4748use datafusion:: physical_plan:: ColumnarValue ;
@@ -867,6 +868,40 @@ pub fn spark_cast(
867868 }
868869}
869870
871+ // copied from datafusion common scalar/mod.rs
872+ fn dict_from_values < K : ArrowDictionaryKeyType > (
873+ values_array : ArrayRef ,
874+ ) -> datafusion:: common:: Result < ArrayRef > {
875+ // Create a key array with `size` elements of 0..array_len for all
876+ // non-null value elements
877+ let key_array: PrimitiveArray < K > = ( 0 ..values_array. len ( ) )
878+ . map ( |index| {
879+ if values_array. is_valid ( index) {
880+ let native_index = K :: Native :: from_usize ( index) . ok_or_else ( || {
881+ DataFusionError :: Internal ( format ! (
882+ "Can not create index of type {} from value {}" ,
883+ K :: DATA_TYPE ,
884+ index
885+ ) )
886+ } ) ?;
887+ Ok ( Some ( native_index) )
888+ } else {
889+ Ok ( None )
890+ }
891+ } )
892+ . collect :: < datafusion:: common:: Result < Vec < _ > > > ( ) ?
893+ . into_iter ( )
894+ . collect ( ) ;
895+
896+ // create a new DictionaryArray
897+ //
898+ // Note: this path could be made faster by using the ArrayData
899+ // APIs and skipping validation, if it every comes up in
900+ // performance traces.
901+ let dict_array = DictionaryArray :: < K > :: try_new ( key_array, values_array) ?;
902+ Ok ( Arc :: new ( dict_array) )
903+ }
904+
870905fn cast_array (
871906 array : ArrayRef ,
872907 to_type : & DataType ,
@@ -896,18 +931,33 @@ fn cast_array(
896931 . downcast_ref :: < DictionaryArray < Int32Type > > ( )
897932 . expect ( "Expected a dictionary array" ) ;
898933
899- let casted_dictionary = DictionaryArray :: < Int32Type > :: new (
900- dict_array. keys ( ) . clone ( ) ,
901- cast_array ( Arc :: clone ( dict_array. values ( ) ) , to_type, cast_options) ?,
902- ) ;
903-
904934 let casted_result = match to_type {
905- Dictionary ( _, _) => Arc :: new ( casted_dictionary. clone ( ) ) ,
906- _ => take ( casted_dictionary. values ( ) . as_ref ( ) , dict_array. keys ( ) , None ) ?,
935+ Dictionary ( _, to_value_type) => {
936+ let casted_dictionary = DictionaryArray :: < Int32Type > :: new (
937+ dict_array. keys ( ) . clone ( ) ,
938+ cast_array ( Arc :: clone ( dict_array. values ( ) ) , to_value_type, cast_options) ?,
939+ ) ;
940+ Arc :: new ( casted_dictionary. clone ( ) )
941+ }
942+ _ => {
943+ let casted_dictionary = DictionaryArray :: < Int32Type > :: new (
944+ dict_array. keys ( ) . clone ( ) ,
945+ cast_array ( Arc :: clone ( dict_array. values ( ) ) , to_type, cast_options) ?,
946+ ) ;
947+ take ( casted_dictionary. values ( ) . as_ref ( ) , dict_array. keys ( ) , None ) ?
948+ }
907949 } ;
908950 return Ok ( spark_cast_postprocess ( casted_result, & from_type, to_type) ) ;
909951 }
910- _ => array,
952+ _ => {
953+ if let Dictionary ( _, _) = to_type {
954+ let dict_array = dict_from_values :: < Int32Type > ( array) ?;
955+ let casted_result = cast_array ( dict_array, to_type, cast_options) ?;
956+ return Ok ( spark_cast_postprocess ( casted_result, & from_type, to_type) ) ;
957+ } else {
958+ array
959+ }
960+ }
911961 } ;
912962 let from_type = array. data_type ( ) ;
913963 let eval_mode = cast_options. eval_mode ;
0 commit comments