diff --git a/python_modules/dagster/dagster/_core/storage/cloud_storage_compute_log_manager.py b/python_modules/dagster/dagster/_core/storage/cloud_storage_compute_log_manager.py index 94cfbc1706f5c..b640c3940d540 100644 --- a/python_modules/dagster/dagster/_core/storage/cloud_storage_compute_log_manager.py +++ b/python_modules/dagster/dagster/_core/storage/cloud_storage_compute_log_manager.py @@ -1,5 +1,6 @@ import json import os +import sys import threading import time from abc import abstractmethod @@ -19,6 +20,7 @@ IO_TYPE_EXTENSION, LocalComputeLogManager, ) +from dagster._utils.error import serializable_error_info_from_exc_info SUBSCRIPTION_POLLING_INTERVAL = 5 @@ -87,6 +89,12 @@ def open_log_stream( def _on_capture_complete(self, log_key: Sequence[str]): self.upload_to_cloud_storage(log_key, ComputeIOType.STDOUT) self.upload_to_cloud_storage(log_key, ComputeIOType.STDERR) + try: + self.local_manager.delete_logs(log_key=log_key) + except Exception: + sys.stderr.write( + f"Exception deleting local logs after capture complete: {serializable_error_info_from_exc_info(sys.exc_info())}\n" + ) def is_capture_complete(self, log_key: Sequence[str]) -> bool: if self.local_manager.is_capture_complete(log_key): diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py index d5d4ee55ca5f1..bce5fbad87e6e 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py @@ -69,6 +69,13 @@ def simple(): event = capture_events[0] file_key = event.logs_captured_data.file_key log_key = manager.build_log_key_for_run(result.run_id, file_key) + + # verify locally cached logs are deleted after they are captured + local_path = manager._local_manager.get_captured_local_path( # noqa: SLF001 + log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT] + ) + assert not os.path.exists(local_path) + log_data = manager.get_log_data(log_key) stdout = log_data.stdout.decode("utf-8") # pyright: ignore[reportOptionalMemberAccess] assert stdout == HELLO_WORLD + SEPARATOR @@ -85,16 +92,11 @@ def simple(): for expected in EXPECTED_LOGS: assert expected in stderr_s3 - # Check download behavior by deleting locally cached logs - local_dir = os.path.dirname( - manager._local_manager.get_captured_local_path( # noqa: SLF001 - log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT] - ) - ) - for filename in os.listdir(local_dir): - os.unlink(os.path.join(local_dir, filename)) - log_data = manager.get_log_data(log_key) + + # Re-downloads the data to the local filesystem again + assert os.path.exists(local_path) + stdout = log_data.stdout.decode("utf-8") # pyright: ignore[reportOptionalMemberAccess] assert stdout == HELLO_WORLD + SEPARATOR