Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More aggressively mask user code errors when masking enabled #27183

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python_modules/dagster/dagster/_core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,9 @@ def user_code_error_boundary(
"""
check.callable_param(msg_fn, "msg_fn")
check.class_param(error_cls, "error_cls", superclass=DagsterUserCodeExecutionError)
from dagster._utils.error import redact_user_stacktrace_if_enabled

with raise_execution_interrupts():
with redact_user_stacktrace_if_enabled(), raise_execution_interrupts():
if log_manager:
log_manager.begin_python_log_capture()
try:
Expand Down
3 changes: 2 additions & 1 deletion python_modules/dagster/dagster/_core/execution/plan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def op_execution_error_boundary(
as respecting the RetryPolicy if present.
"""
from dagster._core.execution.context.system import StepExecutionContext
from dagster._utils.error import redact_user_stacktrace_if_enabled

check.callable_param(msg_fn, "msg_fn")
check.class_param(error_cls, "error_cls", superclass=DagsterUserCodeExecutionError)
check.inst_param(step_context, "step_context", StepExecutionContext)

with raise_execution_interrupts():
with redact_user_stacktrace_if_enabled(), raise_execution_interrupts():
step_context.log.begin_python_log_capture()
retry_policy = step_context.op_retry_policy

Expand Down
97 changes: 75 additions & 22 deletions python_modules/dagster/dagster/_utils/error.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import contextlib
import logging
import os
import sys
import traceback
import uuid
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from contextvars import ContextVar
from types import TracebackType
from typing import Any, NamedTuple, Optional, Union

from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.errors import DagsterUserCodeExecutionError
from dagster._serdes import whitelist_for_serdes


Expand Down Expand Up @@ -93,6 +97,45 @@ def _should_redact_user_code_error() -> bool:
"DAGSTER_REDACTED_ERROR_LOGGER_NAME", "dagster.redacted_errors"
)

error_id_by_exception: ContextVar[Mapping[int, str]] = ContextVar(
"error_id_by_exception", default={}
)
Comment on lines +100 to +102
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name here should probably have 'redacted' in it: redacted_user_code_error_id_by_exception?



@contextlib.contextmanager
def redact_user_stacktrace_if_enabled():
"""Context manager which, if a user has enabled redacting user code errors, logs exceptions raised from within,
and clears the stacktrace from the exception. It also marks the exception to be redacted if it was to be persisted
or otherwise serialized to be sent to Dagster Plus. This is useful for preventing sensitive information from
being leaked in error messages.
"""
if not _should_redact_user_code_error():
yield
else:
try:
yield
except BaseException as e:
exc_info = sys.exc_info()

# Generate a unique error ID for this error, or re-use an existing one
# if this error has already been seen
existing_error_id = error_id_by_exception.get().get(id(e))

if not existing_error_id:
error_id = str(uuid.uuid4())

# Track the error ID for this exception so we can redact it later
error_id_by_exception.set({**error_id_by_exception.get(), id(e): error_id})
Comment on lines +100 to +128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the way you are using the context var here is equivalent to just having a process global dict. What exactly is the intention here and do any of the existing tests validate that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe the goal is to increase the set of cases that this logic can handle (all kinds of exceptions besides DagsterUserCodeExecutionError can be emitted within user code, like KeyboardInterrupt or SystemExit or other DagsterError subclasses) while still only triggering the redaction if the exception was actually raised within a op_execution_error_boundary or user_code_error_boundary. I don't have a strong opinion about global dict vs. contextvar

masked_logger = logging.getLogger(_REDACTED_ERROR_LOGGER_NAME)

masked_logger.error(
f"Error occurred during user code execution, error ID {error_id}",
exc_info=exc_info,
)

# Redact the stacktrace to ensure it will not be passed to Dagster Plus
raise e.with_traceback(None) from None


