diff --git a/python_modules/dagster/dagster/_core/errors.py b/python_modules/dagster/dagster/_core/errors.py index e2da56d5eb05d..5ea13ffe99314 100644 --- a/python_modules/dagster/dagster/_core/errors.py +++ b/python_modules/dagster/dagster/_core/errors.py @@ -280,8 +280,9 @@ def user_code_error_boundary( """ check.callable_param(msg_fn, "msg_fn") check.class_param(error_cls, "error_cls", superclass=DagsterUserCodeExecutionError) + from dagster._utils.error import redact_user_stacktrace_if_enabled - with raise_execution_interrupts(): + with redact_user_stacktrace_if_enabled(), raise_execution_interrupts(): if log_manager: log_manager.begin_python_log_capture() try: diff --git a/python_modules/dagster/dagster/_core/execution/plan/utils.py b/python_modules/dagster/dagster/_core/execution/plan/utils.py index eaeb65a9e03c5..9d0869a6817aa 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/utils.py +++ b/python_modules/dagster/dagster/_core/execution/plan/utils.py @@ -42,12 +42,13 @@ def op_execution_error_boundary( as respecting the RetryPolicy if present. """ from dagster._core.execution.context.system import StepExecutionContext + from dagster._utils.error import redact_user_stacktrace_if_enabled check.callable_param(msg_fn, "msg_fn") check.class_param(error_cls, "error_cls", superclass=DagsterUserCodeExecutionError) check.inst_param(step_context, "step_context", StepExecutionContext) - with raise_execution_interrupts(): + with redact_user_stacktrace_if_enabled(), raise_execution_interrupts(): step_context.log.begin_python_log_capture() retry_policy = step_context.op_retry_policy diff --git a/python_modules/dagster/dagster/_utils/error.py b/python_modules/dagster/dagster/_utils/error.py index 291bce2f0b3bd..280fb5aaaf9ad 100644 --- a/python_modules/dagster/dagster/_utils/error.py +++ b/python_modules/dagster/dagster/_utils/error.py @@ -1,14 +1,18 @@ +import contextlib import logging import os +import sys import traceback import uuid -from collections.abc import Sequence +from collections.abc import Mapping, Sequence +from contextvars import ContextVar from types import TracebackType from typing import Any, NamedTuple, Optional, Union from typing_extensions import TypeAlias import dagster._check as check +from dagster._core.errors import DagsterUserCodeExecutionError from dagster._serdes import whitelist_for_serdes @@ -93,6 +97,45 @@ def _should_redact_user_code_error() -> bool: "DAGSTER_REDACTED_ERROR_LOGGER_NAME", "dagster.redacted_errors" ) +error_id_by_exception: ContextVar[Mapping[int, str]] = ContextVar( + "error_id_by_exception", default={} +) + + +@contextlib.contextmanager +def redact_user_stacktrace_if_enabled(): + """Context manager which, if a user has enabled redacting user code errors, logs exceptions raised from within, + and clears the stacktrace from the exception. It also marks the exception to be redacted if it was to be persisted + or otherwise serialized to be sent to Dagster Plus. This is useful for preventing sensitive information from + being leaked in error messages. + """ + if not _should_redact_user_code_error(): + yield + else: + try: + yield + except BaseException as e: + exc_info = sys.exc_info() + + # Generate a unique error ID for this error, or re-use an existing one + # if this error has already been seen + existing_error_id = error_id_by_exception.get().get(id(e)) + + if not existing_error_id: + error_id = str(uuid.uuid4()) + + # Track the error ID for this exception so we can redact it later + error_id_by_exception.set({**error_id_by_exception.get(), id(e): error_id}) + masked_logger = logging.getLogger(_REDACTED_ERROR_LOGGER_NAME) + + masked_logger.error( + f"Error occurred during user code execution, error ID {error_id}", + exc_info=exc_info, + ) + + # Redact the stacktrace to ensure it will not be passed to Dagster Plus + raise e.with_traceback(None) from None + def serializable_error_info_from_exc_info( exc_info: ExceptionInfo, @@ -116,27 +159,37 @@ def serializable_error_info_from_exc_info( e = check.not_none(e, additional_message=additional_message) tb = check.not_none(tb, additional_message=additional_message) - from dagster._core.errors import DagsterUserCodeExecutionError, DagsterUserCodeProcessError - - if isinstance(e, DagsterUserCodeExecutionError) and _should_redact_user_code_error(): - error_id = str(uuid.uuid4()) - masked_logger = logging.getLogger(_REDACTED_ERROR_LOGGER_NAME) - - masked_logger.error( - f"Error occurred during user code execution, error ID {error_id}", - exc_info=exc_info, - ) - return SerializableErrorInfo( - message=( - f"Error occurred during user code execution, error ID {error_id}. " - "The error has been masked to prevent leaking sensitive information. " - "Search in logs for this error ID for more details." - ), - stack=[], - cls_name="DagsterRedactedUserCodeError", - cause=None, - context=None, - ) + from dagster._core.errors import DagsterUserCodeProcessError + + err_id = error_id_by_exception.get().get(id(e)) + if err_id: + if isinstance(e, DagsterUserCodeExecutionError): + return SerializableErrorInfo( + message=( + f"Error occurred during user code execution, error ID {err_id}. " + "The error has been masked to prevent leaking sensitive information. " + "Search in logs for this error ID for more details." + ), + stack=[], + cls_name="DagsterRedactedUserCodeError", + cause=None, + context=None, + ) + else: + tb_exc = traceback.TracebackException(exc_type, e, tb) + error_info = _serializable_error_info_from_tb(tb_exc) + return SerializableErrorInfo( + message=error_info.message + + ( + f"Error ID {err_id}. " + "The error has been masked to prevent leaking sensitive information. " + "Search in logs for this error ID for more details." + ), + stack=[], + cls_name=error_info.cls_name, + cause=None, + context=None, + ) if ( hoist_user_code_error diff --git a/python_modules/dagster/dagster_tests/core_tests/test_mask_user_code_errors.py b/python_modules/dagster/dagster_tests/core_tests/test_mask_user_code_errors.py index f51fa894a6149..a7fe9d16acf1a 100644 --- a/python_modules/dagster/dagster_tests/core_tests/test_mask_user_code_errors.py +++ b/python_modules/dagster/dagster_tests/core_tests/test_mask_user_code_errors.py @@ -2,12 +2,18 @@ import sys import time import traceback -from typing import Any +from typing import Any, Callable import pytest from dagster import Config, RunConfig, config_mapping, job, op +from dagster._core.definitions.events import Failure from dagster._core.definitions.timestamp import TimestampWithTimezone -from dagster._core.errors import DagsterUserCodeProcessError +from dagster._core.errors import ( + DagsterExecutionInterruptedError, + DagsterUserCodeExecutionError, + DagsterUserCodeProcessError, + user_code_error_boundary, +) from dagster._core.test_utils import environ, instance_for_test from dagster._utils.error import ( _serializable_error_info_from_tb, @@ -30,16 +36,107 @@ def __init__(self): ) +class hunter2: + pass + + @pytest.fixture(scope="function") def enable_masking_user_code_errors() -> Any: with environ({"DAGSTER_REDACT_USER_CODE_ERRORS": "1"}): yield -def test_masking_op_execution(enable_masking_user_code_errors) -> Any: +def test_masking_basic(enable_masking_user_code_errors): + try: + with user_code_error_boundary( + error_cls=DagsterUserCodeExecutionError, + msg_fn=lambda: "hunter2", + ): + + def hunter2(): + raise UserError() + + hunter2() + except Exception: + exc_info = sys.exc_info() + err_info = serializable_error_info_from_exc_info(exc_info) + + assert "hunter2" not in str(err_info) + + +def test_masking_nested_user_code_err_boundaries(enable_masking_user_code_errors): + try: + with user_code_error_boundary( + error_cls=DagsterUserCodeExecutionError, + msg_fn=lambda: "hunter2 as well", + ): + with user_code_error_boundary( + error_cls=DagsterUserCodeExecutionError, + msg_fn=lambda: "hunter2", + ): + + def hunter2(): + raise UserError() + + hunter2() + except Exception: + exc_info = sys.exc_info() + err_info = serializable_error_info_from_exc_info(exc_info) + + assert "hunter2" not in str(err_info) + + +def test_masking_nested_user_code_err_boundaries_reraise(enable_masking_user_code_errors): + try: + try: + with user_code_error_boundary( + error_cls=DagsterUserCodeExecutionError, + msg_fn=lambda: "hunter2", + ): + + def hunter2(): + raise UserError() + + hunter2() + except Exception as e: + # Mimics behavior of resource teardown, which runs in a + # user_code_error_boundary after the user code raises an error + with user_code_error_boundary( + error_cls=DagsterUserCodeExecutionError, + msg_fn=lambda: "teardown after we raised hunter2 error", + ): + # do teardown stuff + raise e + + except Exception: + exc_info = sys.exc_info() + err_info = serializable_error_info_from_exc_info(exc_info) + + assert "hunter2" not in str(err_info) + + +@pytest.mark.parametrize( + "exc_name, expect_exc_name_in_error, build_exc", + [ + ("UserError", False, lambda: UserError()), + ("TypeError", False, lambda: TypeError("hunter2")), + ("KeyboardInterrupt", True, lambda: KeyboardInterrupt()), + ("DagsterExecutionInterruptedError", True, lambda: DagsterExecutionInterruptedError()), + ("Failure", True, lambda: Failure("asdf")), + ], +) +def test_masking_op_execution( + enable_masking_user_code_errors, + exc_name: str, + expect_exc_name_in_error: bool, + build_exc: Callable[[], BaseException], +) -> Any: @op def throws_user_error(_): - raise UserError() + def hunter2(): + raise build_exc() + + hunter2() @job def job_def(): @@ -47,12 +144,24 @@ def job_def(): result = job_def.execute_in_process(raise_on_error=False) assert not result.success - assert not any("hunter2" in str(event) for event in result.all_events) + + # Ensure error message and contents of user code don't leak (e.g. hunter2 text or function name) + assert not any("hunter2" in str(event).lower() for event in result.all_events), [ + str(event) for event in result.all_events if "hunter2" in str(event) + ] + step_error = next(event for event in result.all_events if event.is_step_failure) - assert ( - step_error.step_failure_data.error - and step_error.step_failure_data.error.cls_name == "DagsterRedactedUserCodeError" - ) + + if expect_exc_name_in_error: + assert ( + step_error.step_failure_data.error + and step_error.step_failure_data.error.cls_name == exc_name + ) + else: + assert ( + step_error.step_failure_data.error + and step_error.step_failure_data.error.cls_name == "DagsterRedactedUserCodeError" + ) ERROR_ID_REGEX = r"Error occurred during user code execution, error ID ([a-z0-9\-]+)"