Skip to content

Commit 09dd3dc

Browse files
authored
python: depend on offical NVIDIA CUDA packages (#2355)
Signed-off-by: Jared Van Bortel <[email protected]>
1 parent c779d8a commit 09dd3dc

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

gpt4all-bindings/python/gpt4all/_pyllmodel.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@
2828
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
2929

3030

31+
# Find CUDA libraries from the official packages
32+
cuda_found = False
33+
if platform.system() in ('Linux', 'Windows'):
34+
try:
35+
from nvidia import cuda_runtime, cublas
36+
except ImportError:
37+
pass # CUDA is optional
38+
else:
39+
if platform.system() == 'Linux':
40+
cudalib = 'lib/libcudart.so.12'
41+
cublaslib = 'lib/libcublas.so.12'
42+
else: # Windows
43+
cudalib = r'bin\cudart64_12.dll'
44+
cublaslib = r'bin\cublas64_12.dll'
45+
46+
# preload the CUDA libs so the backend can find them
47+
ctypes.CDLL(os.path.join(cuda_runtime.__path__[0], cudalib), mode=ctypes.RTLD_GLOBAL)
48+
ctypes.CDLL(os.path.join(cublas.__path__[0], cublaslib), mode=ctypes.RTLD_GLOBAL)
49+
cuda_found = True
50+
51+
3152
# TODO: provide a config file to make this more robust
3253
MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build"
3354

@@ -218,7 +239,16 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int, backend: str):
218239
model = llmodel.llmodel_model_create2(self.model_path, backend.encode(), ctypes.byref(err))
219240
if model is None:
220241
s = err.value
221-
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
242+
errmsg = 'null' if s is None else s.decode()
243+
244+
if (
245+
backend == 'cuda'
246+
and not cuda_found
247+
and errmsg.startswith('Could not find any implementations for backend')
248+
):
249+
print('WARNING: CUDA runtime libraries not found. Try `pip install "gpt4all[cuda]"`\n', file=sys.stderr)
250+
251+
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
222252
self.model: ctypes.c_void_p | None = model
223253

224254
def __del__(self, llmodel=llmodel):

gpt4all-bindings/python/setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,15 @@ def get_long_description():
9393
'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"',
9494
],
9595
extras_require={
96+
'cuda': [
97+
'nvidia-cuda-runtime-cu12',
98+
'nvidia-cublas-cu12',
99+
],
100+
'all': [
101+
'gpt4all[cuda]; platform_system == "Windows" or platform_system == "Linux"',
102+
],
96103
'dev': [
104+
'gpt4all[all]',
97105
'pytest',
98106
'twine',
99107
'wheel',

0 commit comments

Comments
 (0)