Skip to content

Commit

Permalink
fix early exit spec decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 26, 2023
1 parent e6ebcdc commit 8e425d4
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ def speculative_decoding_with_same_model(
prompt_seq_len, out, device = prompt.shape[-1], prompt.clone(), prompt.device
sample_num_times = max(0, seq_len - prompt_seq_len)

assert prompt.shape[0] == 1, 'batched spec decoding not supported yet'

cache = None
small_cache = None

Expand Down Expand Up @@ -540,7 +538,10 @@ def forward(
early_exit_layers, layers = layers[:early_exit_layer_index], layers[early_exit_layer_index:]
x = x[:, cache_embeds_len:]

iter_early_cache_kvs = iter(early_cache_kvs)
if exists(early_cache_kvs):
iter_early_cache_kvs = iter(early_cache_kvs.unbind(dim = 1))
else:
iter_early_cache_kvs = iter([])

for ind, (attn, ff) in enumerate(early_exit_layers):
residual = x
Expand Down

0 comments on commit 8e425d4

Please sign in to comment.