From 5885b3e5e4235142ee444746aaa66c3fc92810d3 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 12 Dec 2024 11:35:00 +0000 Subject: [PATCH] 2024-12-12 nightly release (e1b5edd871256161410a082a7d72290669704624) --- torchrec/distributed/embeddingbag.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 06ad9f26e..84e033a31 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -65,6 +65,7 @@ QuantizedCommCodecs, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -938,7 +939,11 @@ def _initialize_torch_state(self) -> None: # noqa ShardedTensor._init_from_local_shards( local_shards, self._name_to_table_size[table_name], - process_group=self._env.process_group, + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), ) )