Skip to content

Commit

Permalink
2024-12-18 nightly release (92b903f)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 18, 2024
1 parent e4eb029 commit deb7122
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2257,6 +2257,7 @@ def setUp(self) -> None:
torch.autograd.set_detect_anomaly(False)
self._exit_stack = ExitStack()
self._exit_stack.enter_context(
# type: ignore[attr-defined]
torch._dynamo.config.patch(
optimize_ddp="python_reducer_without_compiled_forward"
),
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,7 @@ def start_embedding_lookup(
context,
source_stream=self._data_dist_stream,
target_stream=stream,
stream_context=self._stream_context,
)
event = torch.get_device_module(self._device).Event()
event.record()
Expand Down
38 changes: 32 additions & 6 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import itertools
import logging
from collections import defaultdict, OrderedDict
from contextlib import AbstractContextManager
from dataclasses import dataclass, field

from itertools import chain
Expand Down Expand Up @@ -248,6 +249,21 @@ def recursive_record_stream(
recursive_record_stream(v, stream)


class NoOpStream:
"""No-Op Context manager that takes in a stream"""

def __init__(self, stream: Optional[torch.Stream]) -> None:
self._stream = stream

def __enter__(self) -> "NoOpStream":
"""Return `self` upon entering the runtime context."""
return self

# pyre-ignore
def __exit__(self, exc_type, exc_value, traceback) -> None:
return None


class PipelinedPreproc(torch.nn.Module):
"""
Wrapper around preproc module found during model graph traversal for sparse data dist
Expand Down Expand Up @@ -297,6 +313,17 @@ def __init__(
f"Preproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!"
)

if self._dist_stream:
device: torch.device = self._dist_stream.device
# pyre-ignore
self._stream_context = (
torch.get_device_module(device).stream
if device.type in ["cuda", "mtia"]
else torch.cuda.stream
)
else:
self._stream_context = NoOpStream

@property
def preproc_module(self) -> torch.nn.Module:
return self._preproc_module
Expand Down Expand Up @@ -341,8 +368,7 @@ def forward(self, *input, **kwargs) -> Any:

with record_function(f"## sdd_input_preproc {self._context.index} ##"):
# should be no-op as we call this in dist stream
# pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream
with torch.cuda.stream(self._dist_stream):
with self._stream_context(self._dist_stream):
res = self._preproc_module(*args, **kwargs)

# Ensure preproc modules output is safe to use from default stream later
Expand All @@ -364,8 +390,7 @@ def forward(self, *input, **kwargs) -> Any:
f"Result of preproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!"
)

# pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream
with torch.cuda.stream(self._default_stream):
with self._stream_context(self._default_stream):
# Cache results, only during _start_data_dist
self._context.preproc_fwd_results[self._fqn] = res

Expand Down Expand Up @@ -760,10 +785,11 @@ def _start_embedding_lookup(
context: EmbeddingTrainPipelineContext,
source_stream: Optional[torch.Stream],
target_stream: Optional[torch.Stream],
# pyre-ignore[2]
stream_context: Callable[..., AbstractContextManager[Any, Any]],
) -> None:
module_context = context.module_contexts[module.forward.name]
# pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream
with torch.cuda.stream(source_stream):
with stream_context(source_stream):
kjt = context.input_dist_tensors_requests[module.forward.name].wait()

if target_stream is not None:
Expand Down

0 comments on commit deb7122

Please sign in to comment.