Skip to content
22 changes: 18 additions & 4 deletions examples/commons/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np
import torch
import torch.cuda.nvtx as nvtx
import torch.fx
import torch.nn as nn
from commons.utils.nvtx_op import output_nvtx_hook, register_setter_and_getter_for_nvtx
Expand Down Expand Up @@ -394,14 +395,27 @@ def forward(self, kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
and self._data_parallel_embedding_collection is None
), "either model_parallel_embedding_collection or data_parallel_embedding_collection must be not None"
embeddings: Dict[str, JaggedTensor] = {}

mp_result = None
if self._model_parallel_embedding_collection is not None:
mp_embeddings_awaitables = self._model_parallel_embedding_collection(kjt)
embeddings = {**embeddings, **(mp_embeddings_awaitables.wait())}
mp_result = self._model_parallel_embedding_collection(kjt)

dp_result = None
if self._data_parallel_embedding_collection is not None:
with torch.cuda.stream(self._side_stream):
dp_embeddings = self._data_parallel_embedding_collection(kjt)
dp_result = self._data_parallel_embedding_collection(kjt)

if mp_result is not None:
nvtx.range_push("MP Embedding")
embeddings = {**embeddings, **mp_result.wait()}
nvtx.range_pop()

if dp_result is not None:
nvtx.range_push("DP Embedding")
torch.cuda.current_stream().wait_stream(self._side_stream)
embeddings = {**embeddings, **dp_embeddings}
embeddings = {**embeddings, **dp_result}
nvtx.range_pop()

return embeddings

def export_local_embedding(self, table_name: str) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
6,244 changes: 6,244 additions & 0 deletions examples/hstu/trace.json

Large diffs are not rendered by default.