diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py b/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py index b8f4f6a29..766ccc3f6 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, field from enum import Enum from logging import Logger -from typing import Type, Any, List, Callable +from typing import Type, Any, List, Callable, Union from spark_rapids_tools import EnumeratedType, CspEnv from spark_rapids_pytools.common.prop_manager import AbstractPropertiesContainer, JSONPropertiesContainer, \ @@ -369,7 +369,7 @@ def validate_env(self): self._handle_inconsistent_configurations(incorrect_envs) def run_sys_cmd(self, - cmd, + cmd: Union[str, list[str]], cmd_input: str = None, fail_ok: bool = False, env_vars: dict = None) -> str: @@ -393,7 +393,11 @@ def process_streams(std_out, std_err): if len(stdout_splits) > 0: std_out_lines = Utils.gen_multiline_str([f'\t| {line}' for line in stdout_splits]) stdout_str = f'\n\t\n{std_out_lines}' - cmd_log_str = Utils.gen_joined_str(' ', process_credentials_option(cmd)) + if isinstance(cmd, list): + cmd_list = cmd + else: + cmd_list = cmd.split(' ') + cmd_log_str = Utils.gen_joined_str(' ', process_credentials_option(cmd_list)) if len(stderr_splits) > 0: std_err_lines = Utils.gen_multiline_str([f'\t| {line}' for line in stderr_splits]) stderr_str = f'\n\t\n{std_err_lines}'