-
Notifications
You must be signed in to change notification settings - Fork 704
[Triton/XPU] Support 4bit dequantization logic on Triton #1629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Triton/XPU] Support 4bit dequantization logic on Triton #1629
Conversation
a1faeb4
to
679cedc
Compare
BNB_TEST_DEVICE="xpu" pytest -s tests/test_linear4bit.py
88 passed in 11.91s
BNB_TEST_DEVICE="xpu" pytest -s tests/test_functional.py
953 passed, 170 skipped, 9 deselected, 37 warnings in 235.89s (0:03:55) |
679cedc
to
ea15027
Compare
Thanks for your contribution, but this PR seems to have a conflict with bitsandbytes-intel. We might need to further discuss to determine the priority. |
Roughly speaking, this is not a conflict. It is a different implementation that can be used depending on the availability of ipex. |
Could you clarify the nature of the conflict? This PR provides 4bit implementation for users that just install bitsandbytes without any additional plugins or libraries like Given current implementation if the user additionally installs |
When @matthewdouglas says we'd like to enable on CPU without IPex path, that's because non-Intel CPUs do not support IPex. But for XPU, it's an Intel-specific device, so they all support IPex. We'd better install IPex on XPU by default so we can get a significant speed-up. More specifically, not all ops in XPU have ipex optimization. I can see most of ops in this PR are duplicated with my PR (As they were the same as CPU implementation, I was thinking could we just move these ops to the default op?). So the design is a little confusing to me. Should we keep both repo to implement XPU ops? Anyway, the example of PEFT is a good point. Let's sync it offline. Would like to hear your opinion. :) |
since |
@matthewdouglas Please take a look. |
ea15027
to
fbb2d00
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Local test run on PVC: BNB_TEST_DEVICE="xpu" pytest -rf --ignore test_optim.py --ignore test_triton.py --ignore test_cuda_setup_evaluator.py
2196 passed, 1555 skipped, 178 deselected, 33 xfailed, 189 warnings in 357.17s (0:05:57) |
@matthewdouglas, could you pls take a look on it? The background is: we'd like contribute triton ops to bnb and make XPU support bnb triton backend. Thx very much. |
fb48d76
to
1414628
Compare
bitsandbytes/__init__.py
Outdated
if torch.xpu.is_available(): | ||
from .backends.xpu import ops as xpu_ops | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're still supporting torch 2.2 for the moment, we'll want to guard this since torch.xpu
didn't exist until 2.3.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added as torch.__veersion__
check >= (2, 3)
# Check if this is fine and fast | ||
if A.dtype != torch.uint8: | ||
A = A.squeeze().view(torch.uint8).unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH this was meant to support FSDP1; I'm not sure it's necessary to try to support other storage dtypes for FSDP2, or if it's worth considering at all. @Titus-von-Koeller may have more thought on that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this just to be able to pass tests. If there should be something more complex or better to skip if the type is not uint8 storage, let me know. I'm not sure about this approach with type casting.
b70a09c
to
ebed8a6
Compare
if torch.__version__ >= (2, 7): | ||
# With default torch, error: | ||
# NotImplementedError: The operator 'aten::_int_mm' for XPU | ||
if ipex_xpu and torch.__version__ >= (2, 7): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jiqing-feng PTAL I am not sure with this fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, torch._int_mm is a PyTorch op that does not require ipex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've checked on 2.8.0a0+git129a297
and fyi: https://github.com/pytorch/pytorch/blob/cd9ff41282ecc7666cfd0fc07e758adb150e55b0/test/inductor/test_select_algorithm.py#L117
from pytorch repo:
@patches
@skipIfXpu(msg="XPU has not supported _int_mm yet")
def test__int_mm(self):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, the op name confused me, I thought it was a pytorch original op.
tests/test_modules.py
Outdated
@@ -287,6 +288,8 @@ def test_linear_kbit_fp32_bias(device, module): | |||
def test_kbit_backprop(device, module): | |||
if device == "cpu": | |||
pytest.xfail("Test is not yet supported on CPU") | |||
if device == "xpu" and ipex_xpu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jiqing-feng PTAL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to skip triton test, maybe if device == "xpu" and not ipex_xpu and is_triton_available:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, agree. This is a pretty ugly check (too long). Maybe some helper functions should be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if torch.__version__ >= (2, 7): | ||
# With default torch, error: | ||
# NotImplementedError: The operator 'aten::_int_mm' for XPU | ||
if ipex_xpu and torch.__version__ >= (2, 7): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, torch._int_mm is a PyTorch op that does not require ipex.
tests/test_modules.py
Outdated
@@ -287,6 +288,8 @@ def test_linear_kbit_fp32_bias(device, module): | |||
def test_kbit_backprop(device, module): | |||
if device == "cpu": | |||
pytest.xfail("Test is not yet supported on CPU") | |||
if device == "xpu" and ipex_xpu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to skip triton test, maybe if device == "xpu" and not ipex_xpu and is_triton_available:
bitsandbytes/backends/triton/ops.py
Outdated
triton_available = True | ||
except ImportError as e: | ||
print("Import error:", e) | ||
triton_available = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can move this check to utils so you can easily use it anywhere you need, just like ipex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
bitsandbytes/backends/xpu/ops.py
Outdated
@@ -49,3 +52,16 @@ def _( | |||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") | |||
|
|||
return out.reshape(shape) | |||
elif triton_available: | |||
# IPEX should be faster for xpu, so at first checking if it is available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this comment to IPEX kernel registration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved higher
bitsandbytes/backends/xpu/ops.py
Outdated
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) | ||
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) | ||
else: | ||
warnings.warn("XPU available, but nor ipex or trtion package is found.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warnings.warn("XPU available, but nor ipex or trtion package is found.") | |
warnings.warn("XPU available, but nor ipex or triton package is found.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
bitsandbytes/__init__.py
Outdated
if torch.xpu.is_available(): | ||
from .backends.xpu import ops as xpu_ops | ||
# xpu was introduced in PyTorch 2.3 | ||
if torch.__version__ >= (2, 3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if torch.__version__ >= (2, 3): | |
if torch.__version__ >= (2, 3) and torch.xpu.is_available(): | |
from .backends.xpu import ops as xpu_ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
bitsandbytes/__init__.py
Outdated
from .backends.xpu import ops as xpu_ops | ||
# xpu was introduced in PyTorch 2.3 | ||
if torch.__version__ >= (2, 3): | ||
if torch.xpu.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if torch.xpu.is_available(): |
bitsandbytes/__init__.py
Outdated
# xpu was introduced in PyTorch 2.3 | ||
if torch.__version__ >= (2, 3): | ||
if torch.xpu.is_available(): | ||
from .backends.xpu import ops as xpu_ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from .backends.xpu import ops as xpu_ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not used directly in this file, but it's required for backed registration. cpu_ops
, default_ops
and cuda_ops
also not used in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just suggest to merge 2 if statements into 1 line for simplicity
283be54
to
a1826d6
Compare
Current performance numbers for inference latency (script is running in docker): For comparison, with current main and IPEX installed (docker image Host:
Benchmarking script:
|
This PR adds xpu backend and trtion kernel for dequantization nf4 dtype. Trtion is an optional import. Tests: tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional implemented quantize_blockwise with binary search that works faster for XPU tests/test_linear4bit.py Signed-off-by: Dmitrii Makarenko <[email protected]>
a1826d6
to
d0736f6
Compare
This PR adds xpu backend and Triton kernel for dequantization nf4 dtype.
Trtion used as an optional import.
Tests:
tests/test_functional.py::TestQuantize4BitFunctional
supported nf4/fp4 casestests/test_functional.py::Test8BitBlockwiseQuantizeFunctional
implemented quantize_blockwise with binary search that works faster for XPUtests/test_linear4bit.py
Signed-off-by: Dmitrii Makarenko [email protected]