7
7
8
8
import re
9
9
import warnings
10
+ from collections .abc import MutableSequence
10
11
11
12
from textwrap import indent
12
- from typing import Any , Dict , List , Optional , overload , OrderedDict
13
+ from typing import Any , Dict , List , Optional , OrderedDict , overload
13
14
14
15
import torch
15
16
@@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
621
622
log(p(z | x, y))
622
623
623
624
Args:
624
- *modules (sequence of TensorDictModules ): An ordered sequence of
625
- :class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
625
+ *modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule ): An ordered sequence of
626
+ :class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
626
627
to be run sequentially.
628
+ The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
629
+ Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
630
+ and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
627
631
628
632
Keyword Args:
629
633
partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
@@ -794,14 +798,13 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
794
798
@overload
795
799
def __init__ (
796
800
self ,
797
- modules : OrderedDict ,
801
+ modules : OrderedDict [ str , TensorDictModuleBase | ProbabilisticTensorDictModule ] ,
798
802
partial_tolerant : bool = False ,
799
803
return_composite : bool | None = None ,
800
804
aggregate_probabilities : bool | None = None ,
801
805
include_sum : bool | None = None ,
802
806
inplace : bool | None = None ,
803
- ) -> None :
804
- ...
807
+ ) -> None : ...
805
808
806
809
@overload
807
810
def __init__ (
@@ -812,8 +815,7 @@ def __init__(
812
815
aggregate_probabilities : bool | None = None ,
813
816
include_sum : bool | None = None ,
814
817
inplace : bool | None = None ,
815
- ) -> None :
816
- ...
818
+ ) -> None : ...
817
819
818
820
def __init__ (
819
821
self ,
@@ -829,7 +831,14 @@ def __init__(
829
831
"ProbabilisticTensorDictSequential must consist of zero or more "
830
832
"TensorDictModules followed by a ProbabilisticTensorDictModule"
831
833
)
832
- if not return_composite and not isinstance (
834
+ self ._ordered_dict = False
835
+ if len (modules ) == 1 and isinstance (modules [0 ], (OrderedDict , MutableSequence )):
836
+ if isinstance (modules [0 ], OrderedDict ):
837
+ modules_list = list (modules [0 ].values ())
838
+ self ._ordered_dict = True
839
+ else :
840
+ modules = modules_list = list (modules [0 ])
841
+ elif not return_composite and not isinstance (
833
842
modules [- 1 ],
834
843
(ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential ),
835
844
):
@@ -838,13 +847,22 @@ def __init__(
838
847
"an instance of ProbabilisticTensorDictModule or another "
839
848
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
840
849
)
850
+ else :
851
+ modules_list = list (modules )
852
+
841
853
# if the modules not including the final probabilistic module return the sampled
842
854
# key we won't be sampling it again, in that case
843
855
# ProbabilisticTensorDictSequential is presumably used to return the
844
856
# distribution using `get_dist` or to sample log_probabilities
845
- _ , out_keys = self ._compute_in_and_out_keys (modules [:- 1 ])
846
- self ._requires_sample = modules [- 1 ].out_keys [0 ] not in set (out_keys )
847
- self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
857
+ _ , out_keys = self ._compute_in_and_out_keys (modules_list [:- 1 ])
858
+ self ._requires_sample = modules_list [- 1 ].out_keys [0 ] not in set (out_keys )
859
+ if self ._ordered_dict :
860
+ self .__dict__ ["_det_part" ] = TensorDictSequential (
861
+ OrderedDict (list (modules [0 ].items ())[:- 1 ])
862
+ )
863
+ else :
864
+ self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
865
+
848
866
super ().__init__ (* modules , partial_tolerant = partial_tolerant )
849
867
self .return_composite = return_composite
850
868
self .aggregate_probabilities = aggregate_probabilities
@@ -885,7 +903,7 @@ def get_dist_params(
885
903
tds = self .det_part
886
904
type = interaction_type ()
887
905
if type is None :
888
- for m in reversed (self .module ):
906
+ for m in reversed (list ( self ._module_iter ()) ):
889
907
if hasattr (m , "default_interaction_type" ):
890
908
type = m .default_interaction_type
891
909
break
@@ -897,7 +915,7 @@ def get_dist_params(
897
915
@property
898
916
def num_samples (self ):
899
917
num_samples = ()
900
- for tdm in self .module :
918
+ for tdm in self ._module_iter () :
901
919
if isinstance (
902
920
tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
903
921
):
@@ -941,7 +959,7 @@ def get_dist(
941
959
942
960
td_copy = tensordict .copy ()
943
961
dists = {}
944
- for i , tdm in enumerate (self .module ):
962
+ for i , tdm in enumerate (self ._module_iter () ):
945
963
if isinstance (
946
964
tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
947
965
):
@@ -981,12 +999,21 @@ def default_interaction_type(self):
981
999
encountered is returned. If no such value is found, a default `interaction_type()` is returned.
982
1000
983
1001
"""
984
- for m in reversed (self .module ):
1002
+ for m in reversed (list ( self ._module_iter ()) ):
985
1003
interaction = getattr (m , "default_interaction_type" , None )
986
1004
if interaction is not None :
987
1005
return interaction
988
1006
return interaction_type ()
989
1007
1008
+ @property
1009
+ def _last_module (self ):
1010
+ if not self ._ordered_dict :
1011
+ return self .module [- 1 ]
1012
+ mod = None
1013
+ for mod in self ._module_iter (): # noqa: B007
1014
+ continue
1015
+ return mod
1016
+
990
1017
def log_prob (
991
1018
self ,
992
1019
tensordict ,
@@ -1103,7 +1130,7 @@ def log_prob(
1103
1130
include_sum = include_sum ,
1104
1131
** kwargs ,
1105
1132
)
1106
- last_module : ProbabilisticTensorDictModule = self .module [ - 1 ]
1133
+ last_module : ProbabilisticTensorDictModule = self ._last_module
1107
1134
out = last_module .log_prob (tensordict_inp , dist = dist , ** kwargs )
1108
1135
if is_tensor_collection (out ):
1109
1136
if tensordict_out is not None :
@@ -1162,7 +1189,7 @@ def forward(
1162
1189
else :
1163
1190
tensordict_exec = tensordict
1164
1191
if self .return_composite :
1165
- for m in self .module :
1192
+ for m in self ._module_iter () :
1166
1193
if isinstance (
1167
1194
m , (ProbabilisticTensorDictModule , ProbabilisticTensorDictModule )
1168
1195
):
@@ -1173,7 +1200,7 @@ def forward(
1173
1200
tensordict_exec = m (tensordict_exec , ** kwargs )
1174
1201
else :
1175
1202
tensordict_exec = self .get_dist_params (tensordict_exec , ** kwargs )
1176
- tensordict_exec = self .module [ - 1 ] (
1203
+ tensordict_exec = self ._last_module (
1177
1204
tensordict_exec , _requires_sample = self ._requires_sample
1178
1205
)
1179
1206
if tensordict_out is not None :
0 commit comments