Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import logging
import os
import platform
from pathlib import Path
import re
from typing import Optional
Expand All @@ -19,6 +20,26 @@
logger = logging.getLogger(__name__)


def _format_env_set_instruction(var_name: str, value: str) -> str:
if platform.system() == "Windows":
return (
f"Set it in PowerShell with `$env:{var_name}='{value}'` or in Command Prompt with"
f" `set {var_name}={value}`."
)

return f"Set it with `export {var_name}={value}`."


def _format_env_clear_instruction(var_name: str) -> str:
if platform.system() == "Windows":
return (
f"Clear it in PowerShell with `Remove-Item Env:{var_name}` or in Command Prompt with"
f" `set {var_name}=`."
)

return f"Clear it with `unset {var_name}`."


def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
"""
Get the disk path to the CUDA BNB native library specified by the
Expand All @@ -38,38 +59,38 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
if not rocm_override_value:
raise RuntimeError(
f"BNB_CUDA_VERSION={cuda_override_value} detected but this is not a CUDA build!\n"
"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
"Clear the variable and retry: unset BNB_CUDA_VERSION\n"
f"Use BNB_ROCM_VERSION instead. {_format_env_set_instruction('BNB_ROCM_VERSION', '<version>')}\n"
f"{_format_env_clear_instruction('BNB_CUDA_VERSION')}\n"
)
logger.warning(
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} is set but ignored on this ROCm build. "
"Clear the variable: unset BNB_CUDA_VERSION",
f"{_format_env_clear_instruction('BNB_CUDA_VERSION')}",
)
if rocm_override_value:
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
logger.warning(
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
"If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n",
f"If this was unintended, {_format_env_clear_instruction('BNB_ROCM_VERSION')}\n",
)
elif torch.version.cuda:
if rocm_override_value:
if not cuda_override_value:
raise RuntimeError(
f"BNB_ROCM_VERSION={rocm_override_value} detected but this is not a ROCm build!\n"
"Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=<version>\n"
"Clear the variable and retry: unset BNB_ROCM_VERSION\n"
f"Use BNB_CUDA_VERSION instead. {_format_env_set_instruction('BNB_CUDA_VERSION', '<version>')}\n"
f"{_format_env_clear_instruction('BNB_ROCM_VERSION')}\n"
)
logger.warning(
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} is set but ignored on this CUDA build. "
"Clear the variable: unset BNB_ROCM_VERSION",
f"{_format_env_clear_instruction('BNB_ROCM_VERSION')}",
)
if cuda_override_value:
library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1)
logger.warning(
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
"If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n",
f"If this was unintended, {_format_env_clear_instruction('BNB_CUDA_VERSION')}\n",
)
else:
if rocm_override_value or cuda_override_value:
Expand Down
33 changes: 32 additions & 1 deletion tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest

from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path
from bitsandbytes.cextension import (
BNB_BACKEND,
_format_env_clear_instruction,
_format_env_set_instruction,
get_cuda_bnb_library_path,
)
from bitsandbytes.cuda_specs import CUDASpecs


Expand All @@ -14,6 +19,32 @@ def cuda120_spec() -> CUDASpecs:
)


def test_format_env_set_instruction_windows(monkeypatch):
monkeypatch.setattr("bitsandbytes.cextension.platform.system", lambda: "Windows")
assert _format_env_set_instruction("BNB_CUDA_VERSION", "<version>") == (
"Set it in PowerShell with `$env:BNB_CUDA_VERSION='<version>'` or in Command Prompt with"
" `set BNB_CUDA_VERSION=<version>`."
)


def test_format_env_clear_instruction_windows(monkeypatch):
monkeypatch.setattr("bitsandbytes.cextension.platform.system", lambda: "Windows")
assert _format_env_clear_instruction("BNB_CUDA_VERSION") == (
"Clear it in PowerShell with `Remove-Item Env:BNB_CUDA_VERSION` or in Command Prompt with"
" `set BNB_CUDA_VERSION=`."
)


def test_format_env_set_instruction_unix(monkeypatch):
monkeypatch.setattr("bitsandbytes.cextension.platform.system", lambda: "Linux")
assert _format_env_set_instruction("BNB_CUDA_VERSION", "<version>") == "Set it with `export BNB_CUDA_VERSION=<version>`."


def test_format_env_clear_instruction_unix(monkeypatch):
monkeypatch.setattr("bitsandbytes.cextension.platform.system", lambda: "Linux")
assert _format_env_clear_instruction("BNB_CUDA_VERSION") == "Clear it with `unset BNB_CUDA_VERSION`."


@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
"""Without overrides, library path uses the detected CUDA 12.0 version."""
Expand Down