Skip to content

Commit 0a23047

Browse files
Thomas Polasekfacebook-github-bot
authored andcommitted
Convert FBCODE to use the Ruff Formatter
Summary: Converts the directory specified to use the Ruff formatter. This is the last big diff to convert all of Fbcode to Ruff. pomsky_fix_bugs drop-conflicts bypass-github-export-checks allow-large-files Reviewed By: amyreese Differential Revision: D66886610 fbshipit-source-id: 8276a7f6164efec189ca0b87e535543ed5bc3615
1 parent b731879 commit 0a23047

File tree

7 files changed

+18
-14
lines changed

7 files changed

+18
-14
lines changed

tests/test_manifest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def test_replicated_entries_only_on_rank_0(rank: int) -> None:
725725

726726

727727
def _update_local_manifest_with_merged_entries(
728-
local_manifest: Dict[str, Entry]
728+
local_manifest: Dict[str, Entry],
729729
) -> None:
730730
"""
731731
Update the expected local manifest with manually merged ShardedTensorEntries

torchsnapshot/asyncio_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def _run_once(self):
4747
timeout = (
4848
0
4949
if ready or self._stopping
50-
else min(max(scheduled[0]._when - now, 0), 86400) if scheduled else None
50+
else min(max(scheduled[0]._when - now, 0), 86400)
51+
if scheduled
52+
else None
5153
)
5254
event_list = self._selector.select(timeout)
5355
self._process_events(event_list)

torchsnapshot/io_preparers/tensor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def can_load_inplace(
199199

200200
@staticmethod
201201
def empty_tensor_from_entry(
202-
entry: Union[TensorEntry, ChunkedTensorEntry]
202+
entry: Union[TensorEntry, ChunkedTensorEntry],
203203
) -> torch.Tensor:
204204
if entry.dtype in SUPPORTED_QUANTIZED_DTYPES:
205205
# TODO: we can't allocate empty quantized tensors because we don't
@@ -394,11 +394,15 @@ def tensor_copy(dst: torch.Tensor, src: torch.Tensor) -> None:
394394
# a region of the larger tensor's storage contain data that does not match
395395
# the larger tensor's qscheme.
396396

397-
if src.is_quantized and (
398-
not dst.is_quantized # Copying from quantized Tensor to non-quantized Tensor is not allowed
399-
or dst.qscheme() != src.qscheme() # Quantized copy only works with same qscheme
400-
or dst.dtype != src.dtype # Quantized copy requires matching dtypes
401-
or (dst._is_view() and not _q_params_equal(dst, src)) # See the top comment
397+
if (
398+
src.is_quantized
399+
and (
400+
not dst.is_quantized # Copying from quantized Tensor to non-quantized Tensor is not allowed
401+
or dst.qscheme()
402+
!= src.qscheme() # Quantized copy only works with same qscheme
403+
or dst.dtype != src.dtype # Quantized copy requires matching dtypes
404+
or (dst._is_view() and not _q_params_equal(dst, src)) # See the top comment
405+
)
402406
):
403407
# TODO: tile the dequantize -> copy to reduce memory footprint
404408
src = _tensor_dequantize(src)

torchsnapshot/manifest_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _get_rank_to_manifest(metadata: SnapshotMetadata) -> List[Dict[str, Entry]]:
109109

110110

111111
def _get_merged_sharded_tensor_entries(
112-
rank_to_manifest: List[Dict[str, Entry]]
112+
rank_to_manifest: List[Dict[str, Entry]],
113113
) -> Dict[str, Entry]:
114114
groups = defaultdict(list)
115115
for manifest in rank_to_manifest:
@@ -130,7 +130,7 @@ def _get_merged_sharded_tensor_entries(
130130

131131

132132
def _get_merged_dtensor_entries(
133-
rank_to_manifest: List[Dict[str, Entry]]
133+
rank_to_manifest: List[Dict[str, Entry]],
134134
) -> Dict[str, Entry]:
135135
"""
136136
Merge all DTensor entries across ranks if sharded

torchsnapshot/partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def partition_write_reqs(
283283

284284

285285
def _consolidate_replicated_chunked_tensor_entries(
286-
rank_to_entries: List[Dict[str, Entry]]
286+
rank_to_entries: List[Dict[str, Entry]],
287287
) -> List[Dict[str, Entry]]:
288288
groups: Dict[str, List[ChunkedTensorEntry]] = defaultdict(list)
289289

torchsnapshot/serialization.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,7 @@ def contiguous_view_as_untyped_storage(tensor: torch.Tensor) -> UntypedStorage:
245245
else:
246246
untyped_storage = tensor.storage().untyped()
247247
return untyped_storage[
248-
tensor.storage_offset()
249-
* tensor.element_size() : tensor.storage_offset()
248+
tensor.storage_offset() * tensor.element_size() : tensor.storage_offset()
250249
* tensor.element_size()
251250
+ tensor.nelement() * tensor.element_size()
252251
]

torchsnapshot/snapshot.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,6 @@ def _coalesce_path_and_replicated(
863863
app_state: AppState,
864864
replicated: List[str],
865865
) -> Tuple[str, Set[str]]:
866-
867866
rank = pg_wrapper.get_rank()
868867

869868
# coalesce path

0 commit comments

Comments
 (0)