Skip to content

Commit a33fa82

Browse files
committed
Move tests for geomspace and logspace to dedicated classes to combine them
1 parent 60fc4e2 commit a33fa82

File tree

1 file changed

+110
-126
lines changed

1 file changed

+110
-126
lines changed

dpnp/tests/test_arraycreation.py

Lines changed: 110 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
get_array,
2222
get_float_dtypes,
2323
has_support_aspect64,
24-
is_lts_driver,
25-
is_tgllp_iris_xe,
26-
is_win_platform,
2724
)
2825
from .third_party.cupy import testing
2926

@@ -85,6 +82,61 @@ def test_validate_positional_args(self, xp):
8582
)
8683

8784

85+
class TestGeomspace:
86+
@pytest.mark.parametrize("sign", [-1, 1])
87+
@pytest.mark.parametrize("dtype", get_all_dtypes())
88+
@pytest.mark.parametrize("num", [2, 4, 8, 3, 9, 27])
89+
@pytest.mark.parametrize("endpoint", [True, False])
90+
def test_basic(self, sign, dtype, num, endpoint):
91+
start = 2 * sign
92+
stop = 127 * sign
93+
94+
func = lambda xp: xp.geomspace(
95+
start, stop, num, endpoint=endpoint, dtype=dtype
96+
)
97+
98+
np_res = func(numpy)
99+
dpnp_res = func(dpnp)
100+
101+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
102+
103+
@pytest.mark.parametrize("start", [1j, 1 + 1j])
104+
@pytest.mark.parametrize("stop", [10j, 10 + 10j])
105+
def test_complex(self, start, stop):
106+
func = lambda xp: xp.geomspace(start, stop, num=10)
107+
np_res = func(numpy)
108+
dpnp_res = func(dpnp)
109+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
110+
111+
@pytest.mark.parametrize("axis", [0, 1])
112+
def test_axis(self, axis):
113+
func = lambda xp: xp.geomspace([2, 3], [20, 15], num=10, axis=axis)
114+
np_res = func(numpy)
115+
dpnp_res = func(dpnp)
116+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
117+
118+
def test_num_zero(self):
119+
func = lambda xp: xp.geomspace(1, 10, num=0, endpoint=False)
120+
np_res = func(numpy)
121+
dpnp_res = func(dpnp)
122+
assert_allclose(dpnp_res, np_res)
123+
124+
@pytest.mark.parametrize(
125+
"start, stop, num",
126+
[
127+
(0, 5, 3),
128+
(2, 0, 3),
129+
(0, 0, 3),
130+
(dpnp.array([0]), 7, 10),
131+
(-2, numpy.array([[0]]), 7),
132+
([2, 4, 0], 3, 5),
133+
],
134+
)
135+
def test_zero_error(self, start, stop, num):
136+
with pytest.raises(ValueError):
137+
dpnp.geomspace(start, stop, num)
138+
139+
88140
class TestLinspace:
89141
@pytest.mark.parametrize("start", [0, -5, 10, -2.5, 9.7])
90142
@pytest.mark.parametrize("stop", [0, 10, -2, 20.5, 120])
@@ -210,6 +262,61 @@ def test_float_num(self, xp):
210262
_ = xp.linspace(0, 1, num=2.5)
211263

212264

