Skip to content

Commit 3f46701

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
use pin_memory for ModelInput and minor refactoring (#2910)
Summary: # context * this is some BE work and minor refactoring when working on pipeline optimization * major change is to add "pin_memory" option to the test_input file for `ModelInput` generation: ``` The `pin_memory()` call for all KJT tensors are important for training benchmark, and also valid argument for the prod training scenario: TrainModelInput should be created on pinned memory for a fast transfer to gpu. For more on pin_memory: https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory ``` * minor refactoring includes (1) default parameters for TrainPipeline benchmark so that the embedding size, batch size, etc. are resonable. (2) fix the batch index error in trace, previously used (curr_index+1) (3) split the `EmbeddingPipelinedForward` __call__ function into two parts. * trace comparison: the `pin_memory()` for the ModelInput is critical for a non_blocking cpu to gpu data copy before copy_batch_to_gpu is the same size as gpu data transfer {F1977324224} after: copy_batch_to_gpu is hardly seen in trace {F1977324220} Reviewed By: aporialiao Differential Revision: D73514639
1 parent 7bd2afc commit 3f46701

File tree

5 files changed

+61
-17
lines changed

5 files changed

+61
-17
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@
5454

5555
@dataclass
5656
class RunOptions:
57-
world_size: int = 4
58-
num_batches: int = 20
57+
world_size: int = 2
58+
num_batches: int = 10
5959
sharding_type: ShardingType = ShardingType.TABLE_WISE
6060
input_type: str = "kjt"
6161
profile: str = ""
6262

6363

6464
@dataclass
6565
class EmbeddingTablesConfig:
66-
num_unweighted_features: int = 4
67-
num_weighted_features: int = 4
66+
num_unweighted_features: int = 100
67+
num_weighted_features: int = 100
6868
embedding_feature_dim: int = 512
6969

7070
def generate_tables(

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,7 @@ def trace_handler(prof) -> None:
867867
profile_memory=True,
868868
with_flops=True,
869869
with_modules=True,
870+
with_stack=False, # usually we don't want to show the entire stack in the trace
870871
on_trace_ready=trace_handler,
871872
) as p:
872873
for i in range(num_profiles):

torchrec/distributed/test_utils/test_input.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def generate_local_batches(
205205
offsets_dtype: torch.dtype = torch.int64,
206206
lengths_dtype: torch.dtype = torch.int64,
207207
all_zeros: bool = False,
208+
pin_memory: bool = False, # pin_memory is needed for training job qps benchmark
208209
) -> List["ModelInput"]:
209210
"""
210211
Returns multi-rank batches (ModelInput) of world_size
@@ -224,6 +225,7 @@ def generate_local_batches(
224225
offsets_dtype=offsets_dtype,
225226
lengths_dtype=lengths_dtype,
226227
all_zeros=all_zeros,
228+
pin_memory=pin_memory,
227229
)
228230
for _ in range(world_size)
229231
]
@@ -256,9 +258,15 @@ def generate(
256258
offsets_dtype: torch.dtype = torch.int64,
257259
lengths_dtype: torch.dtype = torch.int64,
258260
all_zeros: bool = False,
261+
pin_memory: bool = False, # pin_memory is needed for training job qps benchmark
259262
) -> "ModelInput":
260263
"""
261264
Returns a single batch of `ModelInput`
265+
266+
The `pin_memory()` call for all KJT tensors are important for training benchmark, and
267+
also valid argument for the prod training scenario: TrainModelInput should be created
268+
on pinned memory for a fast transfer to gpu. For more on pin_memory:
269+
https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory
262270
"""
263271
float_features = (
264272
torch.zeros((batch_size, num_float_features), device=device)
@@ -279,6 +287,7 @@ def generate(
279287
offsets_dtype=offsets_dtype,
280288
lengths_dtype=lengths_dtype,
281289
all_zeros=all_zeros,
290+
pin_memory=pin_memory,
282291
)
283292
if tables is not None and len(tables) > 0
284293
else None
@@ -297,6 +306,7 @@ def generate(
297306
offsets_dtype=offsets_dtype,
298307
lengths_dtype=lengths_dtype,
299308
all_zeros=all_zeros,
309+
pin_memory=pin_memory,
300310
)
301311
if weighted_tables is not None and len(weighted_tables) > 0
302312
else None
@@ -306,6 +316,9 @@ def generate(
306316
if all_zeros
307317
else torch.rand((batch_size,), device=device)
308318
)
319+
if pin_memory:
320+
float_features = float_features.pin_memory()
321+
label = label.pin_memory()
309322
return ModelInput(
310323
float_features=float_features,
311324
idlist_features=idlist_features,
@@ -404,13 +417,18 @@ def _assemble_kjt(
404417
device: Optional[torch.device] = None,
405418
use_offsets: bool = False,
406419
offsets_dtype: torch.dtype = torch.int64,
420+
pin_memory: bool = False,
407421
) -> KeyedJaggedTensor:
408422
"""
409-
410423
Assembles a KeyedJaggedTensor (KJT) from the provided per-feature lengths and indices.
411424
412425
This method is used to generate corresponding local_batches and global_batch KJTs.
413426
It concatenates the lengths and indices for each feature to form a complete KJT.
427+
428+
The `pin_memory()` call for all KJT tensors are important for training benchmark, and
429+
also valid argument for the prod training scenario: TrainModelInput should be created
430+
on pinned memory for a fast transfer to gpu. For more on pin_memory:
431+
https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory
414432
"""
415433

416434
lengths = torch.cat(lengths_per_feature)
@@ -422,6 +440,11 @@ def _assemble_kjt(
422440
[torch.tensor([0], device=device), lengths.cumsum(0)]
423441
).to(offsets_dtype)
424442
lengths = None
443+
if pin_memory:
444+
indices = indices.pin_memory()
445+
lengths = lengths.pin_memory() if lengths else None
446+
weights = weights.pin_memory() if weights else None
447+
offsets = offsets.pin_memory() if offsets else None
425448
return KeyedJaggedTensor(features, indices, weights, lengths, offsets)
426449

427450
@staticmethod
@@ -440,6 +463,7 @@ def create_standard_kjt(
440463
offsets_dtype: torch.dtype = torch.int64,
441464
lengths_dtype: torch.dtype = torch.int64,
442465
all_zeros: bool = False,
466+
pin_memory: bool = False,
443467
) -> KeyedJaggedTensor:
444468
features, lengths_per_feature, indices_per_feature = (
445469
ModelInput._create_features_lengths_indices(
@@ -462,6 +486,7 @@ def create_standard_kjt(
462486
device=device,
463487
use_offsets=use_offsets,
464488
offsets_dtype=offsets_dtype,
489+
pin_memory=pin_memory,
465490
)
466491

467492
@staticmethod
@@ -555,14 +580,15 @@ class TdModelInput(ModelInput):
555580

556581
@dataclass
557582
class TestSparseNNInputConfig:
558-
batch_size: int = 1
583+
batch_size: int = 8192
559584
num_float_features: int = 10
560585
feature_pooling_avg: int = 10
561586
use_offsets: bool = False
562587
dev_str: str = ""
563588
long_kjt_indices: bool = True
564589
long_kjt_offsets: bool = True
565590
long_kjt_lengths: bool = True
591+
pin_memory: bool = True
566592

567593
def generate_model_input(
568594
self,
@@ -584,4 +610,5 @@ def generate_model_input(
584610
indices_dtype=torch.int64 if self.long_kjt_indices else torch.int32,
585611
offsets_dtype=torch.int64 if self.long_kjt_offsets else torch.int32,
586612
lengths_dtype=torch.int64 if self.long_kjt_lengths else torch.int32,
613+
pin_memory=self.pin_memory,
587614
)

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def copy_batch_to_gpu(
680680
`self._execute_all_batches=True`, then returns None.
681681
"""
682682
context = self._create_context()
683-
with record_function(f"## copy_batch_to_gpu {self._next_index} ##"):
683+
with record_function(f"## copy_batch_to_gpu {context.index} ##"):
684684
with self._stream_context(self._memcpy_stream):
685685
batch = self._next_batch(dataloader_iter)
686686
if batch is not None:
@@ -1008,11 +1008,13 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
10081008
context.detached_embedding_tensors,
10091009
):
10101010
grads = [tensor.grad for tensor in detached_emb_tensors]
1011-
# Some embeddings may never get used in the final loss computation,
1012-
# so the grads will be `None`. If we don't exclude these, it will fail
1013-
# with error: "grad can be implicitly created only for scalar outputs"
1014-
# Alternatively, if the tensor has only 1 element, pytorch can still
1015-
# figure out how to do autograd
1011+
"""
1012+
Some embeddings may never get used in the final loss computation,
1013+
so the grads will be `None`. If we don't exclude these, it will fail
1014+
with error: "grad can be implicitly created only for scalar outputs"
1015+
Alternatively, if the tensor has only 1 element, pytorch can still
1016+
figure out how to do autograd
1017+
"""
10161018
embs_to_backprop, grads_to_use, invalid_features = [], [], []
10171019
assert len(embedding_features) == len(emb_tensors)
10181020
for features, tensor, grad in zip(embedding_features, emb_tensors, grads):

