From 98b7a4ee950ec4fa6954177bcad5057d1b354098 Mon Sep 17 00:00:00 2001 From: Partho Sarthi Date: Tue, 24 Oct 2023 13:21:03 -0700 Subject: [PATCH] Fix spinner animation blocking diagnostic prompt Signed-off-by: Partho Sarthi --- .../spark_rapids_pytools/common/utilities.py | 22 ++++++++++++++----- .../spark_rapids_pytools/rapids/diagnostic.py | 10 ++++++--- .../rapids/rapids_tool.py | 5 ++++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/user_tools/src/spark_rapids_pytools/common/utilities.py b/user_tools/src/spark_rapids_pytools/common/utilities.py index b8d56817d..fb2666cd7 100644 --- a/user_tools/src/spark_rapids_pytools/common/utilities.py +++ b/user_tools/src/spark_rapids_pytools/common/utilities.py @@ -379,23 +379,25 @@ class ToolsSpinner: A class to manage the spinner animation. Reference: https://stackoverflow.com/a/66558182 - :param in_debug_mode: Flag indicating if running in debug (verbose) mode. Defaults to False. + :param enabled: Flag indicating if the spinner is enabled. Defaults to True. """ - in_debug_mode: bool = field(default=False, init=True) - pixel_spinner: PixelSpinner = field(default=PixelSpinner('Processing...'), init=False) + enabled: bool = field(default=True, init=True) + pixel_spinner: PixelSpinner = field(default=PixelSpinner('Processing...', hide_cursor=False), init=False) end: str = field(default='Processing Completed!', init=False) timeout: float = field(default=0.1, init=False) completed: bool = field(default=False, init=False) spinner_thread: threading.Thread = field(default=None, init=False) + pause_event: threading.Event = field(default=threading.Event(), init=False) def _spinner_animation(self): while not self.completed: self.pixel_spinner.next() time.sleep(self.timeout) + while self.pause_event.is_set(): + self.pause_event.wait(self.timeout) def start(self): - # Don't start if in debug mode - if not self.in_debug_mode: + if self.enabled: self.spinner_thread = threading.Thread(target=self._spinner_animation, daemon=True) self.spinner_thread.start() return self @@ -404,6 +406,16 @@ def stop(self): self.completed = True print(f'\r\n{self.end}', flush=True) + def pause(self, insert_newline=False): + if self.enabled: + if insert_newline: + # Print a newline for visual separation + print() + self.pause_event.set() + + def resume(self): + self.pause_event.clear() + def __enter__(self): return self.start() diff --git a/user_tools/src/spark_rapids_pytools/rapids/diagnostic.py b/user_tools/src/spark_rapids_pytools/rapids/diagnostic.py index 35df84fb1..510600497 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/diagnostic.py +++ b/user_tools/src/spark_rapids_pytools/rapids/diagnostic.py @@ -40,16 +40,20 @@ def _process_custom_args(self): self.thread_num = thread_num self.logger.debug('Set thread number as: %d', self.thread_num) - - self.logger.warning('This operation will collect sensitive information from your cluster, ' - 'such as OS & HW info, Yarn/Spark configurations and log files etc.') + log_message = ('This operation will collect sensitive information from your cluster, ' + 'such as OS & HW info, Yarn/Spark configurations and log files etc.') yes = self.wrapper_options.get('yes', False) if yes: + self.logger.warning(log_message) self.logger.info('Confirmed by command line option.') else: + # Pause the spinner for user prompt + self.spinner.pause(insert_newline=True) + print(log_message) user_input = input('Do you want to continue (yes/no): ') if user_input.lower() not in ['yes', 'y']: raise RuntimeError('User canceled the operation.') + self.spinner.resume() def requires_cluster_connection(self) -> bool: return True diff --git a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py index b686871be..de86f63d1 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py +++ b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py @@ -59,6 +59,7 @@ class RapidsTool(object): name: str = field(default=None, init=False) ctxt: ToolContext = field(default=None, init=False) logger: Logger = field(default=None, init=False) + spinner: ToolsSpinner = field(default=None, init=False) def pretty_name(self): return self.name.capitalize() @@ -272,7 +273,9 @@ def _verify_exec_cluster(self): self._handle_non_running_exec_cluster(msg) def launch(self): - with ToolsSpinner(in_debug_mode=ToolLogging.is_debug_mode_enabled()): + # Spinner should not be enabled in debug mode + enable_spinner = not ToolLogging.is_debug_mode_enabled() + with ToolsSpinner(enabled=enable_spinner) as self.spinner: self._init_tool() self._connect_to_execution_cluster() self._process_arguments()