Skip to content

Commit

Permalink
complete speculative sampling with prophet net on cached embeddings i…
Browse files Browse the repository at this point in the history
…dea!
  • Loading branch information
lucidrains committed Oct 3, 2023
1 parent 46408b0 commit f4e55c4
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 175 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ Also have a few ideas of my own that I will try and share in this repository, if
- [x] figure out batched spec decoding - different rows may advance at different rates
- [x] further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable
- [x] make batched spec decoding work with early exit strategy
- [x] complete speculative sampling with prophet transformer idea - seems to work well! 🙌

- [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings)
- [x] complete prophet net with hierarchical transformer training
- [ ] complete the spec decoding algorithm using trained prophet net transformer

- [ ] get some wandb charts and see how prophet compares with early exit strategy, share on repository
- [ ] for early exit strategy, try randomly summing last cached embedding back to the same model (a la alphafold2 recycling), randomly cropped along sequence length, and train early exit loss this way. see if one can improve the gamma this way
- [ ] dedicate a morning to microoptimizations

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.16',
version = '0.1.0',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
352 changes: 186 additions & 166 deletions speculative_decoding/speculative_decoding_with_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,171 +92,6 @@ def base_decoding(

return out[..., prompt_seq_len:]

# speculative decoding functions

def safe_div(num, den, eps = 1e-10):
return num / max(den, eps)

def find_first_true_index(bool_tensor, dim = -1):
return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)

@torch.no_grad()
def speculative_decoding_with_prophet_model(
net: Module,
prompt: Tensor,
seq_len: int,
gamma: int = 5,
temperature = 1.,
filter_thres = 0.9,
lenience = 1.,
pad_id = 0
):
"""
eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
"""

raise NotImplementedError

batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device
sample_num_times = max(0, seq_len - prompt_seq_len)

cache = None
small_cache = None

num_steps = 0
total_accepted = 0

batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

while (seq_lens < seq_len).any():

# predict with smaller network

all_small_logits = []
q_sampled_out = []

for _ in range(gamma):
small_logits, small_cache = net(
out,
cache = small_cache,
return_cache = True,
return_early_exit_only = True,
seq_start_pos = out.shape[-1] - seq_lens
)

small_logits = small_logits[:, -1]

small_logits = top_k(small_logits, thres = filter_thres)
all_small_logits.append(small_logits)

sample = gumbel_sample(small_logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample[..., None]), dim = -1)
seq_lens += 1

q_sampled_out.append(rearrange(sample, 'b -> b 1 1'))

q_sampled_out = torch.cat(q_sampled_out, dim = -2)
small_logits = torch.stack(all_small_logits, dim = -2)

# verify with larger network

logits, cache = net(
out,
cache = cache,
early_exit_cache = small_cache,
return_cache = True,
start_from_early_exit_hiddens = True,
seq_start_pos = out.shape[-1] - seq_lens
)

logits = logits[..., -(gamma + 1):, :]
logits = top_k(logits, thres = filter_thres)

# prob and prob of small model (p(x) and q(x) in algorithm 1)

prob = safe_div(logits, temperature).softmax(dim = -1)
small_prob = safe_div(small_logits, temperature).softmax(dim = -1)

p, prob_next = prob[:, :-1], prob[:, -1]

p = p.gather(-1, q_sampled_out)
q = small_prob.gather(-1, q_sampled_out) * lenience

p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)]

r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)

accepted = find_first_true_index(r > (p / q))

total_accepted += accepted.float().mean()
num_steps += 1

num_rejected = gamma - accepted
has_rejected = num_rejected > 0

accepted = rearrange(accepted, 'b -> b 1')
accepted.clamp_(max = gamma - 1)

adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')

prob_next = torch.where(
rearrange(has_rejected, '... -> ... 1'),
adjusted_prob,
prob_next
)

# do a bunch of slicing and align everything to the right, including kv caches

max_num_rejected = num_rejected.amax()
seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]

seq_lens -= num_rejected
max_seq_len = seq_lens.amax()

if batch > 1:
out = F.pad(out, (0, max_num_rejected), value = pad_id)
out = out[batch_range, seq_offset_indices]

cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache)

cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache)

cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache)

cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)
small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache)

if out.shape[-1] > max_seq_len:
left_index = out.shape[-1] - max_seq_len
out = out[:, left_index:]
cache = tuple(t[..., left_index:, :] for t in cache)
small_cache = tuple(t[..., left_index:, :] for t in small_cache)

# sample the additional token, one of the tricks in the paper to better bound the worst case

next_token = torch.multinomial(prob_next, 1)

out = torch.cat((out, next_token), dim = -1)
seq_lens += 1

# now left align

num_pad_left = out.shape[-1] - seq_lens
max_pad_left = num_pad_left.amax()
out = F.pad(out, (0, max_pad_left), value = pad_id)

seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long)
out = out[batch_range, seq_len_range + num_pad_left[..., None]]

return out[..., prompt_seq_len:], total_accepted / num_steps

# norm