def serializable_error_info_from_exc_info(
exc_info: ExceptionInfo,
Expand All @@ -116,27 +159,37 @@ def serializable_error_info_from_exc_info(
e = check.not_none(e, additional_message=additional_message)
tb = check.not_none(tb, additional_message=additional_message)

from dagster._core.errors import DagsterUserCodeExecutionError, DagsterUserCodeProcessError

if isinstance(e, DagsterUserCodeExecutionError) and _should_redact_user_code_error():
error_id = str(uuid.uuid4())
masked_logger = logging.getLogger(_REDACTED_ERROR_LOGGER_NAME)

masked_logger.error(
f"Error occurred during user code execution, error ID {error_id}",
exc_info=exc_info,
)
return SerializableErrorInfo(
message=(
f"Error occurred during user code execution, error ID {error_id}. "
"The error has been masked to prevent leaking sensitive information. "
"Search in logs for this error ID for more details."
),
stack=[],
cls_name="DagsterRedactedUserCodeError",
cause=None,
context=None,
)
from dagster._core.errors import DagsterUserCodeProcessError

err_id = error_id_by_exception.get().get(id(e))
if err_id:
if isinstance(e, DagsterUserCodeExecutionError):
return SerializableErrorInfo(
message=(
f"Error occurred during user code execution, error ID {err_id}. "
"The error has been masked to prevent leaking sensitive information. "
"Search in logs for this error ID for more details."
),
stack=[],
cls_name="DagsterRedactedUserCodeError",
cause=None,
context=None,
)
else:
tb_exc = traceback.TracebackException(exc_type, e, tb)
Comment on lines +166 to +179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe worth explaining the difference between these two cases - with user code errors, you don't even want to show the message - but with other errors (framework errors or interrupts or Failure / RetryRequested raised within the error boundary), the message is not sensitive and can be displayed for clarity, but the traceback is.

error_info = _serializable_error_info_from_tb(tb_exc)
return SerializableErrorInfo(
message=error_info.message
+ (
f"Error ID {err_id}. "
"The error has been masked to prevent leaking sensitive information. "
"Search in logs for this error ID for more details."
),
stack=[],
cls_name=error_info.cls_name,
cause=None,
context=None,
)

if (
hoist_user_code_error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
import sys
import time
import traceback
from typing import Any
from typing import Any, Callable

import pytest
from dagster import Config, RunConfig, config_mapping, job, op
from dagster._core.definitions.events import Failure
from dagster._core.definitions.timestamp import TimestampWithTimezone
from dagster._core.errors import DagsterUserCodeProcessError
from dagster._core.errors import (
DagsterExecutionInterruptedError,
DagsterUserCodeExecutionError,
DagsterUserCodeProcessError,
user_code_error_boundary,
)
from dagster._core.test_utils import environ, instance_for_test
from dagster._utils.error import (
_serializable_error_info_from_tb,
Expand All @@ -30,29 +36,132 @@ def __init__(self):
)


class hunter2:
pass


@pytest.fixture(scope="function")
def enable_masking_user_code_errors() -> Any:
with environ({"DAGSTER_REDACT_USER_CODE_ERRORS": "1"}):
yield


def test_masking_op_execution(enable_masking_user_code_errors) -> Any:
def test_masking_basic(enable_masking_user_code_errors):
try:
with user_code_error_boundary(
error_cls=DagsterUserCodeExecutionError,
msg_fn=lambda: "hunter2",
):

def hunter2():
raise UserError()

hunter2()
except Exception:
exc_info = sys.exc_info()
err_info = serializable_error_info_from_exc_info(exc_info)

assert "hunter2" not in str(err_info)


def test_masking_nested_user_code_err_boundaries(enable_masking_user_code_errors):
try:
with user_code_error_boundary(
error_cls=DagsterUserCodeExecutionError,
msg_fn=lambda: "hunter2 as well",
):
with user_code_error_boundary(
error_cls=DagsterUserCodeExecutionError,
msg_fn=lambda: "hunter2",
):

def hunter2():
raise UserError()

hunter2()
except Exception:
exc_info = sys.exc_info()
err_info = serializable_error_info_from_exc_info(exc_info)

assert "hunter2" not in str(err_info)


def test_masking_nested_user_code_err_boundaries_reraise(enable_masking_user_code_errors):
try:
try:
with user_code_error_boundary(
error_cls=DagsterUserCodeExecutionError,
msg_fn=lambda: "hunter2",
):

def hunter2():
raise UserError()

hunter2()
except Exception as e:
# Mimics behavior of resource teardown, which runs in a
# user_code_error_boundary after the user code raises an error
with user_code_error_boundary(
error_cls=DagsterUserCodeExecutionError,
msg_fn=lambda: "teardown after we raised hunter2 error",
):
# do teardown stuff
raise e

except Exception:
exc_info = sys.exc_info()
err_info = serializable_error_info_from_exc_info(exc_info)

assert "hunter2" not in str(err_info)


@pytest.mark.parametrize(
"exc_name, expect_exc_name_in_error, build_exc",
[
("UserError", False, lambda: UserError()),
("TypeError", False, lambda: TypeError("hunter2")),
("KeyboardInterrupt", True, lambda: KeyboardInterrupt()),
("DagsterExecutionInterruptedError", True, lambda: DagsterExecutionInterruptedError()),
("Failure", True, lambda: Failure("asdf")),
],
)
def test_masking_op_execution(
enable_masking_user_code_errors,
exc_name: str,
expect_exc_name_in_error: bool,
build_exc: Callable[[], BaseException],
) -> Any:
@op
def throws_user_error(_):
raise UserError()
def hunter2():
raise build_exc()

hunter2()

@job
def job_def():
throws_user_error()

result = job_def.execute_in_process(raise_on_error=False)
assert not result.success
assert not any("hunter2" in str(event) for event in result.all_events)

# Ensure error message and contents of user code don't leak (e.g. hunter2 text or function name)
assert not any("hunter2" in str(event).lower() for event in result.all_events), [
str(event) for event in result.all_events if "hunter2" in str(event)
]

step_error = next(event for event in result.all_events if event.is_step_failure)
assert (
step_error.step_failure_data.error
and step_error.step_failure_data.error.cls_name == "DagsterRedactedUserCodeError"
)

if expect_exc_name_in_error:
assert (
step_error.step_failure_data.error
and step_error.step_failure_data.error.cls_name == exc_name
)
else:
assert (
step_error.step_failure_data.error
and step_error.step_failure_data.error.cls_name == "DagsterRedactedUserCodeError"
)


ERROR_ID_REGEX = r"Error occurred during user code execution, error ID ([a-z0-9\-]+)"
Expand Down
Loading