Skip to content

Commit

Permalink
[Feature] Plotting TensorDictSequential graphs
Browse files Browse the repository at this point in the history
ghstack-source-id: ff93fb45f6d64b3ab960cc801631923305b879ca
Pull Request resolved: #1144
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent 2360386 commit 3b33b7b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings

from textwrap import indent
from typing import Any, Dict, List, Optional, overload, OrderedDict
from typing import Any, Dict, List, Optional, OrderedDict, overload

import torch

Expand Down Expand Up @@ -800,8 +800,7 @@ def __init__(
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None:
...
) -> None: ...

@overload
def __init__(
Expand All @@ -812,8 +811,7 @@ def __init__(
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None:
...
) -> None: ...

def __init__(
self,
Expand Down
62 changes: 62 additions & 0 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import collections
import contextlib
import logging
from copy import deepcopy
from typing import Any, Callable, Iterable, List, OrderedDict, overload
Expand Down Expand Up @@ -574,3 +575,64 @@ def __setitem__(

def __delitem__(self, index: int | slice | str) -> None:
self.module.__delitem__(idx=index)

def plot(self, example_input: TensorDictBase | None = None, **kwargs):
import pydot

graph = pydot.Dot(
"my_graph", graph_type="digraph", bgcolor="yellow", splines="curved"
)
graph.set_bgcolor("white")

if example_input is not None:
from torch._subclasses.fake_tensor import FakeTensorMode

fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
fake_td = example_input.apply(
lambda x: converter.from_real_tensor(fake_mode, x)
)
else:
fake_td = None
fake_mode = contextlib.nullcontext()

with fake_mode:
iterator = (
enumerate(self._module_iter())
if not isinstance(self.module, nn.ModuleDict)
else self.module.items()
)
for name, module in iterator:
graph.add_node(
pydot.Node(str(name), shape="box")
) # label=str(node.module)))

# Check if in_keys are there already
in_keys = module.in_keys
for in_key in in_keys:
if in_key not in graph.obj_dict["nodes"]:
in_key_node = pydot.Node(
in_key, label=in_key, shape="plaintext"
)
graph.add_node(in_key_node)
in_key_edge = pydot.Edge(
in_key, str(name), color="blue", style="arrow"
)
graph.add_edge(in_key_edge)

if not isinstance(module, TensorDictModule):
fake_td = self._run_module(module, fake_td, **kwargs)

out_keys = module.out_keys
for out_key in out_keys:
if out_key not in graph.obj_dict["nodes"]:
out_key_node = pydot.Node(
out_key, label=out_key, shape="plaintext"
)
graph.add_node(out_key_node)
out_key_edge = pydot.Edge(
str(name), out_key, color="blue", style="arrow"
)
graph.add_edge(out_key_edge)

graph.write_png("/Users/vmoens/Downloads/my_graph.png")

0 comments on commit 3b33b7b

Please sign in to comment.