Skip to content

Commit 5584ffa

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 1bd2406 + a1fe539 commit 5584ffa

File tree

3 files changed

+70
-25
lines changed

3 files changed

+70
-25
lines changed

tensordict/nn/probabilistic.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import re
99
import warnings
10+
from collections.abc import MutableSequence
1011

1112
from textwrap import indent
12-
from typing import Any, Dict, List, Optional, overload, OrderedDict
13+
from typing import Any, Dict, List, Optional, OrderedDict, overload
1314

1415
import torch
1516

@@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
621622
log(p(z | x, y))
622623
623624
Args:
624-
*modules (sequence of TensorDictModules): An ordered sequence of
625-
:class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
625+
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule): An ordered sequence of
626+
:class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
626627
to be run sequentially.
628+
The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
629+
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
630+
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
627631
628632
Keyword Args:
629633
partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
@@ -794,14 +798,13 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
794798
@overload
795799
def __init__(
796800
self,
797-
modules: OrderedDict,
801+
modules: OrderedDict[str, TensorDictModuleBase | ProbabilisticTensorDictModule],
798802
partial_tolerant: bool = False,
799803
return_composite: bool | None = None,
800804
aggregate_probabilities: bool | None = None,
801805
include_sum: bool | None = None,
802806
inplace: bool | None = None,
803-
) -> None:
804-
...
807+
) -> None: ...
805808

806809
@overload
807810
def __init__(
@@ -812,8 +815,7 @@ def __init__(
812815
aggregate_probabilities: bool | None = None,
813816
include_sum: bool | None = None,
814817
inplace: bool | None = None,
815-
) -> None:
816-
...
818+
) -> None: ...
817819

