From cedcd58adf9669aac27021be36d660c3c91a6c9c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 3 Dec 2024 14:36:44 +0800 Subject: [PATCH 1/8] use multiple contains in rlike rewrite Signed-off-by: Haoyang Li --- .../com/nvidia/spark/rapids/RegexParser.scala | 11 +++++---- .../spark/sql/rapids/stringFunctions.scala | 24 +++++++++---------- .../RegularExpressionRewriteSuite.scala | 16 +++++++------ 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 89fd5bf9191..2b0b46f55ea 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ListBuffer import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars import com.nvidia.spark.rapids.RegexParser.toReadableString +import org.apache.spark.unsafe.types.UTF8String + /** * Regular expression parser based on a Pratt Parser design. * @@ -1988,7 +1990,7 @@ object RegexOptimizationType { case class Contains(literal: String) extends RegexOptimizationType case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int) extends RegexOptimizationType - case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType + case class MultipleContains(literals: Seq[UTF8String]) extends RegexOptimizationType case object NoOptimization extends RegexOptimizationType } @@ -2057,16 +2059,17 @@ object RegexRewrite { } } - private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = { + private def getMultipleContainsLiterals(ast: RegexAST): Seq[UTF8String] = { ast match { case RegexGroup(_, term, _) => getMultipleContainsLiterals(term) case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => { getMultipleContainsLiterals(ls) match { case Seq() => Seq.empty - case literals => RegexCharsToString(parts) +: literals + case literals => UTF8String.fromString(RegexCharsToString(parts)) +: literals } } - case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts)) + case RegexSequence(parts) if (isLiteralString(parts)) => + Seq(UTF8String.fromString(RegexCharsToString(parts))) case _ => Seq.empty } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 79db87f1736..276909dfece 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1202,7 +1202,9 @@ class GpuRLikeMeta( } case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType)) case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType)) - case MultipleContains(ls) => GpuMultipleContains(lhs, ls) + case MultipleContains(ls) => { + GpuContainsAny(lhs, ls) + } case PrefixRange(s, length, start, end) => GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end) case _ => throw new IllegalStateException("Unexpected optimization type") @@ -1233,7 +1235,7 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String) override def dataType: DataType = BooleanType } -case class GpuMultipleContains(input: Expression, searchList: Seq[String]) +case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerantShim { override def dataType: DataType = BooleanType @@ -1243,17 +1245,15 @@ case class GpuMultipleContains(input: Expression, searchList: Seq[String]) override def inputTypes: Seq[AbstractDataType] = Seq(StringType) override def doColumnar(input: GpuColumnVector): ColumnVector = { - assert(searchList.length > 1) - val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar => - input.getBase.stringContains(searchScalar) + val targetsBytes = targets.map(t => t.getBytes).toArray + val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => + input.getBase.stringContains(targetsCv) } - searchList.tail.foldLeft(accInit) { (acc, search) => - val containsSearch = withResource(Scalar.fromString(search)) { searchScalar => - input.getBase.stringContains(searchScalar) - } - withResource(acc) { _ => - withResource(containsSearch) { _ => - acc.or(containsSearch) + // boolCvs is a sequence of ColumnVectors, we need to OR them together + boolCvs.reduce { + (cv1, cv2) => withResource(cv1) { cv1 => + withResource(cv2) { cv2 => + cv1.or(cv2) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala index a55815b95ef..12e12fd957f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala @@ -17,10 +17,12 @@ package com.nvidia.spark.rapids import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.unsafe.types.UTF8String + class RegularExpressionRewriteSuite extends AnyFunSuite { - private def verifyRewritePattern(patterns: Seq[String], excepted: Seq[RegexOptimizationType]): - Unit = { + private def verifyRewritePattern(patterns: Seq[String], + excepted: Seq[RegexOptimizationType]): Unit = { val results = patterns.map { pattern => val ast = new RegexParser(pattern).parse() RegexRewrite.matchSimplePattern(ast) @@ -87,11 +89,11 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { "(火花|急流)" ) val excepted = Seq( - MultipleContains(Seq("abc", "def")), - MultipleContains(Seq("abc", "def", "ghi")), - MultipleContains(Seq("abc", "def")), - MultipleContains(Seq("abc", "def")), - MultipleContains(Seq("火花", "急流")) + MultipleContains(Seq("abc", "def").map(UTF8String.fromString)), + MultipleContains(Seq("abc", "def", "ghi").map(UTF8String.fromString)), + MultipleContains(Seq("abc", "def").map(UTF8String.fromString)), + MultipleContains(Seq("abc", "def").map(UTF8String.fromString)), + MultipleContains(Seq("火花", "急流").map(UTF8String.fromString)) ) verifyRewritePattern(patterns, excepted) } From 14900ddfc2f525d9ee7faa330cd10adb840a7476 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Dec 2024 16:13:12 +0800 Subject: [PATCH 2/8] memory leak Signed-off-by: Haoyang Li --- .../org/apache/spark/sql/rapids/stringFunctions.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 276909dfece..ecf27dde992 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1249,13 +1249,11 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => input.getBase.stringContains(targetsCv) } - // boolCvs is a sequence of ColumnVectors, we need to OR them together - boolCvs.reduce { - (cv1, cv2) => withResource(cv1) { cv1 => - withResource(cv2) { cv2 => - cv1.or(cv2) - } + withResource(boolCvs) { _ => + val falseCv = withResource(Scalar.fromBool(false)) { falseScalar => + ColumnVector.fromScalar(falseScalar, input.getRowCount.toInt) } + boolCvs.foldLeft(falseCv)((l, r) => withResource(l) { _ => l.or(r)}) } } } From ee65a32ce36abc77c0630609e4b264d3c2be108b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Dec 2024 16:19:31 +0800 Subject: [PATCH 3/8] address comment Signed-off-by: Haoyang Li --- .../scala/org/apache/spark/sql/rapids/stringFunctions.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index ecf27dde992..17372e4a568 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1202,9 +1202,7 @@ class GpuRLikeMeta( } case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType)) case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType)) - case MultipleContains(ls) => { - GpuContainsAny(lhs, ls) - } + case MultipleContains(ls) => GpuContainsAny(lhs, ls) case PrefixRange(s, length, start, end) => GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end) case _ => throw new IllegalStateException("Unexpected optimization type") From 24e75a22ade1f37363bb5211db3f02bcf6cff315 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 6 Dec 2024 13:23:52 +0800 Subject: [PATCH 4/8] save a temp columnvector Signed-off-by: Haoyang Li --- .../org/apache/spark/sql/rapids/stringFunctions.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 17372e4a568..10960cbe2ad 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1247,11 +1247,9 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => input.getBase.stringContains(targetsCv) } - withResource(boolCvs) { _ => - val falseCv = withResource(Scalar.fromBool(false)) { falseScalar => - ColumnVector.fromScalar(falseScalar, input.getRowCount.toInt) - } - boolCvs.foldLeft(falseCv)((l, r) => withResource(l) { _ => l.or(r)}) + withResource(boolCvs.tail) { _ => + // boolCvs.head and intermediate values are closed within the withResource in the lambda + boolCvs.tail.foldLeft(boolCvs.head)((l, r) => withResource(l) { _ => l.or(r)}) } } } From ca1ba0cb5a38b17f6eace79be4d965167fcd44a8 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Dec 2024 15:22:39 +0800 Subject: [PATCH 5/8] foreach Signed-off-by: Haoyang Li --- .../spark/sql/rapids/stringFunctions.scala | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 10960cbe2ad..b026aa2ae9d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1247,10 +1247,26 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => input.getBase.stringContains(targetsCv) } - withResource(boolCvs.tail) { _ => - // boolCvs.head and intermediate values are closed within the withResource in the lambda - boolCvs.tail.foldLeft(boolCvs.head)((l, r) => withResource(l) { _ => l.or(r)}) + // withResource(boolCvs.tail) { _ => + // // boolCvs.head and intermediate values are closed within the withResource in the lambda + // boolCvs.tail.foldLeft(boolCvs.head)((l, r) => withResource(l) { _ => l.or(r)}) + // } + var ret: ColumnVector = null + closeOnExcept(boolCvs) { _ => + boolCvs.indices.foreach { i => + if (ret == null) { + ret = boolCvs(i) + boolCvs(i) = null + } else { + val tmp = ret.or(boolCvs(i)) + ret.close() + ret = tmp + boolCvs(i).close() + boolCvs(i) = null + } + } } + ret } } From c84e05e5a6ea5d82886e09f689ebfe3937e4edb5 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Dec 2024 15:46:41 +0800 Subject: [PATCH 6/8] foldLeft again Signed-off-by: Haoyang Li --- .../spark/sql/rapids/stringFunctions.scala | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index b026aa2ae9d..b27d30506f8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1247,26 +1247,15 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => input.getBase.stringContains(targetsCv) } - // withResource(boolCvs.tail) { _ => - // // boolCvs.head and intermediate values are closed within the withResource in the lambda - // boolCvs.tail.foldLeft(boolCvs.head)((l, r) => withResource(l) { _ => l.or(r)}) - // } - var ret: ColumnVector = null - closeOnExcept(boolCvs) { _ => - boolCvs.indices.foreach { i => - if (ret == null) { - ret = boolCvs(i) - boolCvs(i) = null - } else { - val tmp = ret.or(boolCvs(i)) - ret.close() - ret = tmp - boolCvs(i).close() - boolCvs(i) = null + closeOnExcept(boolCvs.tail) { _ => + boolCvs.tail.foldLeft(boolCvs.head) { + (l, r) => withResource(l) { _ => + withResource(r) { _ => + l.or(r) + } } } } - ret } } From 4f33192b366cf314a1b672c54376ff51845c3751 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Dec 2024 23:28:19 +0800 Subject: [PATCH 7/8] withResource Signed-off-by: Haoyang Li --- .../spark/sql/rapids/stringFunctions.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index b27d30506f8..7db9ade5746 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1247,15 +1247,22 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => input.getBase.stringContains(targetsCv) } - closeOnExcept(boolCvs.tail) { _ => - boolCvs.tail.foldLeft(boolCvs.head) { - (l, r) => withResource(l) { _ => - withResource(r) { _ => - l.or(r) - } + var ret: ColumnVector = null + withResource(boolCvs) { _ => + boolCvs.indices.foreach { i => + if (ret == null) { + ret = boolCvs(i) + boolCvs(i) = null + } else { + val tmp = ret.or(boolCvs(i)) + ret.close() + ret = tmp + boolCvs(i).close() + boolCvs(i) = null } } } + ret } } From 613cb1ed7a41b8a7affc499956944ff0a9abd8a3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Dec 2024 17:52:37 +0800 Subject: [PATCH 8/8] Use AST Signed-off-by: Haoyang Li --- .../spark/sql/rapids/stringFunctions.scala | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 7db9ade5746..f668195abc7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -24,7 +24,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar} +import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -1242,27 +1242,26 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + def multiOrsAst: ast.AstExpression = { + (1 until targets.length) + .foldLeft(new ast.ColumnReference(0).asInstanceOf[ast.AstExpression]) { (acc, id) => + new ast.BinaryOperation(ast.BinaryOperator.NULL_LOGICAL_OR, acc, new ast.ColumnReference(id)) + } + } + override def doColumnar(input: GpuColumnVector): ColumnVector = { val targetsBytes = targets.map(t => t.getBytes).toArray val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => input.getBase.stringContains(targetsCv) } - var ret: ColumnVector = null - withResource(boolCvs) { _ => - boolCvs.indices.foreach { i => - if (ret == null) { - ret = boolCvs(i) - boolCvs(i) = null - } else { - val tmp = ret.or(boolCvs(i)) - ret.close() - ret = tmp - boolCvs(i).close() - boolCvs(i) = null - } + val boolTable = withResource(boolCvs) { _ => + new Table(boolCvs: _*) + } + withResource(boolTable) { _ => + withResource(multiOrsAst.compile()) { compiledAst => + compiledAst.computeColumn(boolTable) } } - ret } }