diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py b/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py index c746e26b3..ce76f9cfd 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py @@ -88,6 +88,9 @@ def create_saving_estimator(self, def create_local_submission_job(self, job_prop, ctxt) -> Any: return DBAzureLocalRapidsJob(prop_container=job_prop, exec_ctxt=ctxt) + def create_distributed_submission_job(self, job_prop, ctxt) -> Any: + pass + def validate_job_submission_args(self, submission_args: dict) -> dict: pass diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py index c5e01c5e5..6bf33eb4c 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py @@ -130,6 +130,9 @@ def create_saving_estimator(self, def create_local_submission_job(self, job_prop, ctxt) -> Any: return DataprocLocalRapidsJob(prop_container=job_prop, exec_ctxt=ctxt) + def create_distributed_submission_job(self, job_prop, ctxt) -> Any: + pass + def validate_job_submission_args(self, submission_args: dict) -> dict: pass diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc_gke.py b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc_gke.py index 4a36c8cae..364c9c8e9 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc_gke.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc_gke.py @@ -92,6 +92,9 @@ def create_saving_estimator(self, def create_local_submission_job(self, job_prop, ctxt) -> Any: return DataprocGkeLocalRapidsJob(prop_container=job_prop, exec_ctxt=ctxt) + def create_distributed_submission_job(self, job_prop, ctxt) -> Any: + pass + @dataclass class DataprocGkeCMDDriver(DataprocCMDDriver): diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/emr.py b/user_tools/src/spark_rapids_pytools/cloud_api/emr.py index 2ae6d713d..11849beed 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/emr.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/emr.py @@ -115,6 +115,9 @@ def create_saving_estimator(self, def create_local_submission_job(self, job_prop, ctxt) -> Any: return EmrLocalRapidsJob(prop_container=job_prop, exec_ctxt=ctxt) + def create_distributed_submission_job(self, job_prop, ctxt) -> Any: + pass + def generate_cluster_configuration(self, render_args: dict): image_version = self.configs.get_value_silent('clusterInference', 'defaultImage') render_args['IMAGE'] = f'"{image_version}"' diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/onprem.py b/user_tools/src/spark_rapids_pytools/cloud_api/onprem.py index cbbefc7c2..1ae80452d 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/onprem.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/onprem.py @@ -19,7 +19,6 @@ from typing import Any, List, Optional from spark_rapids_tools import CspEnv -from spark_rapids_pytools.rapids.rapids_job import RapidsLocalJob from spark_rapids_pytools.cloud_api.sp_types import PlatformBase, ClusterBase, ClusterNode, \ CMDDriverBase, ClusterGetAccessor, GpuDevice, \ GpuHWInfo, NodeHWInfo, SparkNodeType, SysInfo @@ -27,6 +26,7 @@ from spark_rapids_pytools.common.sys_storage import StorageDriver from spark_rapids_pytools.pricing.dataproc_pricing import DataprocPriceProvider from spark_rapids_pytools.pricing.price_provider import SavingsEstimator +from spark_rapids_pytools.rapids.rapids_job import RapidsLocalJob, RapidsDistributedJob @dataclass @@ -49,6 +49,9 @@ def _install_storage_driver(self): def create_local_submission_job(self, job_prop, ctxt) -> Any: return OnPremLocalRapidsJob(prop_container=job_prop, exec_ctxt=ctxt) + def create_distributed_submission_job(self, job_prop, ctxt) -> RapidsDistributedJob: + return OnPremDistributedRapidsJob(prop_container=job_prop, exec_ctxt=ctxt) + def _construct_cluster_from_props(self, cluster: str, props: str = None, is_inferred: bool = False, is_props_file: bool = False): return OnPremCluster(self, is_inferred=is_inferred).set_connection(cluster_id=cluster, props=props) @@ -154,6 +157,15 @@ class OnPremLocalRapidsJob(RapidsLocalJob): job_label = 'onpremLocal' +# pylint: disable=abstract-method +@dataclass +class OnPremDistributedRapidsJob(RapidsDistributedJob): + """ + Implementation of a RAPIDS job that runs on a distributed cluster + """ + job_label = 'onprem.distributed' + + @dataclass class OnPremNode(ClusterNode): """Implementation of Onprem cluster node.""" diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py b/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py index 949ac29d9..2c566f166 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py @@ -22,11 +22,11 @@ from logging import Logger from typing import Type, Any, List, Callable, Union, Optional, final, Dict -from spark_rapids_tools import EnumeratedType, CspEnv from spark_rapids_pytools.common.prop_manager import AbstractPropertiesContainer, JSONPropertiesContainer, \ get_elem_non_safe from spark_rapids_pytools.common.sys_storage import StorageDriver, FSUtil from spark_rapids_pytools.common.utilities import ToolLogging, SysCmd, Utils, TemplateGenerator +from spark_rapids_tools import EnumeratedType, CspEnv class DeployMode(EnumeratedType): @@ -884,6 +884,9 @@ def create_saving_estimator(self, def create_local_submission_job(self, job_prop, ctxt) -> Any: raise NotImplementedError + def create_distributed_submission_job(self, job_prop, ctxt) -> Any: + raise NotImplementedError + def load_platform_configs(self): config_file_name = f'{CspEnv.tostring(self.type_id).lower()}-configs.json' config_path = Utils.resource_path(config_file_name) diff --git a/user_tools/src/spark_rapids_pytools/rapids/qualification.py b/user_tools/src/spark_rapids_pytools/rapids/qualification.py index 921dcc9b5..3b8e97ed8 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/qualification.py +++ b/user_tools/src/spark_rapids_pytools/rapids/qualification.py @@ -29,7 +29,7 @@ from spark_rapids_pytools.common.sys_storage import FSUtil from spark_rapids_pytools.common.utilities import Utils, TemplateGenerator from spark_rapids_pytools.rapids.rapids_tool import RapidsJarTool -from spark_rapids_tools.enums import QualFilterApp, QualEstimationModel +from spark_rapids_tools.enums import QualFilterApp, QualEstimationModel, SubmissionMode from spark_rapids_tools.storagelib import CspFs from spark_rapids_tools.tools.additional_heuristics import AdditionalHeuristics from spark_rapids_tools.tools.cluster_config_recommender import ClusterConfigRecommender @@ -153,6 +153,17 @@ def _process_estimation_model_args(self) -> None: estimation_model_args = QualEstimationModel.create_default_model_args(selected_model) self.ctxt.set_ctxt('estimationModelArgs', estimation_model_args) + def _process_submission_mode_arg(self) -> None: + """ + Process the value provided by `--submission_mode` argument. + """ + submission_mode_arg = self.wrapper_options.get('submissionMode') + if submission_mode_arg is None or not submission_mode_arg: + submission_mode = SubmissionMode.get_default() + else: + submission_mode = SubmissionMode.fromstring(submission_mode_arg) + self.ctxt.set_ctxt('submissionMode', submission_mode) + def _process_custom_args(self) -> None: """ Qualification tool processes extra arguments: @@ -181,6 +192,7 @@ def _process_custom_args(self) -> None: self._process_estimation_model_args() self._process_offline_cluster_args() self._process_eventlogs_args() + self._process_submission_mode_arg() # This is noise to dump everything # self.logger.debug('%s custom arguments = %s', self.pretty_name(), self.ctxt.props['wrapperCtx']) @@ -375,7 +387,7 @@ def create_stdout_table_pprinter(total_apps: pd.DataFrame, df = self._read_qualification_output_file('summaryReport') # 1. Operations related to XGboost modelling - if self.ctxt.get_ctxt('estimationModelArgs')['xgboostEnabled']: + if not df.empty and self.ctxt.get_ctxt('estimationModelArgs')['xgboostEnabled']: try: df = self.__update_apps_with_prediction_info(df, self.ctxt.get_ctxt('estimationModelArgs')) diff --git a/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py b/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py index a77a94b00..c8cc43d1c 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py +++ b/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py @@ -17,12 +17,13 @@ import os from dataclasses import dataclass, field from logging import Logger -from typing import List, Optional +from typing import List, Optional, Union from spark_rapids_pytools.common.prop_manager import JSONPropertiesContainer from spark_rapids_pytools.common.utilities import ToolLogging, Utils from spark_rapids_pytools.rapids.tool_ctxt import ToolContext from spark_rapids_tools.storagelib import LocalPath +from spark_rapids_tools_distributed.jar_cmd_args import JarCmdArgs @dataclass @@ -38,6 +39,8 @@ def _init_fields(self): self.props['sparkConfArgs'] = {} if self.get_value_silent('platformArgs') is None: self.props['platformArgs'] = {} + if self.get_value_silent('distributedToolsConfigs') is None: + self.props['distributedToolsConfigs'] = {} def get_jar_file(self): return self.get_value('rapidsArgs', 'jarFile') @@ -48,6 +51,9 @@ def get_jar_main_class(self): def get_rapids_args(self): return self.get_value('rapidsArgs', 'jarArgs') + def get_distribution_tools_configs(self): + return self.get_value('distributedToolsConfigs') + @dataclass class RapidsJob: @@ -90,10 +96,10 @@ def _build_rapids_args(self): rapids_arguments.extend(extra_rapids_args) return rapids_arguments - def _build_submission_cmd(self) -> list: + def _build_submission_cmd(self) -> Union[list, JarCmdArgs]: raise NotImplementedError - def _submit_job(self, cmd_args: list) -> str: + def _submit_job(self, cmd_args: Union[list, JarCmdArgs]) -> str: raise NotImplementedError def _print_job_output(self, job_output: str): @@ -125,13 +131,6 @@ def run_job(self): self._cleanup_temp_log4j_files() return job_output - -@dataclass -class RapidsLocalJob(RapidsJob): - """ - Implementation of a RAPIDS job that runs local on a machine. - """ - def _get_hadoop_classpath(self) -> Optional[str]: """ Gets the Hadoop's configuration directory from the environment variables. @@ -202,6 +201,13 @@ def _build_jvm_args(self): vm_args.append(val) return vm_args + +@dataclass +class RapidsLocalJob(RapidsJob): + """ + Implementation of a RAPIDS job that runs local on a machine. + """ + def _build_submission_cmd(self) -> list: # env vars are added later as a separate dictionary classpath_arr = self._build_classpath() @@ -218,3 +224,32 @@ def _submit_job(self, cmd_args: list) -> str: out_std = self.exec_ctxt.platform.cli.run_sys_cmd(cmd=cmd_args, env_vars=env_args) return out_std + + +@dataclass +class RapidsDistributedJob(RapidsJob): + """ + Implementation of a RAPIDS job that runs distributed on a cluster. + """ + + def _build_submission_cmd(self) -> JarCmdArgs: + classpath_arr = self._build_classpath() + hadoop_cp = self._get_hadoop_classpath() + jvm_args_arr = self._build_jvm_args() + jar_main_class = self.prop_container.get_jar_main_class() + jar_output_dir_args = self._get_persistent_rapids_args() + extra_rapids_args = self.prop_container.get_rapids_args() + return JarCmdArgs(jvm_args_arr, classpath_arr, hadoop_cp, jar_main_class, + jar_output_dir_args, extra_rapids_args) + + def _build_classpath(self) -> List[str]: + """ + Only the Spark RAPIDS Tools JAR file is needed for the classpath. + Assumption: Each worker node should have the Spark Jars pre-installed. + TODO: Ship the Spark JARs to the cluster to avoid version mismatch issues. + """ + return ['-cp', self.prop_container.get_jar_file()] + + def _submit_job(self, cmd_args: JarCmdArgs) -> None: + # TODO: Support for submitting the Tools JAR to a Spark cluster + raise NotImplementedError('Distributed job submission is not yet supported') diff --git a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py index ee4338ccd..c8180a71b 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py +++ b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py @@ -39,6 +39,7 @@ from spark_rapids_pytools.rapids.tool_ctxt import ToolContext from spark_rapids_tools import CspEnv from spark_rapids_tools.configuration.common import RuntimeDependency +from spark_rapids_tools.configuration.submission.distributed_config import DistributedToolsConfig from spark_rapids_tools.configuration.tools_config import ToolsConfig from spark_rapids_tools.enums import DependencyType from spark_rapids_tools.storagelib import LocalPath, CspFs @@ -608,7 +609,7 @@ def populate_dependency_list() -> List[RuntimeDependency]: # check if the dependencies is defined in a config file config_obj = self.get_tools_config_obj() if config_obj is not None: - if config_obj.runtime.dependencies: + if config_obj.runtime and config_obj.runtime.dependencies: return config_obj.runtime.dependencies self.logger.info('The ToolsConfig did not specify the dependencies. ' 'Falling back to the default dependencies.') @@ -939,10 +940,33 @@ def _prepare_local_job_arguments(self): 'sparkConfArgs': spark_conf_args, 'platformArgs': platform_args } + # Set the configuration for the distributed tools + distributed_tools_configs = self._get_distributed_tools_configs() + if distributed_tools_configs: + job_properties_json['distributedToolsConfigs'] = distributed_tools_configs rapids_job_container = RapidsJobPropContainer(prop_arg=job_properties_json, file_load=False) self.ctxt.set_ctxt('rapidsJobContainers', [rapids_job_container]) + def _get_distributed_tools_configs(self) -> Optional[DistributedToolsConfig]: + """ + Parse the tools configuration and return as distributed tools configuration object + """ + config_obj = self.get_tools_config_obj() + if config_obj and config_obj.submission: + if self.ctxt.is_distributed_mode(): + return config_obj + self.logger.warning( + 'Distributed tool configurations detected, but distributed mode is not enabled.' + 'Use \'--submission_mode distributed\' flag to enable distributed mode. Switching to local mode.' + ) + elif self.ctxt.is_distributed_mode(): + self.logger.warning( + 'Distributed mode is enabled, but no distributed tool configurations were provided. ' + 'Using default settings.' + ) + return None + def _archive_results(self): self._archive_local_results() @@ -961,8 +985,12 @@ def _submit_jobs(self): executors_cnt = len(rapids_job_containers) if Utilities.conc_mode_enabled else 1 with ThreadPoolExecutor(max_workers=executors_cnt) as executor: for rapids_job in rapids_job_containers: - job_obj = self.ctxt.platform.create_local_submission_job(job_prop=rapids_job, - ctxt=self.ctxt) + if self.ctxt.is_distributed_mode(): + job_obj = self.ctxt.platform.create_distributed_submission_job(job_prop=rapids_job, + ctxt=self.ctxt) + else: + job_obj = self.ctxt.platform.create_local_submission_job(job_prop=rapids_job, + ctxt=self.ctxt) futures = executor.submit(job_obj.run_job) futures_list.append(futures) try: @@ -970,5 +998,5 @@ def _submit_jobs(self): result = future.result() results.append(result) except Exception as ex: # pylint: disable=broad-except - self.logger.error('Failed to download dependencies %s', ex) + self.logger.error('Failed to submit jobs %s', ex) raise ex diff --git a/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py b/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py index 7e850dcfa..714192271 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py +++ b/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py @@ -26,6 +26,7 @@ from spark_rapids_pytools.common.sys_storage import FSUtil from spark_rapids_pytools.common.utilities import ToolLogging, Utils from spark_rapids_tools import CspEnv, CspPath +from spark_rapids_tools.enums import SubmissionMode from spark_rapids_tools.utils import Utilities @@ -89,6 +90,12 @@ def get_deploy_mode(self) -> Any: def is_fat_wheel_mode(self) -> bool: return self.get_ctxt('fatWheelModeEnabled') + def is_distributed_mode(self) -> bool: + return self.get_ctxt('submissionMode') == SubmissionMode.DISTRIBUTED + + def is_local_mode(self) -> bool: + return self.get_ctxt('submissionMode') == SubmissionMode.LOCAL + def set_ctxt(self, key: str, val: Any): self.props['wrapperCtx'][key] = val diff --git a/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json b/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json index d0cde1298..9cddf96f8 100644 --- a/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json +++ b/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json @@ -28,8 +28,7 @@ "value": "a65839fbf1869f81a1632e09f415e586922e4f80" }, "size": 962685 - }, - "type": "jar" + } }, { "name": "AWS Java SDK Bundled", @@ -40,8 +39,7 @@ "value": "02deec3a0ad83d13d032b1812421b23d7a961eea" }, "size": 280645251 - }, - "type": "jar" + } } ], "333": [ diff --git a/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py b/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py index 978f683c1..2177dec82 100644 --- a/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py +++ b/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py @@ -29,8 +29,10 @@ from spark_rapids_pytools.common.utilities import ToolLogging, Utils from spark_rapids_tools.cloud import ClientCluster from spark_rapids_tools.utils import AbstractPropContainer, is_http_file +from ..configuration.submission.distributed_config import DistributedToolsConfig +from ..configuration.submission.local_config import LocalToolsConfig from ..configuration.tools_config import ToolsConfig -from ..enums import QualFilterApp, CspEnv, QualEstimationModel +from ..enums import QualFilterApp, CspEnv, QualEstimationModel, SubmissionMode from ..storagelib.csppath import CspPath from ..tools.autotuner import AutoTunerPropMgr from ..utils.util import dump_tool_usage, Utilities @@ -372,6 +374,9 @@ def process_jvm_args(self) -> None: self.p_args['toolArgs']['jobResources'] = adjusted_resources self.p_args['toolArgs']['log4jPath'] = Utils.resource_path('dev/log4j.properties') + def load_tools_config_internal(self) -> ToolsConfig: + return LocalToolsConfig.load_from_file(self.tools_config_path) + def load_tools_config(self) -> None: """ Load the tools config file if it is provided. It creates a ToolsConfig object and sets it @@ -382,7 +387,7 @@ def load_tools_config(self) -> None: if self.tools_config_path is not None: # the CLI provides a tools config file try: - self.p_args['toolArgs']['toolsConfig'] = ToolsConfig.load_from_file(self.tools_config_path) + self.p_args['toolArgs']['toolsConfig'] = self.load_tools_config_internal() except ValidationError as ve: # If required, we can dump the expected specification by appending # 'ToolsConfig.get_schema()' to the error message @@ -470,6 +475,7 @@ class QualifyUserArgModel(ToolUserArgModel): """ filter_apps: Optional[QualFilterApp] = None estimation_model_args: Optional[Dict] = dataclasses.field(default_factory=dict) + submission_mode: Optional[SubmissionMode] = None def init_tool_args(self) -> None: self.p_args['toolArgs']['platform'] = self.platform @@ -487,6 +493,7 @@ def init_tool_args(self) -> None: self.p_args['toolArgs']['estimationModelArgs'] = QualEstimationModel.create_default_model_args(def_model) else: self.p_args['toolArgs']['estimationModelArgs'] = self.estimation_model_args + self.submission_mode = self.submission_mode or SubmissionMode.get_default() @model_validator(mode='after') def validate_arg_cases(self) -> 'QualifyUserArgModel': @@ -497,6 +504,13 @@ def validate_arg_cases(self) -> 'QualifyUserArgModel': def is_concurrent_submission(self) -> bool: return self.p_args['toolArgs']['estimationModelArgs']['xgboostEnabled'] + def load_tools_config_internal(self) -> ToolsConfig: + # Override the method to load the tools config file based on the submission mode + config_class = ( + DistributedToolsConfig if self.submission_mode == SubmissionMode.DISTRIBUTED else LocalToolsConfig + ) + return config_class.load_from_file(self.tools_config_path) + def build_tools_args(self) -> dict: # At this point, if the platform is still none, then we can set it to the default value # which is the onPrem platform. @@ -532,7 +546,8 @@ def build_tools_args(self) -> dict: 'eventlogs': self.eventlogs, 'filterApps': QualFilterApp.fromstring(self.p_args['toolArgs']['filterApps']), 'toolsJar': self.p_args['toolArgs']['toolsJar'], - 'estimationModelArgs': self.p_args['toolArgs']['estimationModelArgs'] + 'estimationModelArgs': self.p_args['toolArgs']['estimationModelArgs'], + 'submissionMode': self.submission_mode } return wrapped_args diff --git a/user_tools/src/spark_rapids_tools/cmdli/tools_cli.py b/user_tools/src/spark_rapids_tools/cmdli/tools_cli.py index 0d46e5025..e575ee8c1 100644 --- a/user_tools/src/spark_rapids_tools/cmdli/tools_cli.py +++ b/user_tools/src/spark_rapids_tools/cmdli/tools_cli.py @@ -46,6 +46,7 @@ def qualification(self, jvm_threads: int = None, verbose: bool = None, tools_config_file: str = None, + submission_mode: str = None, **rapids_options) -> None: """The Qualification cmd provides estimated speedups by migrating Apache Spark applications to GPU accelerated clusters. @@ -83,6 +84,8 @@ def qualification(self, :param tools_config_file: Path to a configuration file that contains the tools' options. For sample configuration files, please visit https://github.com/NVIDIA/spark-rapids-tools/tree/main/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/valid + :param submission_mode: Submission mode to run the qualification tool. + Supported modes are "local" and "distributed". :param rapids_options: A list of valid Qualification tool options. Note that the wrapper ignores ["output-directory", "platform"] flags, and it does not support multiple "spark-property" arguments. @@ -95,6 +98,7 @@ def qualification(self, output_folder = Utils.get_value_or_pop(output_folder, rapids_options, 'o') filter_apps = Utils.get_value_or_pop(filter_apps, rapids_options, 'f') verbose = Utils.get_value_or_pop(verbose, rapids_options, 'v', False) + submission_mode = Utils.get_value_or_pop(submission_mode, rapids_options, 's') if verbose: ToolLogging.enable_debug_mode() init_environment('qual') @@ -117,7 +121,8 @@ def qualification(self, jvm_threads=jvm_threads, filter_apps=filter_apps, estimation_model_args=estimation_model_args, - tools_config_path=tools_config_file) + tools_config_path=tools_config_file, + submission_mode=submission_mode) if qual_args: tool_obj = QualificationAsLocal(platform_type=qual_args['runtimePlatform'], output_folder=qual_args['outputFolder'], diff --git a/user_tools/src/spark_rapids_tools/configuration/common.py b/user_tools/src/spark_rapids_tools/configuration/common.py index 439904cae..9df8a763e 100644 --- a/user_tools/src/spark_rapids_tools/configuration/common.py +++ b/user_tools/src/spark_rapids_tools/configuration/common.py @@ -22,7 +22,18 @@ from spark_rapids_tools.storagelib.tools.fs_utils import FileHashAlgorithm -class RuntimeDependencyType(BaseModel): +class BaseConfig(BaseModel, extra='forbid'): + """ + BaseConfig class for Pydantic models that enforces the `extra = forbid` + setting. This ensures that no extra keys are allowed in any model or + subclass that inherits from this base class. + + This base class is meant to be inherited by other Pydantic models related + to tools configurations so that we can enforce a global rule. + """ + + +class RuntimeDependencyType(BaseConfig): """Defines the type of runtime dependency required by the tools' java cmd.""" dep_type: DependencyType = Field( @@ -36,7 +47,7 @@ class RuntimeDependencyType(BaseModel): examples=['jars/*']) -class DependencyVerification(BaseModel): +class DependencyVerification(BaseConfig): """The verification information of a runtime dependency required by the tools' java cmd.""" size: int = Field( default=0, @@ -53,7 +64,7 @@ class DependencyVerification(BaseModel): }]) -class RuntimeDependency(BaseModel): +class RuntimeDependency(BaseConfig): """Holds information about a runtime dependency required by the tools' java cmd.""" name: str = Field( description='The name of the dependency.', @@ -72,3 +83,15 @@ class RuntimeDependency(BaseModel): verification: DependencyVerification = Field( default=None, description='Optional specification to verify the dependency file.') + + +class SparkProperty(BaseConfig): + """Represents a single Spark property with a name and value.""" + name: str = Field( + description='Name of the Spark property, e.g., "spark.executor.memory".') + value: str = Field( + description='Value of the Spark property, e.g., "4g".') + + +class SubmissionConfig(BaseConfig): + """Base class for the tools configuration.""" diff --git a/user_tools/src/spark_rapids_tools/configuration/runtime_conf.py b/user_tools/src/spark_rapids_tools/configuration/runtime_conf.py index 40ab68cf7..7878d369d 100644 --- a/user_tools/src/spark_rapids_tools/configuration/runtime_conf.py +++ b/user_tools/src/spark_rapids_tools/configuration/runtime_conf.py @@ -16,12 +16,12 @@ from typing import List -from pydantic import BaseModel, Field +from pydantic import Field -from spark_rapids_tools.configuration.common import RuntimeDependency +from spark_rapids_tools.configuration.common import RuntimeDependency, BaseConfig -class ToolsRuntimeConfig(BaseModel): +class ToolsRuntimeConfig(BaseConfig): """The runtime configurations of the tools as defined by the user.""" dependencies: List[RuntimeDependency] = Field( description='The list of runtime dependencies required by the tools java cmd. ' diff --git a/user_tools/src/spark_rapids_tools/configuration/submission/__init__.py b/user_tools/src/spark_rapids_tools/configuration/submission/__init__.py new file mode 100644 index 000000000..e8b752e95 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/configuration/submission/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/user_tools/src/spark_rapids_tools/configuration/submission/distributed_config.py b/user_tools/src/spark_rapids_tools/configuration/submission/distributed_config.py new file mode 100644 index 000000000..f50f3b9b5 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/configuration/submission/distributed_config.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Configuration file for distributed submission mode """ +from typing import List, Optional + +from pydantic import Field + +from spark_rapids_tools.configuration.common import SparkProperty, SubmissionConfig +from spark_rapids_tools.configuration.tools_config import ToolsConfig + + +class DistributedSubmissionConfig(SubmissionConfig): + """Configuration class for distributed submission mode""" + remote_cache_dir: str = Field( + description='Remote cache directory where the intermediate output data from each task will be stored. ' + 'Default is hdfs:///tmp/spark_rapids_distributed_tools_cache.', + default=['hdfs:///tmp/spark_rapids_distributed_tools_cache'] + ) + + spark_properties: List[SparkProperty] = Field( + default_factory=list, + description='List of Spark properties to be used for the Spark session.', + examples=[{'name': 'spark.executor.memory', 'value': '4g'}, + {'name': 'spark.executor.cores', 'value': '4'}] + ) + + +class DistributedToolsConfig(ToolsConfig): + """Container for the distributed submission mode configurations. This is the parts of the configuration + that can be passed as an input to the CLI""" + submission: Optional[DistributedSubmissionConfig] = Field( + default=None, + description='Configuration related to distributed submission mode.') diff --git a/user_tools/src/spark_rapids_tools/configuration/submission/local_config.py b/user_tools/src/spark_rapids_tools/configuration/submission/local_config.py new file mode 100644 index 000000000..944f6fb0f --- /dev/null +++ b/user_tools/src/spark_rapids_tools/configuration/submission/local_config.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Configuration file for local submission mode """ +from typing import Optional + +from pydantic import Field + +from spark_rapids_tools.configuration.common import SubmissionConfig +from spark_rapids_tools.configuration.tools_config import ToolsConfig + + +class LocalSubmissionConfig(SubmissionConfig): + """Configuration class for local submission mode""" + + +class LocalToolsConfig(ToolsConfig): + """Container for the local submission mode configurations. This is the parts of the configuration that + can be passed as an input to the CLI""" + submission: Optional[LocalSubmissionConfig] = Field( + default=None, + description='Configuration related to local submission mode.') diff --git a/user_tools/src/spark_rapids_tools/configuration/tools_config.py b/user_tools/src/spark_rapids_tools/configuration/tools_config.py index b330cfc28..24627d497 100644 --- a/user_tools/src/spark_rapids_tools/configuration/tools_config.py +++ b/user_tools/src/spark_rapids_tools/configuration/tools_config.py @@ -18,25 +18,32 @@ import json from typing import Union, Optional -from pydantic import BaseModel, Field, ValidationError +from pydantic import Field, ValidationError from spark_rapids_tools import CspPathT +from spark_rapids_tools.configuration.common import BaseConfig, SubmissionConfig from spark_rapids_tools.configuration.runtime_conf import ToolsRuntimeConfig from spark_rapids_tools.utils import AbstractPropContainer -class ToolsConfig(BaseModel): +class ToolsConfig(BaseConfig): """Main container for the user's defined tools configuration""" api_version: float = Field( description='The version of the API that the tools are using. ' 'This is used to test the compatibility of the ' 'configuration file against the current tools release.', - examples=['1.0'], - le=1.0, # minimum version compatible with the current tools implementation + examples=['1.0, 1.1'], + le=1.1, # minimum version compatible with the current tools implementation ge=1.0) - runtime: ToolsRuntimeConfig = Field( + + runtime: Optional[ToolsRuntimeConfig] = Field( + default=None, description='Configuration related to the runtime environment of the tools.') + submission: Optional[SubmissionConfig] = Field( + default=None, + description='Configuration related to the submission.') + @classmethod def load_from_file(cls, file_path: Union[str, CspPathT]) -> Optional['ToolsConfig']: """Load the tools configuration from a file""" @@ -44,7 +51,7 @@ def load_from_file(cls, file_path: Union[str, CspPathT]) -> Optional['ToolsConfi prop_container = AbstractPropContainer.load_from_file(file_path) return cls(**prop_container.props) except ValidationError as e: - # Do nothing. This is kept as a place holder if we want to log the error inside the + # Do nothing. This is kept as a placeholder if we want to log the error inside the # class first raise e diff --git a/user_tools/src/spark_rapids_tools/enums.py b/user_tools/src/spark_rapids_tools/enums.py index 46db8aad1..d2b450832 100644 --- a/user_tools/src/spark_rapids_tools/enums.py +++ b/user_tools/src/spark_rapids_tools/enums.py @@ -222,3 +222,13 @@ def create_default_model_args(cls, model_type: str) -> dict: 'xgboostEnabled': model_type == QualEstimationModel.XGBOOST, 'customModelFile': None, } + + +class SubmissionMode(EnumeratedType): + """Values used to define the submission mode of the applications""" + LOCAL = 'local' + DISTRIBUTED = 'distributed' + + @classmethod + def get_default(cls) -> 'SubmissionMode': + return cls.LOCAL diff --git a/user_tools/src/spark_rapids_tools_distributed/jar_cmd_args.py b/user_tools/src/spark_rapids_tools_distributed/jar_cmd_args.py new file mode 100644 index 000000000..c6513c1c6 --- /dev/null +++ b/user_tools/src/spark_rapids_tools_distributed/jar_cmd_args.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Jar command arguments for running the Tools JAR on Spark """ +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class JarCmdArgs: + """ + Wrapper class to store the arguments required to run the Tools JAR on Spark. + """ + jvm_args: List[str] = field(default=None, init=True) + classpath_arr: List[str] = field(default=None, init=True) + hadoop_classpath: str = field(default=None, init=True) + jar_main_class: str = field(default=None, init=True) + jar_output_dir_args: List[str] = field(default=None, init=True) + extra_rapids_args: List[str] = field(default=None, init=True) diff --git a/user_tools/tests/spark_rapids_tools_ut/conftest.py b/user_tools/tests/spark_rapids_tools_ut/conftest.py index de3f2da12..019026250 100644 --- a/user_tools/tests/spark_rapids_tools_ut/conftest.py +++ b/user_tools/tests/spark_rapids_tools_ut/conftest.py @@ -49,6 +49,7 @@ def gen_cpu_cluster_props(): autotuner_prop_path = 'worker_info.yaml' # valid tools config files valid_tools_conf_files = ['tools_config_00.yaml'] +valid_distributed_mode_tools_conf_files = ['tools_config_01.yaml', 'tools_config_02.yaml'] # invalid tools config files invalid_tools_conf_files = [ # test older API_version diff --git a/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-distributed-config-specification.json b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-distributed-config-specification.json new file mode 100644 index 000000000..c183fec18 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-distributed-config-specification.json @@ -0,0 +1,262 @@ +{ + "$defs": { + "DependencyType": { + "description": "Represents the dependency type for the tools' java cmd.", + "enum": [ + "jar", + "archive" + ], + "title": "DependencyType", + "type": "string" + }, + "DependencyVerification": { + "additionalProperties": false, + "description": "The verification information of a runtime dependency required by the tools' java cmd.", + "properties": { + "size": { + "default": 0, + "description": "The size of the dependency file.", + "examples": [ + 3265393 + ], + "title": "Size", + "type": "integer" + }, + "file_hash": { + "$ref": "#/$defs/FileHashAlgorithm", + "default": null, + "description": "The hash function to verify the file.", + "examples": [ + { + "algorithm": "md5", + "value": "bc9bf7fedde0e700b974426fbd8d869c" + } + ] + } + }, + "title": "DependencyVerification", + "type": "object" + }, + "DistributedSubmissionConfig": { + "additionalProperties": false, + "description": "Configuration class for distributed submission mode", + "properties": { + "remote_cache_dir": { + "default": [ + "hdfs:///tmp/spark_rapids_distributed_tools_cache" + ], + "description": "Remote cache directory where the intermediate output data from each task will be stored. Default is hdfs:///tmp/spark_rapids_distributed_tools_cache.", + "title": "Remote Cache Dir", + "type": "string" + }, + "spark_properties": { + "description": "List of Spark properties to be used for the Spark session.", + "examples": [ + { + "name": "spark.executor.memory", + "value": "4g" + }, + { + "name": "spark.executor.cores", + "value": "4" + } + ], + "items": { + "$ref": "#/$defs/SparkProperty" + }, + "title": "Spark Properties", + "type": "array" + } + }, + "title": "DistributedSubmissionConfig", + "type": "object" + }, + "FileHashAlgorithm": { + "description": "Represents a file hash algorithm and its value. Used for verification against an existing file.", + "properties": { + "algorithm": { + "$ref": "#/$defs/HashAlgorithm" + }, + "value": { + "title": "Value", + "type": "string" + } + }, + "required": [ + "algorithm", + "value" + ], + "title": "FileHashAlgorithm", + "type": "object" + }, + "HashAlgorithm": { + "description": "Represents the supported hashing algorithms", + "enum": [ + "md5", + "sha1", + "sha256", + "sha512" + ], + "title": "HashAlgorithm", + "type": "string" + }, + "RuntimeDependency": { + "additionalProperties": false, + "description": "Holds information about a runtime dependency required by the tools' java cmd.", + "properties": { + "name": { + "description": "The name of the dependency.", + "examples": [ + "Spark-3.5.0", + "AWS Java SDK" + ], + "title": "Name", + "type": "string" + }, + "uri": { + "anyOf": [ + { + "format": "uri", + "minLength": 1, + "type": "string" + }, + { + "format": "file-path", + "type": "string" + } + ], + "description": "The location of the dependency file. It can be a URL to a remote web/storage or a file path.", + "examples": [ + "file:///path/to/file.tgz", + "https://mvn-url/24.08.1/rapids-4-spark-tools_2.12-24.08.1.jar", + "gs://bucket-name/path/to/file.jar" + ], + "title": "Uri" + }, + "dependency_type": { + "$ref": "#/$defs/RuntimeDependencyType", + "description": "Specifies the dependency type to determine how the item is processed. For example, jar files are appended to the java classpath while archive files such as spark are extracted first before adding subdirectory _/jars/* to the classpath." + }, + "verification": { + "$ref": "#/$defs/DependencyVerification", + "default": null, + "description": "Optional specification to verify the dependency file." + } + }, + "required": [ + "name", + "uri" + ], + "title": "RuntimeDependency", + "type": "object" + }, + "RuntimeDependencyType": { + "additionalProperties": false, + "description": "Defines the type of runtime dependency required by the tools' java cmd.", + "properties": { + "dep_type": { + "$ref": "#/$defs/DependencyType", + "description": "The type of the dependency." + }, + "relative_path": { + "default": null, + "description": "Specifies the relative path from within the archive file which will be added to the java cmd. Requires field dep_type to be set to (archive).", + "examples": [ + "jars/*" + ], + "title": "Relative Path", + "type": "string" + } + }, + "required": [ + "dep_type" + ], + "title": "RuntimeDependencyType", + "type": "object" + }, + "SparkProperty": { + "additionalProperties": false, + "description": "Represents a single Spark property with a name and value.", + "properties": { + "name": { + "description": "Name of the Spark property, e.g., \"spark.executor.memory\".", + "title": "Name", + "type": "string" + }, + "value": { + "description": "Value of the Spark property, e.g., \"4g\".", + "title": "Value", + "type": "string" + } + }, + "required": [ + "name", + "value" + ], + "title": "SparkProperty", + "type": "object" + }, + "ToolsRuntimeConfig": { + "additionalProperties": false, + "description": "The runtime configurations of the tools as defined by the user.", + "properties": { + "dependencies": { + "description": "The list of runtime dependencies required by the tools java cmd. Set this list to specify Spark binaries along with any other required jar files (i.e., hadoop jars, gcp connectors,..etc.). When specified, the default predefined dependencies will be ignored.", + "items": { + "$ref": "#/$defs/RuntimeDependency" + }, + "title": "Dependencies", + "type": "array" + } + }, + "required": [ + "dependencies" + ], + "title": "ToolsRuntimeConfig", + "type": "object" + } + }, + "additionalProperties": false, + "description": "Container for the distributed submission mode configurations. This is the parts of the configuration\nthat can be passed as an input to the CLI", + "properties": { + "api_version": { + "description": "The version of the API that the tools are using. This is used to test the compatibility of the configuration file against the current tools release.", + "examples": [ + "1.0" + ], + "maximum": 1.0, + "minimum": 1.0, + "title": "Api Version", + "type": "number" + }, + "runtime": { + "anyOf": [ + { + "$ref": "#/$defs/ToolsRuntimeConfig" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Configuration related to the runtime environment of the tools." + }, + "submission": { + "anyOf": [ + { + "$ref": "#/$defs/DistributedSubmissionConfig" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Configuration related to distributed submission mode." + } + }, + "required": [ + "api_version" + ], + "title": "DistributedToolsConfig", + "type": "object" +} diff --git a/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-config-specification.json b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-local-config-specification.json similarity index 83% rename from user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-config-specification.json rename to user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-local-config-specification.json index 9dbef10ab..c0eb270cf 100644 --- a/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-config-specification.json +++ b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/sample-local-config-specification.json @@ -10,6 +10,7 @@ "type": "string" }, "DependencyVerification": { + "additionalProperties": false, "description": "The verification information of a runtime dependency required by the tools' java cmd.", "properties": { "size": { @@ -65,7 +66,15 @@ "title": "HashAlgorithm", "type": "string" }, + "LocalSubmissionConfig": { + "additionalProperties": false, + "description": "Configuration class for local submission mode", + "properties": {}, + "title": "LocalSubmissionConfig", + "type": "object" + }, "RuntimeDependency": { + "additionalProperties": false, "description": "Holds information about a runtime dependency required by the tools' java cmd.", "properties": { "name": { @@ -115,6 +124,7 @@ "type": "object" }, "RuntimeDependencyType": { + "additionalProperties": false, "description": "Defines the type of runtime dependency required by the tools' java cmd.", "properties": { "dep_type": { @@ -138,6 +148,7 @@ "type": "object" }, "ToolsRuntimeConfig": { + "additionalProperties": false, "description": "The runtime configurations of the tools as defined by the user.", "properties": { "dependencies": { @@ -156,7 +167,8 @@ "type": "object" } }, - "description": "Main container for the user's defined tools configuration", + "additionalProperties": false, + "description": "Container for the local submission mode configurations. This is the parts of the configuration that\ncan be passed as an input to the CLI", "properties": { "api_version": { "description": "The version of the API that the tools are using. This is used to test the compatibility of the configuration file against the current tools release.", @@ -169,14 +181,33 @@ "type": "number" }, "runtime": { - "$ref": "#/$defs/ToolsRuntimeConfig", + "anyOf": [ + { + "$ref": "#/$defs/ToolsRuntimeConfig" + }, + { + "type": "null" + } + ], + "default": null, "description": "Configuration related to the runtime environment of the tools." + }, + "submission": { + "anyOf": [ + { + "$ref": "#/$defs/LocalSubmissionConfig" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Configuration related to local submission mode." } }, "required": [ - "api_version", - "runtime" + "api_version" ], - "title": "ToolsConfig", + "title": "LocalToolsConfig", "type": "object" } diff --git a/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/valid/tools_config_01.yaml b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/valid/tools_config_01.yaml new file mode 100644 index 000000000..b16792020 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/valid/tools_config_01.yaml @@ -0,0 +1,9 @@ +# This yaml file is a sample configuration file for the distributed tools. It is valid +# only if `--submission_mode distributed` is passed to the CLI. It provides submission +# related configurations. +api_version: '1.1' +submission: + remote_cache_dir: 'hdfs:///tmp/spark_rapids_distributed_tools_cache' + spark_properties: + - name: 'spark.executor.memory' + value: '20g' diff --git a/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/valid/tools_config_02.yaml b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/valid/tools_config_02.yaml new file mode 100644 index 000000000..726a02532 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/resources/tools_config/valid/tools_config_02.yaml @@ -0,0 +1,17 @@ +# This yaml file is a sample configuration file for the distributed tools. It is valid +# only if `--submission_mode distributed` is passed to the CLI. It provides runtime +# dependencies and submission related configurations. +api_version: '1.1' +runtime: + dependencies: + - name: my-spark350 + uri: https:///archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz + dependency_type: + dep_type: archive + # for tgz files, it is required to give the subfolder where the jars are located + relative_path: jars/* +submission: + remote_cache_dir: 'hdfs:///tmp/spark_rapids_distributed_tools_cache' + spark_properties: + - name: 'spark.executor.memory' + value: '20g' diff --git a/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py b/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py index 300751242..f889d1450 100644 --- a/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py +++ b/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py @@ -25,7 +25,7 @@ from spark_rapids_tools.cmdli.argprocessor import AbsToolUserArgModel, ArgValueCase from spark_rapids_tools.enums import QualFilterApp from .conftest import SparkRapidsToolsUT, autotuner_prop_path, all_cpu_cluster_props, all_csps, \ - valid_tools_conf_files, invalid_tools_conf_files + valid_tools_conf_files, invalid_tools_conf_files, valid_distributed_mode_tools_conf_files @dataclasses.dataclass @@ -77,25 +77,25 @@ def validate_args_w_savings_disabled(tool_name: str, t_args: dict): assert t_args['filterApps'] == QualFilterApp.get_default() @staticmethod - def create_tool_args_should_pass(tool_name: str, platform=None, cluster=None, - eventlogs=None, tools_jar=None, tools_config_path=None): + def create_tool_args_should_pass(tool_name: str, **kwargs): return AbsToolUserArgModel.create_tool_args(tool_name, - platform=platform, - cluster=cluster, - eventlogs=eventlogs, - tools_jar=tools_jar, - tools_config_path=tools_config_path) + platform=kwargs.get('platform'), + cluster=kwargs.get('cluster'), + eventlogs=kwargs.get('eventlogs'), + tools_jar=kwargs.get('tools_jar'), + tools_config_path=kwargs.get('tools_config_path'), + submission_mode=kwargs.get('submission_mode')) @staticmethod - def create_tool_args_should_fail(tool_name: str, platform=None, cluster=None, - eventlogs=None, tools_jar=None, tools_config_path=None): + def create_tool_args_should_fail(tool_name: str, **kwargs): with pytest.raises(SystemExit) as pytest_wrapped_e: AbsToolUserArgModel.create_tool_args(tool_name, - platform=platform, - cluster=cluster, - eventlogs=eventlogs, - tools_jar=tools_jar, - tools_config_path=tools_config_path) + platform=kwargs.get('platform'), + cluster=kwargs.get('cluster'), + eventlogs=kwargs.get('eventlogs'), + tools_jar=kwargs.get('tools_jar'), + tools_config_path=kwargs.get('tools_config_path'), + submission_mode=kwargs.get('submission_mode')) assert pytest_wrapped_e.type == SystemExit @staticmethod @@ -349,6 +349,20 @@ def test_invalid_tools_configs(self, get_ut_data_dir, tool_name, csp, tools_conf tools_config_path=tools_conf_path) assert pytest_wrapped_e.type == SystemExit + @pytest.mark.parametrize('tool_name', ['qualification']) + @pytest.mark.parametrize('csp', ['onprem']) + @pytest.mark.parametrize('submission_mode', ['distributed']) + @pytest.mark.parametrize('tools_conf_fname', valid_distributed_mode_tools_conf_files) + def test_distributed_mode_configs(self, get_ut_data_dir, tool_name, csp, submission_mode, tools_conf_fname): + tools_conf_path = f'{get_ut_data_dir}/tools_config/valid/{tools_conf_fname}' + # should pass: tools config file is provided + tool_args = self.create_tool_args_should_pass(tool_name, + platform=csp, + eventlogs=f'{get_ut_data_dir}/eventlogs', + tools_config_path=tools_conf_path, + submission_mode=submission_mode) + assert tool_args['toolsConfig'] is not None + def test_arg_cases_coverage(self): """ This test ensures that above tests have covered all possible states of the `platform`, `cluster`,