Skip to content

Commit

Permalink
[Feature] Plotting TensorDictSequential graphs
Browse files Browse the repository at this point in the history
ghstack-source-id: 9f27d6b67f7b0946f70d12efcb677e6139bd1ec1
Pull Request resolved: #1144
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent e073cbe commit 684dca7
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import collections
import contextlib
import logging
from copy import deepcopy
from typing import Any, Callable, Iterable, List, OrderedDict, overload
Expand Down Expand Up @@ -578,3 +579,64 @@ def __setitem__(

def __delitem__(self, index: int | slice | str) -> None:
self.module.__delitem__(idx=index)

def plot(self, example_input: TensorDictBase | None = None, **kwargs):
import pydot

graph = pydot.Dot(
"my_graph", graph_type="digraph", bgcolor="yellow", splines="curved"
)
graph.set_bgcolor("white")

if example_input is not None:
from torch._subclasses.fake_tensor import FakeTensorMode

fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
fake_td = example_input.apply(
lambda x: converter.from_real_tensor(fake_mode, x)
)
else:
fake_td = None
fake_mode = contextlib.nullcontext()

with fake_mode:
iterator = (
enumerate(self._module_iter())
if not isinstance(self.module, nn.ModuleDict)
else self.module.items()
)
for name, module in iterator:
graph.add_node(
pydot.Node(str(name), shape="box")
) # label=str(node.module)))

# Check if in_keys are there already
in_keys = module.in_keys
for in_key in in_keys:
if in_key not in graph.obj_dict["nodes"]:
in_key_node = pydot.Node(
in_key, label=in_key, shape="plaintext"
)
graph.add_node(in_key_node)
in_key_edge = pydot.Edge(
in_key, str(name), color="blue", style="arrow"
)
graph.add_edge(in_key_edge)

if not isinstance(module, TensorDictModule):
fake_td = self._run_module(module, fake_td, **kwargs)

out_keys = module.out_keys
for out_key in out_keys:
if out_key not in graph.obj_dict["nodes"]:
out_key_node = pydot.Node(
out_key, label=out_key, shape="plaintext"
)
graph.add_node(out_key_node)
out_key_edge = pydot.Edge(
str(name), out_key, color="blue", style="arrow"
)
graph.add_edge(out_key_edge)

graph.write_png("/Users/vmoens/Downloads/my_graph.png")

0 comments on commit 684dca7

Please sign in to comment.