From 43e1543fa2d2942c2da47c131ab651c2db06d522 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 26 Sep 2023 18:52:23 +0200 Subject: [PATCH] just copy batched spec decoding and make batch early exit strategy work --- README.md | 2 +- setup.py | 2 +- speculative_decoding/speculative_decoding.py | 93 ++++++++++++++++---- 3 files changed, 79 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 5f4c965..8ae4450 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ Also have a few ideas of my own that I will try and share in this repository, if - [x] for early exit, allow an extra transformer block head (separate from main transformer stem) - [x] figure out batched spec decoding - different rows may advance at different rates - [x] further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable +- [x] make batched spec decoding work with early exit strategy - [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings) - [ ] dedicate a morning to microoptimizations -- [ ] make batched spec decoding work with early exit strategy ## Citations diff --git a/setup.py b/setup.py index 2124a15..a2f7769 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.10', + version = '0.0.11', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding.py b/speculative_decoding/speculative_decoding.py index 8e9515a..d7f6e62 100644 --- a/speculative_decoding/speculative_decoding.py +++ b/speculative_decoding/speculative_decoding.py @@ -258,13 +258,14 @@ def speculative_decoding_with_same_model( gamma: int = 5, temperature = 1., filter_thres = 0.9, - lenience = 1. + lenience = 1., + pad_id = 0 ): """ eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192 """ - prompt_seq_len, out, device = prompt.shape[-1], prompt.clone(), prompt.device + batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device sample_num_times = max(0, seq_len - prompt_seq_len) cache = None @@ -273,7 +274,10 @@ def speculative_decoding_with_same_model( num_steps = 0 total_accepted = 0 - while out.shape[-1] < seq_len: + batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None] + seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long) + + while (seq_lens < seq_len).any(): # predict with smaller network @@ -281,7 +285,14 @@ def speculative_decoding_with_same_model( q_sampled_out = [] for _ in range(gamma): - small_logits, small_cache = net(out, cache = small_cache, return_cache = True, return_early_exit_only = True) + small_logits, small_cache = net( + out, + cache = small_cache, + return_cache = True, + return_early_exit_only = True, + seq_start_pos = out.shape[-1] - seq_lens + ) + small_logits = small_logits[:, -1] small_logits = top_k(small_logits, thres = filter_thres) @@ -289,6 +300,7 @@ def speculative_decoding_with_same_model( sample = gumbel_sample(small_logits, temperature = temperature, dim = -1) out = torch.cat((out, sample[..., None]), dim = -1) + seq_lens += 1 q_sampled_out.append(rearrange(sample, 'b -> b 1 1')) @@ -297,7 +309,14 @@ def speculative_decoding_with_same_model( # verify with larger network - logits, cache = net(out, cache = cache, early_exit_cache = small_cache, return_cache = True, start_from_early_exit_hiddens = True) + logits, cache = net( + out, + cache = cache, + early_exit_cache = small_cache, + return_cache = True, + start_from_early_exit_hiddens = True, + seq_start_pos = out.shape[-1] - seq_lens + ) logits = logits[..., -(gamma + 1):, :] logits = top_k(logits, thres = filter_thres) @@ -317,27 +336,69 @@ def speculative_decoding_with_same_model( r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1) accepted = find_first_true_index(r > (p / q)) - n = accepted.item() # need to handle batched spec decoding - total_accepted += n + total_accepted += accepted.float().mean() num_steps += 1 - if n < gamma: - adjusted_prob = F.relu(prob[:, n] - small_prob[:, n]) - prob_next = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) - out = out[:, :-(gamma - n)] + num_rejected = gamma - accepted + has_rejected = num_rejected > 0 - # adjust cache + accepted = rearrange(accepted, 'b -> b 1') + adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) + adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) + adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') - next_seq_len = out.shape[-1] - cache = tuple(t[..., :next_seq_len, :] for t in cache) - small_cache = tuple(t[..., :next_seq_len, :] for t in small_cache) + prob_next = torch.where( + rearrange(has_rejected, '... -> ... 1'), + adjusted_prob, + prob_next + ) + + # do a bunch of slicing and align everything to the right, including kv caches - # sample the additional token + max_num_rejected = num_rejected.amax() + seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long) + seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None] + + seq_lens -= num_rejected + max_seq_len = seq_lens.amax() + + if batch > 1: + out = F.pad(out, (0, max_num_rejected), value = pad_id) + out = out[batch_range, seq_offset_indices] + + cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache) + small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache) + + cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache) + small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache) + + cache = tuple(t[batch_range, seq_offset_indices] for t in cache) + small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache) + + cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache) + small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache) + + if out.shape[-1] > max_seq_len: + out = out[:, -max_seq_len:] + cache = tuple(t[..., -max_seq_len:, :] for t in cache) + small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache) + + # sample the additional token, one of the tricks in the paper to better bound the worst case next_token = torch.multinomial(prob_next, 1) out = torch.cat((out, next_token), dim = -1) + seq_lens += 1 + + # now left align + + num_pad_left = out.shape[-1] - seq_lens + max_pad_left = num_pad_left.amax() + out = F.pad(out, (0, max_pad_left), value = pad_id) + + seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long) + out = out[batch_range, seq_len_range + num_pad_left[..., None]] return out[..., prompt_seq_len:], total_accepted / num_steps