Skip to content

Commit c801f10

Browse files
authored
[sparse] Add in missing op support for FP8 Sparse (#3014)
* [sparse] Add in missing op support for FP8 Sparse Summary: For ads, we are missing some op support in their lowering stack, namely `.to(dtype=torch.float)` and `.clone()` This PR adds in op support for the `CutlassSemiSparseLayout`. Test Plan: ``` python test/test_sparse_api -k lowering ``` Reviewers: Subscribers: Tasks: Tags: * update * ruff fix * update tests * fix test to add in layout kwarg * skip non h100
1 parent f75b251 commit c801f10

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
from torch.testing._internal import common_utils
1313

1414
from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
15+
from torchao.quantization import (
16+
Float8DynamicActivationFloat8SemiSparseWeightConfig,
17+
Float8DynamicActivationFloat8WeightConfig,
18+
)
1519
from torchao.quantization.quant_api import (
1620
Int4WeightOnlyConfig,
1721
Int8DynamicActivationInt8WeightConfig,
1822
quantize_,
1923
)
2024
from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
25+
from torchao.utils import is_sm_at_least_90
2126

2227
logging.basicConfig(
2328
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -121,6 +126,69 @@ def test_sparse_marlin(self, compile):
121126

122127
torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)
123128

129+
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
130+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
131+
@common_utils.parametrize("compile", [True, False])
132+
def test_fp8_cutlass_sparse(self, compile):
133+
input = torch.rand((256, 256)).half().cuda()
134+
model = (
135+
nn.Sequential(
136+
nn.Linear(256, 1024),
137+
nn.Linear(1024, 256),
138+
)
139+
.half()
140+
.cuda()
141+
.eval()
142+
)
143+
144+
apply_fake_sparsity(model)
145+
model_copy = copy.deepcopy(model)
146+
147+
# Quantized
148+
quantize_(model_copy.bfloat16(), Float8DynamicActivationFloat8WeightConfig())
149+
dense_result = model_copy(input.bfloat16()).half()
150+
151+
# Sparse + quantized
152+
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())
153+
if compile:
154+
model = torch.compile(model)
155+
sparse_result = model(input)
156+
157+
torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)
158+
159+
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
160+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
161+
def test_fp8_cutlass_sparse_lowering_op_clone(self):
162+
with torch.inference_mode():
163+
model = nn.Linear(256, 1024).half().cuda().eval()
164+
apply_fake_sparsity(model)
165+
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())
166+
167+
original = model.weight.original_weight_tensor.tensor_impl.get_plain()
168+
cloned = model.weight.original_weight_tensor.tensor_impl.clone().get_plain()
169+
170+
for o, c in zip(original, cloned):
171+
torch.testing.assert_close(o, c, atol=0.0, rtol=0.0)
172+
173+
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
174+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
175+
def test_fp8_cutlass_sparse_lowering_op_to(self):
176+
# Need to run with inference mode to avoid dispatching to `aten.to_copy`
177+
with torch.inference_mode():
178+
model = nn.Linear(256, 1024).half().cuda().eval()
179+
apply_fake_sparsity(model)
180+
model_copy = copy.deepcopy(model)
181+
expected = model_copy.weight.to(dtype=torch.float)
182+
183+
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())
184+
185+
original = torch.ops.aten.to.dtype_layout(
186+
model.weight.original_weight_tensor.tensor_impl,
187+
dtype=torch.float,
188+
layout=torch.strided,
189+
)
190+
torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1)
191+
124192

125193
class TestBlockSparseWeight(common_utils.TestCase):
126194
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

torchao/dtypes/floatx/cutlass_semi_sparse_layout.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
100100
raise ValueError(
101101
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
102102
)
103+
elif func is aten.clone.default:
104+
return return_and_correct_aliasing(
105+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
106+
)
107+
elif func is aten.to.dtype_layout:
108+
dense, scale, _ = args[0].get_plain()
109+
dense = dense.to(
110+
*args[1:],
111+
dtype=kwargs.get("dtype", dense.dtype),
112+
device=kwargs.get("device", dense.device),
113+
)
114+
return scale * dense
103115

104116
raise NotImplementedError(
105117
f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported"

0 commit comments

Comments
 (0)