|
28 | 28 | EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
29 | 29 |
|
30 | 30 |
|
| 31 | +# Find CUDA libraries from the official packages |
| 32 | +cuda_found = False |
| 33 | +if platform.system() in ('Linux', 'Windows'): |
| 34 | + try: |
| 35 | + from nvidia import cuda_runtime, cublas |
| 36 | + except ImportError: |
| 37 | + pass # CUDA is optional |
| 38 | + else: |
| 39 | + if platform.system() == 'Linux': |
| 40 | + cudalib = 'lib/libcudart.so.12' |
| 41 | + cublaslib = 'lib/libcublas.so.12' |
| 42 | + else: # Windows |
| 43 | + cudalib = r'bin\cudart64_12.dll' |
| 44 | + cublaslib = r'bin\cublas64_12.dll' |
| 45 | + |
| 46 | + # preload the CUDA libs so the backend can find them |
| 47 | + ctypes.CDLL(os.path.join(cuda_runtime.__path__[0], cudalib), mode=ctypes.RTLD_GLOBAL) |
| 48 | + ctypes.CDLL(os.path.join(cublas.__path__[0], cublaslib), mode=ctypes.RTLD_GLOBAL) |
| 49 | + cuda_found = True |
| 50 | + |
| 51 | + |
31 | 52 | # TODO: provide a config file to make this more robust
|
32 | 53 | MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build"
|
33 | 54 |
|
@@ -218,7 +239,16 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int, backend: str):
|
218 | 239 | model = llmodel.llmodel_model_create2(self.model_path, backend.encode(), ctypes.byref(err))
|
219 | 240 | if model is None:
|
220 | 241 | s = err.value
|
221 |
| - raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}") |
| 242 | + errmsg = 'null' if s is None else s.decode() |
| 243 | + |
| 244 | + if ( |
| 245 | + backend == 'cuda' |
| 246 | + and not cuda_found |
| 247 | + and errmsg.startswith('Could not find any implementations for backend') |
| 248 | + ): |
| 249 | + print('WARNING: CUDA runtime libraries not found. Try `pip install "gpt4all[cuda]"`\n', file=sys.stderr) |
| 250 | + |
| 251 | + raise RuntimeError(f"Unable to instantiate model: {errmsg}") |
222 | 252 | self.model: ctypes.c_void_p | None = model
|
223 | 253 |
|
224 | 254 | def __del__(self, llmodel=llmodel):
|
|
0 commit comments