@@ -54,6 +54,7 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
5454 return torch .cat ([jt .values () for jt in jd .values ()])
5555
5656
57+ # TODO: keep the old implementation for backward compatibility and will remove it later
5758@torch .fx .wrap
5859def _mcc_lazy_init (
5960 features : KeyedJaggedTensor ,
@@ -78,6 +79,34 @@ def _mcc_lazy_init(
7879 return (features , created_feature_order , features_order )
7980
8081
82+ @torch .fx .wrap
83+ def _mcc_lazy_init_inplace (
84+ features : KeyedJaggedTensor ,
85+ feature_names : List [str ],
86+ features_order : List [int ],
87+ created_feature_order : List [bool ],
88+ ) -> KeyedJaggedTensor :
89+ input_feature_names : List [str ] = features .keys ()
90+ if not created_feature_order or not created_feature_order [0 ]:
91+ for f in feature_names :
92+ features_order .append (input_feature_names .index (f ))
93+
94+ if features_order == list (range (len (input_feature_names ))):
95+ features_order .clear ()
96+
97+ if len (created_feature_order ) > 0 :
98+ created_feature_order [0 ] = True
99+ else :
100+ created_feature_order .append (True )
101+
102+ if len (features_order ) > 0 :
103+ features = features .permute (
104+ features_order ,
105+ )
106+
107+ return features
108+
109+
81110@torch .fx .wrap
82111def _get_length_per_key (kjt : KeyedJaggedTensor ) -> torch .Tensor :
83112 return torch .tensor (kjt .length_per_key ())
@@ -298,6 +327,7 @@ class ManagedCollisionCollection(nn.Module):
298327
299328 _table_to_features : Dict [str , List [str ]]
300329 _features_order : List [int ]
330+ _created_feature_order : List [bool ] # use list for inplace update in leaf function
301331
302332 def __init__ (
303333 self ,
@@ -338,7 +368,7 @@ def __init__(
338368 self ._feature_names : List [str ] = [
339369 feature for config in embedding_configs for feature in config .feature_names
340370 ]
341- self ._created_feature_order = False
371+ self ._created_feature_order = [ False ]
342372 self ._features_order = []
343373
344374 def _create_feature_order (
@@ -360,11 +390,7 @@ def forward(
360390 self ,
361391 features : KeyedJaggedTensor ,
362392 ) -> KeyedJaggedTensor :
363- (
364- features ,
365- self ._created_feature_order ,
366- self ._features_order ,
367- ) = _mcc_lazy_init (
393+ features = _mcc_lazy_init_inplace (
368394 features ,
369395 self ._feature_names ,
370396 self ._features_order ,
0 commit comments