Skip to content

[Triton/XPU] Support 4bit dequantization logic on Triton #1629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Devjiu
Copy link

@Devjiu Devjiu commented May 8, 2025

This PR adds xpu backend and Triton kernel for dequantization nf4 dtype.
Trtion used as an optional import.

Tests:

  • tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases
  • tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional implemented quantize_blockwise with binary search that works faster for XPU
  • tests/test_linear4bit.py

Signed-off-by: Dmitrii Makarenko [email protected]

@Devjiu Devjiu force-pushed the dmitriim/add_xpu_triton_kernel branch 5 times, most recently from a1faeb4 to 679cedc Compare May 14, 2025 16:49
@Devjiu
Copy link
Author

Devjiu commented May 14, 2025

BNB_TEST_DEVICE="xpu" pytest -s tests/test_linear4bit.py 
88 passed in 11.91s 

BNB_TEST_DEVICE="xpu" pytest -s tests/test_functional.py
953 passed, 170 skipped, 9 deselected, 37 warnings in 235.89s (0:03:55)

@Devjiu Devjiu force-pushed the dmitriim/add_xpu_triton_kernel branch from 679cedc to ea15027 Compare May 14, 2025 16:56
@Devjiu Devjiu marked this pull request as ready for review May 14, 2025 16:59
@Devjiu Devjiu changed the title [xpu/triton] Add trtion dequantization kernel [Triton/XPU] Support 4bit dequantization logic on Triton May 14, 2025
@jiqing-feng
Copy link
Contributor

Thanks for your contribution, but this PR seems to have a conflict with bitsandbytes-intel. We might need to further discuss to determine the priority.

@Devjiu
Copy link
Author

Devjiu commented May 15, 2025

Thanks for your contribution, but this PR seems to have a conflict with bitsandbytes-intel. We might need to further discuss to determine the priority.

Roughly speaking, this is not a conflict. It is a different implementation that can be used depending on the availability of ipex.

@Egor-Krivov
Copy link
Contributor

Thanks for your contribution, but this PR seems to have a conflict with bitsandbytes-intel. We might need to further discuss to determine the priority.

Could you clarify the nature of the conflict? This PR provides 4bit implementation for users that just install bitsandbytes without any additional plugins or libraries like IPEX or bitsandbytes-intel. For example, by installing PEFT, that will only install bitsandbytes.

Given current implementation if the user additionally installs bitsandbytes-intel it should just replace kernels defined in the main repo.

@jiqing-feng
Copy link
Contributor

jiqing-feng commented May 16, 2025

When @matthewdouglas says we'd like to enable on CPU without IPex path, that's because non-Intel CPUs do not support IPex. But for XPU, it's an Intel-specific device, so they all support IPex. We'd better install IPex on XPU by default so we can get a significant speed-up.

More specifically, not all ops in XPU have ipex optimization. I can see most of ops in this PR are duplicated with my PR (As they were the same as CPU implementation, I was thinking could we just move these ops to the default op?). So the design is a little confusing to me. Should we keep both repo to implement XPU ops?

Anyway, the example of PEFT is a good point. Let's sync it offline. Would like to hear your opinion. :)

@yao-matrix
Copy link

since triton is platform agnostic, is it possible we try to upstream your ops to bitsandbytes/triton folder?

@Devjiu
Copy link
Author

Devjiu commented May 16, 2025

since triton is platform agnostic, is it possible we try to upstream your ops to bitsandbytes/triton folder?

@matthewdouglas Please take a look.
@yao-matrix For my approach I get approval to use Triton for XPU, but maybe we can share the code base. But as you know, in Triton different HW requires slightly different cores to be efficient, so ultimately it is not completely platform agnostic

@Devjiu Devjiu force-pushed the dmitriim/add_xpu_triton_kernel branch from ea15027 to fbb2d00 Compare May 16, 2025 13:38
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas matthewdouglas added this to the v0.47.0 milestone May 20, 2025
@Devjiu
Copy link
Author

Devjiu commented May 22, 2025

Local test run on PVC:

BNB_TEST_DEVICE="xpu"  pytest -rf --ignore test_optim.py --ignore test_triton.py --ignore test_cuda_setup_evaluator.py
2196 passed, 1555 skipped, 178 deselected, 33 xfailed, 189 warnings in 357.17s (0:05:57)

@yao-matrix
Copy link

@matthewdouglas, could you pls take a look on it? The background is: we'd like contribute triton ops to bnb and make XPU support bnb triton backend. Thx very much.

@Devjiu Devjiu force-pushed the dmitriim/add_xpu_triton_kernel branch from fb48d76 to 1414628 Compare May 26, 2025 16:37
Comment on lines 37 to 39
if torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're still supporting torch 2.2 for the moment, we'll want to guard this since torch.xpu didn't exist until 2.3.0.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added as torch.__veersion__ check >= (2, 3)

