diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index e3dc92752..78527c474 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -5,6 +5,7 @@ import sys import threading import time +from concurrent.futures import CancelledError from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path, PosixPath @@ -177,14 +178,18 @@ def flash( # noqa: C901 self.logger.info(f"Flash operation succeeded on attempt {attempt + 1}") break except Exception as e: - # Check if this is a retryable or non-retryable error (including sub-exceptions) - retryable_error = self._get_retryable_error(e) - non_retryable_error = self._get_non_retryable_error(e) + # Categorize the exception as retryable or non-retryable + categorized_error = self._categorize_exception(e) - if retryable_error is not None: + if isinstance(categorized_error, FlashNonRetryableError): + # Non-retryable error, fail immediately + self.logger.error(f"Flash operation failed with non-retryable error: {categorized_error}") + raise FlashError(f"Flash operation failed: {categorized_error}") from e + else: + # Retryable error if attempt < retries: self.logger.warning( - f"Flash attempt {attempt + 1} failed with retryable error: {retryable_error}" + f"Flash attempt {attempt + 1} failed with retryable error: {categorized_error}" ) self.logger.info(f"Retrying flash operation (attempt {attempt + 2}/{retries + 1})") # Wait a bit before retrying @@ -193,16 +198,8 @@ def flash( # noqa: C901 else: self.logger.error(f"Flash operation failed after {retries + 1} attempts") raise FlashError( - f"Flash operation failed after {retries + 1} attempts. Last error: {retryable_error}" + f"Flash operation failed after {retries + 1} attempts. Last error: {categorized_error}" ) from e - elif non_retryable_error is not None: - # Non-retryable error, fail immediately - self.logger.error(f"Flash operation failed with non-retryable error: {non_retryable_error}") - raise FlashError(f"Flash operation failed: {non_retryable_error}") from e - else: - # Unexpected error, don't retry - self.logger.error(f"Flash operation failed with unexpected error: {e}") - raise FlashError(f"Flash operation failed: {e}") from e total_time = time.time() - start_time @@ -210,69 +207,79 @@ def flash( # noqa: C901 minutes, seconds = divmod(total_time, 60) self.logger.info(f"Flashing completed in {int(minutes)}m {int(seconds):02d}s") - def _get_retryable_error(self, exception: Exception) -> FlashRetryableError | None: - """Find a retryable error in an exception (or any of its causes). + def _categorize_exception(self, exception: Exception) -> FlashRetryableError | FlashNonRetryableError: + """Categorize an exception as retryable or non-retryable. + + This method searches through the exception chain (including ExceptionGroups) + to find FlashRetryableError or FlashNonRetryableError instances. + + Priority: + 1. FlashNonRetryableError - highest priority, fail immediately + 2. FlashRetryableError - retry with backoff + 3. Unknown exceptions - log full stack trace and treat as retryable Args: - exception: The exception to check + exception: The exception to categorize Returns: - The FlashRetryableError if found, None otherwise + FlashRetryableError or FlashNonRetryableError """ - # Check if this is an ExceptionGroup and look through its exceptions - if hasattr(exception, 'exceptions'): - for sub_exc in exception.exceptions: - result = self._get_retryable_error(sub_exc) - if result is not None: - return result - - # Check the current exception - if isinstance(exception, FlashRetryableError): - return exception + # First pass: look for non-retryable errors (highest priority) + non_retryable = self._find_exception_in_chain(exception, FlashNonRetryableError) + if non_retryable is not None: + return non_retryable + + # Second pass: look for retryable errors + retryable = self._find_exception_in_chain(exception, FlashRetryableError) + if retryable is not None: + return retryable + + # CancelledError is a special case that should be treated as non-retryable + if isinstance(exception, CancelledError): + return FlashNonRetryableError("Operation cancelled") + + # Unknown exception - log full stack trace and wrap as retryable + self.logger.exception( + f"Unknown exception encountered during flash operation, treating as retryable: " + f"{type(exception).__name__}: {exception}" + ) + wrapped_exception = FlashRetryableError(f"Unknown error occurred: {type(exception).__name__}: {exception}") + wrapped_exception.__cause__ = exception + return wrapped_exception - # Check the cause chain - current = getattr(exception, '__cause__', None) - while current is not None: - if isinstance(current, FlashRetryableError): - return current - # Also check if the cause is an ExceptionGroup - if hasattr(current, 'exceptions'): - for sub_exc in current.exceptions: - result = self._get_retryable_error(sub_exc) - if result is not None: - return result - current = getattr(current, '__cause__', None) - return None + def _find_exception_in_chain(self, exception: Exception, target_type: type) -> Exception | None: + """Find an exception of a specific type in an exception chain. - def _get_non_retryable_error(self, exception: Exception) -> FlashNonRetryableError | None: - """Find a non-retryable error in an exception (or any of its causes). + Searches through the exception, its ExceptionGroup members (if any), + and the cause chain recursively. Args: - exception: The exception to check + exception: The exception to search + target_type: The exception type to search for Returns: - The FlashNonRetryableError if found, None otherwise + The found exception instance if found, None otherwise """ # Check if this is an ExceptionGroup and look through its exceptions if hasattr(exception, 'exceptions'): for sub_exc in exception.exceptions: - result = self._get_non_retryable_error(sub_exc) + result = self._find_exception_in_chain(sub_exc, target_type) if result is not None: return result # Check the current exception - if isinstance(exception, FlashNonRetryableError): + if isinstance(exception, target_type): return exception # Check the cause chain current = getattr(exception, '__cause__', None) while current is not None: - if isinstance(current, FlashNonRetryableError): + if isinstance(current, target_type): return current # Also check if the cause is an ExceptionGroup if hasattr(current, 'exceptions'): for sub_exc in current.exceptions: - result = self._get_non_retryable_error(sub_exc) + result = self._find_exception_in_chain(sub_exc, target_type) if result is not None: return result current = getattr(current, '__cause__', None) @@ -1156,7 +1163,7 @@ def base(): @click.option( "--fls-version", type=str, - default="0.1.5", # TODO(majopela): set default to "" once fls is included in our images + default="0.1.9", # TODO(majopela): set default to "" once fls is included in our images help="Download an specific fls version from the github releases", ) @debug_console_option diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py index 67282b55d..104642760 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py @@ -1,7 +1,9 @@ +from concurrent.futures import CancelledError + import click import pytest -from .client import BaseFlasherClient +from .client import BaseFlasherClient, FlashNonRetryableError, FlashRetryableError from jumpstarter.common.exceptions import ArgumentError @@ -12,7 +14,14 @@ def __init__(self): self._manifest = None self._console_debug = False self.logger = type( - "MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None} + "MockLogger", + (), + { + "warning": lambda *args, **kwargs: None, + "info": lambda *args, **kwargs: None, + "error": lambda *args, **kwargs: None, + "exception": lambda *args, **kwargs: None, + }, )() def close(self): @@ -49,3 +58,143 @@ def test_flash_fails_with_invalid_headers(): with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"): client.flash("test.raw", headers={"Invalid Header": "value"}) + + +def test_categorize_exception_returns_non_retryable_when_present(): + """Test that non-retryable errors take priority""" + client = MockFlasherClient() + + # Direct non-retryable error + error = FlashNonRetryableError("Config error") + result = client._categorize_exception(error) + assert isinstance(result, FlashNonRetryableError) + assert str(result) == "Config error" + + +def test_categorize_exception_returns_retryable_when_present(): + """Test that retryable errors are returned""" + client = MockFlasherClient() + + # Direct retryable error + error = FlashRetryableError("Network timeout") + result = client._categorize_exception(error) + assert isinstance(result, FlashRetryableError) + assert str(result) == "Network timeout" + + +def test_categorize_exception_wraps_unknown_exceptions(): + """Test that unknown exceptions are wrapped as retryable""" + client = MockFlasherClient() + + # Unknown exception type + error = ValueError("Something went wrong") + result = client._categorize_exception(error) + assert isinstance(result, FlashRetryableError) + assert "ValueError" in str(result) + assert "Something went wrong" in str(result) + # Verify the cause chain is preserved + assert result.__cause__ is error + + +def test_categorize_exception_cancelled_error_is_non_retryable(): + """Test that CancelledError is treated as non-retryable""" + client = MockFlasherClient() + + # CancelledError should be treated as non-retryable + error = CancelledError() + result = client._categorize_exception(error) + assert isinstance(result, FlashNonRetryableError) + assert "Operation cancelled" in str(result) + + +def test_categorize_exception_searches_cause_chain(): + """Test that categorization searches through the cause chain""" + client = MockFlasherClient() + + # Create a chain: generic -> generic -> retryable + root = FlashRetryableError("Root cause") + middle = ValueError("Middle error") + middle.__cause__ = root + top = RuntimeError("Top error") + top.__cause__ = middle + + result = client._categorize_exception(top) + assert isinstance(result, FlashRetryableError) + assert str(result) == "Root cause" + + +def test_find_exception_in_chain_finds_target_type(): + """Test that _find_exception_in_chain correctly finds the target type""" + client = MockFlasherClient() + + # Create a chain with retryable error + retryable = FlashRetryableError("Network error") + generic = RuntimeError("Generic error") + generic.__cause__ = retryable + + result = client._find_exception_in_chain(generic, FlashRetryableError) + assert result is retryable + assert str(result) == "Network error" + + +def test_find_exception_in_chain_returns_none_when_not_found(): + """Test that _find_exception_in_chain returns None when target not found""" + client = MockFlasherClient() + + error = ValueError("Some error") + result = client._find_exception_in_chain(error, FlashRetryableError) + assert result is None + + +def test_find_exception_in_chain_handles_exception_groups(): + """Test that _find_exception_in_chain searches through ExceptionGroups""" + client = MockFlasherClient() + + # Create an ExceptionGroup with a retryable error + retryable = FlashRetryableError("Network timeout") + generic = ValueError("Generic error") + + # Mock an ExceptionGroup (Python 3.11+) + class MockExceptionGroup(Exception): + def __init__(self, message, exceptions): + super().__init__(message) + self.exceptions = exceptions + + group = MockExceptionGroup("Multiple errors", [generic, retryable]) + + result = client._find_exception_in_chain(group, FlashRetryableError) + assert result is retryable + + +def test_categorize_exception_with_nested_exception_groups(): + """Test categorization with nested ExceptionGroups""" + client = MockFlasherClient() + + # Create nested ExceptionGroups + non_retryable = FlashNonRetryableError("Config error") + + class MockExceptionGroup(Exception): + def __init__(self, message, exceptions): + super().__init__(message) + self.exceptions = exceptions + + inner_group = MockExceptionGroup("Inner errors", [non_retryable]) + outer_group = MockExceptionGroup("Outer errors", [ValueError("Other"), inner_group]) + + result = client._categorize_exception(outer_group) + assert isinstance(result, FlashNonRetryableError) + assert str(result) == "Config error" + + +def test_categorize_exception_preserves_cause_for_wrapped_exceptions(): + """Test that wrapped unknown exceptions preserve the cause chain""" + client = MockFlasherClient() + + original = IOError("File not found") + result = client._categorize_exception(original) + + assert isinstance(result, FlashRetryableError) + assert result.__cause__ is original + # IOError is an alias for OSError in Python 3 + assert "OSError" in str(result) or "IOError" in str(result) + assert "File not found" in str(result)