28
28
29
29
30
30
class RepoStatus (Enum ):
31
- """Collections of Repo Statuses"""
31
+ """Repo Statuses"""
32
32
33
33
CLEAN = 1
34
34
CHANGES_NOT_ADDED = 2
@@ -39,7 +39,7 @@ class RepoStatus(Enum):
39
39
class SizeInfo :
40
40
"""Size info for a file or the repo in bytes.
41
41
42
- Deduped size can't be disabled. So it will always be there .
42
+ Deduped size can't be disabled. So it is always performed .
43
43
44
44
Both sparsified and gzipped are optional. They are applied in the following
45
45
order if both are enabled:
@@ -59,7 +59,7 @@ class SizeInfo:
59
59
class _SHA1_Tensor :
60
60
"""Representing a tensor using sha1(s) from SHA1 store.
61
61
62
- It can be either a dense one or 2 sparse one with SST and DST.
62
+ It can be either a dense one or two sparse one ( SST and DST) .
63
63
"""
64
64
65
65
is_dense : bool = True
@@ -68,23 +68,35 @@ class _SHA1_Tensor:
68
68
dst_sha1 : str = ""
69
69
70
70
71
- def _recursive_apply_to_elements (data : Union [List [Any ], Dict [str , Any ]], fn : Any ) -> None :
71
+ def _recursive_apply_to_elements (data : Union [List [Any ], Dict [str , Any ]], fn : Any , names : List [ str ] ) -> None :
72
72
"""Helper function to traverse a dict recursively and apply a function to leafs.
73
73
74
- The input `data` is a dict or a list and it should only contain dict and list.
74
+
75
+ Args:
76
+ data (dict or list):
77
+ A dict or a list and it should only contain dict and list.
78
+ fn (Any):
79
+ A call back function on each element. Signature:
80
+ fn(element: Any, names: List[str]) -> Any
81
+ names (list):
82
+ Stack of names for making the element path.
75
83
"""
76
84
if isinstance (data , list ):
77
85
for i , _ in enumerate (data ):
86
+ names .append (str (i ))
78
87
if isinstance (data [i ], (list , dict )):
79
- _recursive_apply_to_elements (data [i ], fn )
88
+ _recursive_apply_to_elements (data [i ], fn , names )
80
89
else :
81
- data [i ] = fn (data [i ])
90
+ data [i ] = fn (data [i ], names )
91
+ names .pop ()
82
92
elif isinstance (data , dict ):
83
93
for key in data .keys ():
94
+ names .append (str (key ))
84
95
if isinstance (data [key ], (list , dict )):
85
- _recursive_apply_to_elements (data [key ], fn )
96
+ _recursive_apply_to_elements (data [key ], fn , names )
86
97
else :
87
- data [key ] = fn (data [key ])
98
+ data [key ] = fn (data [key ], names )
99
+ names .pop ()
88
100
else :
89
101
assert False , f"Unexpected data type: { type (data )} "
90
102
@@ -250,21 +262,21 @@ def add(
250
262
# yet. Need to figure out a method for delta tracking.
251
263
if per_tensor :
252
264
253
- def fn (element : Any ) -> Any :
265
+ def fn (element : Any , names : List [ str ] ) -> Any :
254
266
"""Callback on each leaf object for _recursive_apply_to_elements below."""
255
267
if isinstance (element , Tensor ):
256
268
if sparsify :
257
269
# TODO (Min): here we will optionally do SST/DST and add those
258
270
# tensors with sparsity.
259
271
# Remember to update ret_state_dict
260
272
raise NotImplementedError ()
261
- sha1 = self ._sha1_store .add (element , compress = gzip )
273
+ sha1 = self ._sha1_store .add (element , compress = gzip , name = "." . join ( names ) )
262
274
return _SHA1_Tensor (is_dense = True , dense_sha1 = sha1 )
263
275
else :
264
276
return element
265
277
266
278
state_dict = torch .load (file_path )
267
- _recursive_apply_to_elements (state_dict , fn )
279
+ _recursive_apply_to_elements (state_dict , fn , [] )
268
280
file_path_or_state_dict = state_dict
269
281
270
282
# Add this top-level object.
0 commit comments