@@ -1282,6 +1282,7 @@ def test_broadcast_arrays():
12821282 ["linspace" , "logspace" , "geomspace" ],
12831283 ids = ["linspace" , "logspace" , "geomspace" ],
12841284)
1285+ @pytest .mark .parametrize ("dtype" , [None , "int" , "float" ], ids = [None , "int" , "float" ])
12851286@pytest .mark .parametrize (
12861287 "start, stop, num_samples, endpoint, axis" ,
12871288 [
@@ -1294,12 +1295,20 @@ def test_broadcast_arrays():
12941295 (1 , np .array ([5 , 6 ]), 30 , False , - 1 ),
12951296 ],
12961297)
1297- def test_space_ops (op , start , stop , num_samples , endpoint , axis ):
1298+ def test_space_ops (op , dtype , start , stop , num_samples , endpoint , axis ):
12981299 pt_func = getattr (pt , op )
12991300 np_func = getattr (np , op )
1300- z = pt_func (start , stop , num_samples , endpoint = endpoint , axis = axis )
1301+ dtype = dtype + config .floatX [- 2 :] if dtype is not None else dtype
1302+ z = pt_func (start , stop , num_samples , endpoint = endpoint , axis = axis , dtype = dtype )
13011303
1302- numpy_res = np_func (start , stop , num = num_samples , endpoint = endpoint , axis = axis )
1304+ numpy_res = np_func (
1305+ start , stop , num = num_samples , endpoint = endpoint , dtype = dtype , axis = axis
1306+ )
13031307 pytensor_res = function (inputs = [], outputs = z , mode = "FAST_COMPILE" )()
13041308
1305- np .testing .assert_allclose (pytensor_res , numpy_res , atol = 1e-6 , rtol = 1e-6 )
1309+ np .testing .assert_allclose (
1310+ pytensor_res ,
1311+ numpy_res ,
1312+ atol = 1e-6 if config .floatX .endswith ("64" ) else 1e-4 ,
1313+ rtol = 1e-6 if config .floatX .endswith ("64" ) else 1e-4 ,
1314+ )
0 commit comments