Skip to content

Commit

Permalink
2024-12-21 nightly release (efca1d6)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 21, 2024
1 parent db292ff commit 8e419b5
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 70 deletions.
11 changes: 7 additions & 4 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,13 @@ def __init__(self, pg: dist.ProcessGroup, device: torch.device) -> None:
# This dummy tensor is used to build the autograd graph between
# CommOp-Req and CommOp-Await. The actual forward tensors, and backwards gradient tensors
# are stored in self.tensor
self.dummy_tensor: torch.Tensor = torch.empty(
1,
requires_grad=True,
device=device,
# torch.zeros is a call_function, not placeholder, hence fx.trace incompatible.
self.dummy_tensor: torch.Tensor = torch.zeros_like(
torch.empty(
1,
requires_grad=True,
device=device,
)
)

def _wait_impl(self) -> W:
Expand Down
33 changes: 27 additions & 6 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec
from torch.distributed._tensor import DTensor
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -102,9 +103,27 @@
EC_INDEX_DEDUP: bool = False


def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
# pyre-ignore
return ps.sharding_spec.shards[0].placement.device().type
def get_device_from_parameter_sharding(
ps: ParameterSharding,
) -> TypeUnion[str, Tuple[str, ...]]:
"""
Returns list of device type per shard if table is sharded across different device type
else reutrns single device type for the table parameter
"""
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
raise ValueError("Expected EnumerableShardingSpec as input to the function")

device_type_list: Tuple[str, ...] = tuple(
# pyre-fixme[16]: `Optional` has no attribute `device`
[shard.placement.device().type for shard in ps.sharding_spec.shards]
)
if len(set(device_type_list)) == 1:
return device_type_list[0]
else:
assert (
ps.sharding_type == "row_wise"
), "Only row_wise sharding supports sharding across multiple device types for a table"
return device_type_list


def set_ec_index_dedup(val: bool) -> None:
Expand Down Expand Up @@ -248,13 +267,13 @@ def create_sharding_infos_by_sharding_device_group(
module: EmbeddingCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
fused_params: Optional[Dict[str, Any]],
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
) -> Dict[Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:

if fused_params is None:
fused_params = {}

sharding_type_device_group_to_sharding_infos: Dict[
Tuple[str, str], List[EmbeddingShardingInfo]
Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
] = {}
# state_dict returns parameter.Tensor, which loses parameter level attributes
parameter_by_name = dict(module.named_parameters())
Expand All @@ -280,7 +299,9 @@ def create_sharding_infos_by_sharding_device_group(
assert param_name in parameter_by_name or param_name in state_dict
param = parameter_by_name.get(param_name, state_dict[param_name])

device_group = get_device_from_parameter_sharding(parameter_sharding)
device_group: TypeUnion[str, Tuple[str, ...]] = (
get_device_from_parameter_sharding(parameter_sharding)
)
if (
parameter_sharding.sharding_type,
device_group,
Expand Down
15 changes: 14 additions & 1 deletion torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,25 +677,30 @@ def __init__(
grouped_configs: List[GroupedEmbeddingConfig],
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
shard_index: Optional[int] = None,
) -> None:
# TODO rename to _create_embedding_kernel
def _create_lookup(
config: GroupedEmbeddingConfig,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
shard_index: Optional[int] = None,
) -> BaseBatchedEmbedding[
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
]:
return QuantBatchedEmbedding(
config=config,
device=device,
fused_params=fused_params,
shard_index=shard_index,
)

super().__init__()
self._emb_modules: nn.ModuleList = nn.ModuleList()
for config in grouped_configs:
self._emb_modules.append(_create_lookup(config, device, fused_params))
self._emb_modules.append(
_create_lookup(config, device, fused_params, shard_index)
)

self._feature_splits: List[int] = [
config.num_features() for config in grouped_configs
Expand Down Expand Up @@ -1076,6 +1081,7 @@ def __init__(
world_size: int,
fused_params: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__()
self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = []
Expand All @@ -1089,11 +1095,18 @@ def __init__(
"meta" if device is not None and device.type == "meta" else "cuda"
)
for rank in range(world_size):
# propagate shard index to get the correct runtime_device based on shard metadata
# in case of heterogenous sharding of a single table acorss different device types
shard_index = (
rank if isinstance(device_type_from_sharding_infos, tuple) else None
)
device = rank_device(device_type, rank)
self._embedding_lookups_per_rank.append(
MetaInferGroupedEmbeddingsLookup(
grouped_configs=grouped_configs_per_rank[rank],
device=rank_device(device_type, rank),
fused_params=fused_params,
shard_index=shard_index,
)
)

Expand Down
109 changes: 75 additions & 34 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torch import nn
from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec
from torchrec.distributed.embedding import (
create_sharding_infos_by_sharding_device_group,
EmbeddingShardingInfo,
Expand Down Expand Up @@ -56,8 +57,11 @@
dtype_to_data_type,
EmbeddingConfig,
)
from torchrec.quant.embedding_modules import (
from torchrec.modules.utils import (
_fx_trec_get_feature_length,
_get_batching_hinted_output,
)
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
)
Expand All @@ -66,6 +70,7 @@

torch.fx.wrap("len")
torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("_fx_trec_get_feature_length")

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand All @@ -83,14 +88,32 @@ def record_stream(self, stream: torch.Stream) -> None:
ctx.record_stream(stream)


def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
# pyre-ignore
return ps.sharding_spec.shards[0].placement.device().type
def get_device_from_parameter_sharding(
ps: ParameterSharding,
) -> Union[str, Tuple[str, ...]]:
"""
Returns list ofdevice type / shard if table is sharded across different device type
else reutrns single device type for the table parameter
"""
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
raise ValueError("Expected EnumerableShardingSpec as input to the function")

device_type_list: Tuple[str, ...] = tuple(
# pyre-fixme[16]: `Optional` has no attribute `device`
[shard.placement.device().type for shard in ps.sharding_spec.shards]
)
if len(set(device_type_list)) == 1:
return device_type_list[0]
else:
assert (
ps.sharding_type == "row_wise"
), "Only row_wise sharding supports sharding across multiple device types for a table"
return device_type_list


def get_device_from_sharding_infos(
emb_shard_infos: List[EmbeddingShardingInfo],
) -> str:
) -> Union[str, Tuple[str, ...]]:
res = list(
{
get_device_from_parameter_sharding(ps.param_sharding)
Expand All @@ -101,6 +124,13 @@ def get_device_from_sharding_infos(
return res[0]


def get_device_for_first_shard_from_sharding_infos(
emb_shard_infos: List[EmbeddingShardingInfo],
) -> str:
device_type = get_device_from_sharding_infos(emb_shard_infos)
return device_type[0] if isinstance(device_type, tuple) else device_type


def create_infer_embedding_sharding(
sharding_type: str,
sharding_infos: List[EmbeddingShardingInfo],
Expand All @@ -112,8 +142,8 @@ def create_infer_embedding_sharding(
List[torch.Tensor],
List[torch.Tensor],
]:
device_type_from_sharding_infos: str = get_device_from_sharding_infos(
sharding_infos
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
get_device_from_sharding_infos(sharding_infos)
)

if device_type_from_sharding_infos in ["cuda", "mtia"]:
Expand All @@ -132,7 +162,9 @@ def create_infer_embedding_sharding(
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)
elif device_type_from_sharding_infos == "cpu":
elif device_type_from_sharding_infos == "cpu" or isinstance(
device_type_from_sharding_infos, tuple
):
if sharding_type == ShardingType.ROW_WISE.value:
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
Expand Down Expand Up @@ -173,17 +205,6 @@ def _get_unbucketize_tensor_via_length_alignment(
return bucketize_permute_tensor


@torch.fx.wrap
def _fx_trec_get_feature_length(
features: KeyedJaggedTensor, embedding_names: List[str]
) -> torch.Tensor:
torch._assert(
len(embedding_names) == len(features.keys()),
"embedding output and features mismatch",
)
return features.lengths()


def _construct_jagged_tensors_tw(
embeddings: List[torch.Tensor],
embedding_names_per_rank: List[List[str]],
Expand Down Expand Up @@ -327,6 +348,7 @@ def _construct_jagged_tensors(
rw_feature_length_after_bucketize: Optional[torch.Tensor],
cw_features_to_permute_indices: Dict[str, torch.Tensor],
key_to_feature_permuted_coordinates: Dict[str, torch.Tensor],
device_type: Union[str, Tuple[str, ...]],
) -> Dict[str, JaggedTensor]:

# Validating sharding type and parameters
Expand All @@ -345,15 +367,24 @@ def _construct_jagged_tensors(
features_before_input_dist_length = _fx_trec_get_feature_length(
features_before_input_dist, embedding_names
)
embeddings = [
_get_batching_hinted_output(
_fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]),
embeddings[i],
)
for i in range(len(embedding_names_per_rank))
]
input_embeddings = []
for i in range(len(embedding_names_per_rank)):
if isinstance(device_type, tuple) and device_type[i] != "cpu":
# batching hint is already propagated and passed for this case
# upstream
input_embeddings.append(embeddings[i])
else:
input_embeddings.append(
_get_batching_hinted_output(
_fx_trec_get_feature_length(
features[i], embedding_names_per_rank[i]
),
embeddings[i],
)
)

return _construct_jagged_tensors_rw(
embeddings,
input_embeddings,
embedding_names,
features_before_input_dist_length,
features_before_input_dist.values() if need_indices else None,
Expand Down Expand Up @@ -437,13 +468,13 @@ def __init__(
self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs()

self._sharding_type_device_group_to_sharding_infos: Dict[
Tuple[str, str], List[EmbeddingShardingInfo]
Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
] = create_sharding_infos_by_sharding_device_group(
module, table_name_to_parameter_sharding, fused_params
)

self._sharding_type_device_group_to_sharding: Dict[
Tuple[str, str],
Tuple[str, Union[str, Tuple[str, ...]]],
EmbeddingSharding[
InferSequenceShardingContext,
InputDistOutputs,
Expand All @@ -457,7 +488,11 @@ def __init__(
(
env
if not isinstance(env, Dict)
else env[get_device_from_sharding_infos(embedding_configs)]
else env[
get_device_for_first_shard_from_sharding_infos(
embedding_configs
)
]
),
device if get_propogate_device() else None,
)
Expand Down Expand Up @@ -580,7 +615,7 @@ def tbes_configs(

def sharding_type_device_group_to_sharding_infos(
self,
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:
return self._sharding_type_device_group_to_sharding_infos

def embedding_configs(self) -> List[EmbeddingConfig]:
Expand Down Expand Up @@ -714,6 +749,9 @@ def input_dist(
unbucketize_permute_tensor=unbucketize_permute_tensor_list[i],
bucket_mapping_tensor=bucket_mapping_tensor_list[i],
bucketized_length=bucketized_length_list[i],
embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[
i
],
)
)
return input_dist_result_list
Expand Down Expand Up @@ -796,7 +834,7 @@ def output_jt_dict(
) -> Dict[str, JaggedTensor]:
jt_dict_res: Dict[str, JaggedTensor] = {}
for (
(sharding_type, _),
(sharding_type, device_type),
emb_sharding,
features_sharding,
embedding_names,
Expand Down Expand Up @@ -844,6 +882,7 @@ def output_jt_dict(
),
cw_features_to_permute_indices=self._features_to_permute_indices,
key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates,
device_type=device_type,
)
for embedding_name in embedding_names:
jt_dict_res[embedding_name] = jt_dict[embedding_name]
Expand Down Expand Up @@ -872,7 +911,9 @@ def create_context(self) -> EmbeddingCollectionContext:
return EmbeddingCollectionContext(sharding_contexts=[])

@property
def shardings(self) -> Dict[Tuple[str, str], FeatureShardingMixIn]:
def shardings(
self,
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], FeatureShardingMixIn]:
# pyre-ignore [7]
return self._sharding_type_device_group_to_sharding

Expand Down Expand Up @@ -965,7 +1006,7 @@ def __init__(
self,
input_feature_names: List[str],
sharding_type_device_group_to_sharding: Dict[
Tuple[str, str],
Tuple[str, Union[str, Tuple[str, ...]]],
EmbeddingSharding[
InferSequenceShardingContext,
InputDistOutputs,
Expand Down
Loading

0 comments on commit 8e419b5

Please sign in to comment.