Skip to content

Commit fd7b962

Browse files
min-xu-aiflying-x
andauthored
[fix] unclose FD and not load/store metadata many times (#1038)
* [fix] unclose FD and not load/store metadata many times * one more stat * Update fairscale/experimental/wgit/sha1_store.py * add name to the objects when added * dict key can be int from a state_dict * removed top_level_objects key; it should be added into repo, not sha1_store Co-authored-by: Min Xu <[email protected]>
1 parent b0c3fe1 commit fd7b962

File tree

3 files changed

+172
-134
lines changed

3 files changed

+172
-134
lines changed

fairscale/experimental/wgit/repo.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
class RepoStatus(Enum):
31-
"""Collections of Repo Statuses"""
31+
"""Repo Statuses"""
3232

3333
CLEAN = 1
3434
CHANGES_NOT_ADDED = 2
@@ -39,7 +39,7 @@ class RepoStatus(Enum):
3939
class SizeInfo:
4040
"""Size info for a file or the repo in bytes.
4141
42-
Deduped size can't be disabled. So it will always be there.
42+
Deduped size can't be disabled. So it is always performed.
4343
4444
Both sparsified and gzipped are optional. They are applied in the following
4545
order if both are enabled:
@@ -59,7 +59,7 @@ class SizeInfo:
5959
class _SHA1_Tensor:
6060
"""Representing a tensor using sha1(s) from SHA1 store.
6161
62-
It can be either a dense one or 2 sparse one with SST and DST.
62+
It can be either a dense one or two sparse one (SST and DST).
6363
"""
6464

6565
is_dense: bool = True
@@ -68,23 +68,35 @@ class _SHA1_Tensor:
6868
dst_sha1: str = ""
6969

7070

71-
def _recursive_apply_to_elements(data: Union[List[Any], Dict[str, Any]], fn: Any) -> None:
71+
def _recursive_apply_to_elements(data: Union[List[Any], Dict[str, Any]], fn: Any, names: List[str]) -> None:
7272
"""Helper function to traverse a dict recursively and apply a function to leafs.
7373
74-
The input `data` is a dict or a list and it should only contain dict and list.
74+
75+
Args:
76+
data (dict or list):
77+
A dict or a list and it should only contain dict and list.
78+
fn (Any):
79+
A call back function on each element. Signature:
80+
fn(element: Any, names: List[str]) -> Any
81+
names (list):
82+
Stack of names for making the element path.
7583
"""
7684
if isinstance(data, list):
7785
for i, _ in enumerate(data):
86+
names.append(str(i))
7887
if isinstance(data[i], (list, dict)):
79-
_recursive_apply_to_elements(data[i], fn)
88+
_recursive_apply_to_elements(data[i], fn, names)
8089
else:
81-
data[i] = fn(data[i])
90+
data[i] = fn(data[i], names)
91+
names.pop()
8292
elif isinstance(data, dict):
8393
for key in data.keys():
94+
names.append(str(key))
8495
if isinstance(data[key], (list, dict)):
85-
_recursive_apply_to_elements(data[key], fn)
96+
_recursive_apply_to_elements(data[key], fn, names)
8697
else:
87-
data[key] = fn(data[key])
98+
data[key] = fn(data[key], names)
99+
names.pop()
88100
else:
89101
assert False, f"Unexpected data type: {type(data)}"
90102

@@ -250,21 +262,21 @@ def add(
250262
# yet. Need to figure out a method for delta tracking.
251263
if per_tensor:
252264

253-
def fn(element: Any) -> Any:
265+
def fn(element: Any, names: List[str]) -> Any:
254266
"""Callback on each leaf object for _recursive_apply_to_elements below."""
255267
if isinstance(element, Tensor):
256268
if sparsify:
257269
# TODO (Min): here we will optionally do SST/DST and add those
258270
# tensors with sparsity.
259271
# Remember to update ret_state_dict
260272
raise NotImplementedError()
261-
sha1 = self._sha1_store.add(element, compress=gzip)
273+
sha1 = self._sha1_store.add(element, compress=gzip, name=".".join(names))
262274
return _SHA1_Tensor(is_dense=True, dense_sha1=sha1)
263275
else:
264276
return element
265277

266278
state_dict = torch.load(file_path)
267-
_recursive_apply_to_elements(state_dict, fn)
279+
_recursive_apply_to_elements(state_dict, fn, [])
268280
file_path_or_state_dict = state_dict
269281

270282
# Add this top-level object.

0 commit comments

Comments
 (0)