diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 267db4c5540d54..ae547905eeb6cb 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -300,8 +300,8 @@ def insert_activation_post_process(m, special_act_post_process=None): def _get_unique_devices_(module): - return {p.device for p in module.parameters()} | { - p.device for p in module.buffers() + return {p.device for p in module.parameters() if p.device.type != "meta"} | { + p.device for p in module.buffers() if p.device.type != "meta" } @@ -779,7 +779,7 @@ def swap_module( # respect device affinity when swapping modules devices = _get_unique_devices_(mod) assert ( - len(devices) <= 1 + len(devices) <= 1 or (len(devices) == 2 and torch.device("meta") in devices) ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" device = next(iter(devices)) if len(devices) > 0 else None if device: