Skip to content

Commit

Permalink
2024-12-13 nightly release (3928a1b)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 13, 2024
1 parent 5885b3e commit f4e58d4
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 118 deletions.
1 change: 1 addition & 0 deletions examples/retrieval/knn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_index(
res = faiss.StandardGpuResources()
# pyre-fixme[16]
config = faiss.GpuIndexIVFPQConfig()
# pyre-ignore[16]
index = faiss.GpuIndexIVFPQ(
res,
embedding_dim,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def _wait_impl(self) -> W:
"""

ret = self.wait_function.apply(self.pg, self, self.dummy_tensor)
if isinstance(ret, torch.Tensor) and ret.device.type == "cuda":
ret.record_stream(torch.get_device_module(ret.device).current_stream())
self.req = None
self.tensor = None
return ret
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import abc
import copy
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -499,7 +500,9 @@ def __init__(
shardable_params: Optional[List[str]] = None,
) -> None:
super().__init__()
self._fused_params = fused_params
self._fused_params: Optional[Dict[str, Any]] = (
copy.deepcopy(fused_params) if fused_params is not None else fused_params
)
if not shardable_params:
shardable_params = []
self._shardable_params: List[str] = shardable_params
Expand Down
46 changes: 26 additions & 20 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,32 +333,34 @@ def log(
"Perf (ms)",
"Storage (HBM, DDR)",
"Cache Load Factor",
"Pooling Factor",
"Num Poolings",
"Sum Pooling Factor",
"Sum Num Poolings",
"Num Indices",
"Output",
"Weighing",
"Weighted",
"Sharder",
"Features",
"Emb Dim (CW Dim)",
"Hash Size",
"Ranks",
],
[
"-----",
"----------",
"----------------",
"-----------",
"--------------------",
"-------------------",
"----------------",
"--------------",
"--------",
"----------",
"---------",
"----------",
"------------------",
"-----------",
"-------",
"-----", # FQN
"----------", # Sharding
"----------------", # Compute Kernel
"-----------", # Perf (ms)
"--------------------", # Storage (HBM, DDR)
"-------------------", # Cache Load Factor
"--------------------", # Sum Pooling Factor
"------------------", # Sum Num Poolings
"-------------", # Num Indices
"--------", # Output
"----------", # Weighted
"---------", # Sharder
"----------", # Features
"------------------", # Emb Dim (CW Dim)
"-----------", # Hash Size
"-------", # Ranks
],
]
feat_batch_sizes = [
Expand Down Expand Up @@ -404,9 +406,12 @@ def log(
and constraints[so.name].num_poolings
else [NUM_POOLINGS] * len(so.input_lengths)
)
num_indices = str(
round(sum(x * y for x, y in zip(so.input_lengths, num_poolings)), 3)
)
num_poolings = str(round(sum(num_poolings), 3))
output = "pooled" if so.is_pooled else "sequence"
weighing = "weighted" if so.is_weighted else "unweighted"
weighted = "weighted" if so.is_weighted else "unweighted"
sharder = sharder_map.get(get_sharder_name(type(so.module[1])), None)
sharder_name = type(sharder).__name__
num_features = len(so.input_lengths)
Expand Down Expand Up @@ -441,8 +446,9 @@ def log(
cache_load_factor,
pooling_factor,
num_poolings,
num_indices,
output,
weighing,
weighted,
sharder_name,
num_features,
embedding_dim,
Expand Down
27 changes: 27 additions & 0 deletions torchrec/distributed/tests/test_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,3 +767,30 @@ def test_all_gather_base_pooled(
specify_pg=specify_pg,
gradient_division=gradient_division,
)

@classmethod
def _test_all_gather_base_pooled_cpu(
cls,
rank: int,
world_size: int,
backend: str,
) -> None:
pg = GroupMember.WORLD
if pg is None:
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
pg = GroupMember.WORLD

device = torch.device(f"cpu")
input_tensor = torch.randn([4, 4], requires_grad=True).to(device)
comm_ops.all_gather_base_pooled(input_tensor, pg).wait()
dist.destroy_process_group()

def test_all_gather_base_pooled_cpu(
self,
) -> None:
self._run_multi_process_test(
world_size=self.WORLD_SIZE,
backend="gloo",
# pyre-ignore [6]
callable=self._test_all_gather_base_pooled_cpu,
)
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def forward(self, x):
fqn="test_module",
args=[],
context=TrainPipelineContext(),
default_stream=MagicMock(),
dist_stream=MagicMock(),
)
# self-check - we want the state dict be the same between vanilla model and "rewritten model"
self.assertDictEqual(model.state_dict(), rewritten_model.state_dict())
Expand Down
115 changes: 51 additions & 64 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def _pipeline_model(
model=self._model,
context=context,
dist_stream=self._data_dist_stream,
default_stream=torch.get_device_module(self._device).current_stream(),
batch=batch,
apply_jit=self._apply_jit,
pipelined_forward=pipelined_forward,
Expand Down Expand Up @@ -576,15 +577,14 @@ def copy_batch_to_gpu(
StopIteration: if the dataloader iterator is exhausted; unless
`self._execute_all_batches=True`, then returns None.
"""
context = None
context = self._create_context()
with record_function(f"## copy_batch_to_gpu {self._next_index} ##"):
with self._stream_context(self._memcpy_stream):
batch = self._next_batch(dataloader_iter)
if batch is not None:
batch = _to_device(batch, self._device, non_blocking=True)
elif not self._execute_all_batches:
raise StopIteration
context = self._create_context()
return batch, context

def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
Expand Down Expand Up @@ -747,25 +747,9 @@ def __init__(
)
self._start_batch = start_batch
self._stash_gradients = stash_gradients
logger.debug(f"Starting semi-sync run at batch: {self._start_batch}")

# use two data streams to support two concurrent batches
self._embedding_odd_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=0))
if device.type in ["cuda", "mtia"]
else None
)
self._embedding_even_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=0))
if device.type in ["cuda", "mtia"]
else None
)
self._overarch_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=-1))
if device.type in ["cuda", "mtia"]
else None
)
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
self._embedding_even_streams: List[Optional[torch.Stream]] = []
self._embedding_streams: List[Optional[torch.Stream]] = []
self._gradients: Dict[str, torch.Tensor] = {}

