diff --git a/README.md b/README.md index 5f4c965..8ae4450 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ Also have a few ideas of my own that I will try and share in this repository, if - [x] for early exit, allow an extra transformer block head (separate from main transformer stem) - [x] figure out batched spec decoding - different rows may advance at different rates - [x] further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable +- [x] make batched spec decoding work with early exit strategy - [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings) - [ ] dedicate a morning to microoptimizations -- [ ] make batched spec decoding work with early exit strategy ## Citations diff --git a/setup.py b/setup.py index 2124a15..a2f7769 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.10', + version = '0.0.11', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding.py b/speculative_decoding/speculative_decoding.py index 8e9515a..d7f6e62 100644 --- a/speculative_decoding/speculative_decoding.py +++ b/speculative_decoding/speculative_decoding.py @@ -258,13 +258,14 @@ def speculative_decoding_with_same_model( gamma: int = 5, temperature = 1., filter_thres = 0.9, - lenience = 1. + lenience = 1., + pad_id = 0 ): """ eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192 """ - prompt_seq_len, out, device = prompt.shape[-1], prompt.clone(), prompt.device + batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device sample_num_times = max(0, seq_len - prompt_seq_len) cache = None @@ -273,7 +274,10 @@ def speculative_decoding_with_same_model( num_steps = 0 total_accepted = 0 - while out.shape[-1] < seq_len: + batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None] + seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long) + + while (seq_lens < seq_len).any(): # predict with smaller network @@ -281,7 +285,14 @@ def speculative_decoding_with_same_model( q_sampled_out = [] for _ in range(gamma): - small_logits, small_cache = net(out, cache = small_cache, return_cache = True, return_early_exit_only = True) + small_logits, small_cache = net( + out, + cache = small_cache, + return_cache = True, + return_early_exit_only = True, + seq_start_pos = out.shape[-1] - seq_lens + ) + small_logits = small_logits[:, -1] small_logits = top_k(small_logits, thres = filter_thres) @@ -289,6 +300,7 @@ def speculative_decoding_with_same_model( sample = gumbel_sample(small_logits, temperature = temperature, dim = -1) out = torch.cat((out, sample[..., None]), dim = -1) + seq_lens += 1 q_sampled_out.append(rearrange(sample, 'b -> b 1 1')) @@ -297,7 +309,14 @@ def speculative_decoding_with_same_model( # verify with larger network - logits, cache = net(out, cache = cache, early_exit_cache = small_cache, return_cache = True, start_from_early_exit_hiddens = True) + logits, cache = net( + out, + cache = cache, + early_exit_cache = small_cache, + return_cache = True, + start_from_early_exit_hiddens = True, + seq_start_pos = out.shape[-1] - seq_lens + ) logits = logits[..., -(gamma + 1):, :] logits = top_k(logits, thres = filter_thres) @@ -317,27 +336,69 @@ def speculative_decoding_with_same_model( r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1) accepted = find_first_true_index(r > (p / q)) - n = accepted.item() # need to handle batched spec decoding - total_accepted += n + total_accepted += accepted.float().mean() num_steps += 1 - if n < gamma: - adjusted_prob = F.relu(prob[:, n] - small_prob[:, n]) - prob_next = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) - out = out[:, :-(gamma - n)] + num_rejected = gamma - accepted + has_rejected = num_rejected > 0 - # adjust cache + accepted = rearrange(accepted, 'b -> b 1') + adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) + adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) + adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') - next_seq_len = out.shape[-1] - cache = tuple(t[..., :next_seq_len, :] for t in cache) - small_cache = tuple(t[..., :next_seq_len, :] for t in small_cache) + prob_next = torch.where( + rearrange(has_rejected, '... -> ... 1'), + adjusted_prob, + prob_next + ) + + # do a bunch of slicing and align everything to the right, including kv caches - # sample the additional token + max_num_rejected = num_rejected.amax() + seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long) + seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None] + + seq_lens -= num_rejected + max_seq_len = seq_lens.amax() + + if batch > 1: + out = F.pad(out, (0, max_num_rejected), value = pad_id) + out = out[batch_range, seq_offset_indices] + + cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache) + small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache) + + cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache) + small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache) + + cache = tuple(t[batch_range, seq_offset_indices] for t in cache) + small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache) + + cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache) + small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache) + + if out.shape[-1] > max_seq_len: + out = out[:, -max_seq_len:] + cache = tuple(t[..., -max_seq_len:, :] for t in cache) + small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache) + + # sample the additional token, one of the tricks in the paper to better bound the worst case next_token = torch.multinomial(prob_next, 1) out = torch.cat((out, next_token), dim = -1) + seq_lens += 1 + + # now left align + + num_pad_left = out.shape[-1] - seq_lens + max_pad_left = num_pad_left.amax() + out = F.pad(out, (0, max_pad_left), value = pad_id) + + seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long) + out = out[batch_range, seq_len_range + num_pad_left[..., None]] return out[..., prompt_seq_len:], total_accepted / num_steps