Comment on lines +122 to +118
# Check if this is fine and fast
if A.dtype != torch.uint8:
A = A.squeeze().view(torch.uint8).unsqueeze(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH this was meant to support FSDP1; I'm not sure it's necessary to try to support other storage dtypes for FSDP2, or if it's worth considering at all. @Titus-von-Koeller may have more thought on that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this just to be able to pass tests. If there should be something more complex or better to skip if the type is not uint8 storage, let me know. I'm not sure about this approach with type casting.

@Devjiu Devjiu force-pushed the dmitriim/add_xpu_triton_kernel branch from b70a09c to ebed8a6 Compare May 28, 2025 18:32
if torch.__version__ >= (2, 7):
# With default torch, error:
# NotImplementedError: The operator 'aten::_int_mm' for XPU
if ipex_xpu and torch.__version__ >= (2, 7):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiqing-feng PTAL I am not sure with this fix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, torch._int_mm is a PyTorch op that does not require ipex.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked on 2.8.0a0+git129a297 and fyi: https://github.com/pytorch/pytorch/blob/cd9ff41282ecc7666cfd0fc07e758adb150e55b0/test/inductor/test_select_algorithm.py#L117
from pytorch repo:

    @patches
    @skipIfXpu(msg="XPU has not supported _int_mm yet")
    def test__int_mm(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the op name confused me, I thought it was a pytorch original op.

@@ -287,6 +288,8 @@ def test_linear_kbit_fp32_bias(device, module):
def test_kbit_backprop(device, module):
if device == "cpu":
pytest.xfail("Test is not yet supported on CPU")
if device == "xpu" and ipex_xpu:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to skip triton test, maybe if device == "xpu" and not ipex_xpu and is_triton_available:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, agree. This is a pretty ugly check (too long). Maybe some helper functions should be added.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if torch.__version__ >= (2, 7):
# With default torch, error:
# NotImplementedError: The operator 'aten::_int_mm' for XPU
if ipex_xpu and torch.__version__ >= (2, 7):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, torch._int_mm is a PyTorch op that does not require ipex.

@@ -287,6 +288,8 @@ def test_linear_kbit_fp32_bias(device, module):
def test_kbit_backprop(device, module):
if device == "cpu":
pytest.xfail("Test is not yet supported on CPU")
if device == "xpu" and ipex_xpu:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to skip triton test, maybe if device == "xpu" and not ipex_xpu and is_triton_available:

triton_available = True
except ImportError as e:
print("Import error:", e)
triton_available = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move this check to utils so you can easily use it anywhere you need, just like ipex.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -49,3 +52,16 @@ def _(
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")

return out.reshape(shape)
elif triton_available:
# IPEX should be faster for xpu, so at first checking if it is available.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this comment to IPEX kernel registration?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved higher

register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
else:
warnings.warn("XPU available, but nor ipex or trtion package is found.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warnings.warn("XPU available, but nor ipex or trtion package is found.")
warnings.warn("XPU available, but nor ipex or triton package is found.")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops
# xpu was introduced in PyTorch 2.3
if torch.__version__ >= (2, 3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if torch.__version__ >= (2, 3):
if torch.__version__ >= (2, 3) and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from .backends.xpu import ops as xpu_ops
# xpu was introduced in PyTorch 2.3
if torch.__version__ >= (2, 3):
if torch.xpu.is_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if torch.xpu.is_available():

# xpu was introduced in PyTorch 2.3
if torch.__version__ >= (2, 3):
if torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .backends.xpu import ops as xpu_ops

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used directly in this file, but it's required for backed registration. cpu_ops, default_ops and cuda_ops also not used in this file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just suggest to merge 2 if statements into 1 line for simplicity

@Devjiu Devjiu force-pushed the dmitriim/add_xpu_triton_kernel branch from 283be54 to a1826d6 Compare June 3, 2025 12:40
@Egor-Krivov
Copy link
Contributor

Current performance numbers for inference latency (script is running in docker):
~128ms per token with eager attention implementation

For comparison, with current main and IPEX installed (docker image intel/intel-extension-for-pytorch:2.7.10-xpu) I can get at most:
~168ms per token with default attention implementation

Host:

sycl-ls
[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Arc(TM) B570 Graphics 20.1.0 [1.6.32567+19]
[opencl:cpu][opencl:0] Intel(R) OpenCL, Intel(R) Core(TM) Ultra 7 265K OpenCL 3.0 (Build 0) [2025.19.4.0.18_160000.xmain-hotfix]
[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) B570 Graphics OpenCL 3.0 NEO  [25.05.32567]

Benchmarking script:

import torch
# FOR IPEX
# import intel_extension_for_pytorch

import time
from pathlib import Path
from collections import defaultdict


from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers import TextStreamer
import numpy as np

MAX_NEW_TOKENS = 256

get_time = time.time

root = Path(__file__).parent.parent

system_prompt = "You are a helpful assistant"
user_prompt = """Summarize this text please: 

```Tell me, O muse, of that ingenious hero who travelled far and wide after he had sacked the famous town of Troy. Many cities did he visit, and many were the nations with whose manners and customs he was acquainted; moreover he suffered much by sea while trying to save his own life and bring his men safely home; but do what he might he could not save his men, for they perished through their own sheer folly in eating the cattle of the Sun-god Hyperion; so the god prevented them from ever reaching home. Tell me, too, about all these things, O daughter of Jove, from whatsoever source you may know them.

So now all who escaped death in battle or by shipwreck had got safely home except Ulysses, and he, though he was longing to return to his wife and country, was detained by the goddess Calypso, who had got him into a large cave and wanted to marry him. But as years went by, there came a time when the gods settled that he should go back to Ithaca; even then, however, when he was among his own people, his troubles were not yet over; nevertheless all the gods had now begun to pity him except Neptune, who still persecuted him without ceasing and would not let him get home.

Now Neptune had gone off to the Ethiopians, who are at the world's end, and lie in two halves, the one looking West and the other East. He had gone there to accept a hecatomb of sheep and oxen, and was enjoying himself at his festival; but the other gods met in the house of Olympian Jove, and the sire of gods and men spoke first. At that moment he was thinking of Aegisthus, who had been killed by Agamemnon's son Orestes; so he said to the other gods:

"See now, how men lay blame upon us gods for what is after all nothing but their own folly. Look at Aegisthus; he must needs make love to Agamemnon's wife unrighteously and then kill Agamemnon, though he knew it would be the death of him; for I sent Mercury to warn him not to do either of these things, inasmuch as Orestes would be sure to take his revenge when he grew up and wanted to return home. Mercury told him this in all good will but he would not listen, and now he has paid for everything in full."

Then Minerva said, "Father, son of Saturn, King of kings, it served Aegisthus right, and so it would any one else who does as he did; but Aegisthus is neither here nor there; it is for Ulysses that my heart bleeds, when I think of his sufferings in that lonely sea-girt island, far away, poor man, from all his friends. It is an island covered with forest, in the very middle of the sea, and a goddess lives there, daughter of the magician Atlas, who looks after the bottom of the ocean, and carries the great columns that keep heaven and earth asunder. This daughter of Atlas has got hold of poor unhappy Ulysses, and keeps trying by every kind of blandishment to make him forget his home, so that he is tired of life, and thinks of nothing but how he may once more see the smoke of his own chimneys. You, sir, take no heed of this, and yet when Ulysses was before Troy did he not propitiate you with many a burnt sacrifice? Why then should you keep on being so angry with him?"

And Jove said, "My child, what are you talking about? How can I forget Ulysses than whom there is no more capable man on earth, nor more liberal in his offerings to the immortal gods that live in heaven? Bear in mind, however, that Neptune is still furious with Ulysses for having blinded an eye of Polyphemus king of the Cyclopes. Polyphemus is son to Neptune by the nymph Thoosa, daughter to the sea-king Phorcys; therefore though he will not kill Ulysses outright, he torments him by preventing him from getting home. Still, let us lay our heads together and see how we can help him to return; Neptune will then be pacified, for if we are all of a mind he can hardly stand out against us."```"""

prompt = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_prompt},
]


def get_inputs(tokenizer):
    inputs = tokenizer.apply_chat_template(
        prompt,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
    )
    return inputs


def get_streamer(tokenizer):
    streamer = Streamer()

    # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    return streamer


class Streamer:
    def __init__(self, print_median=False):
        self.times = []
        self.print_median = print_median

    def put(self, t):
        self.times.append(get_time())
        if len(self.times) > 1:
            print("Token latency: {:.1f} ms".format(1000 * (self.times[-1] - self.times[-2])))

        if len(self.times) % 10 == 3 and self.print_median:
            ts = np.array(self.times)
            diff = ts[1:] - ts[:-1]
            # print("Token latency:", 1000 * diff, "ms")
            print("Token latency median:", np.median(1000 * diff), "ms")

    def end(self, *args):
        print(args)

        times = np.array(self.times)
        diff = times[1:] - times[:-1]
        print("Median latency, ms", np.median(diff) * 1000)
        percentiles = [10, 25, 50, 75, 90]
        print("Latency percentiles", {p: round(1000 * float(np.percentile(diff, p)), 1) for p in percentiles})

device = 'xpu'
model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager")
# FOR IPEX DEFAULT ATTENTION IS FASTER
# model = AutoModelForCausalLM.from_pretrained(model_id)


inputs = get_inputs(tokenizer)
streamer = get_streamer(tokenizer)
inputs.to(device)
print("Tokens: ", inputs[0])


generation_config = GenerationConfig(
    use_cache=True,
    forced_eos_token_id=1,
    eos_token_id=1,
    max_new_tokens=MAX_NEW_TOKENS,
    do_sample=False,
)

outputs = model.generate(
    **inputs,
    streamer=streamer,
    generation_config=generation_config,
)

@Egor-Krivov Egor-Krivov mentioned this pull request Jun 4, 2025
Devjiu added 4 commits June 4, 2025 12:54
This PR adds xpu backend and trtion kernel for dequantization nf4 dtype.
Trtion is an optional import.
Tests:
	tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases
	tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional
implemented quantize_blockwise with binary search that works faster for XPU
        tests/test_linear4bit.py

Signed-off-by: Dmitrii Makarenko <[email protected]>
@Devjiu Devjiu force-pushed the dmitriim/add_xpu_triton_kernel branch from a1826d6 to d0736f6 Compare June 4, 2025 12:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants