Skip to content

Commit 20af57c

Browse files
TongWei1105wangyum
authored andcommitted
[SPARK-54147][SQL] Set OMP_NUM_THREADS to spark.task.cpus by default in BaseScriptTransformationExec
### What changes were proposed in this pull request? Set OMP_NUM_THREADS to spark.task.cpus by default in BaseScriptTransformationExec ### Why are the changes needed? When we use the TRANSFORM function to invoke a Python script,the Python script uses packages such as PyTorch or NumPy. Since these libraries, by default, start a number of intra-op threads equal to the number of available CPU cores on the node, this can lead to CPU overload. ``` ADD ARCHIVE s3://example-bucket/udf/emotion/emotion_predict.zip; ADD ARCHIVE s3://example-bucket/udf/emotion/python_env.zip; INSERT OVERWRITE TABLE demo_db.text_emotion_result PARTITION (dt = 'XXX') SELECT TRANSFORM( id, title, content ) USING './python_env.zip/python_env/bin/python emotion_predict.zip/emotion_predict/predict.py' AS (id, title, content, emotion_label, emotion_score) FROM ( SELECT /*+ REPARTITION(1000) */ id, title, content FROM demo_db.text_input_data WHERE dt = 'XXX' ) src; ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually. Closes #52850 from TongWei1105/SPARK-54147. Authored-by: TongWei1105 <[email protected]> Signed-off-by: Yuming Wang <[email protected]>
1 parent 3f41adc commit 20af57c

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
8484
val path = System.getenv("PATH") + File.pathSeparator +
8585
SparkFiles.getRootDirectory()
8686
builder.environment().put("PATH", path)
87+
// if OMP_NUM_THREADS is not explicitly set, override it with the value of "spark.task.cpus"
88+
if (System.getenv("OMP_NUM_THREADS") == null) {
89+
builder.environment().put("OMP_NUM_THREADS", conf.getConfString("spark.task.cpus", "1"))
90+
}
8791

8892
val proc = builder.start()
8993
val inputStream = proc.getInputStream

0 commit comments

Comments
 (0)