diff --git a/setup.py b/setup.py index 41bc8c0..28365a4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding_with_prophet.py b/speculative_decoding/speculative_decoding_with_prophet.py index e1669d1..e1ec08c 100644 --- a/speculative_decoding/speculative_decoding_with_prophet.py +++ b/speculative_decoding/speculative_decoding_with_prophet.py @@ -222,16 +222,19 @@ def forward( ): has_start_tokens = exists(start_tokens) + start_token_len = 0 + if exists(start_tokens): + if start_tokens.ndim == 2: + start_tokens = rearrange(start_tokens, 'b d -> b 1 d') + + start_token_len = start_tokens.shape[-2] + if return_loss: - start_index = (1 if has_start_tokens else 0) - x, labels = x[:, start_index:-1], x[:, 1:] + x, labels = x[:, start_token_len:-1], x[:, 1:] x = self.token_emb(x) if exists(start_tokens): - if start_tokens.ndim == 2: - start_tokens = rearrange(start_tokens, 'b d -> b 1 d') - x = torch.cat((start_tokens, x), dim = 1) # handle seq start pos offset @@ -311,11 +314,14 @@ def __init__( model_prophet_same_dim = model.dim == prophet.dim self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Linear(model.dim, prophet.dim, bias = False) + assert num_leading_start_tokens >= 1 self.num_leading_start_tokens = num_leading_start_tokens + self.prophet_train_length = prophet_train_length self.detach_model_embed_for_prophet = detach_model_embed_for_prophet def forward(self, x): + num_start_tokens = self.num_leading_start_tokens batch, seq_len, device = *x.shape, x.device prophet_seq_len = self.prophet_train_length assert seq_len >= prophet_seq_len @@ -334,17 +340,25 @@ def forward(self, x): batch_arange = torch.arange(batch, device = device, dtype = torch.long) prophet_seq_arange = torch.arange(prophet_seq_len, device = device, dtype = torch.long) - num_seq_train_prophet = seq_len - prophet_seq_len + num_seq_train_prophet = seq_len - prophet_seq_len - (num_start_tokens - 1) + offsets = torch.arange(num_seq_train_prophet, device = device, dtype = torch.long) prophet_input = x[ - batch_arange[..., None, None], - prophet_seq_arange + offsets[..., None] + batch_arange[:, None, None], + offsets[..., None] + prophet_seq_arange ] prophet_input = rearrange(prophet_input, '... n -> (...) n') - prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n d -> (b n) d') + start_tokens_arange = torch.arange(num_start_tokens, device = device, dtype = torch.long) + + prophet_start_tokens = prophet_start_tokens[ + batch_arange[:, None, None], + offsets[..., None] + start_tokens_arange + ] + + prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n l d -> (b n) l d') prophet_loss, _ = self.prophet(prophet_input, start_tokens = prophet_start_tokens, return_loss = True) @@ -380,6 +394,7 @@ def speculative_decoding_with_prophet_model( model = net.model to_prophet_start_token = net.to_prophet_start_token prophet = net.prophet + num_start_tokens = net.num_leading_start_tokens batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device @@ -397,17 +412,18 @@ def speculative_decoding_with_prophet_model( # sample the first token from the main model - logits, cache = model(out, return_cache = True) - logits = logits[:, -1:] - logits = top_k(logits, thres = filter_thres) - sample = gumbel_sample(logits, temperature = temperature, dim = -1) - out = torch.cat((out, sample), dim = -1) - seq_lens += 1 + for _ in range(num_start_tokens): + logits, cache = model(out, return_cache = True) + logits = logits[:, -1:] + logits = top_k(logits, thres = filter_thres) + sample = gumbel_sample(logits, temperature = temperature, dim = -1) + out = torch.cat((out, sample), dim = -1) + seq_lens += 1 # now we have the first cached embedding to use as the prophet network start token for the speculative sampling _, embeds = cache - next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1]) + next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:]) while (seq_lens < seq_len).any(): @@ -521,7 +537,7 @@ def speculative_decoding_with_prophet_model( seq_lens += 1 _, embeds = cache - next_prophet_start_tokens = to_prophet_start_token(embeds[:, -1]) + next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:]) # now left align