From f4e55c4d240942a565c92664724fa2419f8e8265 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 3 Oct 2023 16:59:45 +0200 Subject: [PATCH] complete speculative sampling with prophet net on cached embeddings idea! --- README.md | 6 +- setup.py | 2 +- .../speculative_decoding_with_prophet.py | 352 +++++++++--------- train_prophet.py | 8 +- 4 files changed, 193 insertions(+), 175 deletions(-) diff --git a/README.md b/README.md index 5833a4f..3b047bb 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,9 @@ Also have a few ideas of my own that I will try and share in this repository, if - [x] figure out batched spec decoding - different rows may advance at different rates - [x] further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable - [x] make batched spec decoding work with early exit strategy +- [x] complete speculative sampling with prophet transformer idea - seems to work well! 🙌 -- [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings) - - [x] complete prophet net with hierarchical transformer training - - [ ] complete the spec decoding algorithm using trained prophet net transformer - +- [ ] get some wandb charts and see how prophet compares with early exit strategy, share on repository - [ ] for early exit strategy, try randomly summing last cached embedding back to the same model (a la alphafold2 recycling), randomly cropped along sequence length, and train early exit loss this way. see if one can improve the gamma this way - [ ] dedicate a morning to microoptimizations diff --git a/setup.py b/setup.py index ab4f765..8e114de 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.16', + version = '0.1.0', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding_with_prophet.py b/speculative_decoding/speculative_decoding_with_prophet.py index 31f4248..e4e2c5b 100644 --- a/speculative_decoding/speculative_decoding_with_prophet.py +++ b/speculative_decoding/speculative_decoding_with_prophet.py @@ -92,171 +92,6 @@ def base_decoding( return out[..., prompt_seq_len:] -# speculative decoding functions - -def safe_div(num, den, eps = 1e-10): - return num / max(den, eps) - -def find_first_true_index(bool_tensor, dim = -1): - return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim) - -@torch.no_grad() -def speculative_decoding_with_prophet_model( - net: Module, - prompt: Tensor, - seq_len: int, - gamma: int = 5, - temperature = 1., - filter_thres = 0.9, - lenience = 1., - pad_id = 0 -): - """ - eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192 - """ - - raise NotImplementedError - - batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device - sample_num_times = max(0, seq_len - prompt_seq_len) - - cache = None - small_cache = None - - num_steps = 0 - total_accepted = 0 - - batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None] - seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long) - - while (seq_lens < seq_len).any(): - - # predict with smaller network - - all_small_logits = [] - q_sampled_out = [] - - for _ in range(gamma): - small_logits, small_cache = net( - out, - cache = small_cache, - return_cache = True, - return_early_exit_only = True, - seq_start_pos = out.shape[-1] - seq_lens - ) - - small_logits = small_logits[:, -1] - - small_logits = top_k(small_logits, thres = filter_thres) - all_small_logits.append(small_logits) - - sample = gumbel_sample(small_logits, temperature = temperature, dim = -1) - out = torch.cat((out, sample[..., None]), dim = -1) - seq_lens += 1 - - q_sampled_out.append(rearrange(sample, 'b -> b 1 1')) - - q_sampled_out = torch.cat(q_sampled_out, dim = -2) - small_logits = torch.stack(all_small_logits, dim = -2) - - # verify with larger network - - logits, cache = net( - out, - cache = cache, - early_exit_cache = small_cache, - return_cache = True, - start_from_early_exit_hiddens = True, - seq_start_pos = out.shape[-1] - seq_lens - ) - - logits = logits[..., -(gamma + 1):, :] - logits = top_k(logits, thres = filter_thres) - - # prob and prob of small model (p(x) and q(x) in algorithm 1) - - prob = safe_div(logits, temperature).softmax(dim = -1) - small_prob = safe_div(small_logits, temperature).softmax(dim = -1) - - p, prob_next = prob[:, :-1], prob[:, -1] - - p = p.gather(-1, q_sampled_out) - q = small_prob.gather(-1, q_sampled_out) * lenience - - p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)] - - r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1) - - accepted = find_first_true_index(r > (p / q)) - - total_accepted += accepted.float().mean() - num_steps += 1 - - num_rejected = gamma - accepted - has_rejected = num_rejected > 0 - - accepted = rearrange(accepted, 'b -> b 1') - accepted.clamp_(max = gamma - 1) - - adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) - adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) - adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') - - prob_next = torch.where( - rearrange(has_rejected, '... -> ... 1'), - adjusted_prob, - prob_next - ) - - # do a bunch of slicing and align everything to the right, including kv caches - - max_num_rejected = num_rejected.amax() - seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long) - seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None] - - seq_lens -= num_rejected - max_seq_len = seq_lens.amax() - - if batch > 1: - out = F.pad(out, (0, max_num_rejected), value = pad_id) - out = out[batch_range, seq_offset_indices] - - cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache) - small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache) - - cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache) - small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache) - - cache = tuple(t[batch_range, seq_offset_indices] for t in cache) - small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache) - - cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache) - small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache) - - if out.shape[-1] > max_seq_len: - left_index = out.shape[-1] - max_seq_len - out = out[:, left_index:] - cache = tuple(t[..., left_index:, :] for t in cache) - small_cache = tuple(t[..., left_index:, :] for t in small_cache) - - # sample the additional token, one of the tricks in the paper to better bound the worst case - - next_token = torch.multinomial(prob_next, 1) - - out = torch.cat((out, next_token), dim = -1) - seq_lens += 1 - - # now left align - - num_pad_left = out.shape[-1] - seq_lens - max_pad_left = num_pad_left.amax() - out = F.pad(out, (0, max_pad_left), value = pad_id) - - seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long) - out = out[batch_range, seq_len_range + num_pad_left[..., None]] - - return out[..., prompt_seq_len:], total_accepted / num_steps - # norm class RMSNorm(Module): @@ -394,6 +229,9 @@ def forward( x = self.token_emb(x) if exists(start_tokens): + if start_tokens.ndim == 2: + start_tokens = rearrange(start_tokens, 'b d -> b 1 d') + x = torch.cat((start_tokens, x), dim = 1) # handle seq start pos offset @@ -504,10 +342,192 @@ def forward(self, x): prophet_input = rearrange(prophet_input, '... n -> (...) n') - prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n d -> (b n) 1 d') + prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n d -> (b n) d') prophet_loss, _ = self.prophet(prophet_input, start_tokens = prophet_start_tokens, return_loss = True) total_loss = total_loss + prophet_loss return total_loss, (main_loss, prophet_loss) + +# speculative decoding functions + +def safe_div(num, den, eps = 1e-10): + return num / max(den, eps) + +def find_first_true_index(bool_tensor, dim = -1): + return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim) + +@torch.no_grad() +def speculative_decoding_with_prophet_model( + net: ModelWithProphetWrapper, + prompt: Tensor, + seq_len: int, + gamma: int = 5, + temperature = 1., + filter_thres = 0.9, + lenience = 1., + pad_id = 0 +): + """ + eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192 + """ + + # extract model, prophet, and model to prophet (if their model dimensions differ) + + model = net.model + to_prophet_start_token = net.to_prophet_start_token + prophet = net.prophet + + batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device + + if (seq_len - prompt_seq_len) <= 0: + return prompt, None + + cache = None + small_cache = None + + num_steps = 0 + total_accepted = 0 + + batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None] + seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long) + + # sample the first token from the main model + + logits, cache = model(out, return_cache = True) + logits = logits[:, -1:] + logits = top_k(logits, thres = filter_thres) + sample = gumbel_sample(logits, temperature = temperature, dim = -1) + out = torch.cat((out, sample), dim = -1) + seq_lens += 1 + + # now we have the first cached embedding to use as the prophet network start token for the speculative sampling + + _, embeds = cache + next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1]) + + while (seq_lens < seq_len).any(): + + # predict with smaller network + + all_small_logits = [] + q_sampled_out = [] + + small_cache = None + num_tokens = 2 # the main model embeddings is 1 step behind the main sequence + + for _ in range(gamma): + small_logits, small_cache = prophet( + out[..., -num_tokens:], + start_tokens = next_prophet_start_tokens, + cache = small_cache, + return_cache = True + ) + + small_logits = small_logits[:, -1:] + + small_logits = top_k(small_logits, thres = filter_thres) + all_small_logits.append(small_logits) + + sample = gumbel_sample(small_logits, temperature = temperature, dim = -1) + out = torch.cat((out, sample), dim = -1) + + seq_lens += 1 + num_tokens += 1 + + q_sampled_out.append(rearrange(sample, '... -> ... 1')) + + q_sampled_out = torch.cat(q_sampled_out, dim = -2) + small_logits = torch.cat(all_small_logits, dim = -2) + + # verify with larger network + + logits, cache = model( + out, + cache = cache, + return_cache = True, + seq_start_pos = out.shape[-1] - seq_lens + ) + + logits = logits[..., -(gamma + 1):, :] + logits = top_k(logits, thres = filter_thres) + + # prob and prob of small model (p(x) and q(x) in algorithm 1) + + prob = safe_div(logits, temperature).softmax(dim = -1) + small_prob = safe_div(small_logits, temperature).softmax(dim = -1) + + p, prob_next = prob[:, :-1], prob[:, -1] + + p = p.gather(-1, q_sampled_out) + q = small_prob.gather(-1, q_sampled_out) * lenience + + p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)] + + r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1) + + accepted = find_first_true_index(r > (p / q)) + + total_accepted += accepted.float().mean() + num_steps += 1 + + num_rejected = gamma - accepted + has_rejected = num_rejected > 0 + + accepted = rearrange(accepted, 'b -> b 1') + accepted.clamp_(max = gamma - 1) + + adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) + adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) + adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') + + prob_next = torch.where( + rearrange(has_rejected, '... -> ... 1'), + adjusted_prob, + prob_next + ) + + # do a bunch of slicing and align everything to the right, including kv caches + + max_num_rejected = num_rejected.amax() + seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long) + seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None] + + seq_lens -= num_rejected + max_seq_len = seq_lens.amax() + + if batch > 1: + out = F.pad(out, (0, max_num_rejected), value = pad_id) + out = out[batch_range, seq_offset_indices] + + cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache) + cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache) + cache = tuple(t[batch_range, seq_offset_indices] for t in cache) + cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache) + + if out.shape[-1] > max_seq_len: + left_index = out.shape[-1] - max_seq_len + out = out[:, left_index:] + cache = tuple(t[..., left_index:, :] for t in cache) + + # sample the additional token, one of the tricks in the paper to better bound the worst case + + next_token = torch.multinomial(prob_next, 1) + + out = torch.cat((out, next_token), dim = -1) + seq_lens += 1 + + _, embeds = cache + next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1]) + + # now left align + + num_pad_left = out.shape[-1] - seq_lens + max_pad_left = num_pad_left.amax() + out = F.pad(out, (0, max_pad_left), value = pad_id) + + seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long) + out = out[batch_range, seq_len_range + num_pad_left[..., None]] + + return out[..., prompt_seq_len:], total_accepted / num_steps diff --git a/train_prophet.py b/train_prophet.py index 3d862a4..db6219b 100644 --- a/train_prophet.py +++ b/train_prophet.py @@ -27,7 +27,7 @@ GRAD_ACCUM_EVERY = 4 LEARNING_RATE = 1e-4 PRIME_LENGTH = 128 -GENERATE_EVERY = 10 +GENERATE_EVERY = 100 GENERATE_LENGTH = 512 SEQ_LEN = 512 GAMMA = 5 @@ -135,8 +135,8 @@ def __len__(self): optim.step() optim.zero_grad() - if False and i % GENERATE_EVERY == 0: - model.eval() + if i % GENERATE_EVERY == 0: + model_and_prophet.eval() inp = random.choice(val_dataset)[:PRIME_LENGTH] prime = decode_tokens(inp) @@ -146,7 +146,7 @@ def __len__(self): sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH) - (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_prophet_model)(model, prompt, GENERATE_LENGTH, GAMMA) + (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_prophet_model)(model_and_prophet, prompt, GENERATE_LENGTH, GAMMA) base_decode_output = decode_tokens(sampled[0]) spec_decode_output = decode_tokens(spec_decode_sampled[0])