From 4587a7ae43b5dafaee371665c4e796da1a977356 Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Fri, 20 Dec 2024 23:59:41 -0800 Subject: [PATCH] Handle meta tensors in FX quantization (#142262) Summary: X-link: https://github.com/pytorch/torchrec/pull/2622 If module being quantized contains a some meta tensors and some tensors with actual device, we should not fail quantization. Quantization should also not fail if new quantized module is created on a meta device. If devices contain meta, copying from meta to meta is not necessary, copying from another device to meta can be skipped. Test Plan: ``` buck run fbcode//mode/dev-nosan fbcode//torchrec/fb/quant/tests:test_embedding_modules ``` Differential Revision: D66895899 --- torch/ao/quantization/quantize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 267db4c5540d54..ae547905eeb6cb 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -300,8 +300,8 @@ def insert_activation_post_process(m, special_act_post_process=None): def _get_unique_devices_(module): - return {p.device for p in module.parameters()} | { - p.device for p in module.buffers() + return {p.device for p in module.parameters() if p.device.type != "meta"} | { + p.device for p in module.buffers() if p.device.type != "meta" } @@ -779,7 +779,7 @@ def swap_module( # respect device affinity when swapping modules devices = _get_unique_devices_(mod) assert ( - len(devices) <= 1 + len(devices) <= 1 or (len(devices) == 2 and torch.device("meta") in devices) ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" device = next(iter(devices)) if len(devices) > 0 else None if device: