From 6f4bfe2c194b99605786603637c6967a2c5027e5 Mon Sep 17 00:00:00 2001 From: Boris Sarana Date: Mon, 6 Jan 2025 15:05:43 -0800 Subject: [PATCH] Move sharding optimization flag to global_settings (#2665) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2665 As per title move the configuration flag to separate module for better abstraction and simpler rollout Reviewed By: iamzainhuda Differential Revision: D67777011 fbshipit-source-id: 8a659bee7b81d3181c4014fdf2678c69b306b8c1 --- torchrec/distributed/embedding_types.py | 7 ++++--- torchrec/distributed/global_settings.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index e631f77a3..a4bdc11b5 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -9,7 +9,6 @@ import abc import copy -import os from dataclasses import dataclass from enum import Enum, unique from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union @@ -21,6 +20,9 @@ from torch.distributed._tensor.placement_types import Placement from torch.nn.modules.module import _addindent from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.global_settings import ( + construct_sharded_tensor_from_metadata_enabled, +) from torchrec.distributed.types import ( get_tensor_size_bytes, ModuleSharder, @@ -346,8 +348,7 @@ def __init__( # option to construct ShardedTensor from metadata avoiding expensive all-gather self._construct_sharded_tensor_from_metadata: bool = ( - os.environ.get("TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA", "0") - == "1" + construct_sharded_tensor_from_metadata_enabled() ) def prefetch( diff --git a/torchrec/distributed/global_settings.py b/torchrec/distributed/global_settings.py index 2b957965c..fd86ac4bb 100644 --- a/torchrec/distributed/global_settings.py +++ b/torchrec/distributed/global_settings.py @@ -7,8 +7,14 @@ # pyre-strict +import os + PROPOGATE_DEVICE: bool = False +TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = ( + "TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA" +) + def set_propogate_device(val: bool) -> None: global PROPOGATE_DEVICE @@ -18,3 +24,9 @@ def set_propogate_device(val: bool) -> None: def get_propogate_device() -> bool: global PROPOGATE_DEVICE return PROPOGATE_DEVICE + + +def construct_sharded_tensor_from_metadata_enabled() -> bool: + return ( + os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1" + )