Skip to content

Commit

Permalink
[Feature] Fix type assertion in Seq build
Browse files Browse the repository at this point in the history
ghstack-source-id: 5f23ec7fde998ddd68793ece58a50d6aee33fc4d
Pull Request resolved: #1143
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent 22432ff commit 2360386
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@
import inspect
import warnings
from textwrap import indent
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
MutableSequence,
Optional,
Sequence,
Tuple,
Union,
)

import torch
from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads
Expand Down Expand Up @@ -981,20 +992,20 @@ def __init__(
else:
if isinstance(in_keys, (str, tuple)):
in_keys = [in_keys]
elif not isinstance(in_keys, list):
elif not isinstance(in_keys, MutableSequence):
raise ValueError(self._IN_KEY_ERR)
self._kwargs = None

if isinstance(out_keys, (str, tuple)):
out_keys = [out_keys]
elif not isinstance(out_keys, list):
elif not isinstance(out_keys, MutableSequence):
raise ValueError(self._OUT_KEY_ERR)
try:
in_keys = unravel_key_list(in_keys)
in_keys = unravel_key_list(list(in_keys))
except Exception:
raise ValueError(self._IN_KEY_ERR)
try:
out_keys = unravel_key_list(out_keys)
out_keys = unravel_key_list(list(out_keys))
except Exception:
raise ValueError(self._OUT_KEY_ERR)

Expand Down
29 changes: 29 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest
import weakref
from collections import OrderedDict
from collections.abc import MutableSequence

import pytest
import torch
Expand Down Expand Up @@ -118,6 +119,34 @@ def test_from_str_correct_raise(self, unsupported_type_str):


class TestTDModule:
class MyMutableSequence(MutableSequence):
def __init__(self, initial_data=None):
self._data = [] if initial_data is None else list(initial_data)

def __getitem__(self, index):
return self._data[index]

def __setitem__(self, index, value):
self._data[index] = value

def __delitem__(self, index):
del self._data[index]

def __len__(self):
return len(self._data)

def insert(self, index, value):
self._data.insert(index, value)

def test_mutable_sequence(self):
in_keys = self.MyMutableSequence(["a", "b", "c"])
out_keys = self.MyMutableSequence(["d", "e", "f"])
mod = TensorDictModule(lambda *x: x, in_keys=in_keys, out_keys=out_keys)
td = mod(TensorDict(a=0, b=0, c=0))
assert "d" in td
assert "e" in td
assert "f" in td

def test_auto_unravel(self):
tdm = TensorDictModule(
lambda x: x,
Expand Down

0 comments on commit 2360386

Please sign in to comment.