Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:NVIDIA/spark-rapids-tools into exec_…
Browse files Browse the repository at this point in the history
…stage_map
  • Loading branch information
nartal1 committed Nov 2, 2023
2 parents 66c9607 + defd16c commit bc10f09
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 17 deletions.
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<artifactId>rapids-4-spark-tools_2.12</artifactId>
<name>RAPIDS Accelerator for Apache Spark tools</name>
<description>RAPIDS Accelerator for Apache Spark tools</description>
<version>23.08.3-SNAPSHOT</version>
<version>23.10.1-SNAPSHOT</version>
<packaging>jar</packaging>
<url>http://github.com/NVIDIA/spark-rapids-tools</url>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,12 @@ object AutoTuner extends Logging {
representer.getPropertyUtils.setSkipMissingProperties(true)
val constructor = new Constructor(classOf[ClusterProperties], new LoaderOptions())
val yamlObjNested = new Yaml(constructor, representer)
Option(yamlObjNested.load(clusterProps).asInstanceOf[ClusterProperties])
val loadedClusterProps = yamlObjNested.load(clusterProps).asInstanceOf[ClusterProperties]
if (loadedClusterProps != null && loadedClusterProps.softwareProperties == null) {
logInfo("softwareProperties is empty from input worker_info file")
loadedClusterProps.softwareProperties = new util.LinkedHashMap[String, String]()
}
Option(loadedClusterProps)
}

def loadClusterProps(filePath: String): Option[ClusterProperties] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,15 @@ abstract class AppBase(
}
}

private def trimSchema(str: String): String = {
val index = str.lastIndexOf(",")
if (index != -1 && str.contains("...")) {
str.substring(0, index)
} else {
str
}
}

