Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 1, 2023
1 parent 43e1543 commit 0434d06
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.11',
version = '0.0.12',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
17 changes: 11 additions & 6 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def speculative_decoding(
has_rejected = num_rejected > 0

accepted = rearrange(accepted, 'b -> b 1')
accepted.clamp_(max = gamma - 1)
adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')
Expand Down Expand Up @@ -228,9 +229,10 @@ def speculative_decoding(
small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache)

if out.shape[-1] > max_seq_len:
out = out[:, -max_seq_len:]
cache = tuple(t[..., -max_seq_len:, :] for t in cache)
small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache)
left_index = out.shape[-1] - max_seq_len
out = out[:, left_index:]
cache = tuple(t[..., left_index:, :] for t in cache)
small_cache = tuple(t[..., left_index:, :] for t in small_cache)

# sample the additional token, one of the tricks in the paper to better bound the worst case

Expand Down Expand Up @@ -344,6 +346,8 @@ def speculative_decoding_with_same_model(
has_rejected = num_rejected > 0

accepted = rearrange(accepted, 'b -> b 1')
accepted.clamp_(max = gamma - 1)

adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')
Expand Down Expand Up @@ -380,9 +384,10 @@ def speculative_decoding_with_same_model(
small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache)

if out.shape[-1] > max_seq_len:
out = out[:, -max_seq_len:]
cache = tuple(t[..., -max_seq_len:, :] for t in cache)
small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache)
left_index = out.shape[-1] - max_seq_len
out = out[:, left_index:]
cache = tuple(t[..., left_index:, :] for t in cache)
small_cache = tuple(t[..., left_index:, :] for t in small_cache)

# sample the additional token, one of the tricks in the paper to better bound the worst case

Expand Down

0 comments on commit 0434d06

Please sign in to comment.