Skip to content

Commit

Permalink
Handle meta tensors in FX quantization (pytorch#142262)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2622


If module being quantized contains a some meta tensors and some tensors with actual device, we should not fail quantization.

Quantization should also not fail if new quantized module is created on a meta device.

Test Plan:
```
buck run fbcode//mode/dev-nosan fbcode//torchrec/fb/quant/tests:test_embedding_modules
```

Differential Revision: D66895899
  • Loading branch information
kausv authored and facebook-github-bot committed Dec 10, 2024
1 parent 20718cd commit fc96877
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,10 @@ def swap_module(
# respect device affinity when swapping modules
devices = _get_unique_devices_(mod)
assert (
len(devices) <= 1
len(devices) <= 1 or (len(devices) == 2 and torch.device("meta") in devices)
), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
device = next(iter(devices)) if len(devices) > 0 else None
if device:
if device and torch.device("meta") not in devices:
new_mod.to(device)
return new_mod

Expand Down

0 comments on commit fc96877

Please sign in to comment.