diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 7db9ade5746..f668195abc7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -24,7 +24,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar} +import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -1242,27 +1242,26 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + def multiOrsAst: ast.AstExpression = { + (1 until targets.length) + .foldLeft(new ast.ColumnReference(0).asInstanceOf[ast.AstExpression]) { (acc, id) => + new ast.BinaryOperation(ast.BinaryOperator.NULL_LOGICAL_OR, acc, new ast.ColumnReference(id)) + } + } + override def doColumnar(input: GpuColumnVector): ColumnVector = { val targetsBytes = targets.map(t => t.getBytes).toArray val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => input.getBase.stringContains(targetsCv) } - var ret: ColumnVector = null - withResource(boolCvs) { _ => - boolCvs.indices.foreach { i => - if (ret == null) { - ret = boolCvs(i) - boolCvs(i) = null - } else { - val tmp = ret.or(boolCvs(i)) - ret.close() - ret = tmp - boolCvs(i).close() - boolCvs(i) = null - } + val boolTable = withResource(boolCvs) { _ => + new Table(boolCvs: _*) + } + withResource(boolTable) { _ => + withResource(multiOrsAst.compile()) { compiledAst => + compiledAst.computeColumn(boolTable) } } - ret } }