// The ReadSchema metadata is only in the eventlog for DataSource V1 readers
protected def checkMetadataForReadSchema(sqlID: Long, planInfo: SparkPlanInfo): Unit = {
// check if planInfo has ReadSchema
Expand All @@ -284,7 +293,7 @@ abstract class AppBase(
val readSchema = ReadParser.formatSchemaStr(meta.getOrElse("ReadSchema", ""))
val scanNode = allNodes.filter(node => {
// Get ReadSchema of each Node and sanitize it for comparison
val trimmedNode = ReadParser.parseReadNode(node).schema.replace("...", "")
val trimmedNode = trimSchema(ReadParser.parseReadNode(node).schema)
readSchema.contains(trimmedNode)
}).filter(x => x.name.startsWith("Scan")).head

Expand Down
2 changes: 1 addition & 1 deletion user_tools/src/spark_rapids_pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@

from spark_rapids_pytools.build import get_version

VERSION = '23.08.3'
VERSION = '23.10.1'
__version__ = get_version(VERSION)
11 changes: 11 additions & 0 deletions user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,20 @@ def _init_nodes(self):
SparkNodeType.MASTER: master_node
}

def _set_zone_from_props(self, prop_container: JSONPropertiesContainer):
"""
Extracts the 'zoneUri' from the properties container and updates the environment variable dictionary.
"""
if prop_container:
zone_uri = prop_container.get_value_silent('config', 'gceClusterConfig', 'zoneUri')
if zone_uri:
self.cli.env_vars['zone'] = FSUtil.get_resource_name(zone_uri)

def _init_connection(self, cluster_id: str = None,
props: str = None) -> dict:
cluster_args = super()._init_connection(cluster_id=cluster_id, props=props)
# extract and update zone to the environment variable
self._set_zone_from_props(cluster_args['props'])
# propagate zone to the cluster
cluster_args.setdefault('zone', self.cli.get_env_var('zone'))
return cluster_args
Expand Down
8 changes: 5 additions & 3 deletions user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dataclasses import dataclass, field
from enum import Enum
from logging import Logger
from typing import Type, Any, List, Callable
from typing import Type, Any, List, Callable, Union

from spark_rapids_tools import EnumeratedType, CspEnv
from spark_rapids_pytools.common.prop_manager import AbstractPropertiesContainer, JSONPropertiesContainer, \
Expand Down Expand Up @@ -369,7 +369,7 @@ def validate_env(self):
self._handle_inconsistent_configurations(incorrect_envs)

def run_sys_cmd(self,
cmd,
cmd: Union[str, list],
cmd_input: str = None,
fail_ok: bool = False,
env_vars: dict = None) -> str:
Expand All @@ -393,7 +393,9 @@ def process_streams(std_out, std_err):
if len(stdout_splits) > 0:
std_out_lines = Utils.gen_multiline_str([f'\t| {line}' for line in stdout_splits])
stdout_str = f'\n\t<STDOUT>\n{std_out_lines}'
cmd_log_str = Utils.gen_joined_str(' ', process_credentials_option(cmd))
# if the command is already a list, use it as-is. Otherwise, split the string into a list.
cmd_list = cmd if isinstance(cmd, list) else cmd.split(' ')
cmd_log_str = Utils.gen_joined_str(' ', process_credentials_option(cmd_list))
if len(stderr_splits) > 0:
std_err_lines = Utils.gen_multiline_str([f'\t| {line}' for line in stderr_splits])
stderr_str = f'\n\t<STDERR>\n{std_err_lines}'
Expand Down
22 changes: 17 additions & 5 deletions user_tools/src/spark_rapids_pytools/common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,23 +379,25 @@ class ToolsSpinner:
A class to manage the spinner animation.
Reference: https://stackoverflow.com/a/66558182
:param in_debug_mode: Flag indicating if running in debug (verbose) mode. Defaults to False.
:param enabled: Flag indicating if the spinner is enabled. Defaults to True.
"""
in_debug_mode: bool = field(default=False, init=True)
pixel_spinner: PixelSpinner = field(default=PixelSpinner('Processing...'), init=False)
enabled: bool = field(default=True, init=True)
pixel_spinner: PixelSpinner = field(default=PixelSpinner('Processing...', hide_cursor=False), init=False)
end: str = field(default='Processing Completed!', init=False)
timeout: float = field(default=0.1, init=False)
completed: bool = field(default=False, init=False)
spinner_thread: threading.Thread = field(default=None, init=False)
pause_event: threading.Event = field(default=threading.Event(), init=False)

def _spinner_animation(self):
while not self.completed:
self.pixel_spinner.next()
time.sleep(self.timeout)
while self.pause_event.is_set():
self.pause_event.wait(self.timeout)

def start(self):
# Don't start if in debug mode
if not self.in_debug_mode:
if self.enabled:
self.spinner_thread = threading.Thread(target=self._spinner_animation, daemon=True)
self.spinner_thread.start()
return self
Expand All @@ -404,6 +406,16 @@ def stop(self):
self.completed = True
print(f'\r\n{self.end}', flush=True)

def pause(self, insert_newline=False):
if self.enabled:
if insert_newline:
# Print a newline for visual separation
print()
self.pause_event.set()

def resume(self):
self.pause_event.clear()

def __enter__(self):
return self.start()

Expand Down
10 changes: 7 additions & 3 deletions user_tools/src/spark_rapids_pytools/rapids/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,20 @@ def _process_custom_args(self):

self.thread_num = thread_num
self.logger.debug('Set thread number as: %d', self.thread_num)

self.logger.warning('This operation will collect sensitive information from your cluster, '
'such as OS & HW info, Yarn/Spark configurations and log files etc.')
log_message = ('This operation will collect sensitive information from your cluster, '
'such as OS & HW info, Yarn/Spark configurations and log files etc.')
yes = self.wrapper_options.get('yes', False)
if yes:
self.logger.warning(log_message)
self.logger.info('Confirmed by command line option.')
else:
# Pause the spinner for user prompt
self.spinner.pause(insert_newline=True)
print(log_message)
user_input = input('Do you want to continue (yes/no): ')
if user_input.lower() not in ['yes', 'y']:
raise RuntimeError('User canceled the operation.')
self.spinner.resume()

def requires_cluster_connection(self) -> bool:
return True
Expand Down
13 changes: 11 additions & 2 deletions user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from logging import Logger
from typing import Any, Callable, Dict, List

import spark_rapids_pytools
from spark_rapids_tools import CspEnv
from spark_rapids_pytools.cloud_api.sp_types import get_platform, \
ClusterBase, DeployMode, NodeHWInfo
Expand Down Expand Up @@ -59,6 +60,7 @@ class RapidsTool(object):
name: str = field(default=None, init=False)
ctxt: ToolContext = field(default=None, init=False)
logger: Logger = field(default=None, init=False)
spinner: ToolsSpinner = field(default=None, init=False)

def pretty_name(self):
return self.name.capitalize()
Expand Down Expand Up @@ -119,6 +121,7 @@ def wrapper(self, *args, **kwargs):
def __post_init__(self):
# when debug is set to true set it in the environment.
self.logger = ToolLogging.get_and_setup_logger(f'rapids.tools.{self.name}')
self.logger.info('Using Spark RAPIDS user tools version %s', spark_rapids_pytools.__version__)

def _check_environment(self) -> None:
self.ctxt.platform.setup_and_validate_env()
Expand Down Expand Up @@ -272,7 +275,9 @@ def _verify_exec_cluster(self):
self._handle_non_running_exec_cluster(msg)

def launch(self):
with ToolsSpinner(in_debug_mode=ToolLogging.is_debug_mode_enabled()):
# Spinner should not be enabled in debug mode
enable_spinner = not ToolLogging.is_debug_mode_enabled()
with ToolsSpinner(enabled=enable_spinner) as self.spinner:
self._init_tool()
self._connect_to_execution_cluster()
self._process_arguments()
Expand Down Expand Up @@ -384,8 +389,12 @@ def _process_jar_arg(self):
fail_ok=False,
create_dir=True)
self.logger.info('RAPIDS accelerator jar is downloaded to work_dir %s', jar_path)
# get the jar file name and add it to the tool args
# get the jar file name
jar_file_name = FSUtil.get_resource_name(jar_path)
version_match = re.search(r'\d{2}\.\d{2}\.\d+', jar_file_name)
jar_version = version_match.group() if version_match else 'Unknown'
self.logger.info('Using Spark RAPIDS accelerator jar version %s', jar_version)
# add jar file name to the tool args
self.ctxt.add_rapids_args('jarFileName', jar_file_name)
self.ctxt.add_rapids_args('jarFilePath', jar_path)

Expand Down

0 comments on commit bc10f09

Please sign in to comment.