From 94720d532a53b7b6f573ccb70242efd965bdc9e5 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Thu, 24 Apr 2025 12:05:03 -0700 Subject: [PATCH] add virtual size support for shardedTensor Differential Revision: D73567630 --- torchrec/distributed/embedding_kernel.py | 133 +++++++++++++++++++++-- 1 file changed, 121 insertions(+), 12 deletions(-) diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index f3bb60619..bf40b8cdf 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -8,6 +8,7 @@ # pyre-strict import abc +import copy import logging from collections import defaultdict, OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union @@ -26,7 +27,12 @@ ShardedEmbeddingTable, ) from torchrec.distributed.shards_wrapper import LocalShardsWrapper -from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata +from torchrec.distributed.types import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + ShardMetadata, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor logger: logging.Logger = logging.getLogger(__name__) @@ -56,6 +62,100 @@ def config(self) -> GroupedEmbeddingConfig: pass +def manually_craft_local_metadata( + local_metadata: ShardMetadata, + param: Union[torch.Tensor, PartiallyMaterializedTensor], + my_rank: int, +) -> None: + local_metadata.shard_sizes = list(param.size()) + local_metadata.shard_offsets = [ + my_rank if dim == 0 else 0 for dim in range(len(param.size())) + ] + + +def manually_crafted_global_metadata( + metadata: ShardedTensorMetadata, + my_rank: int, + param: Union[torch.Tensor, PartiallyMaterializedTensor], +) -> None: + # update tensor properties from local tensor properties, this should be universal for all ranks + metadata.tensor_properties.dtype = param.dtype + metadata.tensor_properties.requires_grad = param.requires_grad + + # manually craft metadata, faking the metadata in a way that all other rank only has 1 row + # to pass the shardedTensor overlapping checks + # NOTE this currently only works for row-wise sharding + fake_total_rows = param.size()[0] + len(metadata.shards_metadata) - 1 + metadata.size = torch.Size( + [ + fake_total_rows if dim == 0 else param.size(dim) + for dim in range(len(param.size())) + ] + ) + + for rank, shard_metadata in enumerate(metadata.shards_metadata): + if rank < my_rank: + shard_metadata.shard_sizes = [ + 1 if dim == 0 else param.size(dim) for dim in range(len(param.size())) + ] + shard_metadata.shard_offsets = [ + rank if dim == 0 else 0 for dim in range(len(param.size())) + ] + elif rank == my_rank: + manually_craft_local_metadata(shard_metadata, param, my_rank) + else: + shard_metadata.shard_sizes = [ + 1 if dim == 0 else param.size(dim) for dim in range(len(param.size())) + ] + shard_metadata.shard_offsets = [ + rank + param.size(0) - 1 if dim == 0 else 0 + for dim in range(len(param.size())) + ] + + +def create_virtual_sharded_tensors( + embedding_tables: List[ShardedEmbeddingTable], + params: Union[List[torch.Tensor], List[PartiallyMaterializedTensor]], + pg: Optional[dist.ProcessGroup] = None, + prefix: str = "", +) -> List[ShardedTensor]: + key_to_local_shards: Dict[str, List[Shard]] = defaultdict(list) + key_to_global_metadata: Dict[str, ShardedTensorMetadata] = {} + + def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: + return prefix + f"{embedding_table.name}" + + my_rank = dist.get_rank() + for embedding_table, param in zip(embedding_tables, params): + key = get_key_from_embedding_table(embedding_table) + assert embedding_table.kv_zch + + assert embedding_table.global_metadata is not None and pg is not None + global_metadata = copy.deepcopy(embedding_table.global_metadata) + manually_crafted_global_metadata(global_metadata, my_rank, param) + key_to_global_metadata[key] = global_metadata + + assert embedding_table.local_metadata is not None + local_metadata = copy.deepcopy(embedding_table.local_metadata) + manually_craft_local_metadata(local_metadata, param, my_rank) + + key_to_local_shards[key].append(Shard(param, local_metadata)) + + result: List[ShardedTensor] = [] + if pg is not None: + for key in key_to_local_shards: + global_metadata = key_to_global_metadata[key] + result.append( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=key_to_local_shards[key], + sharded_tensor_metadata=global_metadata, + process_group=pg, + skip_tensor_size_check=True, + ) + ) + return result + + def get_state_dict( embedding_tables: List[ShardedEmbeddingTable], params: Union[ @@ -68,6 +168,7 @@ def get_state_dict( pg: Optional[dist.ProcessGroup] = None, destination: Optional[Dict[str, Any]] = None, prefix: str = "", + use_virtual_size: bool = False, ) -> Dict[str, Any]: if destination is None: destination = OrderedDict() @@ -84,11 +185,15 @@ def get_state_dict( # pyre-ignore[33] key_to_local_tensor_shards: Dict[str, List[Any]] = defaultdict(list) + # validate on the function input for kv zch cases + for emb_table in embedding_tables: + assert use_virtual_size == emb_table.kv_zch + def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: return prefix + f"{embedding_table.name}.weight" for embedding_table, param in zip(embedding_tables, params): - key = get_key_from_embedding_table(embedding_table) + weights_key = get_key_from_embedding_table(embedding_table) is_quant = embedding_table.compute_kernel in [ EmbeddingComputeKernel.QUANT, EmbeddingComputeKernel.QUANT_UVM, @@ -103,17 +208,18 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: qbias = param[2] param = param[0] - assert embedding_table.local_rows == param.size( # pyre-ignore[16] - 0 - ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16] + if not embedding_table.kv_zch: + assert embedding_table.local_rows == param.size( # pyre-ignore[16] + 0 + ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16] if qscale is not None: assert embedding_table.local_cols == param.size(1) # pyre-ignore[16] if embedding_table.dtensor_metadata is not None and pg is not None: # DTensor path - key_to_dtensor_metadata[key] = embedding_table.dtensor_metadata - key_to_local_tensor_shards[key].append( + key_to_dtensor_metadata[weights_key] = embedding_table.dtensor_metadata + key_to_local_tensor_shards[weights_key].append( [ param, embedding_table.local_metadata.shard_offsets, # pyre-ignore[16] @@ -127,21 +233,23 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: embedding_table.global_metadata.tensor_properties.requires_grad = ( param.requires_grad # pyre-ignore[16] ) - key_to_global_metadata[key] = embedding_table.global_metadata + key_to_global_metadata[weights_key] = embedding_table.global_metadata - key_to_local_shards[key].append( + # for kv zch cases, we use virtual space, the logic will be the same as non-kv zch cases + key_to_local_shards[weights_key].append( # pyre-fixme[6]: For 1st argument expected `Tensor` but got # `Union[Module, Tensor]`. # pyre-fixme[6]: For 2nd argument expected `ShardMetadata` but got # `Optional[ShardMetadata]`. Shard(param, embedding_table.local_metadata) ) + else: - destination[key] = param + destination[weights_key] = param if qscale is not None: - destination[f"{key}_qscale"] = qscale + destination[f"{weights_key}_qscale"] = qscale if qbias is not None: - destination[f"{key}_qbias"] = qbias + destination[f"{weights_key}_qbias"] = qbias if pg is not None: # Populate the remaining destinations that have a global metadata @@ -152,6 +260,7 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: local_shards=key_to_local_shards[key], sharded_tensor_metadata=global_metadata, process_group=pg, + skip_tensor_size_check=use_virtual_size, ) ) # DTensor path