Skip to content

Commit

Permalink
[Feature] td_flatten_with_keys (#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 13, 2024
1 parent 010dfb2 commit 86d6406
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 11 deletions.
67 changes: 60 additions & 7 deletions tensordict/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Tuple

import torch
from tensordict import (
LazyStackedTensorDict,
PersistentTensorDict,
SubTensorDict,
TensorDict,
TensorDictBase,
)
from tensordict.utils import implement_for

try:
from torch.utils._pytree import Context, register_pytree_node
from torch.utils._pytree import Context, MappingKey, register_pytree_node
except ImportError:
from torch.utils._pytree import (
_register_pytree_node as register_pytree_node,
Expand Down Expand Up @@ -80,8 +83,16 @@ def _str_to_tensordictdict(str_spec: str) -> Tuple[List[str], str]:


def _tensordict_flatten(d: TensorDict) -> Tuple[List[Any], Context]:
return list(d.values()), {
"keys": list(d.keys()),
items = tuple(d.items())
if items:
keys, values = zip(*d.items())
keys = list(keys)
values = list(values)
else:
keys = []
values = []
return values, {
"keys": keys,
"batch_size": d.batch_size,
"names": d.names,
"device": d.device,
Expand All @@ -96,17 +107,59 @@ def _tensordictdict_unflatten(values: List[Any], context: Context) -> Dict[Any,
if all(val.device == device for val in values if hasattr(val, "device"))
else None
)
batch_size = context["batch_size"]
names = context["names"]
keys = context["keys"]
batch_dims = len(batch_size)
if any(tensor.shape[:batch_dims] != batch_size for tensor in values):
batch_size = torch.Size([])
names = None
return TensorDict(
dict(zip(context["keys"], values)),
context["batch_size"],
names=context["names"],
dict(zip(keys, values)),
batch_size=batch_size,
names=names,
device=device,
_run_checks=False,
)


for cls in PYTREE_REGISTERED_TDS:
def _td_flatten_with_keys(
d: TensorDictBase,
):
items = tuple(d.items())
if items:
keys, values = zip(*d.items())
keys = list(keys)
values = list(values)
else:
keys = []
values = []
return [(MappingKey(k), v) for k, v in zip(keys, values)], {
"keys": keys,
"batch_size": d.batch_size,
"names": d.names,
"device": d.device,
}


@implement_for("torch", None, "2.3")
def _register_td_node(cls):
register_pytree_node(
cls,
_tensordict_flatten,
_tensordictdict_unflatten,
)


@implement_for("torch", "2.3")
def _register_td_node(cls): # noqa: F811
register_pytree_node(
cls,
_tensordict_flatten,
_tensordictdict_unflatten,
flatten_with_keys_fn=_td_flatten_with_keys,
)


for cls in PYTREE_REGISTERED_TDS:
_register_td_node(cls)
7 changes: 4 additions & 3 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,10 @@ def __init__(
implement_for._setters.append(self)

@staticmethod
def check_version(version, from_version, to_version):
return (from_version is None or parse(version) >= parse(from_version)) and (
to_version is None or parse(version) < parse(to_version)
def check_version(version: str, from_version: str | None, to_version: str | None):
version = parse(".".join([str(v) for v in parse(version).release]))
return (from_version is None or version >= parse(from_version)) and (
to_version is None or version < parse(to_version)
)

@staticmethod
Expand Down
32 changes: 31 additions & 1 deletion test/test_functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch

from _utils_internal import expand_list, TestTensorDictsBase
from _utils_internal import expand_list, get_available_devices, TestTensorDictsBase

from tensordict import LazyStackedTensorDict, TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
Expand All @@ -17,6 +17,7 @@
make_functional,
repopulate_module,
)
from tensordict.utils import implement_for
from torch import nn
from torch.nn import Linear
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -633,3 +634,32 @@ def test_pytree_vs_apply(self):
for v1, v2 in zip(td_pytree.values(True), td_apply.values(True)):
# recursively checks the shape, including for the nested tensordicts
assert v1.shape == v2.shape

@implement_for("torch", "2.3")
def test_map_with_path(self):
def assert_path(path, tensor):
assert path[0].key == "a"
assert path[1].key == "b"
assert path[2].key == "c"
return tensor

td = TensorDict({"a": {"b": {"c": [1]}}}, [1])
torch.utils._pytree.tree_map_with_path(assert_path, td)

@implement_for("torch", None, "2.3")
def test_map_with_path(self): # noqa: F811
pytest.skip(reason="tree_map_with_path not implemented")

@pytest.mark.parametrize("dest", get_available_devices())
def test_device_map(self, dest):
td = TensorDict({"a": {"b": {"c": [1]}, "d": [2]}}, [1], device="cpu")
td_device = tree_map(lambda x: x.to(dest), td)
if dest == torch.device("cpu"):
assert td_device.device == torch.device("cpu")
else:
assert td_device.device is None

def test_shape_map(self):
td = TensorDict({"a": {"b": {"c": [1]}, "d": [2]}}, [1])
td_no_shape = tree_map(lambda x: x.squeeze(), td)
assert td_no_shape.shape == torch.Size([])

0 comments on commit 86d6406

Please sign in to comment.