From aac06084c3b5eae353eb04f773b6f7c4a118acdc Mon Sep 17 00:00:00 2001 From: Shuao Xiong Date: Mon, 28 Apr 2025 15:20:21 -0700 Subject: [PATCH] move module attribute inplace update to leaf function in ManagedCollisionModule (#2913) Summary: inplace update will cause unexpected module attribute mutation described in https://github.com/pytorch/pytorch/issues/70449 by moving it to leaf function we guaranteed no side effect during fx tracing. Reviewed By: zlzhao1104 Differential Revision: D73448087 --- torchrec/modules/mc_modules.py | 38 ++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index 1a3dffe90..a472b88e1 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -54,6 +54,7 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor: return torch.cat([jt.values() for jt in jd.values()]) +# TODO: keep the old implementation for backward compatibility and will remove it later @torch.fx.wrap def _mcc_lazy_init( features: KeyedJaggedTensor, @@ -78,6 +79,34 @@ def _mcc_lazy_init( return (features, created_feature_order, features_order) +@torch.fx.wrap +def _mcc_lazy_init_inplace( + features: KeyedJaggedTensor, + feature_names: List[str], + features_order: List[int], + created_feature_order: List[bool], +) -> KeyedJaggedTensor: + input_feature_names: List[str] = features.keys() + if not created_feature_order or not created_feature_order[0]: + for f in feature_names: + features_order.append(input_feature_names.index(f)) + + if features_order == list(range(len(input_feature_names))): + features_order.clear() + + if len(created_feature_order) > 0: + created_feature_order[0] = True + else: + created_feature_order.append(True) + + if len(features_order) > 0: + features = features.permute( + features_order, + ) + + return features + + @torch.fx.wrap def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor: return torch.tensor(kjt.length_per_key()) @@ -298,6 +327,7 @@ class ManagedCollisionCollection(nn.Module): _table_to_features: Dict[str, List[str]] _features_order: List[int] + _created_feature_order: List[bool] # use list for inplace update in leaf function def __init__( self, @@ -338,7 +368,7 @@ def __init__( self._feature_names: List[str] = [ feature for config in embedding_configs for feature in config.feature_names ] - self._created_feature_order = False + self._created_feature_order: List[bool] = [False] self._features_order = [] def _create_feature_order( @@ -360,11 +390,7 @@ def forward( self, features: KeyedJaggedTensor, ) -> KeyedJaggedTensor: - ( - features, - self._created_feature_order, - self._features_order, - ) = _mcc_lazy_init( + features = _mcc_lazy_init_inplace( features, self._feature_names, self._features_order,