Skip to content

Commit

Permalink
fix multi-gpu with static cache (#32543)
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc authored and ArthurZucker committed Aug 20, 2024
1 parent 0cc22c9 commit 6be68b3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,8 @@ def update(
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

Expand All @@ -1078,8 +1080,6 @@ def update(
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
# operation, that avoids copies and uses less memory.
try:
# If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
Expand Down

0 comments on commit 6be68b3

Please sign in to comment.