Skip to content

Commit e652b9a

Browse files
authored
Add additional tests for TensorDataset (#187)
* Add additional tests for TensorDataset * Add explicit casting to avoid windows error
1 parent 51f048d commit e652b9a

File tree

2 files changed

+119
-7
lines changed

2 files changed

+119
-7
lines changed

cebra/data/datasets.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self,
7171
super().__init__(device=device)
7272
self.neural = self._to_tensor(neural, check_dtype="float").float()
7373
self.continuous = self._to_tensor(continuous, check_dtype="float")
74-
self.discrete = self._to_tensor(discrete, check_dtype="integer")
74+
self.discrete = self._to_tensor(discrete, check_dtype="int")
7575
if self.continuous is None and self.discrete is None:
7676
raise ValueError(
7777
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
@@ -87,7 +87,7 @@ def _to_tensor(
8787
8888
Args:
8989
array: Array to check.
90-
check_dtype (list, optional): If not `None`, list of dtypes to which the values in `array`
90+
check_dtype: If not `None`, list of dtypes to which the values in `array`
9191
must belong to. Defaults to None.
9292
9393
Returns:
@@ -98,11 +98,20 @@ def _to_tensor(
9898
if isinstance(array, np.ndarray):
9999
array = torch.from_numpy(array)
100100
if check_dtype is not None:
101+
if check_dtype not in ["int", "float"]:
102+
raise ValueError(
103+
f"check_dtype must be 'int' or 'float', got {check_dtype}")
101104
if (check_dtype == "int" and not cebra_helper._is_integer(array)
102105
) or (check_dtype == "float" and
103106
not cebra_helper._is_floating(array)):
104107
raise TypeError(
105108
f"Array has type {array.dtype} instead of {check_dtype}.")
109+
if cebra_helper._is_floating(array):
110+
array = array.float()
111+
if cebra_helper._is_integer(array):
112+
# NOTE(stes): Required for standardizing number format on
113+
# windows machines.
114+
array = array.long()
106115
return array
107116

108117
@property

tests/test_datasets.py

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def test_demo():
6868

6969
@pytest.mark.requires_dataset
7070
def test_hippocampus():
71-
from cebra.datasets import hippocampus
72-
7371
pytest.skip("Outdated")
72+
73+
from cebra.datasets import hippocampus # noqa: F401
7474
dataset = cebra.datasets.init("rat-hippocampus-single")
7575
loader = cebra.data.ContinuousDataLoader(
7676
dataset=dataset,
@@ -99,7 +99,7 @@ def test_hippocampus():
9999

100100
@pytest.mark.requires_dataset
101101
def test_monkey():
102-
from cebra.datasets import monkey_reaching
102+
from cebra.datasets import monkey_reaching # noqa: F401
103103

104104
dataset = cebra.datasets.init(
105105
"area2-bump-pos-active-passive",
@@ -111,7 +111,7 @@ def test_monkey():
111111

112112
@pytest.mark.requires_dataset
113113
def test_allen():
114-
from cebra.datasets import allen
114+
from cebra.datasets import allen # noqa: F401
115115

116116
pytest.skip("Test takes too long")
117117

@@ -148,7 +148,7 @@ def test_allen():
148148
multisubject_options.extend(
149149
cebra.datasets.get_options(
150150
"rat-hippocampus-multisubjects-3fold-trial-split*"))
151-
except:
151+
except: # noqa: E722
152152
options = []
153153

154154

@@ -388,3 +388,106 @@ def test_download_file_wrong_content_disposition(filename, url,
388388
expected_checksum=expected_checksum,
389389
location=temp_dir,
390390
file_name=filename)
391+
392+
393+
@pytest.mark.parametrize("neural, continuous, discrete", [
394+
(np.random.randn(100, 30), np.random.randn(
395+
100, 2), np.random.randint(0, 5, (100,))),
396+
(np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))),
397+
(np.random.randn(200, 40), np.random.randn(200, 5), None),
398+
])
399+
def test_tensor_dataset_initialization(neural, continuous, discrete):
400+
dataset = cebra.data.datasets.TensorDataset(neural,
401+
continuous=continuous,
402+
discrete=discrete)
403+
assert dataset.neural.shape == neural.shape
404+
if continuous is not None:
405+
assert dataset.continuous.shape == continuous.shape
406+
if discrete is not None:
407+
assert dataset.discrete.shape == discrete.shape
408+
409+
410+
def test_tensor_dataset_invalid_initialization():
411+
neural = np.random.randn(100, 30)
412+
with pytest.raises(ValueError):
413+
cebra.data.datasets.TensorDataset(neural)
414+
415+
416+
@pytest.mark.parametrize("neural, continuous, discrete", [
417+
(np.random.randn(100, 30), np.random.randn(
418+
100, 2), np.random.randint(0, 5, (100,))),
419+
(np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))),
420+
(np.random.randn(200, 40), np.random.randn(200, 5), None),
421+
])
422+
def test_tensor_dataset_length(neural, continuous, discrete):
423+
dataset = cebra.data.datasets.TensorDataset(neural,
424+
continuous=continuous,
425+
discrete=discrete)
426+
assert len(dataset) == len(neural)
427+
428+
429+
@pytest.mark.parametrize("neural, continuous, discrete", [
430+
(np.random.randn(100, 30), np.random.randn(
431+
100, 2), np.random.randint(0, 5, (100,))),
432+
(np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))),
433+
(np.random.randn(200, 40), np.random.randn(200, 5), None),
434+
])
435+
def test_tensor_dataset_getitem(neural, continuous, discrete):
436+
dataset = cebra.data.datasets.TensorDataset(neural,
437+
continuous=continuous,
438+
discrete=discrete)
439+
index = torch.randint(0, len(dataset), (10,))
440+
batch = dataset[index]
441+
assert batch.shape[0] == len(index)
442+
assert batch.shape[1] == neural.shape[1]
443+
444+
445+
def test_tensor_dataset_invalid_discrete_type():
446+
neural = np.random.randn(100, 30)
447+
continuous = np.random.randn(100, 2)
448+
discrete = np.random.randn(100, 2) # Invalid type: float instead of int
449+
with pytest.raises(TypeError):
450+
cebra.data.datasets.TensorDataset(neural,
451+
continuous=continuous,
452+
discrete=discrete)
453+
454+
455+
@pytest.mark.parametrize("array, check_dtype, expected_dtype", [
456+
(np.random.randn(100, 30), "float", torch.float32),
457+
(np.random.randint(0, 5, (100, 30)), "int", torch.int64),
458+
(torch.randn(100, 30), "float", torch.float32),
459+
(torch.randint(0, 5, (100, 30)), "int", torch.int64),
460+
(None, None, None),
461+
])
462+
def test_to_tensor(array, check_dtype, expected_dtype):
463+
dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2),
464+
continuous=np.random.randn(
465+
10, 2))
466+
result = dataset._to_tensor(array, check_dtype=check_dtype)
467+
if array is None:
468+
assert result is None
469+
else:
470+
assert isinstance(result, torch.Tensor)
471+
assert result.dtype == expected_dtype
472+
473+
474+
def test_to_tensor_invalid_dtype():
475+
dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2),
476+
continuous=np.random.randn(
477+
10, 2))
478+
array = np.random.randn(100, 30)
479+
with pytest.raises(TypeError):
480+
dataset._to_tensor(array, check_dtype="int")
481+
array = np.random.randint(0, 5, (100, 30))
482+
with pytest.raises(TypeError):
483+
dataset._to_tensor(array, check_dtype="float")
484+
485+
486+
def test_to_tensor_invalid_check_dtype():
487+
dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2),
488+
continuous=np.random.randn(
489+
10, 2))
490+
array = np.random.randn(100, 30)
491+
with pytest.raises(ValueError,
492+
match="check_dtype must be 'int' or 'float', got"):
493+
dataset._to_tensor(array, check_dtype="invalid_dtype")

0 commit comments

Comments
 (0)