From 684dca7bdba48aaf98c24c32597da5fc5d730d31 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Dec 2024 11:14:45 +0000 Subject: [PATCH] [Feature] Plotting TensorDictSequential graphs ghstack-source-id: 9f27d6b67f7b0946f70d12efcb677e6139bd1ec1 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1144 --- tensordict/nn/sequence.py | 62 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index adb2ff314..38560fcbf 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -6,6 +6,7 @@ from __future__ import annotations import collections +import contextlib import logging from copy import deepcopy from typing import Any, Callable, Iterable, List, OrderedDict, overload @@ -578,3 +579,64 @@ def __setitem__( def __delitem__(self, index: int | slice | str) -> None: self.module.__delitem__(idx=index) + + def plot(self, example_input: TensorDictBase | None = None, **kwargs): + import pydot + + graph = pydot.Dot( + "my_graph", graph_type="digraph", bgcolor="yellow", splines="curved" + ) + graph.set_bgcolor("white") + + if example_input is not None: + from torch._subclasses.fake_tensor import FakeTensorMode + + fake_mode = FakeTensorMode() + converter = fake_mode.fake_tensor_converter + fake_td = example_input.apply( + lambda x: converter.from_real_tensor(fake_mode, x) + ) + else: + fake_td = None + fake_mode = contextlib.nullcontext() + + with fake_mode: + iterator = ( + enumerate(self._module_iter()) + if not isinstance(self.module, nn.ModuleDict) + else self.module.items() + ) + for name, module in iterator: + graph.add_node( + pydot.Node(str(name), shape="box") + ) # label=str(node.module))) + + # Check if in_keys are there already + in_keys = module.in_keys + for in_key in in_keys: + if in_key not in graph.obj_dict["nodes"]: + in_key_node = pydot.Node( + in_key, label=in_key, shape="plaintext" + ) + graph.add_node(in_key_node) + in_key_edge = pydot.Edge( + in_key, str(name), color="blue", style="arrow" + ) + graph.add_edge(in_key_edge) + + if not isinstance(module, TensorDictModule): + fake_td = self._run_module(module, fake_td, **kwargs) + + out_keys = module.out_keys + for out_key in out_keys: + if out_key not in graph.obj_dict["nodes"]: + out_key_node = pydot.Node( + out_key, label=out_key, shape="plaintext" + ) + graph.add_node(out_key_node) + out_key_edge = pydot.Edge( + str(name), out_key, color="blue", style="arrow" + ) + graph.add_edge(out_key_edge) + + graph.write_png("/Users/vmoens/Downloads/my_graph.png")