Skip to content

Commit

Permalink
Revert "Add back alloc/free hack for unsharded data"
Browse files Browse the repository at this point in the history
This reverts commit 0386841.
  • Loading branch information
epwalsh committed Apr 24, 2024
1 parent 0386841 commit 5f140f2
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from olmo_core.distributed.utils import get_rank, get_world_size
from olmo_core.stream import Stream
from olmo_core.utils import alloc_storage, free_storage, get_default_device
from olmo_core.utils import get_default_device

log = logging.getLogger(__name__)

Expand All @@ -44,11 +44,6 @@ class FlatParamHandle:
Consolidated data for all of the local sharded data of the parameters including padding.
"""

params_data_unsharded: Optional[torch.Tensor] = None
"""
Holds the unsharded data for ``params_data``.
"""

params_sharded_data_lp: Optional[torch.Tensor] = None
"""
Low-precision version of sharded ``params_data``.
Expand Down Expand Up @@ -246,26 +241,20 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
if rank0_only or dist.get_backend() == dist.Backend.GLOO:
return

dtype = dtype or self.params_data.dtype

# Initialize unsharded, padded ``params_data`` without the all-gather.
if self.params_data_unsharded is None or self.params_data_unsharded.dtype != dtype:
self.params_data_unsharded = torch.empty(
self.params_data.unsharded_shape, dtype=dtype, device=self.device
)
else:
alloc_storage(self.params_data_unsharded, torch.Size(self.params_data.unsharded_shape))

self.params_data.unshard_(unsharded_data=self.params_data_unsharded, dtype=dtype, rank0_only=rank0_only)
all_params_unsharded_data = torch.empty(
self.params_data.unsharded_shape, dtype=dtype or self.params_data.dtype, device=self.device
)
self.params_data.unshard_(unsharded_data=all_params_unsharded_data, dtype=dtype, rank0_only=rank0_only)

# Cast sharded ``params_data`` to ``dtype``.
if dtype is not None:
self.params_sharded_data_lp = self.params_data.sharded_chunk(self.params_data_unsharded)
self.params_sharded_data_lp = self.params_data.sharded_chunk(all_params_unsharded_data)
self.params_sharded_data_lp.copy_(self.params_data.sharded_data)

# Initialize unsharded, padded gradient.
if set_grads and self.params_unsharded_grad is None:
self.params_unsharded_grad = torch.zeros_like(self.params_data_unsharded)
self.params_unsharded_grad = torch.zeros_like(all_params_unsharded_data)

def unshard_(
self,
Expand Down Expand Up @@ -344,8 +333,6 @@ def reshard_(self, writeback: bool = False):
if not self.params:
return

assert self.params_data_unsharded is not None

self.params_data.reshard_(writeback=writeback)
offset = 0
for flat_param in self.params:
Expand All @@ -355,8 +342,6 @@ def reshard_(self, writeback: bool = False):
flat_param.data = self.params_data[offset : offset + flat_param.sharded_numel]
offset += flat_param.sharded_numel

free_storage(self.params_data_unsharded)

def pre_reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
Expand Down

0 comments on commit 5f140f2

Please sign in to comment.