Skip to content

Commit f826b65

Browse files
authored
feat: support concat for strings (#2604)
1 parent 2c307b7 commit f826b65

File tree

50 files changed

+6898
-7689
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+6898
-7689
lines changed

docs/source/user-guide/latest/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ These settings can be used to determine which parts of the plan are accelerated
201201
| `spark.comet.expression.CheckOverflow.enabled` | Enable Comet acceleration for `CheckOverflow` | true |
202202
| `spark.comet.expression.Chr.enabled` | Enable Comet acceleration for `Chr` | true |
203203
| `spark.comet.expression.Coalesce.enabled` | Enable Comet acceleration for `Coalesce` | true |
204+
| `spark.comet.expression.Concat.enabled` | Enable Comet acceleration for `Concat` | true |
204205
| `spark.comet.expression.ConcatWs.enabled` | Enable Comet acceleration for `ConcatWs` | true |
205206
| `spark.comet.expression.Contains.enabled` | Enable Comet acceleration for `Contains` | true |
206207
| `spark.comet.expression.Cos.enabled` | Enable Comet acceleration for `Cos` | true |

fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,11 @@ object Meta {
126126
SparkTypeOneOf(
127127
Seq(
128128
SparkStringType,
129-
SparkNumericType,
130-
SparkBinaryType,
131129
SparkArrayType(
132130
SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, SparkBinaryType))))),
133131
SparkTypeOneOf(
134132
Seq(
135133
SparkStringType,
136-
SparkNumericType,
137-
SparkBinaryType,
138134
SparkArrayType(
139135
SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, SparkBinaryType))))))),
140136
createFunctionWithInputTypes("concat_ws", Seq(SparkStringType, SparkStringType)),

native/core/src/execution/jni_api.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use datafusion_spark::function::hash::sha1::SparkSha1;
4747
use datafusion_spark::function::hash::sha2::SparkSha2;
4848
use datafusion_spark::function::math::expm1::SparkExpm1;
4949
use datafusion_spark::function::string::char::CharFunc;
50+
use datafusion_spark::function::string::concat::SparkConcat;
5051
use futures::poll;
5152
use futures::stream::StreamExt;
5253
use jni::objects::JByteBuffer;
@@ -317,20 +318,23 @@ fn prepare_datafusion_session_context(
317318
let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime));
318319

319320
datafusion::functions_nested::register_all(&mut session_ctx)?;
321+
register_datafusion_spark_function(&session_ctx);
322+
// Must be the last one to override existing functions with the same name
323+
datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?;
324+
325+
Ok(session_ctx)
326+
}
320327

321-
// register UDFs from datafusion-spark crate
328+
// register UDFs from datafusion-spark crate
329+
fn register_datafusion_spark_function(session_ctx: &SessionContext) {
322330
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
323331
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
324332
session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));
325333
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default()));
326334
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default()));
327335
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default()));
328336
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default()));
329-
330-
// Must be the last one to override existing functions with the same name
331-
datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?;
332-
333-
Ok(session_ctx)
337+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default()));
334338
}
335339

336340
/// Prepares arrow arrays for output.

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
164164
classOf[BitLength] -> CometScalarFunction("bit_length"),
165165
classOf[Chr] -> CometScalarFunction("char"),
166166
classOf[ConcatWs] -> CometScalarFunction("concat_ws"),
167+
classOf[Concat] -> CometConcat,
167168
classOf[Contains] -> CometScalarFunction("contains"),
168169
classOf[EndsWith] -> CometScalarFunction("ends_with"),
169170
classOf[InitCap] -> CometInitCap,

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
2525
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
2626

2727
import org.apache.comet.CometConf
@@ -113,6 +113,18 @@ object CometSubstring extends CometExpressionSerde[Substring] {
113113
}
114114
}
115115

116+
object CometConcat extends CometScalarFunction[Concat]("concat") {
117+
val unsupportedReason = "CONCAT supports only string input parameters"
118+
119+
override def getSupportLevel(expr: Concat): SupportLevel = {
120+
if (expr.children.forall(_.dataType == DataTypes.StringType)) {
121+
Compatible()
122+
} else {
123+
Incompatible(Some(unsupportedReason))
124+
}
125+
}
126+
}
127+
116128
object CometLike extends CometExpressionSerde[Like] {
117129

118130
override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {

0 commit comments

Comments
 (0)