|
12 | 12 | from torch.testing._internal import common_utils
|
13 | 13 |
|
14 | 14 | from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
|
| 15 | +from torchao.quantization import ( |
| 16 | + Float8DynamicActivationFloat8SemiSparseWeightConfig, |
| 17 | + Float8DynamicActivationFloat8WeightConfig, |
| 18 | +) |
15 | 19 | from torchao.quantization.quant_api import (
|
16 | 20 | Int4WeightOnlyConfig,
|
17 | 21 | Int8DynamicActivationInt8WeightConfig,
|
18 | 22 | quantize_,
|
19 | 23 | )
|
20 | 24 | from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
|
| 25 | +from torchao.utils import is_sm_at_least_90 |
21 | 26 |
|
22 | 27 | logging.basicConfig(
|
23 | 28 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
@@ -121,6 +126,69 @@ def test_sparse_marlin(self, compile):
|
121 | 126 |
|
122 | 127 | torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)
|
123 | 128 |
|
| 129 | + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") |
| 130 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 131 | + @common_utils.parametrize("compile", [True, False]) |
| 132 | + def test_fp8_cutlass_sparse(self, compile): |
| 133 | + input = torch.rand((256, 256)).half().cuda() |
| 134 | + model = ( |
| 135 | + nn.Sequential( |
| 136 | + nn.Linear(256, 1024), |
| 137 | + nn.Linear(1024, 256), |
| 138 | + ) |
| 139 | + .half() |
| 140 | + .cuda() |
| 141 | + .eval() |
| 142 | + ) |
| 143 | + |
| 144 | + apply_fake_sparsity(model) |
| 145 | + model_copy = copy.deepcopy(model) |
| 146 | + |
| 147 | + # Quantized |
| 148 | + quantize_(model_copy.bfloat16(), Float8DynamicActivationFloat8WeightConfig()) |
| 149 | + dense_result = model_copy(input.bfloat16()).half() |
| 150 | + |
| 151 | + # Sparse + quantized |
| 152 | + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) |
| 153 | + if compile: |
| 154 | + model = torch.compile(model) |
| 155 | + sparse_result = model(input) |
| 156 | + |
| 157 | + torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) |
| 158 | + |
| 159 | + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") |
| 160 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 161 | + def test_fp8_cutlass_sparse_lowering_op_clone(self): |
| 162 | + with torch.inference_mode(): |
| 163 | + model = nn.Linear(256, 1024).half().cuda().eval() |
| 164 | + apply_fake_sparsity(model) |
| 165 | + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) |
| 166 | + |
| 167 | + original = model.weight.original_weight_tensor.tensor_impl.get_plain() |
| 168 | + cloned = model.weight.original_weight_tensor.tensor_impl.clone().get_plain() |
| 169 | + |
| 170 | + for o, c in zip(original, cloned): |
| 171 | + torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) |
| 172 | + |
| 173 | + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") |
| 174 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 175 | + def test_fp8_cutlass_sparse_lowering_op_to(self): |
| 176 | + # Need to run with inference mode to avoid dispatching to `aten.to_copy` |
| 177 | + with torch.inference_mode(): |
| 178 | + model = nn.Linear(256, 1024).half().cuda().eval() |
| 179 | + apply_fake_sparsity(model) |
| 180 | + model_copy = copy.deepcopy(model) |
| 181 | + expected = model_copy.weight.to(dtype=torch.float) |
| 182 | + |
| 183 | + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) |
| 184 | + |
| 185 | + original = torch.ops.aten.to.dtype_layout( |
| 186 | + model.weight.original_weight_tensor.tensor_impl, |
| 187 | + dtype=torch.float, |
| 188 | + layout=torch.strided, |
| 189 | + ) |
| 190 | + torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1) |
| 191 | + |
124 | 192 |
|
125 | 193 | class TestBlockSparseWeight(common_utils.TestCase):
|
126 | 194 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
|
0 commit comments