diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index e631f77a3..a4bdc11b5 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -9,7 +9,6 @@ import abc import copy -import os from dataclasses import dataclass from enum import Enum, unique from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union @@ -21,6 +20,9 @@ from torch.distributed._tensor.placement_types import Placement from torch.nn.modules.module import _addindent from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.global_settings import ( + construct_sharded_tensor_from_metadata_enabled, +) from torchrec.distributed.types import ( get_tensor_size_bytes, ModuleSharder, @@ -346,8 +348,7 @@ def __init__( # option to construct ShardedTensor from metadata avoiding expensive all-gather self._construct_sharded_tensor_from_metadata: bool = ( - os.environ.get("TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA", "0") - == "1" + construct_sharded_tensor_from_metadata_enabled() ) def prefetch( diff --git a/torchrec/distributed/global_settings.py b/torchrec/distributed/global_settings.py index 2b957965c..fd86ac4bb 100644 --- a/torchrec/distributed/global_settings.py +++ b/torchrec/distributed/global_settings.py @@ -7,8 +7,14 @@ # pyre-strict +import os + PROPOGATE_DEVICE: bool = False +TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = ( + "TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA" +) + def set_propogate_device(val: bool) -> None: global PROPOGATE_DEVICE @@ -18,3 +24,9 @@ def set_propogate_device(val: bool) -> None: def get_propogate_device() -> bool: global PROPOGATE_DEVICE return PROPOGATE_DEVICE + + +def construct_sharded_tensor_from_metadata_enabled() -> bool: + return ( + os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1" + )