Skip to content

Commit aa43584

Browse files
kausvfacebook-github-bot
authored andcommitted
Handle meta tensors in FX quantization (pytorch#142262)
Summary: X-link: pytorch/torchrec#2622 If module being quantized contains a some meta tensors and some tensors with actual device, we should not fail quantization. Quantization should also not fail if new quantized module is created on a meta device. If devices contain meta, copying from meta to meta is not necessary, copying from another device to meta can be skipped. Test Plan: ``` buck run fbcode//mode/dev-nosan fbcode//torchrec/fb/quant/tests:test_embedding_modules ``` Reviewed By: emlin Differential Revision: D66895899
1 parent e15442a commit aa43584

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torch/ao/quantization/quantize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ def insert_activation_post_process(m, special_act_post_process=None):
300300

301301

302302
def _get_unique_devices_(module):
303-
return {p.device for p in module.parameters()} | {
304-
p.device for p in module.buffers()
303+
return {p.device for p in module.parameters() if p.device.type != "meta"} | {
304+
p.device for p in module.buffers() if p.device.type != "meta"
305305
}
306306

307307

@@ -779,7 +779,7 @@ def swap_module(
779779
# respect device affinity when swapping modules
780780
devices = _get_unique_devices_(mod)
781781
assert (
782-
len(devices) <= 1
782+
len(devices) <= 1 or (len(devices) == 2 and torch.device("meta") in devices)
783783
), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
784784
device = next(iter(devices)) if len(devices) > 0 else None
785785
if device:

0 commit comments

Comments
 (0)