Skip to content

Commit 9898850

Browse files
authored
Fix linting errors in tests (#188)
* apply auto-fixes * Fix linting errors in tests/ * Fix version check
1 parent e652b9a commit 9898850

15 files changed

+18
-42
lines changed

tests/test_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@
2121
#
2222
def test_api():
2323
import cebra.distributions
24-
from cebra.distributions import TimedeltaDistribution
2524

2625
cebra.distributions.TimedeltaDistribution

tests/test_cli.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,3 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22-
import argparse
23-
24-
import pytest

tests/test_criterions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22-
import numpy as np
2322
import pytest
2423
import torch
2524
from torch import nn
@@ -294,7 +293,7 @@ def _sample_dist_matrices(seed):
294293

295294

296295
@pytest.mark.parametrize("seed", [42, 4242, 424242])
297-
def test_infonce(seed):
296+
def test_infonce_check_output_parts(seed):
298297
pos_dist, neg_dist = _sample_dist_matrices(seed)
299298

300299
ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)

tests/test_datasets.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def test_demo():
6969
@pytest.mark.requires_dataset
7070
def test_hippocampus():
7171
pytest.skip("Outdated")
72-
73-
from cebra.datasets import hippocampus # noqa: F401
7472
dataset = cebra.datasets.init("rat-hippocampus-single")
7573
loader = cebra.data.ContinuousDataLoader(
7674
dataset=dataset,
@@ -99,8 +97,6 @@ def test_hippocampus():
9997

10098
@pytest.mark.requires_dataset
10199
def test_monkey():
102-
from cebra.datasets import monkey_reaching # noqa: F401
103-
104100
dataset = cebra.datasets.init(
105101
"area2-bump-pos-active-passive",
106102
path=pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40",
@@ -111,8 +107,6 @@ def test_monkey():
111107

112108
@pytest.mark.requires_dataset
113109
def test_allen():
114-
from cebra.datasets import allen # noqa: F401
115-
116110
pytest.skip("Test takes too long")
117111

118112
ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111")

tests/test_demo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#
2222
import glob
2323
import re
24-
import sys
2524

2625
import pytest
2726

tests/test_distributions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def prepare(N=1000, n=128, d=5, probs=[0.3, 0.1, 0.6], device="cpu"):
4343
continuous = torch.randn(N, d).to(device)
4444

4545
rand = torch.from_numpy(np.random.randint(0, N, (n,))).to(device)
46-
qidx = discrete[rand].to(device)
46+
_ = discrete[rand].to(device)
4747
query = continuous[rand] + 0.1 * torch.randn(n, d).to(device)
4848
query = query.to(device)
4949

@@ -173,7 +173,7 @@ def test_mixed():
173173
discrete, continuous)
174174

175175
reference_idx = distribution.sample_prior(10)
176-
positive_idx = distribution.sample_conditional(reference_idx)
176+
_ = distribution.sample_conditional(reference_idx)
177177

178178
# The conditional distribution p(· | disc, cont) should yield
179179
# samples where the label exactly matches the reference sample.
@@ -193,7 +193,7 @@ def test_continuous(benchmark):
193193
def _test_distribution(dist):
194194
distribution = dist(continuous)
195195
reference_idx = distribution.sample_prior(10)
196-
positive_idx = distribution.sample_conditional(reference_idx)
196+
_ = distribution.sample_conditional(reference_idx)
197197
return distribution
198198

199199
distribution = _test_distribution(

tests/test_grid_search.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
# limitations under the License.
2121
#
2222
import numpy as np
23-
import pytest
2423

2524
import cebra
2625
import cebra.grid_search

tests/test_integration_train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
# limitations under the License.
2121
#
2222
import itertools
23-
from typing import List
2423

2524
import pytest
2625
import torch

tests/test_load.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
import itertools
2323
import pathlib
2424
import pickle
25-
import platform
2625
import tempfile
27-
import unittest
28-
from unittest.mock import patch
2926

3027
import h5py
3128
import hdf5storage
@@ -125,7 +122,7 @@ def generate_numpy_confounder(filename, dtype):
125122

126123

127124
@register("npz")
128-
def generate_numpy_path(filename, dtype):
125+
def generate_numpy_path_2(filename, dtype):
129126
A = np.arange(1000, dtype=dtype).reshape(10, 100)
130127
np.savez(filename, array=A, other_data="test")
131128
loaded_A = cebra_load.load(pathlib.Path(filename))
@@ -418,7 +415,7 @@ def generate_csv_path(filename, dtype):
418415

419416
@register_error("csv")
420417
def generate_csv_empty_file(filename, dtype):
421-
with open(filename, "w") as creating_new_csv_file:
418+
with open(filename, "w") as _:
422419
pass
423420
_ = cebra_load.load(filename)
424421

@@ -619,7 +616,6 @@ def generate_pickle_invalid_key(filename, dtype):
619616

620617
@register_error("pkl", "p")
621618
def generate_pickle_no_array(filename, dtype):
622-
A = np.arange(1000, dtype=dtype).reshape(10, 100)
623619
with open(filename, "wb") as f:
624620
pickle.dump({"A": "test_1", "B": "test_2"}, f)
625621
_ = cebra_load.load(filename)

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def test_version_check(version, raises):
155155
cebra.models.model._check_torch_version(raise_error=True)
156156

157157

158-
def test_version_check():
159-
raises = not cebra.models.model._check_torch_version(raise_error=False)
158+
def test_version_check_dropout_available():
159+
raises = cebra.models.model._check_torch_version(raise_error=False)
160160
if raises:
161161
assert len(cebra.models.get_options("*dropout*")) == 0
162162
else:

0 commit comments

Comments
 (0)