-
Notifications
You must be signed in to change notification settings - Fork 376
[WIP] Move float8 cutlass sparse layout to Float8SemiSparseTensor #3182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import tempfile | ||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| instantiate_parametrized_tests, | ||
| parametrize, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao.quantization import ( | ||
| Float8WeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.sparsity.sparse_api import apply_fake_sparsity | ||
| from torchao.testing.utils import skip_if_rocm | ||
| from torchao.utils import torch_version_at_least | ||
|
|
||
| BF16_ACT_CONFIG = Float8WeightOnlyConfig( | ||
| group_size=128, | ||
| packing_format="cutlass_semi_sparse", | ||
| ) | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| class TestFloat8SemiSparseTensor(TestCase): | ||
| def setUp(self): | ||
| self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] | ||
|
|
||
| @skip_if_rocm("ROCm enablement in progress") | ||
| @parametrize("config", [BF16_ACT_CONFIG]) | ||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), | ||
| ((32, 128), 512, 128), | ||
| ((2, 32, 128), 256, 12), | ||
| ], | ||
| ) | ||
| def test_linear(self, config, sizes): | ||
| dtype = torch.bfloat16 | ||
| device = "cuda" | ||
|
|
||
| M, N, K = sizes | ||
| input = torch.randn(*M, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
|
|
||
| apply_fake_sparsity(linear) | ||
| original = linear(input) | ||
| quantize_(linear, config) | ||
| quantized = linear(input) | ||
| self.assertTrue(compute_error(original, quantized) > 20) | ||
|
|
||
| compiled_linear = torch.compile(linear) | ||
| quantized_and_compiled = compiled_linear(input) | ||
| self.assertTrue(compute_error(original, quantized_and_compiled) > 20) | ||
|
|
||
| @skip_if_rocm("ROCm enablement in progress") | ||
| @unittest.skip("Fix later") | ||
| @parametrize("config", [BF16_ACT_CONFIG]) | ||
| def test_to_device(self, config): | ||
| for device in self.GPU_DEVICES: | ||
| linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
| quantize_(linear, config) | ||
| linear.to(device) | ||
|
|
||
| linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
| quantize_(linear, config) | ||
| linear.to(device=device) | ||
|
|
||
| linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
| quantize_(linear, config) | ||
| linear.to(device) | ||
|
|
||
| @skip_if_rocm("ROCm enablement in progress") | ||
| @parametrize("config", [BF16_ACT_CONFIG]) | ||
| def test_module_path(self, config): | ||
| linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
| quantize_(linear.cuda(), config) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Float8SemiSparseTensor'>", | ||
| ) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| torch.save(linear.state_dict(), f) | ||
| f.seek(0) | ||
| state_dict = torch.load(f) | ||
| self.assertEqual( | ||
| str(type(state_dict["weight"])), | ||
| "<class 'torchao.quantization.Float8SemiSparseTensor'>", | ||
| ) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(TestFloat8SemiSparseTensor) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,3 +32,4 @@ class PackingFormat(str, Enum): | |
| needed for the rest of the system to understand the specific format that's adopted. | ||
| """ | ||
| OPAQUE = "opaque" | ||
| # todo: add semi-sparse | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 | ||
| from torchao.quantization.quant_primitives import ( | ||
| _choose_scale_float8, | ||
| _quantize_affine_float8, | ||
| ) | ||
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
| __all__ = ["Float8SemiSparseTensor"] | ||
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| class Float8SemiSparseTensor(TorchAOBaseTensor): | ||
| tensor_data_names = ["sparse", "scale", "meta"] | ||
|
|
||
| def __new__( | ||
| cls, | ||
| sparse: torch.Tensor, | ||
| meta: torch.Tensor, | ||
| scale: torch.Tensor, | ||
| ): | ||
| kwargs = {} | ||
| kwargs["device"] = sparse.device | ||
| kwargs["dtype"] = scale.dtype | ||
| kwargs["requires_grad"] = False | ||
| shape = (sparse.shape[0], 2 * sparse.shape[-1]) | ||
| return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
|
||
| def __init__( | ||
| self, | ||
| sparse: torch.Tensor, | ||
| meta: torch.Tensor, | ||
| scale: torch.Tensor, | ||
| ): | ||
| super().__init__() | ||
| self.sparse = sparse | ||
| self.meta = meta | ||
| self.scale = scale | ||
|
|
||
| def _quantization_type(self): | ||
| return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" | ||
|
|
||
| @classmethod | ||
| def from_hp( | ||
| cls, | ||
| w: torch.Tensor, | ||
| block_size: List[int], | ||
| ): | ||
| from torchao.sparsity.utils import mask_creator | ||
|
|
||
| dense = w * mask_creator(w).bool() | ||
|
|
||
| scale = _choose_scale_float8( | ||
| dense, | ||
| block_size=block_size, | ||
| float8_dtype=torch.float8_e4m3fn, | ||
| ) | ||
|
|
||
| w_fp8 = _quantize_affine_float8( | ||
| dense, | ||
| scale=scale, | ||
| float8_dtype=torch.float8_e4m3fn, | ||
| ) | ||
|
|
||
| sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8) | ||
|
|
||
| return cls( | ||
| sparse, | ||
| meta, | ||
| scale, | ||
| ) | ||
|
|
||
|
|
||
| implements = Float8SemiSparseTensor.implements | ||
| implements_torch_function = Float8SemiSparseTensor.implements_torch_function | ||
|
|
||
|
|
||
| @implements(aten.linear.default) | ||
|
||
| @implements_torch_function(torch.nn.functional.linear) | ||
| def _(func, types, args, kwargs): | ||
| from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 | ||
|
|
||
| input_tensor, weight_tensor, bias = ( | ||
| args[0], | ||
| args[1], | ||
| args[2] if len(args) > 2 else None, | ||
| ) | ||
|
|
||
| input = input_tensor.qdata | ||
| input_scale = input_tensor.scale | ||
| weight = weight_tensor.sparse | ||
| weight_meta = weight_tensor.meta | ||
| weight_scale = weight_tensor.scale | ||
| out_dtype = input_tensor.dtype | ||
|
|
||
| out = rowwise_scaled_linear_sparse_cutlass_f8f8( | ||
| input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype | ||
| ) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| Float8SemiSparseTensor.__module__ = "torchao.quantization" | ||
|
|
||
| # Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True` | ||
| torch.serialization.add_safe_globals([Float8SemiSparseTensor]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this config makes sense, it's not something we support. From what I understand this is a bf16 a + fp8 sparse weight? We only have kernel support for fp8xfp8 +2:4 sparse matmul, no support for mixed input dtypes currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it seems I should be mirroring test_fp8_cutlass_sparse (from test_sparse_api.py) instead
with the difference being using the new flag/config which exposes the tensor subclass being added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
Float8DynamicActivationFloat8SemiSparseWeightConfigshould eventually resolve to your subclass.But I would like to hold off on that until
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. In that case, should we sequence the changes as follows?
Float8SemiSparseTensorwith linear supportFloat8DynamicActivationFloat8SemiSparseWeightConfigafter QRTThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcaip @jerryzh168 Mind confirming that you're onboard with this direction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that sounds good to me.