diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index a4f94d541d..0fd97c4008 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -274,7 +274,13 @@ def __init__( is_remote_folder = backend != '' if is_remote_folder: # If uploading to a remote path, use a temporary directory to save local checkpoints. - local_folder = os.path.join(tempfile.mkdtemp(), local_folder) + root_temp_folder = None + if os.environ.get('TMPDIR') is not None: + root_temp_folder = os.environ.get('TMPDIR') + elif os.path.exists('/local_disk0/'): # Probably we are on MLR, so we have to use /local_disk0 + root_temp_folder = '/local_disk0/temp' + os.makedirs(root_temp_folder, exist_ok=True) + local_folder = os.path.join(tempfile.mkdtemp(prefix=root_temp_folder), local_folder) filename = str(filename) remote_file_name = str(remote_file_name) if remote_file_name is not None else None