Skip to content
57 changes: 39 additions & 18 deletions cuda_core/cuda/core/experimental/_utils/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from cuda import cudart as runtime
from cuda import nvrtc

from cuda.core.experimental._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS
from cuda.core.experimental._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS


class CUDAError(Exception):
pass
Expand Down Expand Up @@ -45,27 +48,45 @@ def cast_to_3_tuple(label, cfg):
return cfg + (1,) * (3 - len(cfg))


def _check_driver_error(error):
if error == driver.CUresult.CUDA_SUCCESS:
return
name_err, name = driver.cuGetErrorName(error)
if name_err != driver.CUresult.CUDA_SUCCESS:
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
name = name.decode()
expl = DRIVER_CU_RESULT_EXPLANATIONS.get(int(error))
if expl is not None:
raise CUDAError(f"{name}: {expl}")
desc_err, desc = driver.cuGetErrorString(error)
if desc_err != driver.CUresult.CUDA_SUCCESS:
raise CUDAError(f"{name}")
desc = desc.decode()
raise CUDAError(f"{name}: {desc}")


def _check_runtime_error(error):
if error == runtime.cudaError_t.cudaSuccess:
return
name_err, name = runtime.cudaGetErrorName(error)
if name_err != runtime.cudaError_t.cudaSuccess:
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
name = name.decode()
expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error))
if expl is not None:
raise CUDAError(f"{name}: {expl}")
desc_err, desc = runtime.cudaGetErrorString(error)
if desc_err != runtime.cudaError_t.cudaSuccess:
raise CUDAError(f"{name}")
desc = desc.decode()
raise CUDAError(f"{name}: {desc}")


def _check_error(error, handle=None):
if isinstance(error, driver.CUresult):
if error == driver.CUresult.CUDA_SUCCESS:
return
err, name = driver.cuGetErrorName(error)
if err == driver.CUresult.CUDA_SUCCESS:
err, desc = driver.cuGetErrorString(error)
if err == driver.CUresult.CUDA_SUCCESS:
raise CUDAError(f"{name.decode()}: {desc.decode()}")
else:
raise CUDAError(f"unknown error: {error}")
_check_driver_error(error)
elif isinstance(error, runtime.cudaError_t):
if error == runtime.cudaError_t.cudaSuccess:
return
err, name = runtime.cudaGetErrorName(error)
if err == runtime.cudaError_t.cudaSuccess:
err, desc = runtime.cudaGetErrorString(error)
if err == runtime.cudaError_t.cudaSuccess:
raise CUDAError(f"{name.decode()}: {desc.decode()}")
else:
raise CUDAError(f"unknown error: {error}")
_check_runtime_error(error)
elif isinstance(error, nvrtc.nvrtcResult):
if error == nvrtc.nvrtcResult.NVRTC_SUCCESS:
return
Expand Down
Loading
Loading