Skip to content

Commit

Permalink
Fix: Always propagate pytorch task worker process exception timestamp…
Browse files Browse the repository at this point in the history
… to task exception (#3057)

* Fix: Always propagate pytorch task worker process exception timestamp to task exception

Signed-off-by: Fabio Grätz <[email protected]>

* Fix exist recoverable error test

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
  • Loading branch information
fg91 and Fabio Grätz authored Jan 18, 2025
1 parent 665c44d commit 3260ddf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
5 changes: 3 additions & 2 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ class FlyteUserException(_FlyteException):
class FlyteUserRuntimeException(_FlyteException):
_ERROR_CODE = "USER:RuntimeError"

def __init__(self, exc_value: Exception):
def __init__(self, exc_value: Exception, timestamp: typing.Optional[float] = None):
"""
FlyteUserRuntimeException is thrown when a user code raises an exception.
:param exc_value: The exception that was raised from user code.
:param timestamp: The timestamp as fractional seconds since epoch when the exception was raised.
"""
self._exc_value = exc_value
super().__init__(str(exc_value))
super().__init__(str(exc_value), timestamp=timestamp)

@property
def value(self):
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException
from flytekit.extend import IgnoreOutputs, TaskPlugins
from flytekit.loggers import logger

Expand Down Expand Up @@ -475,7 +475,7 @@ def fn_partial():
# the automatically assigned timestamp based on exception creation time
raise FlyteRecoverableException(e.format_msg(), timestamp=first_failure.timestamp)
else:
raise RuntimeError(e.format_msg())
raise FlyteUserRuntimeException(e, timestamp=first_failure.timestamp)
except SignalException as e:
logger.exception(f"Elastic launch agent process terminating: {e}")
raise IgnoreOutputs()
Expand Down
38 changes: 36 additions & 2 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flytekit import task, workflow
from flytekit.core.context_manager import FlyteContext, FlyteContextManager, ExecutionState, ExecutionParameters, OutputMetadataTracker
from flytekit.configuration import SerializationSettings
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException

@pytest.fixture(autouse=True, scope="function")
def restore_env():
Expand Down Expand Up @@ -223,7 +223,7 @@ def wf(recoverable: bool):
with pytest.raises(FlyteRecoverableException):
wf(recoverable=recoverable)
else:
with pytest.raises(RuntimeError):
with pytest.raises(FlyteUserRuntimeException):
wf(recoverable=recoverable)


Expand Down Expand Up @@ -276,3 +276,37 @@ def test_task_omp_set():
assert os.environ["OMP_NUM_THREADS"] == "42"

test_task_omp_set()


def test_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise Exception("Test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None


def test_recoverable_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise FlyteRecoverableException("Recoverable test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None

0 comments on commit 3260ddf

Please sign in to comment.