Skip to content

enabling xpu in UT test #2424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions test/quantization/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from pathlib import Path

import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import (
TestCase,
)

from torchao._models.llama.model import (
ModelArgs,
Transformer,
prepare_inputs_for_model,
)
from torchao.utils import auto_detect_device
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.quantization.utils import compute_error
Expand All @@ -18,10 +21,10 @@

torch.manual_seed(0)

_DEVICE = auto_detect_device()

class TestGPTQ(TestCase):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_gptq_quantizer_int4_weight_only(self):
from torchao._models._eval import (
LMEvalInputRecorder,
Expand All @@ -30,7 +33,6 @@ def test_gptq_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer

precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path(
"../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"
)
Expand Down Expand Up @@ -80,7 +82,7 @@ def test_gptq_quantizer_int4_weight_only(self):
model = quantizer.quantize(model, *inputs).cuda()

model.reset_caches()
with torch.device("cuda"):
with torch.device(_DEVICE):
model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size)

limit = 1
Expand All @@ -89,7 +91,7 @@ def test_gptq_quantizer_int4_weight_only(self):
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
_DEVICE,
).run_eval(
["wikitext"],
limit,
Expand All @@ -102,7 +104,6 @@ def test_gptq_quantizer_int4_weight_only(self):

class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_add_tensors(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -115,7 +116,6 @@ def test_multitensor_add_tensors(self):
self.assertTrue(torch.equal(mt.values[1], tensor2))

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_pad_unpad(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -127,7 +127,6 @@ def test_multitensor_pad_unpad(self):
self.assertEqual(mt.count, 1)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_inplace_operation(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -138,7 +137,6 @@ def test_multitensor_inplace_operation(self):


class TestMultiTensorInputRecorder(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_input_recorder(self):
from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder

Expand All @@ -159,7 +157,6 @@ def test_multitensor_input_recorder(self):
self.assertTrue(isinstance(MT_input[2][2], MultiTensor))
self.assertEqual(MT_input[3], torch.float)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_gptq_with_input_recorder(self):
from torchao.quantization.GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Expand All @@ -170,7 +167,7 @@ def test_gptq_with_input_recorder(self):

config = ModelArgs(n_layer=2)

with torch.device("cuda"):
with torch.device(_DEVICE):
model = Transformer(config)
model.setup_caches(max_batch_size=2, max_seq_length=100)
idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32)
Expand All @@ -191,7 +188,11 @@ def test_gptq_with_input_recorder(self):

args = input_recorder.get_recorded_inputs()

quantizer = Int4WeightOnlyGPTQQuantizer()
if _DEVICE == "xpu":
from torchao.dtypes import Int4XPULayout
quantizer = Int4WeightOnlyGPTQQuantizer(device=torch.device("xpu"), layout=Int4XPULayout())
else:
quantizer = Int4WeightOnlyGPTQQuantizer()

quantizer.quantize(model, *args)

Expand Down
11 changes: 4 additions & 7 deletions test/quantization/test_moe_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_90,
auto_detect_device,
)

_DEVICE = auto_detect_device()

if torch.version.hip is not None:
pytest.skip(
"ROCm support for MoE quantization is under development",
Expand All @@ -52,7 +55,7 @@ def _test_impl_moe_quant(
base_class=AffineQuantizedTensor,
tensor_impl_class=None,
dtype=torch.bfloat16,
device="cuda",
device=_DEVICE,
fullgraph=False,
):
"""
Expand Down Expand Up @@ -114,8 +117,6 @@ def _test_impl_moe_quant(
]
)
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")
if not TORCH_VERSION_AT_LEAST_2_5:
self.skipTest("Test only enabled for 2.5+")

Expand All @@ -138,10 +139,6 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
]
)
def test_int4wo_base(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")
if not is_sm_at_least_90():
self.skipTest("Requires CUDA capability >= 9.0")
if not TORCH_VERSION_AT_LEAST_2_5:
self.skipTest("Test only enabled for 2.5+")

Expand Down
18 changes: 10 additions & 8 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_6,
auto_detect_device,
)

# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
_GPU_IS_AVAILABLE = True if torch.cuda.is_available() or torch.xpu.is_available() else False

_DEVICE = auto_detect_device()

class Sub(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -329,7 +331,7 @@ def _set_ptq_weight(
group_size,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to("cuda"),
q_weight.to(_DEVICE),
qat_linear.inner_k_tiles,
)
ptq_linear.weight = q_weight
Expand Down Expand Up @@ -600,13 +602,13 @@ def _assert_close_4w(self, val, ref):
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
inner_k_tiles = 8
scales_precision = torch.bfloat16
device = torch.device("cuda")
device = torch.device(_DEVICE)
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
x = torch.randn(100, 256, dtype=dtype, device=device)
Expand Down Expand Up @@ -654,13 +656,13 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available")
def test_qat_4w_linear(self):
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear

group_size = 128
device = torch.device("cuda")
device = torch.device(_DEVICE)
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
qat_linear = Int4WeightOnlyQATLinear(
Expand Down Expand Up @@ -701,14 +703,14 @@ def test_qat_4w_quantizer_gradients(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer

group_size = 32
inner_k_tiles = 8
device = torch.device("cuda")
device = torch.device(_DEVICE)
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
Expand Down
Loading