Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import threading
import time
from concurrent.futures import CancelledError
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path, PosixPath
Expand Down Expand Up @@ -177,14 +178,18 @@ def flash( # noqa: C901
self.logger.info(f"Flash operation succeeded on attempt {attempt + 1}")
break
except Exception as e:
# Check if this is a retryable or non-retryable error (including sub-exceptions)
retryable_error = self._get_retryable_error(e)
non_retryable_error = self._get_non_retryable_error(e)
# Categorize the exception as retryable or non-retryable
categorized_error = self._categorize_exception(e)

if retryable_error is not None:
if isinstance(categorized_error, FlashNonRetryableError):
# Non-retryable error, fail immediately
self.logger.error(f"Flash operation failed with non-retryable error: {categorized_error}")
raise FlashError(f"Flash operation failed: {categorized_error}") from e
else:
# Retryable error
if attempt < retries:
self.logger.warning(
f"Flash attempt {attempt + 1} failed with retryable error: {retryable_error}"
f"Flash attempt {attempt + 1} failed with retryable error: {categorized_error}"
)
self.logger.info(f"Retrying flash operation (attempt {attempt + 2}/{retries + 1})")
# Wait a bit before retrying
Expand All @@ -193,86 +198,88 @@ def flash( # noqa: C901
else:
self.logger.error(f"Flash operation failed after {retries + 1} attempts")
raise FlashError(
f"Flash operation failed after {retries + 1} attempts. Last error: {retryable_error}"
f"Flash operation failed after {retries + 1} attempts. Last error: {categorized_error}"
) from e
elif non_retryable_error is not None:
# Non-retryable error, fail immediately
self.logger.error(f"Flash operation failed with non-retryable error: {non_retryable_error}")
raise FlashError(f"Flash operation failed: {non_retryable_error}") from e
else:
# Unexpected error, don't retry
self.logger.error(f"Flash operation failed with unexpected error: {e}")
raise FlashError(f"Flash operation failed: {e}") from e


total_time = time.time() - start_time
# total time in minutes:seconds
minutes, seconds = divmod(total_time, 60)
self.logger.info(f"Flashing completed in {int(minutes)}m {int(seconds):02d}s")

def _get_retryable_error(self, exception: Exception) -> FlashRetryableError | None:
"""Find a retryable error in an exception (or any of its causes).
def _categorize_exception(self, exception: Exception) -> FlashRetryableError | FlashNonRetryableError:
"""Categorize an exception as retryable or non-retryable.

This method searches through the exception chain (including ExceptionGroups)
to find FlashRetryableError or FlashNonRetryableError instances.

Priority:
1. FlashNonRetryableError - highest priority, fail immediately
2. FlashRetryableError - retry with backoff
3. Unknown exceptions - log full stack trace and treat as retryable

Args:
exception: The exception to check
exception: The exception to categorize

Returns:
The FlashRetryableError if found, None otherwise
FlashRetryableError or FlashNonRetryableError
"""
# Check if this is an ExceptionGroup and look through its exceptions
if hasattr(exception, 'exceptions'):
for sub_exc in exception.exceptions:
result = self._get_retryable_error(sub_exc)
if result is not None:
return result

# Check the current exception
if isinstance(exception, FlashRetryableError):
return exception
# First pass: look for non-retryable errors (highest priority)
non_retryable = self._find_exception_in_chain(exception, FlashNonRetryableError)
if non_retryable is not None:
return non_retryable

# Second pass: look for retryable errors
retryable = self._find_exception_in_chain(exception, FlashRetryableError)
if retryable is not None:
return retryable

# CancelledError is a special case that should be treated as non-retryable
if isinstance(exception, CancelledError):
return FlashNonRetryableError("Operation cancelled")

# Unknown exception - log full stack trace and wrap as retryable
self.logger.exception(
f"Unknown exception encountered during flash operation, treating as retryable: "
f"{type(exception).__name__}: {exception}"
)
wrapped_exception = FlashRetryableError(f"Unknown error occurred: {type(exception).__name__}: {exception}")
wrapped_exception.__cause__ = exception
return wrapped_exception

# Check the cause chain
current = getattr(exception, '__cause__', None)
while current is not None:
if isinstance(current, FlashRetryableError):
return current
# Also check if the cause is an ExceptionGroup
if hasattr(current, 'exceptions'):
for sub_exc in current.exceptions:
result = self._get_retryable_error(sub_exc)
if result is not None:
return result
current = getattr(current, '__cause__', None)
return None
def _find_exception_in_chain(self, exception: Exception, target_type: type) -> Exception | None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the leaf_exceptions helper for flattening exception groups.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait I missed this... ok I will look into reusing the other implementation, hopefully they are compatible enough.

"""Find an exception of a specific type in an exception chain.

def _get_non_retryable_error(self, exception: Exception) -> FlashNonRetryableError | None:
"""Find a non-retryable error in an exception (or any of its causes).
Searches through the exception, its ExceptionGroup members (if any),
and the cause chain recursively.

Args:
exception: The exception to check
exception: The exception to search
target_type: The exception type to search for

Returns:
The FlashNonRetryableError if found, None otherwise
The found exception instance if found, None otherwise
"""
# Check if this is an ExceptionGroup and look through its exceptions
if hasattr(exception, 'exceptions'):
for sub_exc in exception.exceptions:
result = self._get_non_retryable_error(sub_exc)
result = self._find_exception_in_chain(sub_exc, target_type)
if result is not None:
return result

# Check the current exception
if isinstance(exception, FlashNonRetryableError):
if isinstance(exception, target_type):
return exception

# Check the cause chain
current = getattr(exception, '__cause__', None)
while current is not None:
if isinstance(current, FlashNonRetryableError):
if isinstance(current, target_type):
return current
# Also check if the cause is an ExceptionGroup
if hasattr(current, 'exceptions'):
for sub_exc in current.exceptions:
result = self._get_non_retryable_error(sub_exc)
result = self._find_exception_in_chain(sub_exc, target_type)
if result is not None:
return result
current = getattr(current, '__cause__', None)
Expand Down Expand Up @@ -1156,7 +1163,7 @@ def base():
@click.option(
"--fls-version",
type=str,
default="0.1.5", # TODO(majopela): set default to "" once fls is included in our images
default="0.1.9", # TODO(majopela): set default to "" once fls is included in our images
help="Download an specific fls version from the github releases",
)
@debug_console_option
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from concurrent.futures import CancelledError

import click
import pytest

from .client import BaseFlasherClient
from .client import BaseFlasherClient, FlashNonRetryableError, FlashRetryableError
from jumpstarter.common.exceptions import ArgumentError


Expand All @@ -12,7 +14,14 @@ def __init__(self):
self._manifest = None
self._console_debug = False
self.logger = type(
"MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None}
"MockLogger",
(),
{
"warning": lambda *args, **kwargs: None,
"info": lambda *args, **kwargs: None,
"error": lambda *args, **kwargs: None,
"exception": lambda *args, **kwargs: None,
},
)()

def close(self):
Expand Down Expand Up @@ -49,3 +58,143 @@ def test_flash_fails_with_invalid_headers():

with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"):
client.flash("test.raw", headers={"Invalid Header": "value"})


def test_categorize_exception_returns_non_retryable_when_present():
"""Test that non-retryable errors take priority"""
client = MockFlasherClient()

# Direct non-retryable error
error = FlashNonRetryableError("Config error")
result = client._categorize_exception(error)
assert isinstance(result, FlashNonRetryableError)
assert str(result) == "Config error"


def test_categorize_exception_returns_retryable_when_present():
"""Test that retryable errors are returned"""
client = MockFlasherClient()

# Direct retryable error
error = FlashRetryableError("Network timeout")
result = client._categorize_exception(error)
assert isinstance(result, FlashRetryableError)
assert str(result) == "Network timeout"


def test_categorize_exception_wraps_unknown_exceptions():
"""Test that unknown exceptions are wrapped as retryable"""
client = MockFlasherClient()

# Unknown exception type
error = ValueError("Something went wrong")
result = client._categorize_exception(error)
assert isinstance(result, FlashRetryableError)
assert "ValueError" in str(result)
assert "Something went wrong" in str(result)
# Verify the cause chain is preserved
assert result.__cause__ is error


def test_categorize_exception_cancelled_error_is_non_retryable():
"""Test that CancelledError is treated as non-retryable"""
client = MockFlasherClient()

# CancelledError should be treated as non-retryable
error = CancelledError()
result = client._categorize_exception(error)
assert isinstance(result, FlashNonRetryableError)
assert "Operation cancelled" in str(result)


def test_categorize_exception_searches_cause_chain():
"""Test that categorization searches through the cause chain"""
client = MockFlasherClient()

# Create a chain: generic -> generic -> retryable
root = FlashRetryableError("Root cause")
middle = ValueError("Middle error")
middle.__cause__ = root
top = RuntimeError("Top error")
top.__cause__ = middle

result = client._categorize_exception(top)
assert isinstance(result, FlashRetryableError)
assert str(result) == "Root cause"


def test_find_exception_in_chain_finds_target_type():
"""Test that _find_exception_in_chain correctly finds the target type"""
client = MockFlasherClient()

# Create a chain with retryable error
retryable = FlashRetryableError("Network error")
generic = RuntimeError("Generic error")
generic.__cause__ = retryable

result = client._find_exception_in_chain(generic, FlashRetryableError)
assert result is retryable
assert str(result) == "Network error"


def test_find_exception_in_chain_returns_none_when_not_found():
"""Test that _find_exception_in_chain returns None when target not found"""
client = MockFlasherClient()

error = ValueError("Some error")
result = client._find_exception_in_chain(error, FlashRetryableError)
assert result is None


def test_find_exception_in_chain_handles_exception_groups():
"""Test that _find_exception_in_chain searches through ExceptionGroups"""
client = MockFlasherClient()

# Create an ExceptionGroup with a retryable error
retryable = FlashRetryableError("Network timeout")
generic = ValueError("Generic error")

# Mock an ExceptionGroup (Python 3.11+)
class MockExceptionGroup(Exception):
def __init__(self, message, exceptions):
super().__init__(message)
self.exceptions = exceptions

group = MockExceptionGroup("Multiple errors", [generic, retryable])

result = client._find_exception_in_chain(group, FlashRetryableError)
assert result is retryable


def test_categorize_exception_with_nested_exception_groups():
"""Test categorization with nested ExceptionGroups"""
client = MockFlasherClient()

# Create nested ExceptionGroups
non_retryable = FlashNonRetryableError("Config error")

class MockExceptionGroup(Exception):
def __init__(self, message, exceptions):
super().__init__(message)
self.exceptions = exceptions

inner_group = MockExceptionGroup("Inner errors", [non_retryable])
outer_group = MockExceptionGroup("Outer errors", [ValueError("Other"), inner_group])

result = client._categorize_exception(outer_group)
assert isinstance(result, FlashNonRetryableError)
assert str(result) == "Config error"


def test_categorize_exception_preserves_cause_for_wrapped_exceptions():
"""Test that wrapped unknown exceptions preserve the cause chain"""
client = MockFlasherClient()

original = IOError("File not found")
result = client._categorize_exception(original)

assert isinstance(result, FlashRetryableError)
assert result.__cause__ is original
# IOError is an alias for OSError in Python 3
assert "OSError" in str(result) or "IOError" in str(result)
assert "File not found" in str(result)
Loading