@@ -152,3 +152,48 @@ def test_launch_scalar_argument(python_type, cpp_type, init_value):
152152
153153    # Check result 
154154    assert  arr [0 ] ==  init_value , f"Expected { init_value } { arr [0 ]}  
155+ 
156+ 
157+ @pytest .mark .skipif (os .environ .get ("CUDA_PATH" ) is  None , reason = "need cg header" ) 
158+ def  test_cooperative_launch ():
159+     dev  =  Device ()
160+     dev .set_current ()
161+     s  =  dev .create_stream (options = {"nonblocking" : True })
162+ 
163+     # CUDA kernel templated on type T 
164+     code  =  r""" 
165+     #include <cooperative_groups.h> 
166+ 
167+     extern "C" __global__ void test_grid_sync() { 
168+         namespace cg = cooperative_groups; 
169+         auto grid = cg::this_grid(); 
170+         grid.sync(); 
171+     } 
172+     """ 
173+ 
174+     # Compile and force instantiation for this type 
175+     arch  =  "" .join (f"{ i }   for  i  in  dev .compute_capability )
176+     include_path  =  str (pathlib .Path (os .environ ["CUDA_PATH" ]) /  pathlib .Path ("include" ))
177+     pro_opts  =  ProgramOptions (std = "c++17" , arch = f"sm_{ arch }  , include_path = include_path )
178+     prog  =  Program (code , code_type = "c++" , options = pro_opts )
179+     ker  =  prog .compile ("cubin" ).get_kernel ("test_grid_sync" )
180+ 
181+     # # Launch without setting cooperative_launch 
182+     # # Commented out as this seems to be a sticky error... 
183+     # config = LaunchConfig(grid=1, block=1) 
184+     # launch(s, config, ker) 
185+     # from cuda.core.experimental._utils.cuda_utils import CUDAError 
186+     # with pytest.raises(CUDAError) as e: 
187+     #     s.sync() 
188+     # assert "CUDA_ERROR_LAUNCH_FAILED" in str(e) 
189+ 
190+     # Crazy grid sizes would not work 
191+     block  =  128 
192+     config  =  LaunchConfig (grid = dev .properties .max_grid_dim_x  //  block  +  1 , block = block , cooperative_launch = True )
193+     with  pytest .raises (ValueError ):
194+         launch (s , config , ker )
195+ 
196+     # This works just fine 
197+     config  =  LaunchConfig (grid = 1 , block = 1 , cooperative_launch = True )
198+     launch (s , config , ker )
199+     s .sync ()
0 commit comments