818820
def __init__(
819821
self,
@@ -829,7 +831,14 @@ def __init__(
829831
"ProbabilisticTensorDictSequential must consist of zero or more "
830832
"TensorDictModules followed by a ProbabilisticTensorDictModule"
831833
)
832-
if not return_composite and not isinstance(
834+
self._ordered_dict = False
835+
if len(modules) == 1 and isinstance(modules[0], (OrderedDict, MutableSequence)):
836+
if isinstance(modules[0], OrderedDict):
837+
modules_list = list(modules[0].values())
838+
self._ordered_dict = True
839+
else:
840+
modules = modules_list = list(modules[0])
841+
elif not return_composite and not isinstance(
833842
modules[-1],
834843
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
835844
):
@@ -838,13 +847,22 @@ def __init__(
838847
"an instance of ProbabilisticTensorDictModule or another "
839848
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
840849
)
850+
else:
851+
modules_list = list(modules)
852+
841853
# if the modules not including the final probabilistic module return the sampled
842854
# key we won't be sampling it again, in that case
843855
# ProbabilisticTensorDictSequential is presumably used to return the
844856
# distribution using `get_dist` or to sample log_probabilities
845-
_, out_keys = self._compute_in_and_out_keys(modules[:-1])
846-
self._requires_sample = modules[-1].out_keys[0] not in set(out_keys)
847-
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
857+
_, out_keys = self._compute_in_and_out_keys(modules_list[:-1])
858+
self._requires_sample = modules_list[-1].out_keys[0] not in set(out_keys)
859+
if self._ordered_dict:
860+
self.__dict__["_det_part"] = TensorDictSequential(
861+
OrderedDict(list(modules[0].items())[:-1])
862+
)
863+
else:
864+
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
865+
848866
super().__init__(*modules, partial_tolerant=partial_tolerant)
849867
self.return_composite = return_composite
850868
self.aggregate_probabilities = aggregate_probabilities
@@ -885,7 +903,7 @@ def get_dist_params(
885903
tds = self.det_part
886904
type = interaction_type()
887905
if type is None:
888-
for m in reversed(self.module):
906+
for m in reversed(list(self._module_iter())):
889907
if hasattr(m, "default_interaction_type"):
890908
type = m.default_interaction_type
891909
break
@@ -897,7 +915,7 @@ def get_dist_params(
897915
@property
898916
def num_samples(self):
899917
num_samples = ()
900-
for tdm in self.module:
918+
for tdm in self._module_iter():
901919
if isinstance(
902920
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
903921
):
@@ -941,7 +959,7 @@ def get_dist(
941959

942960
td_copy = tensordict.copy()
943961
dists = {}
944-
for i, tdm in enumerate(self.module):
962+
for i, tdm in enumerate(self._module_iter()):
945963
if isinstance(
946964
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
947965
):
@@ -981,12 +999,21 @@ def default_interaction_type(self):
981999
encountered is returned. If no such value is found, a default `interaction_type()` is returned.
9821000
9831001
"""
984-
for m in reversed(self.module):
1002+
for m in reversed(list(self._module_iter())):
9851003
interaction = getattr(m, "default_interaction_type", None)
9861004
if interaction is not None:
9871005
return interaction
9881006
return interaction_type()
9891007

1008+
@property
1009+
def _last_module(self):
1010+
if not self._ordered_dict:
1011+
return self.module[-1]
1012+
mod = None
1013+
for mod in self._module_iter(): # noqa: B007
1014+
continue
1015+
return mod
1016+
9901017
def log_prob(
9911018
self,
9921019
tensordict,
@@ -1103,7 +1130,7 @@ def log_prob(
11031130
include_sum=include_sum,
11041131
**kwargs,
11051132
)
1106-
last_module: ProbabilisticTensorDictModule = self.module[-1]
1133+
last_module: ProbabilisticTensorDictModule = self._last_module
11071134
out = last_module.log_prob(tensordict_inp, dist=dist, **kwargs)
11081135
if is_tensor_collection(out):
11091136
if tensordict_out is not None:
@@ -1162,7 +1189,7 @@ def forward(
11621189
else:
11631190
tensordict_exec = tensordict
11641191
if self.return_composite:
1165-
for m in self.module:
1192+
for m in self._module_iter():
11661193
if isinstance(
11671194
m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule)
11681195
):
@@ -1173,7 +1200,7 @@ def forward(
11731200
tensordict_exec = m(tensordict_exec, **kwargs)
11741201
else:
11751202
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
1176-
tensordict_exec = self.module[-1](
1203+
tensordict_exec = self._last_module(
11771204
tensordict_exec, _requires_sample=self._requires_sample
11781205
)
11791206
if tensordict_out is not None:

tensordict/nn/sequence.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,18 @@ class TensorDictSequential(TensorDictModule):
5353
buffers) will be concatenated in a single list.
5454
5555
Args:
56-
modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
56+
modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]):
57+
ordered sequence of callables that take a TensorDictBase as input and return a TensorDictBase.
58+
These can be instances of TensorDictModuleBase or any other function that matches this signature.
59+
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
60+
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
5761
Keyword Args:
5862
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
5963
If so, the only module that will be executed are those who can be executed given the keys that
6064
are present.
6165
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
6266
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
63-
looking for those that have the required keys, if any.
67+
looking for those that have the required keys, if any. Defaults to False.
6468
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
6569
``out_keys`` will be written.
6670

test/test_nn.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,9 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist):
480480

481481
in_keys = ["in"]
482482
net = TensorDictModule(module=net, in_keys=in_keys, out_keys=out_keys)
483+
corr = TensorDictModule(
484+
lambda low: max_dist - low.abs(), in_keys=out_keys, out_keys=out_keys
485+
)
483486

484487
kwargs = {
485488
"distribution_class": distributions.Uniform,
@@ -494,7 +497,7 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist):
494497
in_keys=dist_in_keys, out_keys=["out"], **kwargs
495498
)
496499

497-
tensordict_module = ProbabilisticTensorDictSequential(net, prob_module)
500+
tensordict_module = ProbabilisticTensorDictSequential(net, corr, prob_module)
498501
assert tensordict_module.default_interaction_type is not None
499502

500503
td = TensorDict({"in": torch.randn(3, 3)}, [3])
@@ -2156,6 +2159,8 @@ def test_nested_keys_probabilistic_normal(self, log_prob_key):
21562159
in_keys=[("data", "states")],
21572160
out_keys=[("data", "scale")],
21582161
)
2162+
scale_module.module.weight.data.abs_()
2163+
scale_module.module.bias.data.abs_()
21592164
td = TensorDict(
21602165
{"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]
21612166
)
@@ -3019,7 +3024,8 @@ def test_prob_module_nested(self, interaction, map_names):
30193024
"interaction", [InteractionType.MODE, InteractionType.MEAN]
30203025
)
30213026
@pytest.mark.parametrize("return_log_prob", [True, False])
3022-
def test_prob_module_seq(self, interaction, return_log_prob):
3027+
@pytest.mark.parametrize("ordereddict", [True, False])
3028+
def test_prob_module_seq(self, interaction, return_log_prob, ordereddict):
30233029
params = TensorDict(
30243030
{
30253031
"params": {
@@ -3042,7 +3048,7 @@ def test_prob_module_seq(self, interaction, return_log_prob):
30423048
("nested", "cont"): distributions.Normal,
30433049
}
30443050
backbone = TensorDictModule(lambda: None, in_keys=[], out_keys=[])
3045-
module = ProbabilisticTensorDictSequential(
3051+
args = [
30463052
backbone,
30473053
ProbabilisticTensorDictModule(
30483054
in_keys=in_keys,
@@ -3052,7 +3058,15 @@ def test_prob_module_seq(self, interaction, return_log_prob):
30523058
default_interaction_type=interaction,
30533059
return_log_prob=return_log_prob,
30543060
),
3055-
)
3061+
]
3062+
if ordereddict:
3063+
args = [
3064+
OrderedDict(
3065+
backbone=args[0],
3066+
proba=args[1],
3067+
)
3068+
]
3069+
module = ProbabilisticTensorDictSequential(*args)
30563070
sample = module(params)
30573071
if return_log_prob:
30583072
assert "cont_log_prob" in sample.keys()

0 commit comments

Comments
 (0)