From d36fdec17fa853b08fff0aa4e04b516c21225e67 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 7 Jan 2025 10:15:20 -0800 Subject: [PATCH] Re-land D66465376 (#2637) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2637 Re-land diff D66465376 NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict` ``` def test_td_scripting(self) -> None: class TestModule(torch.nn.Module): torch.jit.ignore # <----- test fails without this ignore def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor: if isinstance(x, TensorDict): keys = list(x.keys()) return torch.cat([x[key]._values for key in keys], dim=0) else: return x._values m = TestModule() gm = torch.fx.symbolic_trace(m) jm = torch.jit.script(gm) values = torch.tensor([0, 1, 2, 3, 2, 3, 4]) kjt = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], values=values, offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), ) torch.testing.assert_allclose(jm(kjt), values) ``` Reviewed By: dstaay-fb Differential Revision: D66460392 fbshipit-source-id: 6fe35ebf2d1ebbac11b7cbba5efda1af1026028e --- torchrec/distributed/embedding.py | 8 ++++++++ torchrec/distributed/embeddingbag.py | 7 +++++++ torchrec/modules/embedding_modules.py | 8 ++++++++ torchrec/sparse/jagged_tensor.py | 7 +++++-- 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index b33d81635..9c314f6fa 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -97,6 +97,14 @@ except OSError: pass +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + + logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 8cfd16ae9..ca9f6a18e 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -102,6 +102,13 @@ except OSError: pass +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 307d66639..9a1878361 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -21,6 +21,14 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + + @torch.fx.wrap def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 15952bfa5..8468c9977 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -49,9 +49,12 @@ # OSS try: - pass + from tensordict import TensorDict except ImportError: - pass + + class TensorDict: + pass + logger: logging.Logger = logging.getLogger()