Skip to content

Commit

Permalink
[tests] Refactor TorchAO serialization fast tests (#10271)
Browse files Browse the repository at this point in the history
refactor
  • Loading branch information
a-r-r-o-w authored and sayakpaul committed Dec 23, 2024
1 parent 2ff6512 commit 1d6c9f4
Showing 1 changed file with 35 additions and 40 deletions.
75 changes: 35 additions & 40 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,21 +447,19 @@ def test_wrong_config(self):
self.get_dummy_components(TorchAoConfig("int42"))


# This class is not to be run as a test by itself. See the tests that follow this class
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
quant_method, quant_method_kwargs = None, None
device = "cuda"

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def get_dummy_model(self, device=None):
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
quantized_model = FluxTransformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
Expand Down Expand Up @@ -497,15 +495,15 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
"timestep": timestep,
}

def test_original_model_expected_slice(self):
quantized_model = self.get_dummy_model(torch_device)
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

def check_serialization_expected_slice(self, expected_slice):
quantized_model = self.get_dummy_model(self.device)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)

with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
Expand All @@ -524,36 +522,33 @@ def check_serialization_expected_slice(self, expected_slice):
)
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

def test_serialization_expected_slice(self):
self.check_serialization_expected_slice(self.serialized_expected_slice)


class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"


class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"


class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"


class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
def test_int_a8w8_cuda(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cuda"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)

def test_int_a16w8_cuda(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cuda"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)

def test_int_a8w8_cpu(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)

def test_int_a16w8_cpu(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)


# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
Expand Down

0 comments on commit 1d6c9f4

Please sign in to comment.