Skip to content

Commit

Permalink
refine QueryPlan, RapidsMeta and test suites for HybridScan
Browse files Browse the repository at this point in the history
  • Loading branch information
sperlingxx committed Nov 20, 2024
1 parent 4d59969 commit 74b6075
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 161 deletions.
12 changes: 7 additions & 5 deletions integration_tests/run_pyspark_from_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,18 @@ EOF
fi
export PYSP_TEST_spark_rapids_memory_gpu_allocSize=${PYSP_TEST_spark_rapids_memory_gpu_allocSize:-'1536m'}

if [[ "$VELOX_TEST" -eq 1 ]]; then
if [ -z "${VELOX_JARS}" ]; then
echo "Error: Environment VELOX_JARS is not set."
# Turns on $LOAD_HYBRID_BACKEND and setup the filepath of hybrid backend jars, to activate the
# hybrid backend while running subsequent integration tests.
if [[ "$LOAD_HYBRID_BACKEND" -eq 1 ]]; then
if [ -z "${HYBRID_BACKEND_JARS}" ]; then
echo "Error: Environment HYBRID_BACKEND_JARS is not set."
exit 1
fi
export PYSP_TEST_spark_jars="${PYSP_TEST_spark_jars},${VELOX_JARS//:/,}"
export PYSP_TEST_spark_jars="${PYSP_TEST_spark_jars},${HYBRID_BACKEND_JARS//:/,}"
export PYSP_TEST_spark_memory_offHeap_enabled=true
export PYSP_TEST_spark_memory_offHeap_size=512M
export PYSP_TEST_spark_rapids_sql_hybrid_load=true
export PYSP_TEST_spark_gluten_loadLibFromJar=true
export PYSP_TEST_spark_rapids_sql_loadVelox=true
fi

SPARK_SHELL_SMOKE_TEST="${SPARK_SHELL_SMOKE_TEST:-0}"
Expand Down
1 change: 1 addition & 0 deletions integration_tests/src/main/python/marks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
pyarrow_test = pytest.mark.pyarrow_test
datagen_overrides = pytest.mark.datagen_overrides
tz_sensitive_test = pytest.mark.tz_sensitive_test
hybrid_test = pytest.mark.hybrid_test
79 changes: 79 additions & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,3 +1588,82 @@ def setup_table(spark):
with_cpu_session(lambda spark: setup_table(spark))
assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.parquet(data_path).select("p"),
conf={"spark.rapids.sql.columnSizeBytes": "100"})

"""
VeloxScan:
1. DecimalType is NOT fully supported
2. TimestampType can NOT be the KeyType of MapType
3. NestedMap is disabled because it may produce incorrect result (usually occurring when table is very small)
"""
velox_gens = [
[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)),
ArrayGen(byte_gen),
ArrayGen(long_gen), ArrayGen(string_gen), ArrayGen(date_gen),
ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),
ArrayGen(ArrayGen(byte_gen)),
StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen],
['child3', decimal_gen_64bit]]),
ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))
],
[MapGen(f(nullable=False), f()) for f in [
BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DateGen]
],
[simple_string_to_string_map_gen,
MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen),
max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
# TODO: It seems that Velox Parquet Scan can NOT handle nested Map correctly
# MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)
],
]

@pytest.mark.skipif(not is_hybrid_backend_loaded(), reason="HybridScan specialized tests")
@pytest.mark.parametrize('parquet_gens', velox_gens, ids=idfn)
@pytest.mark.parametrize('gen_rows', [20, 100, 512, 1024, 4096], ids=idfn)
@hybrid_test
def test_parquet_read_round_trip_hybrid(spark_tmp_path, parquet_gens, gen_rows):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark: gen_df(spark, gen_list, length=gen_rows).write.parquet(data_path),
conf=rebase_write_corrected_conf)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path),
conf={
'spark.sql.sources.useV1SourceList': 'parquet',
'spark.rapids.sql.parquet.useHybridReader': 'true',
})

# Creating scenarios in which CoalesceConverter will coalesce several input batches by adjusting
# reader_batch_size and coalesced_batch_size, tests if the CoalesceConverter functions correctly
# when coalescing is needed.
@pytest.mark.skipif(not is_hybrid_backend_loaded(), reason="HybridScan specialized tests")
@pytest.mark.parametrize('reader_batch_size', [512, 1024, 2048], ids=idfn)
@pytest.mark.parametrize('coalesced_batch_size', [1 << 25, 1 << 27], ids=idfn)
@pytest.mark.parametrize('gen_rows', [8192, 10000], ids=idfn)
@hybrid_test
def test_parquet_read_round_trip_hybrid_multiple_batches(spark_tmp_path,
reader_batch_size,
coalesced_batch_size,
gen_rows):
gens = []
for g in velox_gens:
gens.extend(g)

