Skip to content

Commit aee9ab5

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Model wrapper for DLRM (#3128)
Summary: Pull Request resolved: #3128 * Added model wrapper for DLRM. The wrapper will take ModelInput as an only parameter in the forward method. The forward method will return just the prediction if it's in inference mode and losses with prediction if it's in training mode. (Because training pipeline expects loss and prediction. See https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py#L670) * Added the parameterized unit tests to cover the model's wrapper Reviewed By: aliafzal Differential Revision: D77167717
1 parent bde9888 commit aee9ab5

File tree

2 files changed

+142
-1
lines changed

2 files changed

+142
-1
lines changed

torchrec/models/dlrm.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
# pyre-strict
99

10-
from typing import Dict, List, Optional, Tuple
10+
from typing import Dict, List, Optional, Tuple, Union
1111

1212
import torch
1313
from torch import nn
1414
from torchrec.datasets.utils import Batch
15+
from torchrec.distributed.test_utils.test_input import ModelInput
1516
from torchrec.modules.crossnet import LowRankCrossNet
1617
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1718
from torchrec.modules.mlp import MLP
@@ -899,3 +900,34 @@ def forward(
899900
loss = self.loss_fn(logits, batch.labels.float())
900901

901902
return loss, (loss.detach(), logits.detach(), batch.labels.detach())
903+
904+
905+
class DLRMWrapper(DLRM):
906+
# pyre-ignore[14, 15]
907+
def forward(
908+
self, model_input: ModelInput
909+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
910+
"""
911+
Forward pass for the DLRMWrapper.
912+
913+
Args:
914+
model_input (ModelInput): Contains dense and sparse features.
915+
916+
Returns:
917+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
918+
If training, returns (loss, prediction). Otherwise, returns prediction.
919+
"""
920+
pred = super().forward(
921+
dense_features=model_input.float_features,
922+
sparse_features=model_input.idlist_features, # pyre-ignore[6]
923+
)
924+
925+
if self.training:
926+
# Calculate loss and return both loss and prediction
927+
loss = torch.nn.functional.binary_cross_entropy_with_logits(
928+
pred.squeeze(), model_input.label
929+
)
930+
return (loss, pred)
931+
else:
932+
# Return just the prediction
933+
return pred

torchrec/models/tests/test_dlrm.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
# pyre-strict
99

1010
import unittest
11+
from dataclasses import dataclass
12+
from typing import List
1113

1214
import torch
15+
from parameterized import parameterized
1316
from torch import nn
1417
from torch.testing import FileCheck # @manual
1518
from torchrec.datasets.utils import Batch
19+
from torchrec.distributed.test_utils.test_input import ModelInput
1620
from torchrec.fx import symbolic_trace
1721
from torchrec.ir.serializer import JsonSerializer
1822
from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules
@@ -23,6 +27,7 @@
2327
DLRM_DCN,
2428
DLRM_Projection,
2529
DLRMTrain,
30+
DLRMWrapper,
2631
InteractionArch,
2732
InteractionDCNArch,
2833
InteractionProjectionArch,
@@ -1283,3 +1288,107 @@ def test_export_serialization(self) -> None:
12831288
deserialized_logits = deserialized_model(features, sparse_features)
12841289

12851290
self.assertEqual(deserialized_logits.size(), (B, 1))
1291+
1292+
1293+
class DLRMWrapperTest(unittest.TestCase):
1294+
@dataclass
1295+
class WrapperTestParams:
1296+
# input parameters
1297+
embedding_configs: List[EmbeddingBagConfig]
1298+
sparse_feature_keys: List[str]
1299+
sparse_feature_values: List[int]
1300+
sparse_feature_offsets: List[int]
1301+
# expected output parameters
1302+
expected_output_size: tuple[int, ...]
1303+
1304+
@parameterized.expand(
1305+
[
1306+
(
1307+
"basic_with_multiple_features",
1308+
WrapperTestParams(
1309+
embedding_configs=[
1310+
EmbeddingBagConfig(
1311+
name="t1",
1312+
embedding_dim=8,
1313+
num_embeddings=100,
1314+
feature_names=["f1", "f3"],
1315+
),
1316+
EmbeddingBagConfig(
1317+
name="t2",
1318+
embedding_dim=8,
1319+
num_embeddings=100,
1320+
feature_names=["f2"],
1321+
),
1322+
],
1323+
sparse_feature_keys=["f1", "f3", "f2"],
1324+
sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3],
1325+
sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11],
1326+
expected_output_size=(2, 1),
1327+
),
1328+
),
1329+
(
1330+
"empty_sparse_features",
1331+
WrapperTestParams(
1332+
embedding_configs=[
1333+
EmbeddingBagConfig(
1334+
name="t1",
1335+
embedding_dim=8,
1336+
num_embeddings=100,
1337+
feature_names=["f1"],
1338+
),
1339+
],
1340+
sparse_feature_keys=["f1"],
1341+
sparse_feature_values=[],
1342+
sparse_feature_offsets=[0, 0, 0],
1343+
expected_output_size=(2, 1),
1344+
),
1345+
),
1346+
]
1347+
)
1348+
def test_wrapper_functionality(
1349+
self, _test_name: str, test_params: WrapperTestParams
1350+
) -> None:
1351+
B = 2
1352+
D = 8
1353+
dense_in_features = 100
1354+
1355+
ebc = EmbeddingBagCollection(tables=test_params.embedding_configs)
1356+
1357+
dlrm_wrapper = DLRMWrapper(
1358+
embedding_bag_collection=ebc,
1359+
dense_in_features=dense_in_features,
1360+
dense_arch_layer_sizes=[20, D],
1361+
over_arch_layer_sizes=[5, 1],
1362+
)
1363+
1364+
# Create ModelInput
1365+
dense_features = torch.rand((B, dense_in_features))
1366+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
1367+
keys=test_params.sparse_feature_keys,
1368+
values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long),
1369+
offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long),
1370+
)
1371+
1372+
model_input = ModelInput(
1373+
float_features=dense_features,
1374+
idlist_features=sparse_features,
1375+
idscore_features=None,
1376+
label=torch.rand((B,)),
1377+
)
1378+
1379+
# Test eval mode - should return just logits
1380+
dlrm_wrapper.eval()
1381+
logits = dlrm_wrapper(model_input)
1382+
self.assertIsInstance(logits, torch.Tensor)
1383+
self.assertEqual(logits.size(), test_params.expected_output_size)
1384+
1385+
# Test training mode - should return (loss, logits) tuple
1386+
dlrm_wrapper.train()
1387+
result = dlrm_wrapper(model_input)
1388+
self.assertIsInstance(result, tuple)
1389+
self.assertEqual(len(result), 2)
1390+
loss, pred = result
1391+
self.assertIsInstance(loss, torch.Tensor)
1392+
self.assertIsInstance(pred, torch.Tensor)
1393+
self.assertEqual(loss.size(), ()) # scalar loss
1394+
self.assertEqual(pred.size(), test_params.expected_output_size)

0 commit comments

Comments
 (0)