Skip to content

Commit

Permalink
2024-12-11 nightly release (33349ec)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 11, 2024
1 parent 876c8fc commit 818174f
Show file tree
Hide file tree
Showing 15 changed files with 511 additions and 35 deletions.
55 changes: 47 additions & 8 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@
PartiallyMaterializedTensor,
)
from torch import nn
from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard
from torchrec.distributed.comm import get_local_rank, get_node_group_size
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict
from torchrec.distributed.embedding_types import (
compute_kernel_to_embedding_location,
DTensorMetadata,
GroupedEmbeddingConfig,
)
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
from torchrec.distributed.types import (
Shard,
ShardedTensor,
Expand Down Expand Up @@ -213,6 +216,7 @@ class ShardParams:
optimizer_states: List[Optional[Tuple[torch.Tensor]]]
local_metadata: List[ShardMetadata]
embedding_weights: List[torch.Tensor]
dtensor_metadata: List[DTensorMetadata]

def get_optimizer_single_value_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
Expand Down Expand Up @@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
continue
if table_config.name not in table_to_shard_params:
table_to_shard_params[table_config.name] = ShardParams(
optimizer_states=[], local_metadata=[], embedding_weights=[]
optimizer_states=[],
local_metadata=[],
embedding_weights=[],
dtensor_metadata=[],
)
optimizer_state_values = None
if optimizer_states:
Expand All @@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
table_to_shard_params[table_config.name].local_metadata.append(
local_metadata
)
table_to_shard_params[table_config.name].dtensor_metadata.append(
table_config.dtensor_metadata
)
table_to_shard_params[table_config.name].embedding_weights.append(weight)

seen_tables = set()
Expand Down Expand Up @@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
# pyre-ignore
def get_sharded_optim_state(
momentum_idx: int, state_key: str
) -> ShardedTensor:
) -> Union[ShardedTensor, DTensor]:
assert momentum_idx > 0
momentum_local_shards: List[Shard] = []
optimizer_sharded_tensor_metadata: ShardedTensorMetadata
Expand Down Expand Up @@ -528,12 +538,41 @@ def get_sharded_optim_state(
)
)

# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=momentum_local_shards,
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
process_group=self._pg,
)
# Convert optimizer state to DTensor if enabled
if table_config.dtensor_metadata:
# if rowwise state we do Shard(0), regardless of how the table is sharded
if optim_state.dim() == 1:
stride = (1,)
placements = (
(Replicate(), DTensorShard(0))
if table_config.dtensor_metadata.mesh.ndim == 2
else (DTensorShard(0),)
)
else:
stride = table_config.dtensor_metadata.stride
placements = table_config.dtensor_metadata.placements

return DTensor.from_local(
local_tensor=LocalShardsWrapper(
local_shards=[x.tensor for x in momentum_local_shards],
local_offsets=[ # pyre-ignore[6]
x.metadata.shard_offsets
for x in momentum_local_shards
],
),
device_mesh=table_config.dtensor_metadata.mesh,
placements=placements,
shape=optimizer_sharded_tensor_metadata.size,
stride=stride,
run_check=False,
)
else:
# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=momentum_local_shards,
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
process_group=self._pg,
)

num_states: int = min(
# pyre-ignore
Expand Down
52 changes: 39 additions & 13 deletions torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,15 @@ def __new__(

# we calculate the total tensor size by "concat" on second tensor dimension
cat_tensor_shape = list(local_shards[0].size())
if len(local_shards) > 1: # column-wise sharding
if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[1] += shard.size()[1]

# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[0] += shard.size()[0]

wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
wrapper_shape = torch.Size(cat_tensor_shape)
chunks_meta = [
Expand Down Expand Up @@ -110,6 +115,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
aten.equal.default: cls.handle_equal,
aten.detach.default: cls.handle_detach,
aten.clone.default: cls.handle_clone,
aten.new_empty.default: cls.handle_new_empty,
}

if func in dispatcher:
Expand Down Expand Up @@ -153,18 +159,28 @@ def handle_to_copy(args, kwargs):
def handle_view(args, kwargs):
view_shape = args[1]
res_shards_list = []
if (
len(args[0].local_shards()) > 1
and args[0].storage_metadata().size[0] == view_shape[0]
and args[0].storage_metadata().size[1] == view_shape[1]
):
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
# init calls view_as() on the global tensor shape
# will fail because the view shape is not applicable to individual shards.
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
if len(args[0].local_shards()) > 1:
if args[0].local_shards()[0].ndim == 2:
assert (
args[0].storage_metadata().size[0] == view_shape[0]
and args[0].storage_metadata().size[1] == view_shape[1]
)
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
# init calls view_as() on the global tensor shape
# will fail because the view shape is not applicable to individual shards.
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
elif args[0].local_shards()[0].ndim == 1:
assert args[0].storage_metadata().size[0] == view_shape[0]
# This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
else:
raise NotImplementedError("No support for view on tensors ndim > 2")
else:
# view is called per shard
res_shards_list = [
Expand Down Expand Up @@ -220,6 +236,16 @@ def handle_clone(args, kwargs):
]
return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())

@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_new_empty(args, kwargs):
self_ls = args[0]
return LocalShardsWrapper(
[torch.empty_like(shard) for shard in self_ls._local_shards],
self_ls.local_offsets(),
)

@property
def device(self) -> torch._C.device: # type: ignore[override]
return (
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,7 +2158,6 @@ def test_sharded_quant_mc_ec_rw(
eviction_policy=DistanceLFU_EvictionPolicy(),
)
},
# pyre-ignore [6] Incompatible parameter type
embedding_configs=mi.tables,
),
)
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/tests/test_mc_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def __init__(
),
ManagedCollisionCollection(
managed_collision_modules=mc_modules,
# pyre-ignore
embedding_configs=tables,
),
return_remapped_features=self._return_remapped,
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/tests/test_mc_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
),
ManagedCollisionCollection(
managed_collision_modules=mc_modules,
# pyre-ignore
embedding_configs=tables,
),
return_remapped_features=self._return_remapped,
Expand Down
1 change: 1 addition & 0 deletions torchrec/inference/inference_legacy/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ def test_quantize_shard_cuda(self) -> None:
quantized_model = quantize_inference_model(model)
sharded_model, _ = shard_quant_model(quantized_model)

# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`.
sharded_qebc = sharded_model._module.sparse.ebc
self.assertEqual(len(sharded_qebc.tbes), 1)
Loading

0 comments on commit 818174f

Please sign in to comment.