diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index e9d9117b3..ceb1eb500 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -22,7 +22,19 @@ from jumpstarter_driver_flashers.bundle import FlasherBundleManifestV1Alpha1 from jumpstarter.client.decorators import driver_click_group -from jumpstarter.common.exceptions import ArgumentError +from jumpstarter.common.exceptions import ArgumentError, JumpstarterException + + +class FlashError(JumpstarterException): + """Base exception for flash-related errors.""" + + +class FlashRetryableError(FlashError): + """Exception for retryable flash errors (network, timeout, etc.).""" + + +class FlashNonRetryableError(FlashError): + """Exception for non-retryable flash errors (configuration, file system, etc.).""" debug_console_option = click.option("--console-debug", is_flag=True, help="Enable console debug mode") @@ -83,6 +95,7 @@ def flash( # noqa: C901 insecure_tls: bool = False, headers: dict[str, str] | None = None, bearer_token: str | None = None, + retries: int = 3, ): if bearer_token: bearer_token = self._validate_bearer_token(bearer_token) @@ -119,10 +132,11 @@ def flash( # noqa: C901 # start counting time for the flash operation start_time = time.time() - if should_download_to_httpd: - # Create a queue to handle exceptions from the thread - error_queue = Queue() + # Initialize storage_thread and error_queue + storage_thread = None + error_queue = Queue() + if should_download_to_httpd: # Start the storage write operation in the background storage_thread = threading.Thread( target=self._transfer_bg_thread, @@ -149,57 +163,184 @@ def flash( # noqa: C901 raise error_queue.get() with self._services_up(): - with self._busybox() as console: - manifest = self.manifest - target = partition or self.call("get_default_target") or manifest.spec.default_target - if not target: - raise ArgumentError("No partition or default target specified") + # Retry logic at the highest level - retry entire console setup and flash operation + for attempt in range(retries + 1): # +1 for initial attempt + try: + self._perform_flash_operation( + partition, path, image_url, should_download_to_httpd, + storage_thread, error_queue, cacert_file, insecure_tls, + headers, bearer_token + ) + self.logger.info(f"Flash operation succeeded on attempt {attempt + 1}") + break + except Exception as e: + # Check if this is a retryable or non-retryable error (including sub-exceptions) + retryable_error = self._get_retryable_error(e) + non_retryable_error = self._get_non_retryable_error(e) + + if retryable_error is not None: + if attempt < retries: + self.logger.warning( + f"Flash attempt {attempt + 1} failed with retryable error: {retryable_error}" + ) + self.logger.info(f"Retrying flash operation (attempt {attempt + 2}/{retries + 1})") + # Wait a bit before retrying + time.sleep(2 ** attempt) # Exponential backoff + continue + else: + self.logger.error(f"Flash operation failed after {retries + 1} attempts") + raise FlashError( + f"Flash operation failed after {retries + 1} attempts. Last error: {retryable_error}" + ) from e + elif non_retryable_error is not None: + # Non-retryable error, fail immediately + self.logger.error(f"Flash operation failed with non-retryable error: {non_retryable_error}") + raise FlashError(f"Flash operation failed: {non_retryable_error}") from e + else: + # Unexpected error, don't retry + self.logger.error(f"Flash operation failed with unexpected error: {e}") + raise FlashError(f"Flash operation failed: {e}") from e - target_device = self._get_target_device(target, manifest, console) - self.logger.info(f"Using target block device: {target_device}") - console.sendline(f"export dhcp_addr={self._dhcp_details.ip_address}") - console.expect(manifest.spec.login.prompt, timeout=EXPECT_TIMEOUT_DEFAULT) - console.sendline(f"export gw_addr={self._dhcp_details.gateway}") - console.expect(manifest.spec.login.prompt, timeout=EXPECT_TIMEOUT_DEFAULT) + total_time = time.time() - start_time + # total time in minutes:seconds + minutes, seconds = divmod(total_time, 60) + self.logger.info(f"Flashing completed in {int(minutes)}m {int(seconds):02d}s") + + def _get_retryable_error(self, exception: Exception) -> FlashRetryableError | None: + """Find a retryable error in an exception (or any of its causes). + + Args: + exception: The exception to check + + Returns: + The FlashRetryableError if found, None otherwise + """ + # Check if this is an ExceptionGroup and look through its exceptions + if hasattr(exception, 'exceptions'): + for sub_exc in exception.exceptions: + result = self._get_retryable_error(sub_exc) + if result is not None: + return result + + # Check the current exception + if isinstance(exception, FlashRetryableError): + return exception + + # Check the cause chain + current = getattr(exception, '__cause__', None) + while current is not None: + if isinstance(current, FlashRetryableError): + return current + # Also check if the cause is an ExceptionGroup + if hasattr(current, 'exceptions'): + for sub_exc in current.exceptions: + result = self._get_retryable_error(sub_exc) + if result is not None: + return result + current = getattr(current, '__cause__', None) + return None + + def _get_non_retryable_error(self, exception: Exception) -> FlashNonRetryableError | None: + """Find a non-retryable error in an exception (or any of its causes). + + Args: + exception: The exception to check + + Returns: + The FlashNonRetryableError if found, None otherwise + """ + # Check if this is an ExceptionGroup and look through its exceptions + if hasattr(exception, 'exceptions'): + for sub_exc in exception.exceptions: + result = self._get_non_retryable_error(sub_exc) + if result is not None: + return result + + # Check the current exception + if isinstance(exception, FlashNonRetryableError): + return exception + + # Check the cause chain + current = getattr(exception, '__cause__', None) + while current is not None: + if isinstance(current, FlashNonRetryableError): + return current + # Also check if the cause is an ExceptionGroup + if hasattr(current, 'exceptions'): + for sub_exc in current.exceptions: + result = self._get_non_retryable_error(sub_exc) + if result is not None: + return result + current = getattr(current, '__cause__', None) + return None - # Preflash commands are executed before the flash operation - # generally used to clean up boot entries in existing devices - for preflash_command in manifest.spec.preflash_commands: - self.logger.info(f"Running preflash command: {preflash_command}") - console.sendline(preflash_command) - console.expect(manifest.spec.login.prompt, timeout=EXPECT_TIMEOUT_DEFAULT) + def _perform_flash_operation( + self, + partition: str | None, + path: PathBuf, + image_url: str, + should_download_to_httpd: bool, + storage_thread: threading.Thread | None, + error_queue: Queue, + cacert_file: str | None, + insecure_tls: bool, + headers: dict[str, str] | None, + bearer_token: str | None, + ): + """Perform the actual flash operation with console setup. - # make sure that the device is connected to the network and has an IP address - console.sendline("udhcpc") + This method contains all the console operations that can be retried. + """ + with self._busybox() as console: + manifest = self.manifest + target = partition or self.call("get_default_target") or manifest.spec.default_target + if not target: + raise ArgumentError("No partition or default target specified") + + target_device = self._get_target_device(target, manifest, console) + + self.logger.info(f"Using target block device: {target_device}") + console.sendline(f"export dhcp_addr={self._dhcp_details.ip_address}") + console.expect(manifest.spec.login.prompt, timeout=EXPECT_TIMEOUT_DEFAULT) + console.sendline(f"export gw_addr={self._dhcp_details.gateway}") + console.expect(manifest.spec.login.prompt, timeout=EXPECT_TIMEOUT_DEFAULT) + + # Preflash commands are executed before the flash operation + # generally used to clean up boot entries in existing devices + for preflash_command in manifest.spec.preflash_commands: + self.logger.info(f"Running preflash command: {preflash_command}") + console.sendline(preflash_command) console.expect(manifest.spec.login.prompt, timeout=EXPECT_TIMEOUT_DEFAULT) - stored_cacert = None - if should_download_to_httpd: - self._wait_for_storage_thread(storage_thread, error_queue) - else: - stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file) + # make sure that the device is connected to the network and has an IP address + console.sendline("udhcpc") + console.expect(manifest.spec.login.prompt, timeout=EXPECT_TIMEOUT_DEFAULT) - header_args = self._prepare_headers(headers, bearer_token) - self._flash_with_progress( - console, - manifest, - path, - image_url, - target_device, - insecure_tls, - stored_cacert, - header_args, - ) - - total_time = time.time() - start_time - # total time in minutes:seconds - minutes, seconds = divmod(total_time, 60) - self.logger.info(f"Flashing completed in {int(minutes)}m {int(seconds):02d}s") - console.sendline("reboot") - time.sleep(2) - self.logger.info("Powering off target") - self.power.off() + stored_cacert = None + if should_download_to_httpd: + self._wait_for_storage_thread(storage_thread, error_queue) + else: + stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file) + + header_args = self._prepare_headers(headers, bearer_token) + + # Perform the actual flash operation + self._flash_with_progress( + console, + manifest, + path, + image_url, + target_device, + insecure_tls, + stored_cacert, + header_args, + ) + + console.sendline("reboot") + time.sleep(2) + self.logger.info("Powering off target") + self.power.off() def _setup_flasher_ssl(self, console, manifest, cacert_file: str | None) -> str | None: """Setup SSL configuration for the flasher. @@ -309,48 +450,132 @@ def _flash_with_progress( self._check_url_access(console, prompt, image_url, tls_args, header_args) # Flash the image, we run curl -> decompress -> dd in the background, so we can monitor dd's progress + # Use pipefail to ensure the pipeline fails if any command in the pipe fails flash_cmd = ( - f'( curl -fsSL {tls_args} {header_args} "{image_url}" | ' + f'( set -o pipefail; curl -fsSL {tls_args} {header_args} "{image_url}" | ' f"{decompress_cmd} " - f"dd of={target_path} bs=64k iflag=fullblock oflag=direct) &" + f'dd of={target_path} bs=64k iflag=fullblock oflag=direct ' + + '&& echo "F""LASH_COMPLETE" || echo "F""LASH_FAILED" ) &' ) console.sendline(flash_cmd) console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT * 2) - # monitor the dd process to understand flashing progrses + # Start monitoring the flash operation + self._monitor_flash_progress(console, prompt) + + self.logger.info("Flushing buffers") + console.sendline("sync") + console.expect(prompt, timeout=EXPECT_TIMEOUT_SYNC) + + def _monitor_flash_progress(self, console, prompt): + """Monitor flash progress by accumulating console output and looking for completion markers.""" + # Get dd process ID for progress monitoring console.sendline("pidof dd") console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT) - dd_pid = console.before.decode(errors="ignore").splitlines()[1].strip() + pidof_output = console.before.decode(errors="ignore") + accumulated_output = pidof_output # just in case we get the FLASH_COMPLETE or FLASH_FAILED markers soon + + # Extract the actual process ID from the output, handling potential error messages + lines = pidof_output.splitlines() + dd_pid = None + for line in lines: + line = line.strip() + # Look for a line that contains only digits (the process ID) + if line.isdigit(): + dd_pid = line + break + + if not dd_pid: + self.logger.error("Could not find dd process ID") + raise FlashNonRetryableError("Could not find dd process ID") # Initialize progress tracking variables last_pos = 0 last_time = time.time() + dd_finished_time = None while True: - console.sendline(f"cat /proc/{dd_pid}/fdinfo/1") - console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT) - if "No such file or directory" in console.before.decode(errors="ignore"): + # Check if dd process is still running for progress monitoring (only if we have a valid PID) + if dd_pid != "0": + console.sendline(f"cat /proc/{dd_pid}/fdinfo/1") + console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT) + data = console.before.decode(errors="ignore") + else: + # If we don't have a valid dd PID, just check for completion markers + data = "" + + # Always accumulate output regardless of dd status + accumulated_output = self._update_accumulated_output(accumulated_output, data) + + # Debug logging to help track output + if "FLASH_" in data: + self.logger.debug(f"Found FLASH marker in data: {data}") + self.logger.debug(f"Current accumulated output: {accumulated_output}") + + # Check for completion markers in the accumulated output + if self._check_completion_markers(accumulated_output): break - data = console.before.decode(errors="ignore") - match = re.search(r"pos:\s+(\d+)", data) - if match: - current_bytes = int(match.group(1)) - current_time = time.time() - elapsed = current_time - last_time - - if elapsed >= 5.0: # Update speed every 5 seconds - bytes_diff = current_bytes - last_pos - speed_mb = (bytes_diff / (1024 * 1024)) / elapsed - total_mb = current_bytes / (1024 * 1024) - self.logger.info(f"Flash progress: {total_mb:.2f} MB, Speed: {speed_mb:.2f} MB/s") - - last_pos = current_bytes - last_time = current_time + + if dd_pid != "0" and "No such file or directory" in data: + # dd process finished, handle the completion + dd_finished_time = self._handle_dd_finished(dd_finished_time) + continue + elif dd_pid != "0": + # dd is still running, monitor progress + last_pos, last_time = self._update_progress_stats(data, last_pos, last_time) + else: + # No valid dd PID, just wait and check for completion markers + pass + time.sleep(1) - self.logger.info("Flushing buffers") - console.sendline("sync") - console.expect(prompt, timeout=EXPECT_TIMEOUT_SYNC) + def _check_completion_markers(self, accumulated_output): + """Check for completion markers in accumulated output.""" + if "FLASH_COMPLETE" in accumulated_output: + self.logger.info("Flash operation completed successfully") + return True + elif "FLASH_FAILED" in accumulated_output: + self.logger.error(f"FLASH_FAILED marker detected in output:\n{accumulated_output}") + raise FlashRetryableError("Flash operation failed - curl or pipeline failed") + return False + + def _handle_dd_finished(self, dd_finished_time): + """Handle case when dd process has finished but no completion marker found yet.""" + if dd_finished_time is None: + dd_finished_time = time.time() + elif time.time() - dd_finished_time > 5: # Wait up to 5 seconds for echo + raise FlashNonRetryableError("Flash operation completed without success/failure marker") + # Continue checking for a few more iterations + time.sleep(1) + return dd_finished_time + + def _update_accumulated_output(self, accumulated_output, data): + """Update accumulated output with new data, keeping only last 64KB.""" + accumulated_output += data + # Keep only the last 64KB to prevent memory growth + if len(accumulated_output) > 64*1024: + accumulated_output = accumulated_output[-64*1024:] + return accumulated_output + + def _update_progress_stats(self, data, last_pos, last_time): + """Update progress statistics and log progress if enough time has elapsed.""" + match = re.search(r"pos:\s+(\d+)", data) + + if match: + current_bytes = int(match.group(1)) + current_time = time.time() + elapsed = current_time - last_time + + if elapsed >= 5.0: # Update speed every 5 seconds + bytes_diff = current_bytes - last_pos + speed_mb = (bytes_diff / (1024 * 1024)) / elapsed + total_mb = current_bytes / (1024 * 1024) + self.logger.info(f"Flash progress: {total_mb:.2f} MB, Speed: {speed_mb:.2f} MB/s") + + last_pos = current_bytes + last_time = current_time + + return last_pos, last_time def _check_url_access(self, console, prompt, image_url: str, tls_args: str, header_args: str): """Check if the image URL is accessible using curl. @@ -373,7 +598,9 @@ def _check_url_access(self, console, prompt, image_url: str, tls_args: str, head console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT) url_status = int(console.before.decode(errors="ignore").strip().splitlines()[-1]) if url_status != 0: - raise RuntimeError(f"Unable to access {image_url} (curl exit status {url_status}), output: {curl_output}") + raise FlashRetryableError( + f"Unable to access {image_url} (curl exit status {url_status}), output: {curl_output}" + ) def _get_target_device(self, target: str, manifest: FlasherBundleManifestV1Alpha1, console) -> str: """Get the target device path from the manifest, resolving block devices if needed. @@ -780,6 +1007,12 @@ def base(): type=str, help="Bearer token for HTTP authentication", ) + @click.option( + "--retries", + type=int, + default=3, + help="Number of retry attempts for flash operation (default: 3)", + ) @debug_console_option def flash( file, @@ -793,6 +1026,7 @@ def flash( insecure_tls, header, bearer, + retries, ): """Flash image to DUT from file""" if os_image_checksum_file and os.path.exists(os_image_checksum_file): @@ -813,6 +1047,7 @@ def flash( insecure_tls=insecure_tls, headers=headers, bearer_token=bearer, + retries=retries, ) @base.command()