Skip to content

Commit

Permalink
Use AST
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven committed Dec 16, 2024
1 parent 4f33192 commit 613cb1e
Showing 1 changed file with 14 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar}
import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -1242,27 +1242,26 @@ case class GpuContainsAny(input: Expression, targets: Seq[UTF8String])

override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

def multiOrsAst: ast.AstExpression = {
(1 until targets.length)
.foldLeft(new ast.ColumnReference(0).asInstanceOf[ast.AstExpression]) { (acc, id) =>
new ast.BinaryOperation(ast.BinaryOperator.NULL_LOGICAL_OR, acc, new ast.ColumnReference(id))
}
}

override def doColumnar(input: GpuColumnVector): ColumnVector = {
val targetsBytes = targets.map(t => t.getBytes).toArray
val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
input.getBase.stringContains(targetsCv)
}
var ret: ColumnVector = null
withResource(boolCvs) { _ =>
boolCvs.indices.foreach { i =>
if (ret == null) {
ret = boolCvs(i)
boolCvs(i) = null
} else {
val tmp = ret.or(boolCvs(i))
ret.close()
ret = tmp
boolCvs(i).close()
boolCvs(i) = null
}
val boolTable = withResource(boolCvs) { _ =>
new Table(boolCvs: _*)
}
withResource(boolTable) { _ =>
withResource(multiOrsAst.compile()) { compiledAst =>
compiledAst.computeColumn(boolTable)
}
}
ret
}
}

Expand Down

0 comments on commit 613cb1e

Please sign in to comment.