Skip to content

Commit

Permalink
Handle meta tensors in FX quantization (pytorch#142262)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2622


If module being quantized contains a some meta tensors and some tensors with actual device, we should not fail quantization.

Quantization should also not fail if new quantized module is created on a meta device.

If devices contain meta, copying from meta to meta is not necessary, copying from another device to meta can be skipped.

Test Plan:
```
buck run fbcode//mode/dev-nosan fbcode//torchrec/fb/quant/tests:test_embedding_modules
```

Differential Revision: D66895899
  • Loading branch information
kausv authored and facebook-github-bot committed Dec 19, 2024
1 parent 288aa87 commit 8e5ad73
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch/ao/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def insert_activation_post_process(m, special_act_post_process=None):


def _get_unique_devices_(module):
return {p.device for p in module.parameters()} | {
p.device for p in module.buffers()
return {p.device for p in module.parameters() if p.device.type != "meta"} | {
p.device for p in module.buffers() if p.device.type != "meta"
}


Expand Down Expand Up @@ -779,9 +779,10 @@ def swap_module(
# respect device affinity when swapping modules
devices = _get_unique_devices_(mod)
assert (
len(devices) <= 1
len(devices) <= 1 or (len(devices) == 2 and torch.device("meta") in devices)
), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
device = next(iter(devices)) if len(devices) > 0 else None
#import fbvscode; fbvscode.attach_debugger()
if device:
new_mod.to(device)
return new_mod
Expand Down

0 comments on commit 8e5ad73

Please sign in to comment.