gen_list = [('_c' + str(i), gen) for i, gen in enumerate(gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark: gen_df(spark, gen_list, length=gen_rows).write.parquet(data_path),
conf=rebase_write_corrected_conf)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path),
conf={
'spark.sql.sources.useV1SourceList': 'parquet',
'spark.rapids.sql.parquet.useHybridReader': 'true',
'spark.gluten.sql.columnar.maxBatchSize': reader_batch_size,
'spark.rapids.sql.batchSizeBytes': coalesced_batch_size,
})
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,6 @@ def is_hive_available():
if is_at_least_precommit_run():
return True
return _spark.conf.get("spark.sql.catalogImplementation") == "hive"

def is_hybrid_backend_loaded():
return _spark.conf.get("spark.rapids.sql.hybrid.load") == "true"
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.filecache.FileCache
import com.nvidia.spark.rapids.jni.{DateTimeRebase, ParquetFooter}
import com.nvidia.spark.rapids.shims.{ColumnDefaultValuesShims, GpuParquetCrypto, GpuTypeShims, ParquetLegacyNanoAsLongShims, ParquetSchemaClipShims, ParquetStringPredShims, ShimFilePartitionReaderFactory, SparkShimImpl}
import com.nvidia.spark.rapids.velox.VeloxBackendApis
import org.apache.commons.io.IOUtils
import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream}
import org.apache.hadoop.conf.Configuration
Expand Down
22 changes: 10 additions & 12 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.nvidia.spark.rapids.lore.{LoreId, OutputLoreId}
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.rapids.hybrid.HybridBackend
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.RapidsPrivateUtil
Expand Down Expand Up @@ -1684,21 +1685,18 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.booleanConf
.createWithDefault(false)

val PARQUET_VELOX_READER = conf("spark.rapids.sql.parquet.useVelox")
.doc("Use velox to do ParquetScan (on CPUs)")
val HYBRID_PARQUET_READER = conf("spark.rapids.sql.parquet.useHybridReader")
.doc("Use HybridScan to read Parquet data via CPUs")
.internal()
.booleanConf
.createWithDefault(true)

object VeloxFilterPushdownType extends Enumeration {
val ALL_SUPPORTED, NONE, UNCHANGED = Value
}
.createWithDefault(false)

val LOAD_VELOX = conf("spark.rapids.sql.loadVelox")
.doc("Load Velox (through Gluten) as a spark driver plugin")
// spark.rapids.sql.hybrid.loadBackend defined at HybridPluginWrapper of spark-rapids-private
val LOAD_HYBRID_BACKEND = conf(HybridBackend.LOAD_BACKEND_KEY)
.doc("Load hybrid backend as an extra plugin of spark-rapids during launch time")
.startupOnly()
.booleanConf
.createWithDefault(true)
.createWithDefault(false)

