Describe the bug
On a single-device model, TransformerBridge.run_with_cache(input, device=...) moves the underlying model (and the input tensors) to that device and never moves it back. The device= argument is meant to choose where cached activations are stored, not to relocate the model. Both the legacy get_caching_hooks (which documents device as "the device to store on") and ActivationCache.to (whose move_model is deprecated) confirm this. After the call, model lives on device while cfg.device still reports the original device, and any subsequent forward/generate fails.
Code example
import torch
from transformer_lens.model_bridge import TransformerBridge
m = TransformerBridge.boot_transformers("distilgpt2", device="mps") # or "cuda"
toks = m.to_tokens("hello world")
_, cache = m.run_with_cache(toks, device="cpu") # intent: offload the cache to CPU
print(next(m.original_model.parameters()).device) # cpu <- the MODEL was moved
print(m.cfg.device) # mps <- now stale / inconsistent
m.generate(m.to_tokens("again"), max_new_tokens=3) # RuntimeError: Placeholder storage has not been allocated on MPS device!
Root cause: the single-device branch of run_with_cache (bridge.py L2058-L2063) runs self.original_model = self.original_model.to(cache_device) with no restore (the finally at L2082-L2084 only removes hooks). The per-activation caching hook already offloads cache tensors via tensor.detach().to(cache_device) (L1980), so the model move is unnecessary, the n_devices > 1 branch (L2046) already declines to move the model (it warns and leaves cache entries on their per-layer devices).
System Info
- Installed from source; also present in released v3.1.0 through v3.2.1
- macOS (Apple Silicon / MPS), reproduced above; the same code path affects any non-CPU primary device (e.g. CUDA)
- Python 3.12
Note: it does not reproduce on a CPU-only setup, where device="cpu" makes the move a no-op. It surfaces only when the model's device differs from the cache device, which is why CI (CPU) does not catch it.
Additional context
Found while implementing #697 (adding return_cache to generate), where exposing a device= cache-offload option (maintainer's note) led me to run_with_cache(device=...).
Checklist
Describe the bug
On a single-device model,
TransformerBridge.run_with_cache(input, device=...)moves the underlying model (and the input tensors) to that device and never moves it back. Thedevice=argument is meant to choose where cached activations are stored, not to relocate the model. Both the legacyget_caching_hooks(which documentsdeviceas "the device to store on") andActivationCache.to(whosemove_modelis deprecated) confirm this. After the call, model lives ondevicewhilecfg.devicestill reports the original device, and any subsequent forward/generate fails.Code example
Root cause: the single-device branch of
run_with_cache(bridge.py L2058-L2063) runsself.original_model = self.original_model.to(cache_device)with no restore (thefinallyat L2082-L2084 only removes hooks). The per-activation caching hook already offloads cache tensors viatensor.detach().to(cache_device)(L1980), so the model move is unnecessary, then_devices > 1branch (L2046) already declines to move the model (it warns and leaves cache entries on their per-layer devices).System Info
Note: it does not reproduce on a CPU-only setup, where
device="cpu"makes the move a no-op. It surfaces only when the model's device differs from the cache device, which is why CI (CPU) does not catch it.Additional context
Found while implementing #697 (adding
return_cachetogenerate), where exposing adevice=cache-offload option (maintainer's note) led me torun_with_cache(device=...).Checklist