You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fromtorchaoimportquantize_fromtorchao.quantizationimportint8_weight_onlyfromtorchimportnnimporttorchlinear=nn.Linear(1024, 1024)
quantize_(linear, int8_weight_only())
linear.cuda()
linear.compile()
linear(torch.randn(1, 1024, device="cuda"))
linear.cpu() # this will errorlinear.cuda() # this will also error
Error
Traceback (most recent call last):
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 945, in _apply
torch.utils.swap_tensors(param, param_applied)
File "/home/xxx/python3.10/site-packages/torch/utils/__init__.py", line 51, in swap_tensors
raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/xxx/debug.py", line 11, in <module>
linear.cpu()
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 1118, in cpu
return self._apply(lambda t: t.cpu())
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 949, in _apply
raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight
This seems like a problem for tensor subclass + compile in general, not limited to AQT. Even doing compile(disable=False) still has this error.
@jerryzh168 May I know if anyone is looking at this issue? It seems to affect tensor subclass + compile in general, so maybe I can open an issue in core instead?
* Warn about nesting and dangling skip in updown.py
The updown processor should warn when skip begin/end are garbled
* Update updown.py
adjust nesting level on skip end
* Update run-docs
replacing llama3 with stories15 meant we got stories15.1.
Fixing for `run-docs readme`
To reproduce
Error
This seems like a problem for tensor subclass + compile in general, not limited to AQT. Even doing
compile(disable=False)
still has this error.cc: @jerryzh168
torchao: 0.7.0+git26648c2c (install from source)
pytorch: tested with 2.5.0 and 2.6.0.dev20241102+cu124
The text was updated successfully, but these errors were encountered: