diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index cfcec3b0c3c407..0847075a97ad88 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -781,10 +781,10 @@ def swap_module( # respect device affinity when swapping modules devices = _get_unique_devices_(mod) assert ( - len(devices) <= 1 + len(devices) <= 1 or (len(devices) == 2 and torch.device("meta") in devices) ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" device = next(iter(devices)) if len(devices) > 0 else None - if device: + if device and torch.device("meta") not in devices: new_mod.to(device) return new_mod