From e0d5079f9cd5229c35c4780031bfb1335eaf0226 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Jun 2025 10:25:45 +0530 Subject: [PATCH 01/12] start implementing disk offloading in group. --- go.diff | 218 ++++++++++++++++++++++++ src/diffusers/hooks/group_offloading.py | 99 ++++++++++- 2 files changed, 313 insertions(+), 4 deletions(-) create mode 100644 go.diff diff --git a/go.diff b/go.diff new file mode 100644 index 000000000000..7640b6b78c1e --- /dev/null +++ b/go.diff @@ -0,0 +1,218 @@ +diff --git a/diffusers/hooks/offload.py b/diffusers/hooks/offload.py +--- a/diffusers/hooks/offload.py ++++ b/diffusers/hooks/offload.py +@@ -1,6 +1,10 @@ + import os +-import torch ++import torch ++from safetensors.torch import save_file, load_file + ++import os + from typing import Optional, Union + from torch import nn + from .module_group import ModuleGroup +@@ -25,6 +29,32 @@ from .hooks import HookRegistry + from .hooks import GroupOffloadingHook, LazyPrefetchGroupOffloadingHook + ++# ------------------------------------------------------------------------------- ++# Helpers for disk/NVMe offload using safetensors ++# ------------------------------------------------------------------------------- ++def _offload_tensor_to_disk_st(tensor: torch.Tensor, path: str) -> None: ++ """ ++ Serialize a tensor out to disk in safetensors format. ++ We pin the CPU copy so that non_blocking loads can overlap copy/compute. ++ """ ++ os.makedirs(os.path.dirname(path), exist_ok=True) ++ cpu_t = tensor.detach().cpu().pin_memory() ++ save_file({"0": cpu_t}, path) ++ # free the original GPU tensor immediately ++ del tensor ++ ++def _load_tensor_from_disk_st( ++ path: str, device: torch.device, non_blocking: bool ++) -> torch.Tensor: ++ """ ++ Load a tensor back in with safetensors. ++ - If non_blocking on CUDA: load to CPU pinned memory, then .to(cuda, non_blocking=True). ++ - Otherwise: direct load_file(device=...). ++ """ ++ # fast path: direct to target device ++ if not (non_blocking and device.type == "cuda"): ++ data = load_file(path, device=device) ++ return data["0"] ++ # pinned-CPU fallback for true non-blocking ++ data = load_file(path, device="cpu") ++ cpu_t = data["0"] ++ return cpu_t.to(device, non_blocking=True) ++ ++ + def apply_group_offloading( + module: torch.nn.Module, + onload_device: torch.device, +- offload_device: torch.device = torch.device("cpu"), +- offload_type: str = "block_level", ++ offload_device: torch.device = torch.device("cpu"), ++ *, ++ offload_to_disk: bool = False, ++ offload_path: Optional[str] = None, ++ offload_type: str = "block_level", + num_blocks_per_group: Optional[int] = None, + non_blocking: bool = False, + use_stream: bool = False, +@@ -37,6 +67,10 @@ def apply_group_offloading( + Example: + ```python + >>> apply_group_offloading(... ) ++ # to store params on NVMe: ++ >>> apply_group_offloading( ++ ... model, ++ ... onload_device=torch.device("cuda"), ++ ... offload_to_disk=True, ++ ... offload_path="/mnt/nvme1/offload", ++ ... offload_type="block_level", ++ ... num_blocks_per_group=1, ++ ... ) + ``` + """ + +@@ -69,6 +103,10 @@ def apply_group_offloading( + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") ++ if offload_to_disk and offload_path is None: ++ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + + _apply_group_offloading_block_level( + module=module, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + num_blocks_per_group=num_blocks_per_group, + offload_device=offload_device, + onload_device=onload_device, +@@ -79,6 +117,11 @@ def apply_group_offloading( + elif offload_type == "leaf_level": ++ if offload_to_disk and offload_path is None: ++ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + _apply_group_offloading_leaf_level( + module=module, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, +@@ -107,10 +150,16 @@ def _apply_group_offloading_block_level( + """ +- module: torch.nn.Module, +- num_blocks_per_group: int, +- offload_device: torch.device, +- onload_device: torch.device, ++ module: torch.nn.Module, ++ num_blocks_per_group: int, ++ offload_device: torch.device, ++ offload_to_disk: bool, ++ offload_path: Optional[str], ++ onload_device: torch.device, + non_blocking: bool, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, + record_stream: Optional[bool] = False, + low_cpu_mem_usage: bool = False, + ) -> None: +@@ -138,7 +187,9 @@ def _apply_group_offloading_block_level( + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + group = ModuleGroup( +- modules=current_modules, ++ modules=current_modules, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + offload_device=offload_device, + onload_device=onload_device, + offload_leader=current_modules[-1], +@@ -187,10 +238,14 @@ def _apply_group_offloading_block_level( + unmatched_group = ModuleGroup( + modules=unmatched_modules, +- offload_device=offload_device, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, ++ offload_device=offload_device, + onload_device=onload_device, + offload_leader=module, + onload_leader=module, ++ # other args omitted for brevity... + ) + + if stream is None: +@@ -216,10 +271,16 @@ def _apply_group_offloading_leaf_level( + """ +- module: torch.nn.Module, +- offload_device: torch.device, +- onload_device: torch.device, +- non_blocking: bool, ++ module: torch.nn.Module, ++ offload_device: torch.device, ++ offload_to_disk: bool, ++ offload_path: Optional[str], ++ onload_device: torch.device, ++ non_blocking: bool, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, + record_stream: Optional[bool] = False, + low_cpu_mem_usage: bool = False, + ) -> None: +@@ -229,7 +290,9 @@ def _apply_group_offloading_leaf_level( + for name, submodule in module.named_modules(): + if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): + continue +- group = ModuleGroup( ++ group = ModuleGroup( ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + modules=[submodule], + offload_device=offload_device, + onload_device=onload_device, +@@ -317,10 +380,14 @@ def _apply_group_offloading_leaf_level( + parent_module = module_dict[name] + assert getattr(parent_module, "_diffusers_hook", None) is None +- group = ModuleGroup( ++ group = ModuleGroup( ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + modules=[], + offload_device=offload_device, + onload_device=onload_device, ++ # additional args omitted for brevity... + ) + _apply_group_offloading_hook(parent_module, group, None) + +@@ -360,6 +427,38 @@ def _apply_lazy_group_offloading_hook( + registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) + + ++# ------------------------------------------------------------------------------- ++# Patch GroupOffloadingHook to use safetensors disk offload ++# ------------------------------------------------------------------------------- ++class GroupOffloadingHook: ++ def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup]): ++ self.group = group ++ self.next_group = next_group ++ # map param/buffer name -> file path ++ self.param_to_path: Dict[str,str] = {} ++ self.buffer_to_path: Dict[str,str] = {} ++ ++ def offload_parameters(self, module: nn.Module): ++ for name, param in module.named_parameters(recurse=False): ++ if self.group.offload_to_disk: ++ path = os.path.join(self.group.offload_path, f"{module.__class__.__name__}__{name}.safetensors") ++ _offload_tensor_to_disk_st(param.data, path) ++ self.param_to_path[name] = path ++ else: ++ param.data = param.data.to(self.group.offload_device, non_blocking=self.group.non_blocking) ++ ++ def onload_parameters(self, module: nn.Module): ++ for name, param in module.named_parameters(recurse=False): ++ if self.group.offload_to_disk: ++ path = self.param_to_path[name] ++ param.data = _load_tensor_from_disk_st(path, self.group.onload_device, self.group.non_blocking) ++ else: ++ param.data = param.data.to(self.group.onload_device, non_blocking=self.group.non_blocking) ++ ++ # analogous changes for buffers... ++ diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 565f8f1ff860..a31acb5a2d6f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -14,9 +14,10 @@ from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple, Union +import os import torch - +import safetensors.torch from ..utils import get_logger, is_accelerate_available from .hooks import HookRegistry, ModelHook @@ -59,6 +60,8 @@ def __init__( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, onload_self: bool = True, + offload_to_disk: bool = False, + offload_path: Optional[str] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -72,7 +75,29 @@ def __init__( self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.cpu_param_dict = self._init_cpu_param_dict() + + self.offload_to_disk = offload_to_disk + self.offload_path = offload_path + self._is_offloaded_to_disk = False + + if self.offload_to_disk: + if self.offload_path is None: + raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + self.safetensors_file_path = os.path.join(self.offload_path, f"group_{id(self)}.safetensors") + + all_tensors = [] + for module in self.modules: + all_tensors.extend(list(module.parameters())) + all_tensors.extend(list(module.buffers())) + all_tensors.extend(self.parameters) + all_tensors.extend(self.buffers) + all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates + + self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self.cpu_param_dict = {} + else: + self.cpu_param_dict = self._init_cpu_param_dict() if self.stream is None and self.record_stream: raise ValueError("`record_stream` cannot be True when `stream` is None.") @@ -124,6 +149,29 @@ def onload_(self): context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) current_stream = torch_accelerator_module.current_stream() if self.record_stream else None + if self.offload_to_disk: + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + if self.stream is not None: + # Load to CPU, pin, and async copy to device for overlapping transfer and compute + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_cpu_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) + else: + # Load directly to the target device (synchronous) + loaded_tensors = safetensors.torch.load_file( + self.safetensors_file_path, device=self.onload_device + ) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + return + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -169,6 +217,18 @@ def onload_(self): @torch.compiler.disable() def offload_(self): r"""Offloads the group of modules to the offload_device.""" + if self.offload_to_disk: + if not self._is_offloaded_to_disk: + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + tensors_to_save = { + key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() + } + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + self._is_offloaded_to_disk = True + + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + return torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) @@ -208,10 +268,13 @@ class GroupOffloadingHook(ModelHook): def __init__( self, group: ModuleGroup, - next_group: Optional[ModuleGroup] = None, + next_group: Optional[ModuleGroup] = None ) -> None: self.group = group self.next_group = next_group + # map param/buffer name -> file path + self.param_to_path: Dict[str,str] = {} + self.buffer_to_path: Dict[str,str] = {} def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -358,6 +421,8 @@ def apply_group_offloading( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", + offload_to_disk: bool = False, + offload_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -401,6 +466,11 @@ def apply_group_offloading( offload_type (`str`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". + offload_to_disk (`bool`, defaults to `False`): + If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited. + Requires `offload_path` to be set. + offload_path (`str`, *optional*): + The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -447,6 +517,9 @@ def apply_group_offloading( else: raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") + if offload_to_disk and offload_path is None: + raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + _raise_error_if_accelerate_model_or_sequential_hook_present(module) if offload_type == "block_level": @@ -458,6 +531,8 @@ def apply_group_offloading( num_blocks_per_group=num_blocks_per_group, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -468,6 +543,8 @@ def apply_group_offloading( module=module, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -481,6 +558,8 @@ def _apply_group_offloading_block_level( module: torch.nn.Module, num_blocks_per_group: int, offload_device: torch.device, + offload_to_disk: bool, + offload_path: Optional[str], onload_device: torch.device, non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, @@ -535,6 +614,8 @@ def _apply_group_offloading_block_level( modules=current_modules, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], non_blocking=non_blocking, @@ -567,6 +648,8 @@ def _apply_group_offloading_block_level( modules=unmatched_modules, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -586,6 +669,8 @@ def _apply_group_offloading_leaf_level( module: torch.nn.Module, offload_device: torch.device, onload_device: torch.device, + offload_to_disk: bool, + offload_path: Optional[str], non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, @@ -629,6 +714,8 @@ def _apply_group_offloading_leaf_level( modules=[submodule], offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=submodule, onload_leader=submodule, non_blocking=non_blocking, @@ -675,6 +762,8 @@ def _apply_group_offloading_leaf_level( onload_device=onload_device, offload_leader=parent_module, onload_leader=parent_module, + offload_to_disk=offload_to_disk, + offload_path=offload_path, parameters=parameters, buffers=buffers, non_blocking=non_blocking, @@ -693,6 +782,8 @@ def _apply_group_offloading_leaf_level( modules=[], offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=module, onload_leader=module, parameters=None, @@ -808,4 +899,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: for submodule in module.modules(): if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device - raise ValueError("Group offloading is not enabled for the provided module.") + raise ValueError("Group offloading is not enabled for the provided module.") \ No newline at end of file From 49ac665460110b9771b78801e8de5374c8db9542 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Jun 2025 10:26:00 +0530 Subject: [PATCH 02/12] delete diff file. --- go.diff | 218 -------------------------------------------------------- 1 file changed, 218 deletions(-) delete mode 100644 go.diff diff --git a/go.diff b/go.diff deleted file mode 100644 index 7640b6b78c1e..000000000000 --- a/go.diff +++ /dev/null @@ -1,218 +0,0 @@ -diff --git a/diffusers/hooks/offload.py b/diffusers/hooks/offload.py ---- a/diffusers/hooks/offload.py -+++ b/diffusers/hooks/offload.py -@@ -1,6 +1,10 @@ - import os --import torch -+import torch -+from safetensors.torch import save_file, load_file - -+import os - from typing import Optional, Union - from torch import nn - from .module_group import ModuleGroup -@@ -25,6 +29,32 @@ from .hooks import HookRegistry - from .hooks import GroupOffloadingHook, LazyPrefetchGroupOffloadingHook - -+# ------------------------------------------------------------------------------- -+# Helpers for disk/NVMe offload using safetensors -+# ------------------------------------------------------------------------------- -+def _offload_tensor_to_disk_st(tensor: torch.Tensor, path: str) -> None: -+ """ -+ Serialize a tensor out to disk in safetensors format. -+ We pin the CPU copy so that non_blocking loads can overlap copy/compute. -+ """ -+ os.makedirs(os.path.dirname(path), exist_ok=True) -+ cpu_t = tensor.detach().cpu().pin_memory() -+ save_file({"0": cpu_t}, path) -+ # free the original GPU tensor immediately -+ del tensor -+ -+def _load_tensor_from_disk_st( -+ path: str, device: torch.device, non_blocking: bool -+) -> torch.Tensor: -+ """ -+ Load a tensor back in with safetensors. -+ - If non_blocking on CUDA: load to CPU pinned memory, then .to(cuda, non_blocking=True). -+ - Otherwise: direct load_file(device=...). -+ """ -+ # fast path: direct to target device -+ if not (non_blocking and device.type == "cuda"): -+ data = load_file(path, device=device) -+ return data["0"] -+ # pinned-CPU fallback for true non-blocking -+ data = load_file(path, device="cpu") -+ cpu_t = data["0"] -+ return cpu_t.to(device, non_blocking=True) -+ -+ - def apply_group_offloading( - module: torch.nn.Module, - onload_device: torch.device, -- offload_device: torch.device = torch.device("cpu"), -- offload_type: str = "block_level", -+ offload_device: torch.device = torch.device("cpu"), -+ *, -+ offload_to_disk: bool = False, -+ offload_path: Optional[str] = None, -+ offload_type: str = "block_level", - num_blocks_per_group: Optional[int] = None, - non_blocking: bool = False, - use_stream: bool = False, -@@ -37,6 +67,10 @@ def apply_group_offloading( - Example: - ```python - >>> apply_group_offloading(... ) -+ # to store params on NVMe: -+ >>> apply_group_offloading( -+ ... model, -+ ... onload_device=torch.device("cuda"), -+ ... offload_to_disk=True, -+ ... offload_path="/mnt/nvme1/offload", -+ ... offload_type="block_level", -+ ... num_blocks_per_group=1, -+ ... ) - ``` - """ - -@@ -69,6 +103,10 @@ def apply_group_offloading( - if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") -+ if offload_to_disk and offload_path is None: -+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") - - _apply_group_offloading_block_level( - module=module, -+ offload_to_disk=offload_to_disk, -+ offload_path=offload_path, - num_blocks_per_group=num_blocks_per_group, - offload_device=offload_device, - onload_device=onload_device, -@@ -79,6 +117,11 @@ def apply_group_offloading( - elif offload_type == "leaf_level": -+ if offload_to_disk and offload_path is None: -+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") - _apply_group_offloading_leaf_level( - module=module, -+ offload_to_disk=offload_to_disk, -+ offload_path=offload_path, - offload_device=offload_device, - onload_device=onload_device, - non_blocking=non_blocking, -@@ -107,10 +150,16 @@ def _apply_group_offloading_block_level( - """ -- module: torch.nn.Module, -- num_blocks_per_group: int, -- offload_device: torch.device, -- onload_device: torch.device, -+ module: torch.nn.Module, -+ num_blocks_per_group: int, -+ offload_device: torch.device, -+ offload_to_disk: bool, -+ offload_path: Optional[str], -+ onload_device: torch.device, - non_blocking: bool, - stream: Union[torch.cuda.Stream, torch.Stream, None] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, - ) -> None: -@@ -138,7 +187,9 @@ def _apply_group_offloading_block_level( - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] - group = ModuleGroup( -- modules=current_modules, -+ modules=current_modules, -+ offload_to_disk=offload_to_disk, -+ offload_path=offload_path, - offload_device=offload_device, - onload_device=onload_device, - offload_leader=current_modules[-1], -@@ -187,10 +238,14 @@ def _apply_group_offloading_block_level( - unmatched_group = ModuleGroup( - modules=unmatched_modules, -- offload_device=offload_device, -+ offload_to_disk=offload_to_disk, -+ offload_path=offload_path, -+ offload_device=offload_device, - onload_device=onload_device, - offload_leader=module, - onload_leader=module, -+ # other args omitted for brevity... - ) - - if stream is None: -@@ -216,10 +271,16 @@ def _apply_group_offloading_leaf_level( - """ -- module: torch.nn.Module, -- offload_device: torch.device, -- onload_device: torch.device, -- non_blocking: bool, -+ module: torch.nn.Module, -+ offload_device: torch.device, -+ offload_to_disk: bool, -+ offload_path: Optional[str], -+ onload_device: torch.device, -+ non_blocking: bool, - stream: Union[torch.cuda.Stream, torch.Stream, None] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, - ) -> None: -@@ -229,7 +290,9 @@ def _apply_group_offloading_leaf_level( - for name, submodule in module.named_modules(): - if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): - continue -- group = ModuleGroup( -+ group = ModuleGroup( -+ offload_to_disk=offload_to_disk, -+ offload_path=offload_path, - modules=[submodule], - offload_device=offload_device, - onload_device=onload_device, -@@ -317,10 +380,14 @@ def _apply_group_offloading_leaf_level( - parent_module = module_dict[name] - assert getattr(parent_module, "_diffusers_hook", None) is None -- group = ModuleGroup( -+ group = ModuleGroup( -+ offload_to_disk=offload_to_disk, -+ offload_path=offload_path, - modules=[], - offload_device=offload_device, - onload_device=onload_device, -+ # additional args omitted for brevity... - ) - _apply_group_offloading_hook(parent_module, group, None) - -@@ -360,6 +427,38 @@ def _apply_lazy_group_offloading_hook( - registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) - - -+# ------------------------------------------------------------------------------- -+# Patch GroupOffloadingHook to use safetensors disk offload -+# ------------------------------------------------------------------------------- -+class GroupOffloadingHook: -+ def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup]): -+ self.group = group -+ self.next_group = next_group -+ # map param/buffer name -> file path -+ self.param_to_path: Dict[str,str] = {} -+ self.buffer_to_path: Dict[str,str] = {} -+ -+ def offload_parameters(self, module: nn.Module): -+ for name, param in module.named_parameters(recurse=False): -+ if self.group.offload_to_disk: -+ path = os.path.join(self.group.offload_path, f"{module.__class__.__name__}__{name}.safetensors") -+ _offload_tensor_to_disk_st(param.data, path) -+ self.param_to_path[name] = path -+ else: -+ param.data = param.data.to(self.group.offload_device, non_blocking=self.group.non_blocking) -+ -+ def onload_parameters(self, module: nn.Module): -+ for name, param in module.named_parameters(recurse=False): -+ if self.group.offload_to_disk: -+ path = self.param_to_path[name] -+ param.data = _load_tensor_from_disk_st(path, self.group.onload_device, self.group.non_blocking) -+ else: -+ param.data = param.data.to(self.group.onload_device, non_blocking=self.group.non_blocking) -+ -+ # analogous changes for buffers... -+ From 278cbc2e474ba027d87202b9b8f7c57f3899ca04 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Jun 2025 12:03:01 +0530 Subject: [PATCH 03/12] updates.patch --- src/diffusers/hooks/group_offloading.py | 23 ++++++++++------------- src/diffusers/models/modeling_utils.py | 20 ++++++++++++-------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index a31acb5a2d6f..ce6f47f67a09 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple, Union -import os -import torch import safetensors.torch +import torch + from ..utils import get_logger, is_accelerate_available from .hooks import HookRegistry, ModelHook @@ -165,9 +166,10 @@ def onload_(self): tensor_obj.data.record_stream(current_stream) else: # Load directly to the target device (synchronous) - loaded_tensors = safetensors.torch.load_file( - self.safetensors_file_path, device=self.onload_device + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): tensor_obj.data = loaded_tensors[key] return @@ -265,16 +267,12 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False - def __init__( - self, - group: ModuleGroup, - next_group: Optional[ModuleGroup] = None - ) -> None: + def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: self.group = group self.next_group = next_group # map param/buffer name -> file path - self.param_to_path: Dict[str,str] = {} - self.buffer_to_path: Dict[str,str] = {} + self.param_to_path: Dict[str, str] = {} + self.buffer_to_path: Dict[str, str] = {} def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -516,7 +514,6 @@ def apply_group_offloading( stream = torch.Stream() else: raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") - if offload_to_disk and offload_path is None: raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") @@ -899,4 +896,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: for submodule in module.modules(): if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device - raise ValueError("Group offloading is not enabled for the provided module.") \ No newline at end of file + raise ValueError("Group offloading is not enabled for the provided module.") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..ce57d17ab0d4 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -543,6 +543,8 @@ def enable_group_offload( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", + offload_to_disk: bool = False, + offload_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -588,15 +590,17 @@ def enable_group_offload( f"open an issue at https://github.com/huggingface/diffusers/issues." ) apply_group_offloading( - self, - onload_device, - offload_device, - offload_type, - num_blocks_per_group, - non_blocking, - use_stream, - record_stream, + module=self, + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + use_stream=use_stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk=offload_to_disk, + offload_path=offload_path, ) def save_pretrained( From d8179b10d3fa4bb4e8d1a0d2ff85b36463717f90 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 12 Jun 2025 10:27:48 +0530 Subject: [PATCH 04/12] offload_to_disk_path --- src/diffusers/hooks/group_offloading.py | 59 +++++++++---------------- src/diffusers/models/modeling_utils.py | 6 +-- 2 files changed, 24 insertions(+), 41 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index ce6f47f67a09..c2adb6ab1d4d 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -61,8 +61,7 @@ def __init__( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, onload_self: bool = True, - offload_to_disk: bool = False, - offload_path: Optional[str] = None, + offload_to_disk_path: Optional[str] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -77,14 +76,11 @@ def __init__( self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.offload_to_disk = offload_to_disk - self.offload_path = offload_path + self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False - if self.offload_to_disk: - if self.offload_path is None: - raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") - self.safetensors_file_path = os.path.join(self.offload_path, f"group_{id(self)}.safetensors") + if self.offload_to_disk_path: + self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") all_tensors = [] for module in self.modules: @@ -150,7 +146,7 @@ def onload_(self): context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) current_stream = torch_accelerator_module.current_stream() if self.record_stream else None - if self.offload_to_disk: + if self.offload_to_disk_path: if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -219,7 +215,7 @@ def onload_(self): @torch.compiler.disable() def offload_(self): r"""Offloads the group of modules to the offload_device.""" - if self.offload_to_disk: + if self.offload_to_disk_path: if not self._is_offloaded_to_disk: os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) tensors_to_save = { @@ -419,8 +415,7 @@ def apply_group_offloading( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", - offload_to_disk: bool = False, - offload_path: Optional[str] = None, + offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -464,11 +459,8 @@ def apply_group_offloading( offload_type (`str`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". - offload_to_disk (`bool`, defaults to `False`): - If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited. - Requires `offload_path` to be set. - offload_path (`str`, *optional*): - The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`. + offload_to_disk_path (`str`, *optional*): + The path to the directory where offloaded parameters will be stored. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -486,6 +478,8 @@ def apply_group_offloading( option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + (TODO: include example with `offload_to_disk_path`) + Example: ```python >>> from diffusers import CogVideoXTransformer3DModel @@ -514,8 +508,6 @@ def apply_group_offloading( stream = torch.Stream() else: raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") - if offload_to_disk and offload_path is None: - raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") _raise_error_if_accelerate_model_or_sequential_hook_present(module) @@ -528,8 +520,7 @@ def apply_group_offloading( num_blocks_per_group=num_blocks_per_group, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -540,8 +531,7 @@ def apply_group_offloading( module=module, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -555,8 +545,7 @@ def _apply_group_offloading_block_level( module: torch.nn.Module, num_blocks_per_group: int, offload_device: torch.device, - offload_to_disk: bool, - offload_path: Optional[str], + offload_to_disk_path: Optional[str], onload_device: torch.device, non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, @@ -572,6 +561,7 @@ def _apply_group_offloading_block_level( The module to which group offloading is applied. offload_device (`torch.device`): The device to which the group of modules are offloaded. This should typically be the CPU. + offload_to_disk_path: TODO onload_device (`torch.device`): The device to which the group of modules are onloaded. non_blocking (`bool`): @@ -611,8 +601,7 @@ def _apply_group_offloading_block_level( modules=current_modules, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], non_blocking=non_blocking, @@ -645,8 +634,7 @@ def _apply_group_offloading_block_level( modules=unmatched_modules, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -666,8 +654,7 @@ def _apply_group_offloading_leaf_level( module: torch.nn.Module, offload_device: torch.device, onload_device: torch.device, - offload_to_disk: bool, - offload_path: Optional[str], + offload_to_disk_path: Optional[str], non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, @@ -686,6 +673,7 @@ def _apply_group_offloading_leaf_level( The device to which the group of modules are offloaded. This should typically be the CPU. onload_device (`torch.device`): The device to which the group of modules are onloaded. + offload_to_disk_path: TODO non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. @@ -711,8 +699,7 @@ def _apply_group_offloading_leaf_level( modules=[submodule], offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=submodule, onload_leader=submodule, non_blocking=non_blocking, @@ -759,8 +746,7 @@ def _apply_group_offloading_leaf_level( onload_device=onload_device, offload_leader=parent_module, onload_leader=parent_module, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, parameters=parameters, buffers=buffers, non_blocking=non_blocking, @@ -779,8 +765,7 @@ def _apply_group_offloading_leaf_level( modules=[], offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=None, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ce57d17ab0d4..c71a8b3b5ae5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -543,8 +543,7 @@ def enable_group_offload( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", - offload_to_disk: bool = False, - offload_path: Optional[str] = None, + offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -599,8 +598,7 @@ def enable_group_offload( use_stream=use_stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, ) def save_pretrained( From 4e4842fb0b50aa65de2b88a10c8f98f45a73b9d1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 12 Jun 2025 10:48:57 +0530 Subject: [PATCH 05/12] check if safetensors already exist. --- src/diffusers/hooks/group_offloading.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c2adb6ab1d4d..d00a00713754 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -216,13 +216,20 @@ def onload_(self): def offload_(self): r"""Offloads the group of modules to the offload_device.""" if self.offload_to_disk_path: - if not self._is_offloaded_to_disk: + # TODO: we can potentially optimize this code path by checking if the _all_ the desired + # safetensor files exist on the disk and if so, skip this step entirely, reducing IO + # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # we perform a write. + # Check if the file has been saved in this session or if it already exists on disk. + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) tensors_to_save = { key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() } safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) - self._is_offloaded_to_disk = True + + # The group is now considered offloaded to disk for the rest of the session. + self._is_offloaded_to_disk = True for tensor_obj in self.tensor_to_key.keys(): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) From 8029cd7ef09a34208c1941c1cb299f59ab47b2c3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 12 Jun 2025 11:08:19 +0530 Subject: [PATCH 06/12] add test and clarify. --- src/diffusers/hooks/group_offloading.py | 1 + tests/models/test_modeling_common.py | 30 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d00a00713754..1ea60c3f3342 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -231,6 +231,7 @@ def offload_(self): # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True + # We do this to free up the RAM which is still holding the up tensor data. for tensor_obj in self.tensor_to_key.keys(): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) return diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5087bd0094a5..75e04b0a509a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -15,6 +15,7 @@ import copy import gc +import glob import inspect import json import os @@ -1608,6 +1609,35 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) _ = model(**inputs_dict)[0] + @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) + @require_torch_accelerator + @torch.no_grad() + def test_group_offloading_with_disk(self, record_stream, offload_type): + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + if not getattr(model, "_supports_group_offloading", True): + return + + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.eval() + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} + with tempfile.TemporaryDirectory() as tmpdir: + model.enable_group_offload( + torch_device, + offload_type=offload_type, + offload_to_disk_path=tmpdir, + use_stream=True, + record_stream=record_stream, + **additional_kwargs, + ) + has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") + assert has_safetensors + _ = model(**inputs_dict)[0] + def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: model = self.model_class(**self.init_dict) From da11656af489861ccad3293db316efa27310c60a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Jun 2025 15:59:48 +0530 Subject: [PATCH 07/12] updates --- src/diffusers/hooks/group_offloading.py | 9 +++------ src/diffusers/models/modeling_utils.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 11d592b73ee6..f9fef12ca08c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -274,9 +274,6 @@ class GroupOffloadingHook(ModelHook): def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: self.group = group self.next_group = next_group - # map param/buffer name -> file path - self.param_to_path: Dict[str, str] = {} - self.buffer_to_path: Dict[str, str] = {} def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -423,12 +420,12 @@ def apply_group_offloading( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", - offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, record_stream: bool = False, low_cpu_mem_usage: bool = False, + offload_to_disk_path: Optional[str] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -553,12 +550,12 @@ def _apply_group_offloading_block_level( module: torch.nn.Module, num_blocks_per_group: int, offload_device: torch.device, - offload_to_disk_path: Optional[str], onload_device: torch.device, non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, + offload_to_disk_path: Optional[str] = None, ) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to @@ -662,11 +659,11 @@ def _apply_group_offloading_leaf_level( module: torch.nn.Module, offload_device: torch.device, onload_device: torch.device, - offload_to_disk_path: Optional[str], non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, + offload_to_disk_path: Optional[str] = None, ) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index beaea4805050..5fa04fb2606f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -543,12 +543,12 @@ def enable_group_offload( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", - offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, record_stream: bool = False, low_cpu_mem_usage=False, + offload_to_disk_path: Optional[str] = None, ) -> None: r""" Activates group offloading for the current model. From 7d2295567f0f41f32f3925a1f768f4465d89335d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Jun 2025 16:10:39 +0530 Subject: [PATCH 08/12] update todos. --- src/diffusers/hooks/group_offloading.py | 13 +++++++++---- tests/models/test_modeling_common.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f9fef12ca08c..7f932e9b7db2 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -464,8 +464,9 @@ def apply_group_offloading( offload_type (`str`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". - offload_to_disk_path (`str`, *optional*): - The path to the directory where offloaded parameters will be stored. + offload_to_disk_path (`str`, *optional*, defaults to `None`): + The path to the directory where parameters will be offloaded. Setting this option can be useful in limited + RAM environment settings where a reasonable speed-memory trade-off is desired. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -566,7 +567,9 @@ def _apply_group_offloading_block_level( The module to which group offloading is applied. offload_device (`torch.device`): The device to which the group of modules are offloaded. This should typically be the CPU. - offload_to_disk_path: TODO + offload_to_disk_path (`str`, *optional*, defaults to `None`): + The path to the directory where parameters will be offloaded. Setting this option can be useful in limited + RAM environment settings where a reasonable speed-memory trade-off is desired. onload_device (`torch.device`): The device to which the group of modules are onloaded. non_blocking (`bool`): @@ -678,7 +681,9 @@ def _apply_group_offloading_leaf_level( The device to which the group of modules are offloaded. This should typically be the CPU. onload_device (`torch.device`): The device to which the group of modules are onloaded. - offload_to_disk_path: TODO + offload_to_disk_path (`str`, *optional*, defaults to `None`): + The path to the directory where parameters will be offloaded. Setting this option can be useful in limited + RAM environment settings where a reasonable speed-memory trade-off is desired. non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 32299eb38853..a2ad89101a29 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1720,7 +1720,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): **additional_kwargs, ) has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") - assert has_safetensors + assert has_safetensors, "No safetensors found in the directory." _ = model(**inputs_dict)[0] def test_auto_model(self, expected_max_diff=5e-5): From 90e546ada1e791798176575d16a8c8ddaecd068a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Jun 2025 16:16:16 +0530 Subject: [PATCH 09/12] update more docs. --- docs/source/en/optimization/memory.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 14a3b1f9049c..bc7fac0ee911 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o +### Offloading to disk + +Group offloading can consume significant system RAM depending on the model size. In limited RAM environments, +it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path` +argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and +[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled. + ## Layerwise casting Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality. From 9e9465646f9e9efeb70001f662e7b733388ca330 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Jun 2025 16:24:40 +0530 Subject: [PATCH 10/12] add nvme save --- src/diffusers/utils/deep_nvme_utils.py | 1843 ++++++++++++++++++++++++ 1 file changed, 1843 insertions(+) create mode 100644 src/diffusers/utils/deep_nvme_utils.py diff --git a/src/diffusers/utils/deep_nvme_utils.py b/src/diffusers/utils/deep_nvme_utils.py new file mode 100644 index 000000000000..9f24bc8e5d9f --- /dev/null +++ b/src/diffusers/utils/deep_nvme_utils.py @@ -0,0 +1,1843 @@ +# mypy: allow-untyped-defs +""" +Adapted from +https://github.com/deepspeedai/DeepSpeedExamples/blob/28a984e77b8d096dadc6389b6d1440b823587e28/deepnvme/model_checkpoint/torch/serialization_fast_v2.6.0.py +""" + +import copyreg +import difflib +import functools +import io +import os +import pickle +import re +import shutil +import struct +import sys +import tarfile +import tempfile +import threading +import warnings +from contextlib import closing, contextmanager +from enum import Enum +from typing import ( + IO, + Any, + BinaryIO, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +from typing_extensions import TypeAlias, TypeIs + +from .testing_utils import is_torch_available + + +if is_torch_available(): + import torch + import torch._weights_only_unpickler as _weights_only_unpickler + from torch._sources import get_source_lines_and_file + from torch._utils import _import_dotted_name + from torch.storage import _get_dtype_from_pickle_storage_type + from torch.types import Storage + + +__all__ = [ + "SourceChangeWarning", + "mkdtemp", + "register_package", + "check_module_version_greater_or_equal", + "validate_cuda_device", + "validate_hpu_device", + "location_tag", + "default_restore_location", + "normalize_storage_type", + "storage_to_tensor_type", + "save", + "load", + "StorageType", + "LoadEndianness", + "get_crc32_options", + "set_crc32_options", + "get_default_load_endianness", + "set_default_load_endianness", + "get_default_mmap_options", + "set_default_mmap_options", + "clear_safe_globals", + "get_safe_globals", + "add_safe_globals", + "safe_globals", + "get_unsafe_globals_in_checkpoint", + "skip_data", +] + +DEFAULT_PROTOCOL = 2 + +LONG_SIZE = struct.Struct("=l").size +INT_SIZE = struct.Struct("=i").size +SHORT_SIZE = struct.Struct("=h").size + +MAGIC_NUMBER = 0x1950A86A20F9469CFC6C +PROTOCOL_VERSION = 1001 +STORAGE_KEY_SEPARATOR = "," + +FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] +MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]] +STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] + +IS_WINDOWS = sys.platform == "win32" + +UNSAFE_MESSAGE = ( + "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " + "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " + "but it can result in arbitrary code execution. Do it only if you got the file from a " + "trusted source." +) + +if not IS_WINDOWS: + from mmap import MAP_PRIVATE, MAP_SHARED +else: + MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] + + +def _default_to_weights_only(pickle_module): + is_fbcode = not hasattr(torch.version, "git_version") + return pickle_module is None and not is_fbcode + + +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: Optional[MAP_LOCATION] = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + +class SourceChangeWarning(Warning): + pass + + +@contextmanager +def mkdtemp(): + path = tempfile.mkdtemp() + try: + yield path + finally: + shutil.rmtree(path) + + +_package_registry: List[ + Tuple[ + int, + Callable[[STORAGE], Optional[str]], + Callable[[STORAGE, str], Optional[STORAGE]], + ] +] = [] + + +class LoadEndianness(Enum): + NATIVE = 1 + LITTLE = 2 + BIG = 3 + + +_default_load_endian: Optional[LoadEndianness] = None + + +def get_default_load_endianness() -> Optional[LoadEndianness]: + """ + Get fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, this byte order is used as fallback. By default, it's + "native" byte order. + + Returns: + default_load_endian: Optional[LoadEndianness] + """ + return _default_load_endian + + +def set_default_load_endianness(endianness): + """ + Set fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, this byte order is used as fallback. By default, it's + "native" byte order. + + Args: + endianness: the new fallback byte order + """ + global _default_load_endian + if not isinstance(endianness, LoadEndianness) and endianness is not None: + raise TypeError("Invalid argument type in function set_default_load_endianness") + _default_load_endian = endianness + + +_compute_crc32: bool = True + + +def get_crc32_options() -> bool: + """ + Get whether :func:`torch.save` computes and writes crc32 for each record. + + Defaults to ``True``. + """ + return _compute_crc32 + + +def set_crc32_options(compute_crc32: bool): + """ + Set whether :func:`torch.save` computes and writes crc32 for each record. + + .. note:: + Setting this to ``False`` may make unzipping of the ``torch.save`` output fail or warn due to corrupted CRC32. + However ``torch.load`` will be able to load the file. + + Args: + compute_crc32 (bool): set crc32 compuation flag + """ + global _compute_crc32 + _compute_crc32 = compute_crc32 + + +_default_mmap_options: int = MAP_PRIVATE + + +def get_default_mmap_options() -> int: + """ + Get default mmap options for :func:`torch.load` with ``mmap=True``. + + Defaults to ``mmap.MAP_PRIVATE``. + + + Returns: + default_mmap_options: int + """ + return _default_mmap_options + + +class set_default_mmap_options: + """ + Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. + + For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported. Please open an issue if you need + any other option to be added here. + + .. note:: + This feature is currently not supported for Windows. + + Args: + flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` + """ + + def __init__(self, flags: int) -> None: + if IS_WINDOWS: + raise RuntimeError("Changing the default mmap options is currently not supported for Windows") + if flags != MAP_PRIVATE and flags != MAP_SHARED: + raise ValueError( + "Invalid argument in function set_default_mmap_options, " + f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}" + ) + global _default_mmap_options + self.prev = _default_mmap_options + _default_mmap_options = flags + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _default_mmap_options + _default_mmap_options = self.prev + + +def clear_safe_globals() -> None: + """ + Clears the list of globals that are safe for ``weights_only`` load. + """ + _weights_only_unpickler._clear_safe_globals() + + +def get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]: + """ + Returns the list of user-added globals that are safe for ``weights_only`` load. + """ + return _weights_only_unpickler._get_safe_globals() + + +def add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]) -> None: + """ + Marks the given globals as safe for ``weights_only`` load. For example, functions added to this list can be called + during unpickling, classes could be instantiated and have state set. + + Each item in the list can either be a function/class or a tuple of the form (function/class, string) where string + is the full path of the function/class. + + Within the serialized format, each function is identified with its full path as ``{__module__}.{__name__}``. When + calling this API, you can provide this full path that should match the one in the checkpoint otherwise the default + ``{fn.__module__}.{fn.__name__}`` will be used. + + Args: + safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe + + Example: + >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") + >>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with + tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` + will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the + code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... + torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # + MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) + """ + _weights_only_unpickler._add_safe_globals(safe_globals) + + +class safe_globals(_weights_only_unpickler._safe_globals): + r"""Context-manager that adds certain globals as safe for ``weights_only`` load. + + Args: + safe_globals: List of globals for weights_only load. + + Example: + >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") + >>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with + tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` + will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the + code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with + torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, + -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == [] + """ + + +def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> List[str]: + """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``. + + For a given function or class ``f``, the corresponding string will be of the form ``{f.__module__}.{f.__name__}``. + + This function will return any GLOBALs in the checkpoint that are not in the set marked safe for ``weights_only`` + (either via :func:`add_safe_globals` or :class:`safe_globals` context or allowlisted by ``torch`` by default). + + .. note:: + This function will statically disassemble the pickle file in the checkpoint. The implication is any classes + dynamically pushed onto the stack during unpickling will not be included in the output. + + Args: + f: File-like object or string containing the checkpoint object saved via ``torch.save`` + + Returns: + A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``. + """ + default_safe_globals_strings = set(_weights_only_unpickler._get_allowed_globals().keys()) + user_safe_global_strings = set(_weights_only_unpickler._get_user_allowed_globals().keys()) + safe_global_strings = default_safe_globals_strings.union(user_safe_global_strings) + + with _open_file_like(f, "rb") as opened_file: + if not _is_zipfile(opened_file): + raise ValueError("Expected input to be a checkpoint returned by torch.save") + with _open_zipfile_reader(opened_file) as zip_file: + if _is_torchscript_zip(zip_file): + raise ValueError( + "Expected input to be a checkpoint returned by torch.save but got a torchscript checkpoint" + ) + data_file = io.BytesIO(zip_file.get_record("data.pkl")) + all_globals = _weights_only_unpickler.get_globals_in_pkl(data_file) + return list(all_globals.difference(safe_global_strings)) + + +class skip_data: + """ + Context-manager that skips writing storage bytes for ``torch.save`` calls. + + Storages will still be saved, but the space that their bytes would usually be written to will be empty space. The + storage bytes can then be populated in a separate pass. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors. + + Example: + >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") >>> import tempfile >>> t = torch.randn(2, 3) >>> with + tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... + torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + +def _is_zipfile(f) -> bool: + # This is a stricter implementation than zipfile.is_zipfile(). + # zipfile.is_zipfile() is True if the magic number appears anywhere in the + # binary. Since we expect the files here to be generated by torch.save or + # torch.jit.save, it's safe to only check the start bytes and avoid + # collisions and assume the zip has only 1 file. + # See bugs.python.org/issue28494. + + start = f.tell() + # Read the first few bytes and match against the ZIP file signature + local_header_magic_number = b"PK\x03\x04" + read_bytes = f.read(len(local_header_magic_number)) + f.seek(start) + return read_bytes == local_header_magic_number + + +def register_package( + priority: int, + tagger: Callable[[STORAGE], Optional[str]], + deserializer: Callable[[STORAGE, str], Optional[STORAGE]], +): + """ + Registers callables for tagging and deserializing storage objects with an associated priority. Tagging associates a + device with a storage object at save time while deserializing moves a storage object to an appropriate device at + load time. :attr:`tagger` and :attr:`deserializer` are run in the order given by their :attr:`priority` until a + tagger/deserializer returns a value that is not `None`. + + To override the deserialization behavior for a device in the global registry, one can register a tagger with a + higher priority than the existing tagger. + + This function can also be used to register a tagger and deserializer for new devices. + + Args: + priority: Indicates the priority associated with the tagger and deserializer, where a lower + value indicates higher priority. + tagger: Callable that takes in a storage object and returns its tagged device as a string + or None. + deserializer: Callable that takes in storage object and a device string and returns a storage + object on the appropriate device or None. + + Returns: + `None` + + Example: + >>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): + >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU + device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return + obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) + """ + queue_elem = (priority, tagger, deserializer) + _package_registry.append(queue_elem) + _package_registry.sort() + + +def check_module_version_greater_or_equal( + module, + req_version_tuple, + error_if_malformed=True, +): + """ + Check if a module's version satisfies requirements + + Usually, a module's version string will be like 'x.y.z', which would be represented as a tuple (x, y, z), but + sometimes it could be an unexpected format. If the version string does not match the given tuple's format up to the + length of the tuple, then error and exit or emit a warning. + + Args: + module: the module to check the version of + req_version_tuple: tuple (usually of ints) representing the required version + error_if_malformed: whether we should exit if module version string is malformed + + Returns: + requirement_is_met: bool + """ + try: + version_strs = module.__version__.split(".") + # Cast module version fields to match the types of the required version + module_version = tuple(type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)) + requirement_is_met = module_version >= req_version_tuple + + except Exception as e: + message = ( + f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" + f" with tuple {str(req_version_tuple)}" + ) + if error_if_malformed: + raise RuntimeError(message) from e + else: + warnings.warn(message + ", but continuing assuming that requirement is met") + requirement_is_met = True + + return requirement_is_met + + +def _cpu_tag(obj): + if obj.device.type == "cpu": + return "cpu" + + +def _mps_tag(obj): + if obj.device.type == "mps": + return "mps" + + +def _meta_tag(obj): + if obj.device.type == "meta": + return "meta" + + +def _backend_tag(backend_name, obj): + if backend_name == "privateuse1": + backend_name = torch._C._get_privateuse1_backend_name() + if obj.device.type == backend_name: + if obj.device.index is None: + return backend_name + else: + return backend_name + ":" + str(obj.device.index) + + +def _cpu_deserialize(obj, location): + if location == "cpu": + return obj + + +def _mps_deserialize(obj, location): + if location.startswith("mps"): + return obj.mps() + + +def _meta_deserialize(obj, location): + if location == "meta": + return torch.UntypedStorage(obj.nbytes(), device="meta") + + +def _validate_device(location, backend_name): + """ + Check whether the device index of specified backend is valid + + In case of privateuse1 backend, your must first register a device_module for privateuse1 using + torch._register_device_module. Implement the following methods in device_module like cuda: + device_module._utils._get_device_index(location, True), device_module.device_count(). + + Args: + location: string of device + backend_name: the backend name or the name of privateuse1, which can be renamed + + Returns: + device_index: int + """ + if not hasattr(torch, backend_name): + raise RuntimeError( + f"The {backend_name.upper()} device module is not registered. " + "If you are running on a CPU-only machine, " + "please use torch.load with map_location=torch.device('cpu') " + "to map your storages to the CPU." + ) + device_module = getattr(torch, backend_name) + if hasattr(device_module, "_utils") and hasattr(device_module._utils, "_get_device_index"): + device_index = device_module._utils._get_device_index(location, True) + device = torch.device(backend_name, device_index) + else: + device = torch.device(location) + device_index = device.index if device.index else 0 + if hasattr(device_module, "is_available") and not device_module.is_available(): + raise RuntimeError( + f"Attempting to deserialize object on a {backend_name.upper()} " + f"device but torch.{backend_name}.is_available() is False. " + "If you are running on a CPU-only machine, " + "please use torch.load with map_location=torch.device('cpu') " + "to map your storages to the CPU." + ) + if hasattr(device_module, "device_count"): + device_count = device_module.device_count() + if device_index >= device_count: + raise RuntimeError( + f"Attempting to deserialize object on {backend_name.upper()} device " + f"{device_index} but torch.{backend_name}.device_count() is {device_count}. " + "Please use torch.load with map_location to map your storages " + "to an existing device." + ) + return device + + +def validate_cuda_device(location): + return _validate_device(location, "cuda").index + + +def validate_hpu_device(location): + return _validate_device(location, "hpu").index + + +def _deserialize(backend_name, obj, location): + if backend_name == "privateuse1": + backend_name = torch._C._get_privateuse1_backend_name() + if location.startswith(backend_name): + device = _validate_device(location, backend_name) + return obj.to(device=device) + + +register_package(10, _cpu_tag, _cpu_deserialize) +register_package( + 20, + functools.partial(_backend_tag, "cuda"), + functools.partial(_deserialize, "cuda"), +) +register_package(21, _mps_tag, _mps_deserialize) +register_package(22, _meta_tag, _meta_deserialize) +register_package( + 23, + functools.partial(_backend_tag, "privateuse1"), + functools.partial(_deserialize, "privateuse1"), +) +register_package( + 24, + functools.partial(_backend_tag, "hpu"), + functools.partial(_deserialize, "hpu"), +) +register_package( + 25, + functools.partial(_backend_tag, "xpu"), + functools.partial(_deserialize, "xpu"), +) + + +def location_tag( + storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage], +): + for _, tagger, _ in _package_registry: + location = tagger(storage) + if location: + return location + raise RuntimeError("don't know how to determine data location of " + torch.typename(storage)) + + +def default_restore_location(storage, location): + """ + Restores `storage` using a deserializer function registered for the `location`. + + This function looks in the registry for deserializer functions that match the `location`. If found, it attempts to + use them, in priority order, to restore `storage` until one returns a not `None` result. If no deserializer can be + found in the registry, or all found fail to bear a result, it raises a `RuntimeError`. + + Args: + storage (STORAGE): the storage object to restore + location (str): the location tag associated with the storage object + + Returns: + storage: Optional[STORAGE] + + Raises: + RuntimeError: If no deserializer matching `location` is found in the registry or if + all matching ones return `None`. + """ + for _, _, fn in _package_registry: + result = fn(storage, location) + if result is not None: + return result + raise RuntimeError( + "don't know how to restore data location of " + torch.typename(storage) + " (tagged with " + location + ")" + ) + + +def normalize_storage_type(storage_type): + return getattr(torch, storage_type.__name__) + + +def storage_to_tensor_type(storage): + storage_type = type(storage) + module = _import_dotted_name(storage_type.__module__) + return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) + + +def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]: + return isinstance(name_or_buffer, (str, os.PathLike)) + + +class _opener: + def __init__(self, file_like): + self.file_like = file_like + + def __enter__(self): + return self.file_like + + def __exit__(self, *args): + pass + + +class _open_file(_opener): + def __init__(self, name, mode): + super().__init__(open(name, mode)) + + def __exit__(self, *args): + self.file_like.close() + + +class _open_buffer_reader(_opener): + def __init__(self, buffer): + super().__init__(buffer) + _check_seekable(buffer) + + +class _open_buffer_writer(_opener): + def __exit__(self, *args): + self.file_like.flush() + + +def _open_file_like(name_or_buffer, mode): + if _is_path(name_or_buffer): + return _open_file(name_or_buffer, mode) + else: + if "w" in mode: + return _open_buffer_writer(name_or_buffer) + elif "r" in mode: + return _open_buffer_reader(name_or_buffer) + else: + raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") + + +class _open_zipfile_reader(_opener): + def __init__(self, name_or_buffer) -> None: + super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) + + +class _open_zipfile_writer_file(_opener): + def __init__(self, name) -> None: + self.file_stream = None + self.name = str(name) + try: + self.name.encode("ascii") + except UnicodeEncodeError: + # PyTorchFileWriter only supports ascii filename. + # For filenames with non-ascii characters, we rely on Python + # for writing out the file. + self.file_stream = io.FileIO(self.name, mode="w") + super().__init__(torch._C.PyTorchFileWriter(self.file_stream, _compute_crc32)) + else: + super().__init__(torch._C.PyTorchFileWriter(self.name, _compute_crc32)) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + if self.file_stream is not None: + self.file_stream.close() + + +class _open_zipfile_writer_buffer(_opener): + def __init__(self, buffer) -> None: + if not callable(getattr(buffer, "write", None)): + msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" + if not hasattr(buffer, "write"): + raise AttributeError(msg) + raise TypeError(msg) + self.buffer = buffer + super().__init__(torch._C.PyTorchFileWriter(buffer, _compute_crc32)) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + self.buffer.flush() + + +def _open_zipfile_writer(name_or_buffer): + container: Type[_opener] + if _is_path(name_or_buffer): + container = _open_zipfile_writer_file + else: + container = _open_zipfile_writer_buffer + return container(name_or_buffer) + + +def _is_compressed_file(f) -> bool: + compress_modules = ["gzip"] + try: + return f.__module__ in compress_modules + except AttributeError: + return False + + +def _should_read_directly(f): + """ + Checks if f is a file that should be read directly. It should be read directly if it is backed by a real file (has + a fileno) and is not a a compressed file (e.g. gzip) + """ + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + + +def _check_seekable(f) -> bool: + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = ( + str(e) + + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead." + ) + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + + +def _check_dill_version(pickle_module) -> None: + """Checks if using dill as the pickle module, and if so, checks if it is the correct version. + If dill version is lower than 0.3.1, a ValueError is raised. + + Args: + pickle_module: module used for pickling metadata and objects + + """ + if pickle_module is not None and pickle_module.__name__ == "dill": + required_dill_version = (0, 3, 1) + if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): + raise ValueError( + ( + "'torch' supports dill >= {}, but you have dill {}. Please upgrade dill or switch to 'pickle'" + ).format( + ".".join([str(num) for num in required_dill_version]), + pickle_module.__version__, + ) + ) + + +def _check_save_filelike(f): + if not _is_path(f) and not hasattr(f, "write"): + raise AttributeError("expected 'f' to be string, path, or a file-like object with a 'write' attribute") + + +def save( + obj: object, + f: FILE_LIKE, + pickle_module: Any = pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, + _use_new_zipfile_serialization: bool = True, + _disable_byteorder_record: bool = False, +) -> None: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("makes cwd dirty") >>> # Save to file >>> x = torch.tensor([0, 1, 2, 3, 4]) >>> + torch.save(x, "tensor.pt") >>> # Save to io.BytesIO buffer >>> buffer = io.BytesIO() >>> torch.save(x, buffer) + """ + torch._C._log_api_usage_once("torch.save") + _check_dill_version(pickle_module) + _check_save_filelike(f) + + if _use_new_zipfile_serialization: + with _open_zipfile_writer(f) as opened_zipfile: + _save( + obj, + opened_zipfile, + pickle_module, + pickle_protocol, + _disable_byteorder_record, + ) + return + else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError("Cannot use skip_data=True with _use_new_zipfile_serialization=False") + with _open_file_like(f, "wb") as opened_file: + _legacy_save(obj, opened_file, pickle_module, pickle_protocol) + + +def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: + import torch.nn as nn + + serialized_container_types = {} + serialized_storages: Dict[str, Tuple[torch.UntypedStorage, torch.dtype]] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: Dict[int, torch.dtype] = {} + + def persistent_id(obj: Any) -> Optional[Tuple]: + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, type) and issubclass(obj, nn.Module): + if obj in serialized_container_types: + return None + serialized_container_types[obj] = True + source_file = source = None + try: + source_lines, _, source_file = get_source_lines_and_file(obj) + source = "".join(source_lines) + except Exception: # saving the source is optional, so we can ignore any errors + warnings.warn( + "Couldn't retrieve source code for container of " + "type " + obj.__name__ + ". It won't be checked " + "for correctness upon loading." + ) + return ("module", obj, source_file, source) + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + storage: torch.UntypedStorage + + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + dtype = obj.dtype + storage_numel = obj._size() + + elif isinstance(obj, torch.UntypedStorage): + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + dtype = torch.uint8 + storage_numel = storage.nbytes() + else: + raise TypeError(f"type not recognized: {type(obj)}") + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that view the same data as different types" + ) + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + view_metadata: Optional[Tuple[str, int, int]] + + # Offset is always 0, but we keep it for backwards compatibility + # with the old serialization format (which supported storage views) + offset = 0 + storage_key = str(storage._cdata) + location = location_tag(storage) + + # TODO: There's an issue here with FC. It might be impossible to + # solve, but it's worth noting. Imagine we save a list `[storage, + # tensor]`, where `tensor.storage()` is the same as `storage`, and + # `tensor.element_size() > 1`. Let's say that `tensor.dtype == + # torch.float`. The storage will be serialized with element size + # of 1, since we're choosing to serialize the first occurance of + # a duplicate storage. Since this legacy serialization format saves + # the numel of the storage, rather than nbytes directly, we'll be + # effectively saving nbytes in this case. We'll be able to load it + # and the tensor back up with no problems in _this_ and future + # versions of pytorch, but in older versions, here's the problem: + # the storage will be loaded up as a UntypedStorage, and then the + # FloatTensor will loaded and the UntypedStorage will be assigned to + # it. Since the storage dtype does not match the tensor dtype, this + # will cause an error. If we reverse the list, like `[tensor, + # storage]`, then we will save the `tensor.storage()` as a faked + # `FloatStorage`, and the saved size will be the correct + # dtype-specific numel count that old versions expect. `tensor` + # will be able to load up properly in old versions, pointing to + # a FloatStorage. However, `storage` is still being translated to + # a UntypedStorage, and it will try to resolve to the same + # FloatStorage that `tensor` contains. This will also cause an + # error. It doesn't seem like there's any way around this. + # Probably, we just cannot maintain FC for the legacy format if the + # saved list contains both a tensor and a storage that point to the + # same data. We should still be able to maintain FC for lists of + # just tensors, as long as all views share the same dtype as the + # tensor they are viewing. + + if storage_key not in serialized_storages: + serialized_storages[storage_key] = (storage, dtype) + is_view = storage._cdata != storage._cdata + if is_view: + view_metadata = (str(storage._cdata), offset, storage.nbytes()) + else: + view_metadata = None + + res = ( + "storage", + storage_type, + storage_key, + location, + storage_numel, + view_metadata, + ) + return res + return None + + sys_info = dict( # noqa + protocol_version=PROTOCOL_VERSION, + little_endian=sys.byteorder == "little", + type_sizes=dict( # noqa + short=SHORT_SIZE, + int=INT_SIZE, + long=LONG_SIZE, + ), + ) + + pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) + pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) + pickle_module.dump(sys_info, f, protocol=pickle_protocol) + + class PyTorchLegacyPickler(pickle_module.Pickler): + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol) + pickler.dump(obj) + + serialized_storage_keys = sorted(serialized_storages.keys()) + pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) + f.flush() + if hasattr(f, "save_torch_storage_object_list"): + sorted_storage_objects = [serialized_storages[key] for key in serialized_storage_keys] + f.save_torch_storage_object_list(sorted_storage_objects, True) + else: + for key in serialized_storage_keys: + storage, dtype = serialized_storages[key] + storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype)) + + +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): + serialized_storages = {} + id_map: Dict[int, str] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: Dict[int, torch.dtype] = {} + + def persistent_id(obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if str(storage.device) != "meta" and storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that view the same data as different types" + ) + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) + serialized_storages[storage_key] = storage + + return ("storage", storage_type, storage_key, location, storage_numel) + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + + class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined] + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchPickler(data_buf, protocol=pickle_protocol) + pickler.dump(obj) + data_value = data_buf.getvalue() + zip_file.write_record("data.pkl", data_value, len(data_value)) + + # Write byte order marker + if not _disable_byteorder_record: + if sys.byteorder not in ["little", "big"]: + raise ValueError("Unknown endianness type: " + sys.byteorder) + + zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder)) + + # Write each tensor to a file named tensor/the_tensor_key in the zip archive + for key in sorted(serialized_storages.keys()): + name = f"data/{key}" + storage = serialized_storages[key] + num_bytes = storage.nbytes() + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) + + +def load( + f: FILE_LIKE, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: Optional[bool] = None, + mmap: Optional[bool] = None, + **pickle_load_args: Any, +) -> Any: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("undefined filepaths") >>> torch.load("tensors.pt", weights_only=True) # Load all tensors + onto the CPU >>> torch.load("tensors.pt", map_location=torch.device("cpu"), weights_only=True) # Load all + tensors onto the CPU, using a function >>> torch.load( ... "tensors.pt", map_location=lambda storage, loc: + storage, weights_only=True ... ) # Load all tensors onto GPU 1 >>> torch.load( ... "tensors.pt", ... + map_location=lambda storage, loc: storage.cuda(1), ... weights_only=True, ... ) # type: ignore[attr-defined] # + Map tensors from GPU 1 to GPU 0 >>> torch.load("tensors.pt", map_location={"cuda:1": "cuda:0"}, + weights_only=True) # Load tensor from io.BytesIO object # Loading from a buffer setting weights_only=False, + warning this can be unsafe >>> with open("tensor.pt", "rb") as f: ... buffer = io.BytesIO(f.read()) >>> + torch.load(buffer, weights_only=False) # Load a module with 'ascii' encoding for unpickling # Loading from a + module setting weights_only=False, warning this can be unsafe >>> torch.load("module.pt", encoding="ascii", + weights_only=False) + """ + torch._C._log_api_usage_once("torch.load") + DOCS_MESSAGE = ( + "\n\nCheck the documentation of torch.load to learn more about types accepted by default with " + "weights_only https://pytorch.org/docs/stable/generated/torch.load.html." + ) + + def _get_wo_message(message: str) -> str: + unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default." + has_unsafe_global = re.search(unsafe_global_pattern, message) is not None + blocklist_pattern = r"whose module (\S+) is blocked" + has_blocklist = re.search(blocklist_pattern, message) is not None + import_pattern = r"(\S+) must be (\S+) to load" + has_import = re.search(import_pattern, message) is not None + if has_unsafe_global: + updated_message = ( + "Weights only load failed. This file can still be loaded, to do so you have two options, " + "\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. " + f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check " + "the recommended steps in the following error message.\n\tWeightsUnpickler error: " + message + ) + else: + if has_import: + return f"Weights only load failed. {message}\n {UNSAFE_MESSAGE}\n" + else: + updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n" + if not has_blocklist: + updated_message += ( + "Please file an issue with the following so that we can make " + "`weights_only=True` compatible with your use case: WeightsUnpickler error: " + ) + updated_message += message + return updated_message + DOCS_MESSAGE + + global _serialization_tls + skip_data = _serialization_tls.skip_data + if skip_data: + raise RuntimeError( + "`torch.load` called within a torch.serialization.skip_data context manager " + "is not supported yet. Please call torch.load outside the skip_data context manager." + ) + + weights_only_not_set = weights_only is None + + if weights_only_not_set: + weights_only = _default_to_weights_only(pickle_module) + + true_values = ["1", "y", "yes", "true"] + # Add ability to force safe only or non-safe weight loads via environment variables + force_weights_only_load = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values + force_no_weights_only_load = os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values + + if force_weights_only_load and force_no_weights_only_load: + raise RuntimeError( + "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` " + "should be set, but both were set." + ) + elif force_weights_only_load: + weights_only = True + elif force_no_weights_only_load: + # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only + if weights_only_not_set: + warnings.warn( + "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the" + "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.", + UserWarning, + stacklevel=2, + ) + weights_only = False + + if weights_only: + if pickle_module is not None: + raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") + else: + if pickle_module is None: + pickle_module = pickle + + # make flipping default BC-compatible + if mmap is None: + mmap = False + + _check_dill_version(pickle_module) + + if "encoding" not in pickle_load_args.keys(): + pickle_load_args["encoding"] = "utf-8" + + with _open_file_like(f, "rb") as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to torch.jit.load, we need to + # reset back to the original position. + orig_position = opened_file.tell() + overall_storage = None + with _open_zipfile_reader(opened_file) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): + warnings.warn( + "'torch.load' received a zip file that looks like a TorchScript archive" + " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" + " silence this warning)", + UserWarning, + ) + if weights_only: + raise RuntimeError( + "Cannot use ``weights_only=True`` with TorchScript archives passed to " + "``torch.load``. " + UNSAFE_MESSAGE + ) + opened_file.seek(orig_position) + return torch.jit.load(opened_file, map_location=map_location) + if mmap: + if not _is_path(f): + raise ValueError("f must be a file path in order to use the mmap argument") + size = os.path.getsize(f) + if not IS_WINDOWS: + shared = get_default_mmap_options() == MAP_SHARED + else: + shared = False + overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size) + if weights_only: + try: + return _load( + opened_zipfile, + map_location, + _weights_only_unpickler, + overall_storage=overall_storage, + **pickle_load_args, + ) + except pickle.UnpicklingError as e: + raise pickle.UnpicklingError(_get_wo_message(str(e))) from None + return _load( + opened_zipfile, + map_location, + pickle_module, + overall_storage=overall_storage, + **pickle_load_args, + ) + if mmap: + f_name = "" if not isinstance(f, str) else f"{f}, " + raise RuntimeError( + "mmap can only be used with files saved with " + f"`torch.save({f_name}_use_new_zipfile_serialization=True), " + "please torch.save your checkpoint with this option in order to use mmap." + ) + if weights_only: + try: + return _legacy_load( + opened_file, + map_location, + _weights_only_unpickler, + **pickle_load_args, + ) + except pickle.UnpicklingError as e: + raise pickle.UnpicklingError(_get_wo_message(str(e))) from None + return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) + + +# Register pickling support for layout instances such as +# torch.sparse_coo, etc +def _get_layout(name): + """Get layout extension object from its string representation.""" + cache = _get_layout.cache # type: ignore[attr-defined] + if not cache: + for v in torch.__dict__.values(): + if isinstance(v, torch.layout): + cache[str(v)] = v + return cache[name] + + +# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 +_get_layout.cache = {} # type: ignore[attr-defined] +copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) + + +def _legacy_load(f, map_location, pickle_module, **pickle_load_args): + deserialized_objects: Dict[int, Any] = {} + + restore_location = _get_restore_location(map_location) + + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if type(name) is str and "Storage" in name: + try: + return StorageType(name) + except KeyError: + pass + return super().find_class(mod_name, name) + + def _check_container_source(container_type, source_file, original_source): + try: + current_source = "".join(get_source_lines_and_file(container_type)[0]) + except Exception: # saving the source is optional, so we can ignore any errors + warnings.warn( + "Couldn't retrieve source code for container of " + "type " + container_type.__name__ + ". It won't be checked " + "for correctness upon loading." + ) + return + if original_source != current_source: + if container_type.dump_patches: + file_name = container_type.__name__ + ".patch" + diff = difflib.unified_diff( + current_source.split("\n"), + original_source.split("\n"), + source_file, + source_file, + lineterm="", + ) + lines = "\n".join(diff) + try: + with open(file_name, "a+") as f: + file_size = f.seek(0, 2) + f.seek(0) + if file_size == 0: + f.write(lines) + elif file_size != len(lines) or f.read() != lines: + raise OSError + msg = ( + "Saved a reverse patch to " + file_name + ". " + "Run `patch -p0 < " + file_name + "` to revert your " + "changes." + ) + except OSError: + msg = ( + "Tried to save a patch, but couldn't create a " + "writable file " + file_name + ". Make sure it " + "doesn't exist and your working directory is " + "writable." + ) + else: + msg = ( + "you can retrieve the original source code by " + "accessing the object's source attribute or set " + "`torch.nn.Module.dump_patches = True` and use the " + "patch tool to revert the changes." + ) + msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" + warnings.warn(msg, SourceChangeWarning) + + def legacy_load(f): + deserialized_objects: Dict[int, Any] = {} + + def persistent_load(saved_id): + if isinstance(saved_id, tuple): + # Ignore containers that don't have any sources saved + if all(saved_id[1:]): + _check_container_source(*saved_id) + return saved_id[0] + return deserialized_objects[int(saved_id)] + + with closing(tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)) as tar, mkdtemp() as tmpdir: + if pickle_module is _weights_only_unpickler: + raise RuntimeError( + "Cannot use ``weights_only=True`` with files saved in the legacy .tar format. " + UNSAFE_MESSAGE + ) + tar.extract("storages", path=tmpdir) + with open(os.path.join(tmpdir, "storages"), "rb", 0) as f: + num_storages = pickle_module.load(f, **pickle_load_args) + for _ in range(num_storages): + args = pickle_module.load(f, **pickle_load_args) + key, location, storage_type = args + dtype = storage_type._dtype + obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) + obj = restore_location(obj, location) + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[key] = torch.storage.TypedStorage( + wrap_storage=obj, dtype=dtype, _internal=True + ) + + storage_views = pickle_module.load(f, **pickle_load_args) + for target_cdata, root_cdata, offset, numel in storage_views: + root = deserialized_objects[root_cdata] + element_size = torch._utils._element_size(root.dtype) + offset_bytes = offset * element_size + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[target_cdata] = torch.storage.TypedStorage( + wrap_storage=root._untyped_storage[offset_bytes : offset_bytes + numel * element_size], + dtype=root.dtype, + _internal=True, + ) + + tar.extract("tensors", path=tmpdir) + with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f: + num_tensors = pickle_module.load(f, **pickle_load_args) + for _ in range(num_tensors): + args = pickle_module.load(f, **pickle_load_args) + key, storage_id, _original_tensor_type = args + storage = deserialized_objects[storage_id] + (ndim,) = struct.unpack(" str: + # When using encoding='bytes' in Py3, some **internal** keys stored as + # strings in Py2 are loaded as bytes. This function decodes them with + # ascii encoding, one that Py3 uses by default. + # + # NOTE: This should only be used on internal keys (e.g., `typename` and + # `location` in `persistent_load` below! + if isinstance(bytes_str, bytes): + return bytes_str.decode("ascii") + return bytes_str + + +def _get_restore_location(map_location): + if map_location is None: + restore_location = default_restore_location + elif isinstance(map_location, dict): + + def restore_location(storage, location): + location = map_location.get(location, location) + return default_restore_location(storage, location) + + elif isinstance(map_location, (str, bytes)): + + def restore_location(storage, location): + return default_restore_location(storage, map_location) + + elif isinstance(map_location, torch.device): + + def restore_location(storage, location): + return default_restore_location(storage, str(map_location)) + + else: + + def restore_location(storage, location): + result = map_location(storage, location) + if result is None: + result = default_restore_location(storage, location) + return result + + return restore_location + + +class StorageType: + def __init__(self, name): + self._dtype = _get_dtype_from_pickle_storage_type(name) + + @property + def dtype(self): + return self._dtype + + def __str__(self): + return f"StorageType(dtype={self.dtype})" + + +def _load( + zip_file, + map_location, + pickle_module, + pickle_file="data.pkl", + overall_storage=None, + **pickle_load_args, +): + restore_location = _get_restore_location(map_location) + + loaded_storages = {} + + # check if byteswapping is needed + byteordername = "byteorder" + byteorderdata = None + if zip_file.has_record(byteordername): + byteorderdata = zip_file.get_record(byteordername) + if byteorderdata not in [b"little", b"big"]: + raise ValueError("Unknown endianness type: " + byteorderdata.decode()) + elif get_default_load_endianness() == LoadEndianness.LITTLE or get_default_load_endianness() is None: + byteorderdata = b"little" + elif get_default_load_endianness() == LoadEndianness.BIG: + byteorderdata = b"big" + elif get_default_load_endianness() == LoadEndianness.NATIVE: + pass + else: + raise ValueError("Invalid load endianness type") + + if not zip_file.has_record(byteordername) and get_default_load_endianness() is None and sys.byteorder == "big": + # Default behaviour was changed + # See https://github.com/pytorch/pytorch/issues/101688 + warnings.warn( + "The default load endianness for checkpoints without a byteorder mark " + "on big endian machines was changed from 'native' to 'little' endian, " + "to avoid this behavior please use " + "torch.serialization.set_default_load_endianness to set " + "the desired default load endianness", + UserWarning, + ) + + def load_tensor(dtype, numel, key, location): + name = f"data/{key}" + if torch._guards.detect_fake_mode(None) is not None: + nbytes = numel * torch._utils._element_size(dtype) + storage = torch.UntypedStorage(nbytes, device="meta") + elif overall_storage is not None: + storage_offset = zip_file.get_record_offset(name) + storage = overall_storage[storage_offset : storage_offset + numel] + else: + storage = ( + zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage + ) + # swap here if byteswapping is needed + if byteorderdata is not None: + if byteorderdata.decode() != sys.byteorder: + storage.byteswap(dtype) + + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + typed_storage = torch.storage.TypedStorage( + wrap_storage=restore_location(storage, location), + dtype=dtype, + _internal=True, + ) + + if typed_storage._data_ptr() != 0: + loaded_storages[key] = typed_storage + + return typed_storage + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + if storage_type is torch.UntypedStorage: + dtype = torch.uint8 + else: + dtype = storage_type.dtype + + if key in loaded_storages: + typed_storage = loaded_storages[key] + else: + nbytes = numel * torch._utils._element_size(dtype) + typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) + + return typed_storage + + load_module_mapping: Dict[str, str] = { + # See https://github.com/pytorch/pytorch/pull/51633 + "torch.tensor": "torch._tensor" + } + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 + # Lets us override the imports that pickle uses when unpickling an object. + # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. + def find_class(self, mod_name, name): + if type(name) is str and "Storage" in name: + try: + return StorageType(name) + except KeyError: + pass + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(zip_file.get_record(pickle_file)) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + # Needed for tensors where storage device and rebuild tensor device are + # not connected (wrapper subclasses and tensors rebuilt using numpy) + global _serialization_tls + _serialization_tls.map_location = map_location + result = unpickler.load() + _serialization_tls.map_location = None + + torch._utils._validate_loaded_sparse_tensors() + torch._C._log_api_usage_metadata("torch.load.metadata", {"serialization_id": zip_file.serialization_id()}) + return result + + +def _is_torchscript_zip(zip_file): + return "constants.pkl" in zip_file.get_all_records() From b535b99e13e359680419d86f5c623d151946f527 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Jun 2025 17:00:04 +0530 Subject: [PATCH 11/12] nvme support --- src/diffusers/hooks/group_offloading.py | 35 +++++++++++---- src/diffusers/models/modeling_utils.py | 2 + src/diffusers/utils/import_utils.py | 18 ++++++++ src/diffusers/utils/state_dict_utils.py | 58 ++++++++++++++++++++++++- 4 files changed, 104 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 7f932e9b7db2..cb9f1cd789ef 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -20,6 +20,7 @@ import torch from ..utils import get_logger, is_accelerate_available +from ..utils.import_utils import is_deepspeed_available, is_deepspeed_version from .hooks import HookRegistry, ModelHook @@ -27,6 +28,8 @@ from accelerate.hooks import AlignDevicesHook, CpuOffload from accelerate.utils import send_to_device +if is_deepspeed_available() and is_deepspeed_version(">=", "0.16"): + from ..utils.state_dict_utils import _fast_aio_save logger = get_logger(__name__) # pylint: disable=invalid-name @@ -62,6 +65,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, + _enable_deepnvme_disk_offloading: Optional[bool] = False ) -> None: self.modules = modules self.offload_device = offload_device @@ -80,7 +84,9 @@ def __init__( self._is_offloaded_to_disk = False if self.offload_to_disk_path: - self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") + self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading + ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors" + self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.{ext}") all_tensors = [] for module in self.modules: @@ -153,8 +159,8 @@ def onload_(self): with context: if self.stream is not None: - # Load to CPU, pin, and async copy to device for overlapping transfer and compute - loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + # Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute + loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu") for key, tensor_obj in self.key_to_tensor.items(): pinned_tensor = loaded_cpu_tensors[key].pin_memory() tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) @@ -165,7 +171,7 @@ def onload_(self): onload_device = ( self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): tensor_obj.data = loaded_tensors[key] return @@ -218,15 +224,18 @@ def offload_(self): if self.offload_to_disk_path: # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO - # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # overhead. Currently, we just check if the given `param_file_path` exists and if not # we perform a write. # Check if the file has been saved in this session or if it already exists on disk. - if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): - os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + if not self._is_offloaded_to_disk and not os.path.exists(self.param_file_path): + os.makedirs(os.path.dirname(self.param_file_path), exist_ok=True) tensors_to_save = { key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() } - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + if not self._enable_deepnvme_disk_offloading: + safetensors.torch.save_file(tensors_to_save, self.param_file_path) + else: + _fast_aio_save(tensors_to_save, self.param_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True @@ -426,6 +435,7 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + _enable_deepnvme_disk_offloading: Optional[bool] = False ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -531,6 +541,7 @@ def apply_group_offloading( stream=stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading ) elif offload_type == "leaf_level": _apply_group_offloading_leaf_level( @@ -542,6 +553,7 @@ def apply_group_offloading( stream=stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading ) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -557,6 +569,7 @@ def _apply_group_offloading_block_level( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + _enable_deepnvme_disk_offloading: Optional[bool] = False ) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to @@ -617,6 +630,7 @@ def _apply_group_offloading_block_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -651,6 +665,7 @@ def _apply_group_offloading_block_level( stream=None, record_stream=False, onload_self=True, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) if stream is None: _apply_group_offloading_hook(module, unmatched_group, None) @@ -667,6 +682,7 @@ def _apply_group_offloading_leaf_level( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + _enable_deepnvme_disk_offloading: Optional[bool] = False ) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory @@ -717,6 +733,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading ) _apply_group_offloading_hook(submodule, group, None) modules_with_group_offloading.add(name) @@ -764,6 +781,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading ) _apply_group_offloading_hook(parent_module, group, None) @@ -785,6 +803,7 @@ def _apply_group_offloading_leaf_level( record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5fa04fb2606f..bf2bb3596d55 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -549,6 +549,7 @@ def enable_group_offload( record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, + _enable_deepnvme_disk_offloading: Optional[bool] = False ) -> None: r""" Activates group offloading for the current model. @@ -599,6 +600,7 @@ def enable_group_offload( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading ) def save_pretrained( diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f12e9de33172..7173d18cc4e7 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -220,6 +220,11 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _nltk_available, _nltk_version = _is_package_available("nltk") _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") +_deepspeed_available, _deepspeed_version = _is_package_available("deepspeed") + + +def is_deepspeed_available(): + return _deepspeed_available def is_torch_available(): @@ -655,6 +660,19 @@ def is_torch_version(operation: str, version: str): return compare_versions(parse(_torch_version), operation, version) +def is_deepspeed_version(operation: str, version: str): + """ + Compares the current DeepSpeed version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of DeepSpeed + """ + return compare_versions(parse(_deepspeed_version), operation, version) + + def is_torch_xla_version(operation: str, version: str): """ Compares the current torch_xla version to a given reference with an operation. diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 50bfce8b15eb..4f494ef8fe97 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -18,13 +18,19 @@ import enum import json -from .import_utils import is_torch_available +from .import_utils import is_deepspeed_available, is_deepspeed_version, is_torch_available from .logging import get_logger if is_torch_available(): import torch +if is_deepspeed_available() and is_deepspeed_version(">", "0.16"): + from deepspeed.io import FastFileWriter, FastFileWriterConfig + from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder + + from .deep_nvme_utils import save as _nvme_save + logger = get_logger(__name__) @@ -364,3 +370,53 @@ def _load_sft_state_dict_metadata(model_file: str): return json.loads(raw) if raw else None else: return None + + +# Utilities below are taken from +# https://github.com/deepspeedai/DeepSpeedExamples/blob/28a984e77b8d096dadc6389b6d1440b823587e28/deepnvme/model_checkpoint/torch_save_utils.py#L16 +def _load_io_ops(args): + if AsyncIOBuilder().is_compatible(): + AsyncIOBuilder().load(verbose=False) + if args.gpu and GDSBuilder().is_compatible(): + GDSBuilder().load(verbose=False) + + +def _get_aio_handle(): + AIO_QUEUE_DEPTH = 8 + AIO_BLOCK_SIZE = 8 * (1024**2) + AIO_INTRA_OP_PARALLEL = 1 + AIO_SINGLE_SUBMIT = False + + h = ( + AsyncIOBuilder() + .load(verbose=False) + .aio_handle( + block_size=AIO_BLOCK_SIZE, + queue_depth=AIO_QUEUE_DEPTH, + single_submit=AIO_SINGLE_SUBMIT, + overlap_events=AIO_SINGLE_SUBMIT, + intra_op_parallelism=AIO_INTRA_OP_PARALLEL, + ) + ) + return h + + +def _get_aio_components(): + PINNED_BUFFER_MB = 64 + h = _get_aio_handle() + pinned_memory = torch.zeros(PINNED_BUFFER_MB * (1024**2), dtype=torch.uint8, device="cpu").pin_memory() + return h, pinned_memory + + +def _fast_aio_save(buffer, file, single_io_buffer=False): + h, pinned_memory = _get_aio_components() + fast_writer_config = FastFileWriterConfig( + dnvme_handle=h, + pinned_tensor=pinned_memory, + double_buffer=not single_io_buffer, + num_parallel_writers=1, + writer_rank=0, + ) + + ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config) + _nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False) From 33f30ef86ec2995cf5d7567e366decf1f23d2b93 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Jun 2025 18:41:46 +0530 Subject: [PATCH 12/12] updates --- src/diffusers/hooks/group_offloading.py | 40 +++++++++++++++---------- src/diffusers/models/modeling_utils.py | 4 +-- src/diffusers/utils/state_dict_utils.py | 1 + 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index cb9f1cd789ef..5b4c22a5f462 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -65,7 +65,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: self.modules = modules self.offload_device = offload_device @@ -86,7 +86,7 @@ def __init__( if self.offload_to_disk_path: self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors" - self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.{ext}") + self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}{ext}") all_tensors = [] for module in self.modules: @@ -160,7 +160,10 @@ def onload_(self): with context: if self.stream is not None: # Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute - loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu") + if self._enable_deepnvme_disk_offloading: + loaded_cpu_tensors = torch.load(self.param_file_path, weights_only=True, map_location="cpu") + else: + loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu") for key, tensor_obj in self.key_to_tensor.items(): pinned_tensor = loaded_cpu_tensors[key].pin_memory() tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) @@ -171,7 +174,12 @@ def onload_(self): onload_device = ( self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device ) - loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device) + if self._enable_deepnvme_disk_offloading: + loaded_tensors = torch.load( + self.param_file_path, weights_only=True, map_location=onload_device + ) + else: + loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): tensor_obj.data = loaded_tensors[key] return @@ -232,10 +240,10 @@ def offload_(self): tensors_to_save = { key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() } - if not self._enable_deepnvme_disk_offloading: - safetensors.torch.save_file(tensors_to_save, self.param_file_path) - else: + if self._enable_deepnvme_disk_offloading: _fast_aio_save(tensors_to_save, self.param_file_path) + else: + safetensors.torch.save_file(tensors_to_save, self.param_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True @@ -435,7 +443,7 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -541,7 +549,7 @@ def apply_group_offloading( stream=stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) elif offload_type == "leaf_level": _apply_group_offloading_leaf_level( @@ -553,7 +561,7 @@ def apply_group_offloading( stream=stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -569,7 +577,7 @@ def _apply_group_offloading_block_level( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to @@ -630,7 +638,7 @@ def _apply_group_offloading_block_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -682,7 +690,7 @@ def _apply_group_offloading_leaf_level( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory @@ -733,7 +741,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) _apply_group_offloading_hook(submodule, group, None) modules_with_group_offloading.add(name) @@ -781,7 +789,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) _apply_group_offloading_hook(parent_module, group, None) @@ -803,7 +811,7 @@ def _apply_group_offloading_leaf_level( record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index bf2bb3596d55..1d0c6815940a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -549,7 +549,7 @@ def enable_group_offload( record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" Activates group offloading for the current model. @@ -600,7 +600,7 @@ def enable_group_offload( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) def save_pretrained( diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 4f494ef8fe97..aea77b5405dc 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -420,3 +420,4 @@ def _fast_aio_save(buffer, file, single_io_buffer=False): ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config) _nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False) + ds_fast_writer.close()