class RMSNorm(Module):
Expand Down Expand Up @@ -394,6 +229,9 @@ def forward(
x = self.token_emb(x)

if exists(start_tokens):
if start_tokens.ndim == 2:
start_tokens = rearrange(start_tokens, 'b d -> b 1 d')

x = torch.cat((start_tokens, x), dim = 1)

# handle seq start pos offset
Expand Down Expand Up @@ -504,10 +342,192 @@ def forward(self, x):

prophet_input = rearrange(prophet_input, '... n -> (...) n')

prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n d -> (b n) 1 d')
prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n d -> (b n) d')

prophet_loss, _ = self.prophet(prophet_input, start_tokens = prophet_start_tokens, return_loss = True)

total_loss = total_loss + prophet_loss

return total_loss, (main_loss, prophet_loss)

# speculative decoding functions

def safe_div(num, den, eps = 1e-10):
return num / max(den, eps)

def find_first_true_index(bool_tensor, dim = -1):
return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)

@torch.no_grad()
def speculative_decoding_with_prophet_model(
net: ModelWithProphetWrapper,
prompt: Tensor,
seq_len: int,
gamma: int = 5,
temperature = 1.,
filter_thres = 0.9,
lenience = 1.,
pad_id = 0
):
"""
eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
"""

# extract model, prophet, and model to prophet (if their model dimensions differ)

model = net.model
to_prophet_start_token = net.to_prophet_start_token
prophet = net.prophet

batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device

if (seq_len - prompt_seq_len) <= 0:
return prompt, None

cache = None
small_cache = None

num_steps = 0
total_accepted = 0

batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

# sample the first token from the main model

logits, cache = model(out, return_cache = True)
logits = logits[:, -1:]
logits = top_k(logits, thres = filter_thres)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample), dim = -1)
seq_lens += 1

# now we have the first cached embedding to use as the prophet network start token for the speculative sampling

_, embeds = cache
next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1])

while (seq_lens < seq_len).any():

# predict with smaller network

all_small_logits = []
q_sampled_out = []

small_cache = None
num_tokens = 2 # the main model embeddings is 1 step behind the main sequence

for _ in range(gamma):
small_logits, small_cache = prophet(
out[..., -num_tokens:],
start_tokens = next_prophet_start_tokens,
cache = small_cache,
return_cache = True
)

small_logits = small_logits[:, -1:]

small_logits = top_k(small_logits, thres = filter_thres)
all_small_logits.append(small_logits)

sample = gumbel_sample(small_logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample), dim = -1)

seq_lens += 1
num_tokens += 1

q_sampled_out.append(rearrange(sample, '... -> ... 1'))

q_sampled_out = torch.cat(q_sampled_out, dim = -2)
small_logits = torch.cat(all_small_logits, dim = -2)

# verify with larger network

logits, cache = model(
out,
cache = cache,
return_cache = True,
seq_start_pos = out.shape[-1] - seq_lens
)

logits = logits[..., -(gamma + 1):, :]
logits = top_k(logits, thres = filter_thres)

# prob and prob of small model (p(x) and q(x) in algorithm 1)

prob = safe_div(logits, temperature).softmax(dim = -1)
small_prob = safe_div(small_logits, temperature).softmax(dim = -1)

p, prob_next = prob[:, :-1], prob[:, -1]

p = p.gather(-1, q_sampled_out)
q = small_prob.gather(-1, q_sampled_out) * lenience

p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)]

r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)

accepted = find_first_true_index(r > (p / q))

total_accepted += accepted.float().mean()
num_steps += 1

num_rejected = gamma - accepted
has_rejected = num_rejected > 0

accepted = rearrange(accepted, 'b -> b 1')
accepted.clamp_(max = gamma - 1)

adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')

prob_next = torch.where(
rearrange(has_rejected, '... -> ... 1'),
adjusted_prob,
prob_next
)

# do a bunch of slicing and align everything to the right, including kv caches

max_num_rejected = num_rejected.amax()
seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]

seq_lens -= num_rejected
max_seq_len = seq_lens.amax()

if batch > 1:
out = F.pad(out, (0, max_num_rejected), value = pad_id)
out = out[batch_range, seq_offset_indices]

cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)

if out.shape[-1] > max_seq_len:
left_index = out.shape[-1] - max_seq_len
out = out[:, left_index:]
cache = tuple(t[..., left_index:, :] for t in cache)

# sample the additional token, one of the tricks in the paper to better bound the worst case

next_token = torch.multinomial(prob_next, 1)

out = torch.cat((out, next_token), dim = -1)
seq_lens += 1

_, embeds = cache
next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1])

# now left align

num_pad_left = out.shape[-1] - seq_lens
max_pad_left = num_pad_left.amax()
out = F.pad(out, (0, max_pad_left), value = pad_id)

seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long)
out = out[batch_range, seq_len_range + num_pad_left[..., None]]

return out[..., prompt_seq_len:], total_accepted / num_steps
Loading

0 comments on commit f4e55c4

Please sign in to comment.