Skip to content

Commit

Permalink
moving on to prophet net / megabyte-esque architecture idea
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 18, 2023
1 parent dea6376 commit 082371c
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 16 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ Also have a few ideas of my own that I will try and share in this repository, if

## Todo

- [x] in early exit scheme, cache the hidden layer during spec decoding, as small and large models share the same first few layers

- [ ] figure out batched spec decoding - different rows may advance at different rates
- [ ] in early exit scheme, cache the hidden layer during spec decoding, as small and large models share the same first few layers
- [ ] for early exit, allow an extra transformer block head (separate from main transformer stem)
- [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the key / values)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
3 changes: 2 additions & 1 deletion speculative_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from speculative_decoding.speculative_decoding import (
Decoder,
base_decoding,
speculative_decoding
speculative_decoding,
speculative_decoding_with_same_model
)
132 changes: 121 additions & 11 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from rotary_embedding_torch import RotaryEmbedding
from beartype import beartype
from beartype.typing import Optional

from collections import namedtuple

Expand Down Expand Up @@ -161,8 +160,100 @@ def speculative_decoding(
# adjust cache

next_seq_len = out.shape[-1]
cache = cache[..., :next_seq_len, :]
small_cache = small_cache[..., :next_seq_len, :]
cache = tuple(t[..., :next_seq_len, :] for t in cache)
small_cache = tuple(t[..., :next_seq_len, :] for t in small_cache)

# sample the additional token

next_token = torch.multinomial(prob_next, 1)

out = torch.cat((out, next_token), dim = -1)

return out[..., prompt_seq_len:], total_accepted / num_steps

@torch.no_grad()
def speculative_decoding_with_same_model(
net: Module,
small_net: Module,
prompt: Tensor,
seq_len: int,
gamma: int = 5,
temperature = 1.,
filter_thres = 0.9,
lenience = 1.
):
"""
eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
"""

prompt_seq_len, out, device = prompt.shape[-1], prompt.clone(), prompt.device
sample_num_times = max(0, seq_len - prompt_seq_len)

assert prompt.shape[0] == 1, 'batched spec decoding not supported yet'

cache = None

num_steps = 0
total_accepted = 0

while out.shape[-1] < seq_len:

# predict with smaller network

all_small_logits = []
q_sampled_out = []

for _ in range(gamma):
small_logits, cache = small_net(out, cache = cache, return_cache = True)
small_logits = small_logits[:, -1]

small_logits = top_k(small_logits, thres = filter_thres)
all_small_logits.append(small_logits)

sample = gumbel_sample(small_logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample[..., None]), dim = -1)

q_sampled_out.append(rearrange(sample, 'b -> b 1 1'))

q_sampled_out = torch.cat(q_sampled_out, dim = -2)
small_logits = torch.stack(all_small_logits, dim = -2)

# verify with larger network

logits, cache = net(out, cache = cache, return_cache = True, start_from_early_exit_hiddens = True)

logits = logits[..., -(gamma + 1):, :]
logits = top_k(logits, thres = filter_thres)

# prob and prob of small model (p(x) and q(x) in algorithm 1)

prob = safe_div(logits, temperature).softmax(dim = -1)
small_prob = safe_div(small_logits, temperature).softmax(dim = -1)

p, prob_next = prob[:, :-1], prob[:, -1]

p = p.gather(-1, q_sampled_out)
q = small_prob.gather(-1, q_sampled_out) * lenience

p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)]

r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)

accepted = find_first_true_index(r > (p / q))
n = accepted.item() # need to handle batched spec decoding

total_accepted += n
num_steps += 1

if n < gamma:
adjusted_prob = F.relu(prob[:, n] - small_prob[:, n])
prob_next = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
out = out[:, :-(gamma - n)]

# adjust cache

next_seq_len = out.shape[-1]
cache = tuple(t[..., :next_seq_len, :] for t in cache)

# sample the additional token

Expand Down Expand Up @@ -309,7 +400,7 @@ def forward(
return_cache = False,
cache = None,
return_early_exit_only = False,
start_from_early_exit_hiddens: Optional[Tensor] = None
start_from_early_exit_hiddens = False
):
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
Expand All @@ -324,11 +415,12 @@ def forward(

cache_kvs = cache_embeds = None

if exists(cache):
if exists(cache) :
cache_kvs, cache_embeds = cache
assert not self.training
num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2]
x = x[:, -num_tokens_keep:]
if not start_from_early_exit_hiddens:
assert not self.training
num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2]
x = x[:, -num_tokens_keep:]

cache_kvs = default(cache_kvs, [])
iter_cache_kvs = iter(cache_kvs)
Expand All @@ -338,10 +430,28 @@ def forward(
# handle if previous cached embedding layer from early exit layer passed in

layers = self.layers

if start_from_early_exit_hiddens:
assert not return_early_exit_only and exists(self.early_exit_layer)
layers = layers[self.early_exit_layer - 1:]
x = start_from_early_exit_hiddens
assert not return_early_exit_only and exists(cache_embeds)
early_exit_layer_index = self.early_exit_layer

cache_embeds_len = cache_embeds.shape[-2]

assert cache_embeds_len <= x.shape[-2]

early_exit_layers, layers = layers[:early_exit_layer_index], layers[early_exit_layer_index:]
x = x[:, cache_embeds_len:]

for ind, (attn, ff) in enumerate(early_exit_layers):
residual = x
attn_out, cached_kv = attn(x, cache = next(iter_cache_kvs, None))
x = residual + attn_out

new_cached_kvs.append(cached_kv)

x = ff(x) + x

x = torch.cat((cache_embeds, x), dim = -2)

# main transformer body

Expand Down
4 changes: 2 additions & 2 deletions train_early_exit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from speculative_decoding import (
Decoder,
base_decoding,
speculative_decoding
speculative_decoding_with_same_model
)

# constants
Expand Down Expand Up @@ -145,7 +145,7 @@ def __len__(self):
sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

small_model = partial(model, return_early_exit_only = True)
(spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding)(model,small_model, prompt, GENERATE_LENGTH, GAMMA)
(spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_same_model)(model,small_model, prompt, GENERATE_LENGTH, GAMMA)

base_decode_output = decode_tokens(sampled[0])
spec_decode_output = decode_tokens(spec_decode_sampled[0])
Expand Down

0 comments on commit 082371c

Please sign in to comment.