Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit db953ec

Browse files
authored
Merge pull request #747 from jumpstarter-dev/backport-746-to-release-0.7
[Backport release-0.7] flasher: make unknown exceptions retriable
2 parents 6280cb1 + 3d4a2ed commit db953ec

File tree

2 files changed

+208
-52
lines changed

2 files changed

+208
-52
lines changed

packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import threading
77
import time
8+
from concurrent.futures import CancelledError
89
from contextlib import contextmanager
910
from dataclasses import dataclass
1011
from pathlib import Path, PosixPath
@@ -177,14 +178,18 @@ def flash( # noqa: C901
177178
self.logger.info(f"Flash operation succeeded on attempt {attempt + 1}")
178179
break
179180
except Exception as e:
180-
# Check if this is a retryable or non-retryable error (including sub-exceptions)
181-
retryable_error = self._get_retryable_error(e)
182-
non_retryable_error = self._get_non_retryable_error(e)
181+
# Categorize the exception as retryable or non-retryable
182+
categorized_error = self._categorize_exception(e)
183183

184-
if retryable_error is not None:
184+
if isinstance(categorized_error, FlashNonRetryableError):
185+
# Non-retryable error, fail immediately
186+
self.logger.error(f"Flash operation failed with non-retryable error: {categorized_error}")
187+
raise FlashError(f"Flash operation failed: {categorized_error}") from e
188+
else:
189+
# Retryable error
185190
if attempt < retries:
186191
self.logger.warning(
187-
f"Flash attempt {attempt + 1} failed with retryable error: {retryable_error}"
192+
f"Flash attempt {attempt + 1} failed with retryable error: {categorized_error}"
188193
)
189194
self.logger.info(f"Retrying flash operation (attempt {attempt + 2}/{retries + 1})")
190195
# Wait a bit before retrying
@@ -193,86 +198,88 @@ def flash( # noqa: C901
193198
else:
194199
self.logger.error(f"Flash operation failed after {retries + 1} attempts")
195200
raise FlashError(
196-
f"Flash operation failed after {retries + 1} attempts. Last error: {retryable_error}"
201+
f"Flash operation failed after {retries + 1} attempts. Last error: {categorized_error}"
197202
) from e
198-
elif non_retryable_error is not None:
199-
# Non-retryable error, fail immediately
200-
self.logger.error(f"Flash operation failed with non-retryable error: {non_retryable_error}")
201-
raise FlashError(f"Flash operation failed: {non_retryable_error}") from e
202-
else:
203-
# Unexpected error, don't retry
204-
self.logger.error(f"Flash operation failed with unexpected error: {e}")
205-
raise FlashError(f"Flash operation failed: {e}") from e
206203

207204

208205
total_time = time.time() - start_time
209206
# total time in minutes:seconds
210207
minutes, seconds = divmod(total_time, 60)
211208
self.logger.info(f"Flashing completed in {int(minutes)}m {int(seconds):02d}s")
212209

213-
def _get_retryable_error(self, exception: Exception) -> FlashRetryableError | None:
214-
"""Find a retryable error in an exception (or any of its causes).
210+
def _categorize_exception(self, exception: Exception) -> FlashRetryableError | FlashNonRetryableError:
211+
"""Categorize an exception as retryable or non-retryable.
212+
213+
This method searches through the exception chain (including ExceptionGroups)
214+
to find FlashRetryableError or FlashNonRetryableError instances.
215+
216+
Priority:
217+
1. FlashNonRetryableError - highest priority, fail immediately
218+
2. FlashRetryableError - retry with backoff
219+
3. Unknown exceptions - log full stack trace and treat as retryable
215220
216221
Args:
217-
exception: The exception to check
222+
exception: The exception to categorize
218223
219224
Returns:
220-
The FlashRetryableError if found, None otherwise
225+
FlashRetryableError or FlashNonRetryableError
221226
"""
222-
# Check if this is an ExceptionGroup and look through its exceptions
223-
if hasattr(exception, 'exceptions'):
224-
for sub_exc in exception.exceptions:
225-
result = self._get_retryable_error(sub_exc)
226-
if result is not None:
227-
return result
228-
229-
# Check the current exception
230-
if isinstance(exception, FlashRetryableError):
231-
return exception
227+
# First pass: look for non-retryable errors (highest priority)
228+
non_retryable = self._find_exception_in_chain(exception, FlashNonRetryableError)
229+
if non_retryable is not None:
230+
return non_retryable
231+
232+
# Second pass: look for retryable errors
233+
retryable = self._find_exception_in_chain(exception, FlashRetryableError)
234+
if retryable is not None:
235+
return retryable
236+
237+
# CancelledError is a special case that should be treated as non-retryable
238+
if isinstance(exception, CancelledError):
239+
return FlashNonRetryableError("Operation cancelled")
240+
241+
# Unknown exception - log full stack trace and wrap as retryable
242+
self.logger.exception(
243+
f"Unknown exception encountered during flash operation, treating as retryable: "
244+
f"{type(exception).__name__}: {exception}"
245+
)
246+
wrapped_exception = FlashRetryableError(f"Unknown error occurred: {type(exception).__name__}: {exception}")
247+
wrapped_exception.__cause__ = exception
248+
return wrapped_exception
232249

233-
# Check the cause chain
234-
current = getattr(exception, '__cause__', None)
235-
while current is not None:
236-
if isinstance(current, FlashRetryableError):
237-
return current
238-
# Also check if the cause is an ExceptionGroup
239-
if hasattr(current, 'exceptions'):
240-
for sub_exc in current.exceptions:
241-
result = self._get_retryable_error(sub_exc)
242-
if result is not None:
243-
return result
244-
current = getattr(current, '__cause__', None)
245-
return None
250+
def _find_exception_in_chain(self, exception: Exception, target_type: type) -> Exception | None:
251+
"""Find an exception of a specific type in an exception chain.
246252
247-
def _get_non_retryable_error(self, exception: Exception) -> FlashNonRetryableError | None:
248-
"""Find a non-retryable error in an exception (or any of its causes).
253+
Searches through the exception, its ExceptionGroup members (if any),
254+
and the cause chain recursively.
249255
250256
Args:
251-
exception: The exception to check
257+
exception: The exception to search
258+
target_type: The exception type to search for
252259
253260
Returns:
254-
The FlashNonRetryableError if found, None otherwise
261+
The found exception instance if found, None otherwise
255262
"""
256263
# Check if this is an ExceptionGroup and look through its exceptions
257264
if hasattr(exception, 'exceptions'):
258265
for sub_exc in exception.exceptions:
259-
result = self._get_non_retryable_error(sub_exc)
266+
result = self._find_exception_in_chain(sub_exc, target_type)
260267
if result is not None:
261268
return result
262269

263270
# Check the current exception
264-
if isinstance(exception, FlashNonRetryableError):
271+
if isinstance(exception, target_type):
265272
return exception
266273

267274
# Check the cause chain
268275
current = getattr(exception, '__cause__', None)
269276
while current is not None:
270-
if isinstance(current, FlashNonRetryableError):
277+
if isinstance(current, target_type):
271278
return current
272279
# Also check if the cause is an ExceptionGroup
273280
if hasattr(current, 'exceptions'):
274281
for sub_exc in current.exceptions:
275-
result = self._get_non_retryable_error(sub_exc)
282+
result = self._find_exception_in_chain(sub_exc, target_type)
276283
if result is not None:
277284
return result
278285
current = getattr(current, '__cause__', None)
@@ -1156,7 +1163,7 @@ def base():
11561163
@click.option(
11571164
"--fls-version",
11581165
type=str,
1159-
default="0.1.5", # TODO(majopela): set default to "" once fls is included in our images
1166+
default="0.1.9", # TODO(majopela): set default to "" once fls is included in our images
11601167
help="Download an specific fls version from the github releases",
11611168
)
11621169
@debug_console_option

packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py

Lines changed: 151 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from concurrent.futures import CancelledError
2+
13
import click
24
import pytest
35

4-
from .client import BaseFlasherClient
6+
from .client import BaseFlasherClient, FlashNonRetryableError, FlashRetryableError
57
from jumpstarter.common.exceptions import ArgumentError
68

79

@@ -12,7 +14,14 @@ def __init__(self):
1214
self._manifest = None
1315
self._console_debug = False
1416
self.logger = type(
15-
"MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None}
17+
"MockLogger",
18+
(),
19+
{
20+
"warning": lambda *args, **kwargs: None,
21+
"info": lambda *args, **kwargs: None,
22+
"error": lambda *args, **kwargs: None,
23+
"exception": lambda *args, **kwargs: None,
24+
},
1625
)()
1726

1827
def close(self):
@@ -49,3 +58,143 @@ def test_flash_fails_with_invalid_headers():
4958

5059
with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"):
5160
client.flash("test.raw", headers={"Invalid Header": "value"})
61+
62+
63+
def test_categorize_exception_returns_non_retryable_when_present():
64+
"""Test that non-retryable errors take priority"""
65+
client = MockFlasherClient()
66+
67+
# Direct non-retryable error
68+
error = FlashNonRetryableError("Config error")
69+
result = client._categorize_exception(error)
70+
assert isinstance(result, FlashNonRetryableError)
71+
assert str(result) == "Config error"
72+
73+
74+
def test_categorize_exception_returns_retryable_when_present():
75+
"""Test that retryable errors are returned"""
76+
client = MockFlasherClient()
77+
78+
# Direct retryable error
79+
error = FlashRetryableError("Network timeout")
80+
result = client._categorize_exception(error)
81+
assert isinstance(result, FlashRetryableError)
82+
assert str(result) == "Network timeout"
83+
84+
85+
def test_categorize_exception_wraps_unknown_exceptions():
86+
"""Test that unknown exceptions are wrapped as retryable"""
87+
client = MockFlasherClient()
88+
89+
# Unknown exception type
90+
error = ValueError("Something went wrong")
91+
result = client._categorize_exception(error)
92+
assert isinstance(result, FlashRetryableError)
93+
assert "ValueError" in str(result)
94+
assert "Something went wrong" in str(result)
95+
# Verify the cause chain is preserved
96+
assert result.__cause__ is error
97+
98+
99+
def test_categorize_exception_cancelled_error_is_non_retryable():
100+
"""Test that CancelledError is treated as non-retryable"""
101+
client = MockFlasherClient()
102+
103+
# CancelledError should be treated as non-retryable
104+
error = CancelledError()
105+
result = client._categorize_exception(error)
106+
assert isinstance(result, FlashNonRetryableError)
107+
assert "Operation cancelled" in str(result)
108+
109+
110+
def test_categorize_exception_searches_cause_chain():
111+
"""Test that categorization searches through the cause chain"""
112+
client = MockFlasherClient()
113+
114+
# Create a chain: generic -> generic -> retryable
115+
root = FlashRetryableError("Root cause")
116+
middle = ValueError("Middle error")
117+
middle.__cause__ = root
118+
top = RuntimeError("Top error")
119+
top.__cause__ = middle
120+
121+
result = client._categorize_exception(top)
122+
assert isinstance(result, FlashRetryableError)
123+
assert str(result) == "Root cause"
124+
125+
126+
def test_find_exception_in_chain_finds_target_type():
127+
"""Test that _find_exception_in_chain correctly finds the target type"""
128+
client = MockFlasherClient()
129+
130+
# Create a chain with retryable error
131+
retryable = FlashRetryableError("Network error")
132+
generic = RuntimeError("Generic error")
133+
generic.__cause__ = retryable
134+
135+
result = client._find_exception_in_chain(generic, FlashRetryableError)
136+
assert result is retryable
137+
assert str(result) == "Network error"
138+
139+
140+
def test_find_exception_in_chain_returns_none_when_not_found():
141+
"""Test that _find_exception_in_chain returns None when target not found"""
142+
client = MockFlasherClient()
143+
144+
error = ValueError("Some error")
145+
result = client._find_exception_in_chain(error, FlashRetryableError)
146+
assert result is None
147+
148+
149+
def test_find_exception_in_chain_handles_exception_groups():
150+
"""Test that _find_exception_in_chain searches through ExceptionGroups"""
151+
client = MockFlasherClient()
152+
153+
# Create an ExceptionGroup with a retryable error
154+
retryable = FlashRetryableError("Network timeout")
155+
generic = ValueError("Generic error")
156+
157+
# Mock an ExceptionGroup (Python 3.11+)
158+
class MockExceptionGroup(Exception):
159+
def __init__(self, message, exceptions):
160+
super().__init__(message)
161+
self.exceptions = exceptions
162+
163+
group = MockExceptionGroup("Multiple errors", [generic, retryable])
164+
165+
result = client._find_exception_in_chain(group, FlashRetryableError)
166+
assert result is retryable
167+
168+
169+
def test_categorize_exception_with_nested_exception_groups():
170+
"""Test categorization with nested ExceptionGroups"""
171+
client = MockFlasherClient()
172+
173+
# Create nested ExceptionGroups
174+
non_retryable = FlashNonRetryableError("Config error")
175+
176+
class MockExceptionGroup(Exception):
177+
def __init__(self, message, exceptions):
178+
super().__init__(message)
179+
self.exceptions = exceptions
180+
181+
inner_group = MockExceptionGroup("Inner errors", [non_retryable])
182+
outer_group = MockExceptionGroup("Outer errors", [ValueError("Other"), inner_group])
183+
184+
result = client._categorize_exception(outer_group)
185+
assert isinstance(result, FlashNonRetryableError)
186+
assert str(result) == "Config error"
187+
188+
189+
def test_categorize_exception_preserves_cause_for_wrapped_exceptions():
190+
"""Test that wrapped unknown exceptions preserve the cause chain"""
191+
client = MockFlasherClient()
192+
193+
original = IOError("File not found")
194+
result = client._categorize_exception(original)
195+
196+
assert isinstance(result, FlashRetryableError)
197+
assert result.__cause__ is original
198+
# IOError is an alias for OSError in Python 3
199+
assert "OSError" in str(result) or "IOError" in str(result)
200+
assert "File not found" in str(result)

0 commit comments

Comments
 (0)