torchrec/distributed/train_pipeline/utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ def __call__(
600600
self._stream
601601
)
602602
ctx.record_stream(cur_stream)
603+
603604
awaitable = self._context.embedding_a2a_requests.pop(self._name)
604605
# in case of MC modules
605606
is_mc_module: bool = isinstance(awaitable, Iterable)
@@ -613,6 +614,24 @@ def __call__(
613614
embeddings = (
614615
awaitable.wait()
615616
) # trigger awaitable manually for type checking
617+
618+
self.detach_embeddings(embeddings=embeddings, cur_stream=cur_stream)
619+
620+
if is_mc_module:
621+
return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts))
622+
else:
623+
return LazyNoWait(embeddings)
624+
625+
def detach_embeddings(
626+
self,
627+
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
628+
cur_stream: torch.Stream,
629+
) -> None:
630+
"""
631+
detach the grad from embeddings so that the backward/opt of the embeddings
632+
won't be invoked by loss.backward(). Instead, there is a dedicated embedding_backward
633+
call in semi-sync pipeline progress.
634+
"""
616635
tensors = []
617636
detached_tensors = []
618637
# in case of EC, embeddings are Dict[str, JaggedTensor]
@@ -650,11 +669,6 @@ def __call__(
650669
self._context.embedding_features.append([list(embeddings.keys())])
651670
self._context.detached_embedding_tensors.append(detached_tensors)
652671

653-
if is_mc_module:
654-
return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts))
655-
else:
656-
return LazyNoWait(embeddings)
657-
658672

659673
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
660674
"""

0 commit comments

Comments
 (0)