def _grad_swap(self) -> None:
Expand All @@ -778,12 +762,7 @@ def _grad_swap(self) -> None:
def _init_embedding_streams(self) -> None:

for _ in self._pipelined_modules:
self._embedding_odd_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
)
self._embedding_even_streams.append(
self._embedding_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
Expand Down Expand Up @@ -839,13 +818,9 @@ def is_semi_sync(self) -> bool:
return self.contexts[0].index >= self._start_batch
return False

def _mlp_optimizer_step(self) -> None:
def _mlp_optimizer_step(self, current_batch: int) -> None:
# special case: not all optimizers support optim.step() on null gradidents
if (
len(self.batches) >= 1
and self.contexts[0].index == self._start_batch
and self._stash_gradients
):
if current_batch == self._start_batch and self._stash_gradients:
return
self._optimizer.step()

Expand All @@ -860,42 +835,56 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
self.contexts[2],
)

losses, output = self._mlp_forward(cast(In, self.batches[0]), self.contexts[0])
batch, context = self.batches[0], self.contexts[0]
is_semi_sync = context.index is not None and context.index >= self._start_batch
iteration: int = context.index or 0
losses, output = self._mlp_forward(cast(In, batch), context)

# After this point, pipelined preproc/module forward won't be called
# so we can advance their contexts to the context of the next batch already
# and also pop batch and context from self.batches and self.contexts
self.dequeue_batch()

# batch no longer needed - delete to free up memory
del batch

# cached preproc fwd results no longer needed - delete to free up memory
del context.preproc_fwd_results

# batch i+3
self.enqueue_batch(dataloader_iter)

if len(self.batches) >= 2 and self.is_semi_sync():
if len(self.batches) >= 1 and is_semi_sync:
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])
self.start_embedding_lookup(self.batches[0], self.contexts[0])

if len(self.batches) >= 3:
self.wait_sparse_data_dist(self.contexts[2])
if len(self.batches) >= 2:
self.wait_sparse_data_dist(self.contexts[1])

if self._model.training:
with record_function(f"## backward {self.contexts[0].index} ##"):
with record_function(f"## backward {iteration} ##"):
torch.sum(losses, dim=0).backward()
# pyre-ignore [6]
self.embedding_backward(self.contexts[0])
with record_function(f"## emb_backward {iteration} ##"):
# pyre-ignore [6]
self.embedding_backward(context)

with record_function(
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
):
if self.is_semi_sync() and self._stash_gradients:
del context # context is no longer needed, deleting to free up memory

with record_function(f"## optimizer {iteration - 1} ##"):
if is_semi_sync and self._stash_gradients:
self._grad_swap()
self._mlp_optimizer_step()
self._mlp_optimizer_step(iteration)

with record_function(
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
):
with record_function(f"## zero_grad {iteration - 1} ##"):
self._optimizer.zero_grad()
else:
del context

if len(self.batches) >= 2 and not self.is_semi_sync():
if len(self.batches) >= 1 and not is_semi_sync:
torch.cuda.synchronize() # needed to avoid race condition
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])
self.start_embedding_lookup(self.batches[0], self.contexts[0])

self.dequeue_batch()
return output

def _mlp_forward(
Expand All @@ -909,14 +898,9 @@ def _mlp_forward(

def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
default_stream = torch.get_device_module(self._device).current_stream()
streams = (
self._embedding_even_streams
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams
)
assert len(context.embedding_features) == len(context.embedding_tensors)
for stream, emb_tensors, embedding_features, detached_emb_tensors in zip(
streams,
self._embedding_streams,
context.embedding_tensors,
context.embedding_features,
context.detached_embedding_tensors,
Expand All @@ -939,7 +923,9 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
embs_to_backprop.append(tensor)
grads_to_use.append(grad)
else:
if isinstance(features, Iterable):
if isinstance(features, str):
invalid_features.append(features)
elif isinstance(features, Iterable):
invalid_features.extend(features)
else:
invalid_features.append(features)
Expand Down Expand Up @@ -1012,13 +998,14 @@ def start_embedding_lookup(
batch, context, torch.get_device_module(self._device).current_stream()
)
for i, module in enumerate(self._pipelined_modules):
stream = (
self._embedding_even_streams[i]
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams[i]
)
stream = self._embedding_streams[i]
with self._stream_context(stream):
_start_embedding_lookup(module, context, stream)
_start_embedding_lookup(
module,
context,
source_stream=self._data_dist_stream,
target_stream=stream,
)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)
Expand Down
Loading

0 comments on commit f4e58d4

Please sign in to comment.