Skip to content

move module attribute inplace update to leaf function in ManagedCollisionModule #2913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
return torch.cat([jt.values() for jt in jd.values()])


# TODO: keep the old implementation for backward compatibility and will remove it later
@torch.fx.wrap
def _mcc_lazy_init(
features: KeyedJaggedTensor,
Expand All @@ -78,6 +79,34 @@ def _mcc_lazy_init(
return (features, created_feature_order, features_order)


@torch.fx.wrap
def _mcc_lazy_init_inplace(
features: KeyedJaggedTensor,
feature_names: List[str],
features_order: List[int],
created_feature_order: List[bool],
) -> KeyedJaggedTensor:
input_feature_names: List[str] = features.keys()
if not created_feature_order or not created_feature_order[0]:
for f in feature_names:
features_order.append(input_feature_names.index(f))

if features_order == list(range(len(input_feature_names))):
features_order.clear()

if len(created_feature_order) > 0:
created_feature_order[0] = True
else:
created_feature_order.append(True)

if len(features_order) > 0:
features = features.permute(
features_order,
)

return features


@torch.fx.wrap
def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor:
return torch.tensor(kjt.length_per_key())
Expand Down Expand Up @@ -298,6 +327,7 @@ class ManagedCollisionCollection(nn.Module):

_table_to_features: Dict[str, List[str]]
_features_order: List[int]
_created_feature_order: List[bool] # use list for inplace update in leaf function

def __init__(
self,
Expand Down Expand Up @@ -338,7 +368,7 @@ def __init__(
self._feature_names: List[str] = [
feature for config in embedding_configs for feature in config.feature_names
]
self._created_feature_order = False
self._created_feature_order: List[bool] = [False]
self._features_order = []

def _create_feature_order(
Expand All @@ -360,11 +390,7 @@ def forward(
self,
features: KeyedJaggedTensor,
) -> KeyedJaggedTensor:
(
features,
self._created_feature_order,
self._features_order,
) = _mcc_lazy_init(
features = _mcc_lazy_init_inplace(
features,
self._feature_names,
self._features_order,
Expand Down
Loading