Skip to content

Commit

Permalink
just copy batched spec decoding and make batch early exit strategy work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 26, 2023
1 parent 8e425d4 commit 43e1543
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ Also have a few ideas of my own that I will try and share in this repository, if
- [x] for early exit, allow an extra transformer block head (separate from main transformer stem)
- [x] figure out batched spec decoding - different rows may advance at different rates
- [x] further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable
- [x] make batched spec decoding work with early exit strategy

- [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings)
- [ ] dedicate a morning to microoptimizations
- [ ] make batched spec decoding work with early exit strategy

## Citations

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.10',
version = '0.0.11',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
93 changes: 77 additions & 16 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,14 @@ def speculative_decoding_with_same_model(
gamma: int = 5,
temperature = 1.,
filter_thres = 0.9,
lenience = 1.
lenience = 1.,
pad_id = 0
):
"""
eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
"""

prompt_seq_len, out, device = prompt.shape[-1], prompt.clone(), prompt.device
batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device
sample_num_times = max(0, seq_len - prompt_seq_len)

cache = None
Expand All @@ -273,22 +274,33 @@ def speculative_decoding_with_same_model(
num_steps = 0
total_accepted = 0

while out.shape[-1] < seq_len:
batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

while (seq_lens < seq_len).any():

# predict with smaller network

all_small_logits = []
q_sampled_out = []

for _ in range(gamma):
small_logits, small_cache = net(out, cache = small_cache, return_cache = True, return_early_exit_only = True)
small_logits, small_cache = net(
out,
cache = small_cache,
return_cache = True,
return_early_exit_only = True,
seq_start_pos = out.shape[-1] - seq_lens
)

small_logits = small_logits[:, -1]

small_logits = top_k(small_logits, thres = filter_thres)
all_small_logits.append(small_logits)

sample = gumbel_sample(small_logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample[..., None]), dim = -1)
seq_lens += 1

q_sampled_out.append(rearrange(sample, 'b -> b 1 1'))

Expand All @@ -297,7 +309,14 @@ def speculative_decoding_with_same_model(

# verify with larger network

logits, cache = net(out, cache = cache, early_exit_cache = small_cache, return_cache = True, start_from_early_exit_hiddens = True)
logits, cache = net(
out,
cache = cache,
early_exit_cache = small_cache,
return_cache = True,
start_from_early_exit_hiddens = True,
seq_start_pos = out.shape[-1] - seq_lens
)

logits = logits[..., -(gamma + 1):, :]
logits = top_k(logits, thres = filter_thres)
Expand All @@ -317,27 +336,69 @@ def speculative_decoding_with_same_model(
r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)

accepted = find_first_true_index(r > (p / q))
n = accepted.item() # need to handle batched spec decoding

total_accepted += n
total_accepted += accepted.float().mean()
num_steps += 1

if n < gamma:
adjusted_prob = F.relu(prob[:, n] - small_prob[:, n])
prob_next = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
out = out[:, :-(gamma - n)]
num_rejected = gamma - accepted
has_rejected = num_rejected > 0

# adjust cache
accepted = rearrange(accepted, 'b -> b 1')
adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')

next_seq_len = out.shape[-1]
cache = tuple(t[..., :next_seq_len, :] for t in cache)
small_cache = tuple(t[..., :next_seq_len, :] for t in small_cache)
prob_next = torch.where(
rearrange(has_rejected, '... -> ... 1'),
adjusted_prob,
prob_next
)

# do a bunch of slicing and align everything to the right, including kv caches

# sample the additional token
max_num_rejected = num_rejected.amax()
seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]

seq_lens -= num_rejected
max_seq_len = seq_lens.amax()

if batch > 1:
out = F.pad(out, (0, max_num_rejected), value = pad_id)
out = out[batch_range, seq_offset_indices]

cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache)

cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache)

cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache)

cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)
small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache)

if out.shape[-1] > max_seq_len:
out = out[:, -max_seq_len:]
cache = tuple(t[..., -max_seq_len:, :] for t in cache)
small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache)

# sample the additional token, one of the tricks in the paper to better bound the worst case

next_token = torch.multinomial(prob_next, 1)

out = torch.cat((out, next_token), dim = -1)
seq_lens += 1

# now left align

num_pad_left = out.shape[-1] - seq_lens
max_pad_left = num_pad_left.amax()
out = F.pad(out, (0, max_pad_left), value = pad_id)

seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long)
out = out[batch_range, seq_len_range + num_pad_left[..., None]]

return out[..., prompt_seq_len:], total_accepted / num_steps

Expand Down

0 comments on commit 43e1543

Please sign in to comment.