diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index 6be8077b8..22cf7419e 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -957,34 +957,42 @@ def __new__(cls, device_id=None): # important: creating a Device instance does not initialize the GPU! if device_id is None: - device_id = handle_return(runtime.cudaGetDevice()) - assert_type(device_id, int) + err, dev = driver.cuCtxGetDevice() + if err == 0: + device_id = int(dev) + else: + ctx = handle_return(driver.cuCtxGetCurrent()) + assert int(ctx) == 0 + device_id = 0 # cudart behavior + assert isinstance(device_id, int), f"{device_id=}" else: - total = handle_return(runtime.cudaGetDeviceCount()) - assert_type(device_id, int) - if not (0 <= device_id < total): + total = handle_return(driver.cuDeviceGetCount()) + if not isinstance(device_id, int) or not (0 <= device_id < total): raise ValueError(f"device_id must be within [0, {total}), got {device_id}") # ensure Device is singleton if not hasattr(_tls, "devices"): - total = handle_return(runtime.cudaGetDeviceCount()) + total = handle_return(driver.cuDeviceGetCount()) _tls.devices = [] for dev_id in range(total): dev = super().__new__(cls) + dev._id = dev_id # If the device is in TCC mode, or does not support memory pools for some other reason, # use the SynchronousMemoryResource which does not use memory pools. if ( handle_return( - runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0) + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id + ) ) ) == 1: dev._mr = _DefaultAsyncMempool(dev_id) else: dev._mr = _SynchronousMemoryResource(dev_id) - dev._has_inited = False dev._properties = None + _tls.devices.append(dev) return _tls.devices[device_id]