Skip to content

Commit

Permalink
generalize to any number of leading start tokens for prophet
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 4, 2023
1 parent 5e8b036 commit 4b77f26
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
50 changes: 33 additions & 17 deletions speculative_decoding/speculative_decoding_with_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,19 @@ def forward(
):
has_start_tokens = exists(start_tokens)

start_token_len = 0
if exists(start_tokens):
if start_tokens.ndim == 2:
start_tokens = rearrange(start_tokens, 'b d -> b 1 d')

start_token_len = start_tokens.shape[-2]

if return_loss:
start_index = (1 if has_start_tokens else 0)
x, labels = x[:, start_index:-1], x[:, 1:]
x, labels = x[:, start_token_len:-1], x[:, 1:]

x = self.token_emb(x)

if exists(start_tokens):
if start_tokens.ndim == 2:
start_tokens = rearrange(start_tokens, 'b d -> b 1 d')

x = torch.cat((start_tokens, x), dim = 1)

# handle seq start pos offset
Expand Down Expand Up @@ -311,11 +314,14 @@ def __init__(
model_prophet_same_dim = model.dim == prophet.dim
self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Linear(model.dim, prophet.dim, bias = False)

assert num_leading_start_tokens >= 1
self.num_leading_start_tokens = num_leading_start_tokens

self.prophet_train_length = prophet_train_length
self.detach_model_embed_for_prophet = detach_model_embed_for_prophet

def forward(self, x):
num_start_tokens = self.num_leading_start_tokens
batch, seq_len, device = *x.shape, x.device
prophet_seq_len = self.prophet_train_length
assert seq_len >= prophet_seq_len
Expand All @@ -334,17 +340,25 @@ def forward(self, x):
batch_arange = torch.arange(batch, device = device, dtype = torch.long)
prophet_seq_arange = torch.arange(prophet_seq_len, device = device, dtype = torch.long)

num_seq_train_prophet = seq_len - prophet_seq_len
num_seq_train_prophet = seq_len - prophet_seq_len - (num_start_tokens - 1)

offsets = torch.arange(num_seq_train_prophet, device = device, dtype = torch.long)

prophet_input = x[
batch_arange[..., None, None],
prophet_seq_arange + offsets[..., None]
batch_arange[:, None, None],
offsets[..., None] + prophet_seq_arange
]

prophet_input = rearrange(prophet_input, '... n -> (...) n')

prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n d -> (b n) d')
start_tokens_arange = torch.arange(num_start_tokens, device = device, dtype = torch.long)

prophet_start_tokens = prophet_start_tokens[
batch_arange[:, None, None],
offsets[..., None] + start_tokens_arange
]

prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n l d -> (b n) l d')

prophet_loss, _ = self.prophet(prophet_input, start_tokens = prophet_start_tokens, return_loss = True)

Expand Down Expand Up @@ -380,6 +394,7 @@ def speculative_decoding_with_prophet_model(
model = net.model
to_prophet_start_token = net.to_prophet_start_token
prophet = net.prophet
num_start_tokens = net.num_leading_start_tokens

batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device

Expand All @@ -397,17 +412,18 @@ def speculative_decoding_with_prophet_model(

# sample the first token from the main model

logits, cache = model(out, return_cache = True)
logits = logits[:, -1:]
logits = top_k(logits, thres = filter_thres)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample), dim = -1)
seq_lens += 1
for _ in range(num_start_tokens):
logits, cache = model(out, return_cache = True)
logits = logits[:, -1:]
logits = top_k(logits, thres = filter_thres)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample), dim = -1)
seq_lens += 1

# now we have the first cached embedding to use as the prophet network start token for the speculative sampling

_, embeds = cache
next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1])
next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:])

while (seq_lens < seq_len).any():

Expand Down Expand Up @@ -521,7 +537,7 @@ def speculative_decoding_with_prophet_model(
seq_lens += 1

_, embeds = cache
next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1])
next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:])

# now left align

Expand Down

0 comments on commit 4b77f26

Please sign in to comment.