Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use faster multi-contains in rlike regex rewrite #11810

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.collection.mutable.ListBuffer
import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars
import com.nvidia.spark.rapids.RegexParser.toReadableString

import org.apache.spark.unsafe.types.UTF8String

/**
* Regular expression parser based on a Pratt Parser design.
*
Expand Down Expand Up @@ -1988,7 +1990,7 @@ object RegexOptimizationType {
case class Contains(literal: String) extends RegexOptimizationType
case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int)
extends RegexOptimizationType
case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType
case class MultipleContains(literals: Seq[UTF8String]) extends RegexOptimizationType
case object NoOptimization extends RegexOptimizationType
}

Expand Down Expand Up @@ -2057,16 +2059,17 @@ object RegexRewrite {
}
}

private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = {
private def getMultipleContainsLiterals(ast: RegexAST): Seq[UTF8String] = {
ast match {
case RegexGroup(_, term, _) => getMultipleContainsLiterals(term)
case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => {
getMultipleContainsLiterals(ls) match {
case Seq() => Seq.empty
case literals => RegexCharsToString(parts) +: literals
case literals => UTF8String.fromString(RegexCharsToString(parts)) +: literals
}
}
case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts))
case RegexSequence(parts) if (isLiteralString(parts)) =>
Seq(UTF8String.fromString(RegexCharsToString(parts)))
case _ => Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ class GpuRLikeMeta(
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType))
case MultipleContains(ls) => GpuMultipleContains(lhs, ls)
case MultipleContains(ls) => GpuContainsAny(lhs, ls)
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
case _ => throw new IllegalStateException("Unexpected optimization type")
Expand Down Expand Up @@ -1233,7 +1233,7 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String)
override def dataType: DataType = BooleanType
}

case class GpuMultipleContains(input: Expression, searchList: Seq[String])
case class GpuContainsAny(input: Expression, targets: Seq[UTF8String])
extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerantShim {

override def dataType: DataType = BooleanType
Expand All @@ -1243,17 +1243,16 @@ case class GpuMultipleContains(input: Expression, searchList: Seq[String])
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def doColumnar(input: GpuColumnVector): ColumnVector = {
assert(searchList.length > 1)
val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar =>
input.getBase.stringContains(searchScalar)
val targetsBytes = targets.map(t => t.getBytes).toArray
val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
input.getBase.stringContains(targetsCv)
}
searchList.tail.foldLeft(accInit) { (acc, search) =>
val containsSearch = withResource(Scalar.fromString(search)) { searchScalar =>
input.getBase.stringContains(searchScalar)
}
withResource(acc) { _ =>
withResource(containsSearch) { _ =>
acc.or(containsSearch)
closeOnExcept(boolCvs.tail) { _ =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not work. We will get double frees if there is an exception. boolCvs.tail is not updated when the or happens. That is why I wrote the code that I did.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a complex safeMap so that we can have it be safe and at the same time make the resulting code simple.

/**
* safeMap: safeMap implementation that is leveraged by other type-specific implicits.
*
* safeMap has the added safety net that as you produce AutoCloseable values they are
* tracked, and if an exception were to occur within the maps's body, it will make every
* attempt to close each produced value.
*
* Note: safeMap will close in case of errors, without any knowledge of whether it should
* or not.
* Use safeMap only in these circumstances if `fn` increases the reference count,
* producing an AutoCloseable, and nothing else is tracking these references:
* a) seq.safeMap(x => {...; x.incRefCount; x})
* b) seq.safeMap(x => GpuColumnVector.from(...))
*
* Usage of safeMap chained with other maps is a bit confusing:
*
* seq.map(GpuColumnVector.from).safeMap(couldThrow)
*
* Will close the column vectors produced from couldThrow up until the time where safeMap
* throws.
*
* The correct pattern of usage in cases like this is:
*
* val closeTheseLater = seq.safeMap(GpuColumnVector.from)
* closeTheseLater.safeMap{ x =>
* var success = false
* try {
* val res = couldThrow(x.incRefCount())
* success = true
* res // return a ref count of 2
* } finally {
* if (!success) {
* // in case of an error, we close x as part of normal error handling
* // the exception will be caught by the safeMap, and it will close all
* // AutoCloseables produced before x
* // - Sequence looks like: [2, 2, 2, ..., 2] + x, which has also has a refcount of 2
* x.close() // x now has a ref count of 1, the rest of the sequence has 2s
* }
* }
* } // safeMap cleaned, and now everything has 1s for ref counts (as they were before)
*
* closeTheseLater.safeClose() // go from 1 to 0 in all things inside closeTheseLater
*
* @param in the Seq[A] to map on
* @param fn a function that takes A, and produces B (a subclass of AutoCloseable)
* @tparam A the type of the elements in Seq
* @tparam B the type of the elements produced in the safeMap (should be subclasses of
* AutoCloseable)
* @tparam Repr the type of the input collection (needed by builder)
* @tparam That the type of the output collection (needed by builder)
* @return a sequence of B, in the success case
*/
protected def safeMap[B <: AutoCloseable, That](
in: collection.SeqLike[A, Repr],
fn: A => B)
(implicit bf: CanBuildFrom[Repr, B, That]): That = {
def builder: mutable.Builder[B, That] = {
val b = bf(in.asInstanceOf[Repr])
b.sizeHint(in)
b
}
val b = builder
for (x <- in) {
var success = false
try {
b += fn(x)
success = true
} finally {
if (!success) {
val res = b.result() // can be a SeqLike or an Array
res match {
// B is erased at this point, even if ClassTag is used
// @ unchecked suppresses a warning that the type of B
// was eliminated due to erasure. That said B is AutoCloseble
// and SeqLike[AutoCloseable, _] is defined
case b: collection.SeqLike[B @ unchecked, _] => b.safeClose()
case a: Array[AutoCloseable] => a.safeClose()
}
}
}
}
b.result()
}
}

Could we just add a safeReduce into implicits.scala instead of trying so hard to make this small piece of code simple?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind it is just a helper function safeReduceAndClose. We don't need to jump through hoops to make the code simple if inherently what we want to do is not simple.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not work. We will get double frees if there is an exception. boolCvs.tail is not updated when the or happens. That is why I wrote the code that I did.

Yes you are right, I missed that point and thought an outer withResource would lead to double closes. Update back to withResource for now.

boolCvs.tail.foldLeft(boolCvs.head) {
(l, r) => withResource(l) { _ =>
withResource(r) { _ =>
l.or(r)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package com.nvidia.spark.rapids

import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.unsafe.types.UTF8String

class RegularExpressionRewriteSuite extends AnyFunSuite {

private def verifyRewritePattern(patterns: Seq[String], excepted: Seq[RegexOptimizationType]):
Unit = {
private def verifyRewritePattern(patterns: Seq[String],
excepted: Seq[RegexOptimizationType]): Unit = {
val results = patterns.map { pattern =>
val ast = new RegexParser(pattern).parse()
RegexRewrite.matchSimplePattern(ast)
Expand Down Expand Up @@ -87,11 +89,11 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
"(火花|急流)"
)
val excepted = Seq(
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def", "ghi")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("火花", "急流"))
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def", "ghi").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("火花", "急流").map(UTF8String.fromString))
)
verifyRewritePattern(patterns, excepted)
}
Expand Down
Loading