265+
class TestLogspace:
266+
@pytest.mark.parametrize("dtype", get_all_dtypes())
267+
@pytest.mark.parametrize("num", [2, 4, 8, 3, 9, 27])
268+
@pytest.mark.parametrize("endpoint", [True, False])
269+
def test_basic(self, dtype, num, endpoint):
270+
start = 2
271+
stop = 5
272+
base = 2
273+
274+
func = lambda xp: xp.logspace(
275+
start, stop, num, endpoint=endpoint, dtype=dtype, base=base
276+
)
277+
278+
np_res = func(numpy)
279+
dpnp_res = func(dpnp)
280+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
281+
282+
@testing.with_requires("numpy>=1.25.0")
283+
@pytest.mark.parametrize("axis", [0, 1])
284+
def test_axis(self, axis):
285+
func = lambda xp: xp.logspace(
286+
[2, 3], [20, 15], num=2, base=[[1, 3], [5, 7]], axis=axis
287+
)
288+
assert_dtype_allclose(func(dpnp), func(numpy))
289+
290+
def test_list_input(self):
291+
expected = numpy.logspace([0], [2], base=[5])
292+
result = dpnp.logspace([0], [2], base=[5])
293+
assert_dtype_allclose(result, expected)
294+
295+
296+
class TestSpaceLike:
297+
@pytest.mark.parametrize("func", ["geomspace", "linspace", "logspace"])
298+
@pytest.mark.parametrize(
299+
"start_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
300+
)
301+
@pytest.mark.parametrize(
302+
"stop_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
303+
)
304+
def test_numpy_dtype(self, func, start_dtype, stop_dtype):
305+
start = numpy.array([1, 2, 3], dtype=start_dtype)
306+
stop = numpy.array([11, 7, -2], dtype=stop_dtype)
307+
getattr(dpnp, func)(start, stop, 10)
308+
309+
@pytest.mark.parametrize("xp", [dpnp, numpy])
310+
@pytest.mark.parametrize("func", ["geomspace", "logspace"])
311+
@pytest.mark.parametrize(
312+
"start, stop, num",
313+
[(2, 5, -3), ([2, 3], 5, -3)],
314+
)
315+
def test_space_num_error(self, xp, func, start, stop, num):
316+
with pytest.raises(ValueError):
317+
getattr(xp, func)(start, stop, num)
318+
319+
213320
class TestTrace:
214321
@pytest.mark.parametrize("a_sh", [(3, 4), (2, 2, 2)])
215322
@pytest.mark.parametrize(
@@ -871,19 +978,6 @@ def test_dpctl_tensor_input(func, args):
871978
assert_array_equal(X, Y)
872979

873980

874-
@pytest.mark.parametrize("func", ["geomspace", "linspace", "logspace"])
875-
@pytest.mark.parametrize(
876-
"start_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
877-
)
878-
@pytest.mark.parametrize(
879-
"stop_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
880-
)
881-
def test_space_numpy_dtype(func, start_dtype, stop_dtype):
882-
start = numpy.array([1, 2, 3], dtype=start_dtype)
883-
stop = numpy.array([11, 7, -2], dtype=stop_dtype)
884-
getattr(dpnp, func)(start, stop, 10)
885-
886-
887981
@pytest.mark.parametrize(
888982
"arrays",
889983
[[], [[1]], [[1, 2, 3], [4, 5, 6]], [[1, 2], [3, 4], [5, 6]]],
@@ -908,116 +1002,6 @@ def test_set_shape(shape):
9081002
assert_array_equal(na, da)
9091003

9101004

911-
@pytest.mark.parametrize(
912-
"start, stop, num",
913-
[
914-
(0, 5, 3),
915-
(2, 0, 3),
916-
(0, 0, 3),
917-
(dpnp.array([0]), 7, 10),
918-
(-2, numpy.array([[0]]), 7),
919-
([2, 4, 0], 3, 5),
920-
],
921-
)
922-
def test_geomspace_zero_error(start, stop, num):
923-
with pytest.raises(ValueError):
924-
dpnp.geomspace(start, stop, num)
925-
926-
927-
@pytest.mark.parametrize("xp", [dpnp, numpy])
928-
@pytest.mark.parametrize("func", ["geomspace", "logspace"])
929-
@pytest.mark.parametrize(
930-
"start, stop, num",
931-
[(2, 5, -3), ([2, 3], 5, -3)],
932-
)
933-
def test_space_num_error(xp, func, start, stop, num):
934-
with pytest.raises(ValueError):
935-
getattr(xp, func)(start, stop, num)
936-
937-
938-
@pytest.mark.parametrize("sign", [-1, 1])
939-
@pytest.mark.parametrize("dtype", get_all_dtypes())
940-
@pytest.mark.parametrize("num", [2, 4, 8, 3, 9, 27])
941-
@pytest.mark.parametrize("endpoint", [True, False])
942-
def test_geomspace(sign, dtype, num, endpoint):
943-
start = 2 * sign
944-
stop = 127 * sign
945-
946-
func = lambda xp: xp.geomspace(
947-
start, stop, num, endpoint=endpoint, dtype=dtype
948-
)
949-
950-
np_res = func(numpy)
951-
dpnp_res = func(dpnp)
952-
953-
assert_allclose(dpnp_res, np_res, rtol=1e-06)
954-
955-
956-
@pytest.mark.parametrize("start", [1j, 1 + 1j])
957-
@pytest.mark.parametrize("stop", [10j, 10 + 10j])
958-
def test_geomspace_complex(start, stop):
959-
func = lambda xp: xp.geomspace(start, stop, num=10)
960-
np_res = func(numpy)
961-
dpnp_res = func(dpnp)
962-
assert_allclose(dpnp_res, np_res, rtol=1e-06)
963-
964-
965-
@pytest.mark.parametrize("axis", [0, 1])
966-
def test_geomspace_axis(axis):
967-
func = lambda xp: xp.geomspace([2, 3], [20, 15], num=10, axis=axis)
968-
np_res = func(numpy)
969-
dpnp_res = func(dpnp)
970-
assert_allclose(dpnp_res, np_res, rtol=1e-06)
971-
972-
973-
def test_geomspace_num0():
974-
func = lambda xp: xp.geomspace(1, 10, num=0, endpoint=False)
975-
np_res = func(numpy)
976-
dpnp_res = func(dpnp)
977-
assert_allclose(dpnp_res, np_res)
978-
979-
980-
@pytest.mark.parametrize("dtype", get_all_dtypes())
981-
@pytest.mark.parametrize("num", [2, 4, 8, 3, 9, 27])
982-
@pytest.mark.parametrize("endpoint", [True, False])
983-
def test_logspace(dtype, num, endpoint):
984-
if not is_win_platform() and is_tgllp_iris_xe() and is_lts_driver():
985-
if (
986-
dpnp.issubdtype(dtype, dpnp.integer)
987-
and num in [8, 27]
988-
and endpoint is True
989-
):
990-
pytest.skip("SAT-7978")
991-
992-
start = 2
993-
stop = 5
994-
base = 2
995-
996-
func = lambda xp: xp.logspace(
997-
start, stop, num, endpoint=endpoint, dtype=dtype, base=base
998-
)
999-
1000-
np_res = func(numpy)
1001-
dpnp_res = func(dpnp)
1002-
1003-
assert_allclose(dpnp_res, np_res, rtol=1e-06)
1004-
1005-
1006-
@testing.with_requires("numpy>=1.25.0")
1007-
@pytest.mark.parametrize("axis", [0, 1])
1008-
def test_logspace_axis(axis):
1009-
func = lambda xp: xp.logspace(
1010-
[2, 3], [20, 15], num=2, base=[[1, 3], [5, 7]], axis=axis
1011-
)
1012-
assert_dtype_allclose(func(dpnp), func(numpy))
1013-
1014-
1015-
def test_logspace_list_input():
1016-
expected = numpy.logspace([0], [2], base=[5])
1017-
result = dpnp.logspace([0], [2], base=[5])
1018-
assert_dtype_allclose(result, expected)
1019-
1020-
10211005
@pytest.mark.parametrize(
10221006
"data", [(), 1, (2, 3), [4], numpy.array(5), numpy.array([6, 7])]
10231007
)

0 commit comments

Comments
 (0)