diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 7b8c8763..b3510e60 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -17,7 +17,6 @@ ) def test_take(x, data): # TODO: - # * negative indices # * different dtypes for indices # axis is optional but only if x.ndim == 1 @@ -28,7 +27,7 @@ def test_take(x, data): kw = {"axis": data.draw(_axis_st)} axis = kw.get("axis", 0) _indices = data.draw( - st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), + st.lists(st.integers(-x.shape[axis], x.shape[axis] - 1), min_size=1, unique=True), label="_indices", ) n_axis = axis if axis>=0 else x.ndim + axis @@ -77,7 +76,6 @@ def test_take(x, data): ) def test_take_along_axis(x, data): # TODO - # 2. negative indices # 3. different dtypes for indices # 4. "broadcast-compatible" indices axis = data.draw( @@ -97,7 +95,7 @@ def test_take_along_axis(x, data): hh.arrays( shape=idx_shape, dtype=dh.default_int, - elements={"min_value": 0, "max_value": x.shape[n_axis]-1} + elements={"min_value": -x.shape[n_axis], "max_value": x.shape[n_axis]-1} ), label="indices" )