diff --git a/setup.py b/setup.py index a2f7769..3b92aa2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.11', + version = '0.0.12', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding.py b/speculative_decoding/speculative_decoding.py index d7f6e62..cbfe332 100644 --- a/speculative_decoding/speculative_decoding.py +++ b/speculative_decoding/speculative_decoding.py @@ -192,6 +192,7 @@ def speculative_decoding( has_rejected = num_rejected > 0 accepted = rearrange(accepted, 'b -> b 1') + accepted.clamp_(max = gamma - 1) adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') @@ -228,9 +229,10 @@ def speculative_decoding( small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache) if out.shape[-1] > max_seq_len: - out = out[:, -max_seq_len:] - cache = tuple(t[..., -max_seq_len:, :] for t in cache) - small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache) + left_index = out.shape[-1] - max_seq_len + out = out[:, left_index:] + cache = tuple(t[..., left_index:, :] for t in cache) + small_cache = tuple(t[..., left_index:, :] for t in small_cache) # sample the additional token, one of the tricks in the paper to better bound the worst case @@ -344,6 +346,8 @@ def speculative_decoding_with_same_model( has_rejected = num_rejected > 0 accepted = rearrange(accepted, 'b -> b 1') + accepted.clamp_(max = gamma - 1) + adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') @@ -380,9 +384,10 @@ def speculative_decoding_with_same_model( small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache) if out.shape[-1] > max_seq_len: - out = out[:, -max_seq_len:] - cache = tuple(t[..., -max_seq_len:, :] for t in cache) - small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache) + left_index = out.shape[-1] - max_seq_len + out = out[:, left_index:] + cache = tuple(t[..., left_index:, :] for t in cache) + small_cache = tuple(t[..., left_index:, :] for t in small_cache) # sample the additional token, one of the tricks in the paper to better bound the worst case