val HASH_AGG_REPLACE_MODE = conf("spark.rapids.sql.hashAgg.replaceMode")
.doc("Only when hash aggregate exec has these modes (\"all\" by default): " +
Expand Down Expand Up @@ -2832,9 +2830,9 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val avroDebugDumpAlways: Boolean = get(AVRO_DEBUG_DUMP_ALWAYS)

lazy val parquetVeloxReader: Boolean = get(PARQUET_VELOX_READER)
lazy val useHybridParquetReader: Boolean = get(HYBRID_PARQUET_READER)

lazy val loadVelox: Boolean = get(LOAD_VELOX)
lazy val loadHybridBackend: Boolean = get(LOAD_HYBRID_BACKEND)

lazy val hashAggReplaceMode: String = get(HASH_AGG_REPLACE_MODE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import com.nvidia.spark.rapids.{AcquireFailed, GpuColumnVector, GpuMetric, GpuSe
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.hybrid.{CoalesceBatchConverter => NativeConverter}
import com.nvidia.spark.rapids.hybrid.HybridJniWrapper
import com.nvidia.spark.rapids.hybrid.RapidsHostColumn

import org.apache.spark.TaskContext
Expand All @@ -43,8 +42,6 @@ class CoalesceConvertIterator(veloxIter: Iterator[ColumnarBatch],
private val converterMetrics = Map(
"C2COutputSize" -> GpuMetric.unwrap(metrics("C2COutputSize")))

@transient private lazy val runtime: HybridJniWrapper = HybridJniWrapper.getOrCreate()

override def hasNext(): Boolean = {
// either converter holds data or upstreaming iterator holds data
val ret = withResource(new NvtxWithMetrics("VeloxC2CHasNext", NvtxColor.WHITE,
Expand Down Expand Up @@ -101,7 +98,6 @@ class CoalesceConvertIterator(veloxIter: Iterator[ColumnarBatch],
}
}
if (converterImpl.isEmpty) {
require(runtime != null, "Please setRuntime before fetching the iterator")
val converter = NativeConverter(
veloxIter.next(),
targetBatchSizeInBytes, schema, converterMetrics
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rapids.hybrid

import com.nvidia.spark.rapids.hybrid.HybridPluginWrapper

object HybridBackend {
val LOAD_BACKEND_KEY: String = HybridPluginWrapper.LOAD_BACKEND_KEY
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,22 @@ import org.apache.spark.sql.rapids.GpuDataSourceScanExec
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* The SparkPlan for HybridParquetScan which does the same job as GpuFileSourceScanExec but in a
* different approach. Therefore, this class is similar to GpuFileSourceScanExec in a lot of way.
*/
case class HybridFileSourceScanExec(originPlan: FileSourceScanExec
)(@transient val rapidsConf: RapidsConf)
extends GpuDataSourceScanExec with GpuExec {
import GpuMetric._

require(originPlan.relation.fileFormat.isInstanceOf[ParquetFileFormat],
require(originPlan.relation.fileFormat.getClass == classOf[ParquetFileFormat],
"HybridScan only supports ParquetFormat")

override def relation: BaseRelation = originPlan.relation

override def tableIdentifier: Option[TableIdentifier] = originPlan.tableIdentifier


// All expressions are on the CPU.
override def gpuExpressions: Seq[Expression] = Nil

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.rapids.hybrid.HybridFileSourceScanExec
import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.{BinaryType, MapType}

class HybridFileSourceScanExecMeta(plan: FileSourceScanExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends SparkPlanMeta[FileSourceScanExec](plan, conf, parent, rule) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with native counterpart
// if possible. Instead regarding filters as childExprs of current Meta, we create
// a new meta for SubqueryBroadcastExec. The reason is that the native replacement of
// FileSourceScan is independent from the replacement of the partitionFilters.
private lazy val partitionFilters = {
val convertBroadcast = (bc: SubqueryBroadcastExec) => {
val meta = GpuOverrides.wrapAndTagPlan(bc, conf)
meta.tagForExplain()
meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec]
}
wrapped.partitionFilters.map { filter =>
filter.transformDown {
case dpe@DynamicPruningExpression(inSub: InSubqueryExec) =>
inSub.plan match {
case bc: SubqueryBroadcastExec =>
dpe.copy(inSub.copy(plan = convertBroadcast(bc)))
case reuse@ReusedSubqueryExec(bc: SubqueryBroadcastExec) =>
dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc))))
case _ =>
dpe
}
}
}
}

// partition filters and data filters are not run on the GPU
override val childExprs: Seq[ExprMeta[_]] = Seq.empty

override def tagPlanForGpu(): Unit = {
val cls = wrapped.relation.fileFormat.getClass
if (cls != classOf[ParquetFileFormat]) {
willNotWorkOnGpu(s"unsupported file format: ${cls.getCanonicalName}")
}
}

override def convertToGpu(): GpuExec = {
// Modifies the original plan to support DPP
val fixedExec = wrapped.copy(partitionFilters = partitionFilters)
HybridFileSourceScanExec(fixedExec)(conf)
}
}

object HybridFileSourceScanExecMeta {

// Determines whether using HybridScan or GpuScan
def useHybridScan(conf: RapidsConf, fsse: FileSourceScanExec): Boolean = {
val isEnabled = if (conf.useHybridParquetReader) {
require(conf.loadHybridBackend,
"Hybrid backend was NOT loaded during the launch of spark-rapids plugin")
true
} else {
false
}
// For the time being, only support reading Parquet
lazy val isParquet = fsse.relation.fileFormat.getClass == classOf[ParquetFileFormat]
// Check if data types of all fields are supported by HybridParquetReader
lazy val allSupportedTypes = fsse.requiredSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, {
// For the time being, the native backend may return incorrect results over nestedMap
case MapType(kt, vt, _) if kt.isInstanceOf[MapType] || vt.isInstanceOf[MapType] => false
// For the time being, BinaryType is not supported yet
case _: BinaryType => false
case _ => true
})
}
// TODO: supports BucketedScan
lazy val noBucketedScan = !fsse.bucketedScan
// TODO: supports working along with Alluxio
lazy val noAlluxio = !AlluxioCfgUtils.enabledAlluxioReplacementAlgoConvertTime(conf)

isEnabled && isParquet && allSupportedTypes && noBucketedScan && noAlluxio
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ object ScanExecShims {
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.BINARY + TypeSig.DECIMAL_128).nested(),
TypeSig.all),
(fsse, conf, p, r) => new FileSourceScanExecMeta(fsse, conf, p, r)),
(fsse, conf, p, r) => {
// TODO: HybridScan supports DataSourceV2
// TODO: HybridScan only supports Spark 32X for now.
if (HybridFileSourceScanExecMeta.useHybridScan(conf, fsse)) {
new HybridFileSourceScanExecMeta(fsse, conf, p, r)
} else {
new FileSourceScanExecMeta(fsse, conf, p, r)
}
}),
GpuOverrides.exec[BatchScanExec](
"The backend for most file input",
ExecChecks(
Expand Down
Loading

0 comments on commit 74b6075

Please sign in to comment.