Skip to content

[Bug Report] run_with_cache(device=...) permanently moves the model and leaves cfg.device stale #1336

@RecreationalMath

Description

@RecreationalMath

Describe the bug
On a single-device model, TransformerBridge.run_with_cache(input, device=...) moves the underlying model (and the input tensors) to that device and never moves it back. The device= argument is meant to choose where cached activations are stored, not to relocate the model. Both the legacy get_caching_hooks (which documents device as "the device to store on") and ActivationCache.to (whose move_model is deprecated) confirm this. After the call, model lives on device while cfg.device still reports the original device, and any subsequent forward/generate fails.

Code example

import torch
from transformer_lens.model_bridge import TransformerBridge

m = TransformerBridge.boot_transformers("distilgpt2", device="mps")  # or "cuda"
toks = m.to_tokens("hello world")
_, cache = m.run_with_cache(toks, device="cpu")        # intent: offload the cache to CPU

print(next(m.original_model.parameters()).device)      # cpu  <- the MODEL was moved
print(m.cfg.device)                                     # mps  <- now stale / inconsistent
m.generate(m.to_tokens("again"), max_new_tokens=3)      # RuntimeError: Placeholder storage has not been allocated on MPS device!

Root cause: the single-device branch of run_with_cache (bridge.py L2058-L2063) runs self.original_model = self.original_model.to(cache_device) with no restore (the finally at L2082-L2084 only removes hooks). The per-activation caching hook already offloads cache tensors via tensor.detach().to(cache_device) (L1980), so the model move is unnecessary, the n_devices > 1 branch (L2046) already declines to move the model (it warns and leaves cache entries on their per-layer devices).

System Info

  • Installed from source; also present in released v3.1.0 through v3.2.1
  • macOS (Apple Silicon / MPS), reproduced above; the same code path affects any non-CPU primary device (e.g. CUDA)
  • Python 3.12

Note: it does not reproduce on a CPU-only setup, where device="cpu" makes the move a no-op. It surfaces only when the model's device differs from the cache device, which is why CI (CPU) does not catch it.

Additional context
Found while implementing #697 (adding return_cache to generate), where exposing a device= cache-offload option (maintainer's note) led me to run_with_cache(device=...).

Checklist

  • I have checked that there is no similar issue in the repo

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions