Skip to content

Commit 6b94411

Browse files
seanx92facebook-github-bot
authored andcommitted
move module attribute inplace update to leaf function in ManagedCollisionModule
Summary: inplace update will cause unexpected module attribute mutation described in pytorch/pytorch#70449 by moving it to leaf function we guaranteed no side effect during fx tracing. Differential Revision: D73448087
1 parent a28ac22 commit 6b94411

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

torchrec/modules/mc_modules.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,27 @@ def _mcc_lazy_init(
5959
features: KeyedJaggedTensor,
6060
feature_names: List[str],
6161
features_order: List[int],
62-
created_feature_order: bool,
63-
) -> Tuple[KeyedJaggedTensor, bool, List[int]]: # features_order
62+
created_feature_order: List[bool],
63+
) -> KeyedJaggedTensor:
6464
input_feature_names: List[str] = features.keys()
65-
if not created_feature_order:
65+
if not created_feature_order or not created_feature_order[0]:
6666
for f in feature_names:
6767
features_order.append(input_feature_names.index(f))
6868

6969
if features_order == list(range(len(input_feature_names))):
70-
features_order = torch.jit.annotate(List[int], [])
71-
created_feature_order = True
70+
features_order.clear()
71+
72+
if len(created_feature_order) > 0:
73+
created_feature_order[0] = True
74+
else:
75+
created_feature_order.append(True)
7276

7377
if len(features_order) > 0:
7478
features = features.permute(
7579
features_order,
7680
)
7781

78-
return (features, created_feature_order, features_order)
82+
return features
7983

8084

8185
@torch.fx.wrap
@@ -298,6 +302,7 @@ class ManagedCollisionCollection(nn.Module):
298302

299303
_table_to_features: Dict[str, List[str]]
300304
_features_order: List[int]
305+
_created_feature_order: List[bool] # use list for inplace update in leaf function
301306

302307
def __init__(
303308
self,
@@ -338,7 +343,7 @@ def __init__(
338343
self._feature_names: List[str] = [
339344
feature for config in embedding_configs for feature in config.feature_names
340345
]
341-
self._created_feature_order = False
346+
self._created_feature_order = [False]
342347
self._features_order = []
343348

344349
def _create_feature_order(
@@ -360,11 +365,7 @@ def forward(
360365
self,
361366
features: KeyedJaggedTensor,
362367
) -> KeyedJaggedTensor:
363-
(
364-
features,
365-
self._created_feature_order,
366-
self._features_order,
367-
) = _mcc_lazy_init(
368+
features = _mcc_lazy_init(
368369
features,
369370
self._feature_names,
370371
self._features_order,

0 commit comments

Comments
 (0)