Skip to content

Commit

Permalink
fix caching in speculative decoding with same model
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 20, 2023
1 parent 24e90c3 commit 2ac580e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
66 changes: 42 additions & 24 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def speculative_decoding_with_same_model(
assert prompt.shape[0] == 1, 'batched spec decoding not supported yet'

cache = None
small_cache = None

num_steps = 0
total_accepted = 0
Expand All @@ -203,7 +204,7 @@ def speculative_decoding_with_same_model(
q_sampled_out = []

for _ in range(gamma):
small_logits, cache = net(out, cache = cache, return_cache = True, return_early_exit_only = True)
small_logits, small_cache = net(out, cache = small_cache, return_cache = True, return_early_exit_only = True)
small_logits = small_logits[:, -1]

small_logits = top_k(small_logits, thres = filter_thres)
Expand All @@ -219,7 +220,7 @@ def speculative_decoding_with_same_model(

# verify with larger network

logits, cache = net(out, cache = cache, return_cache = True, start_from_early_exit_hiddens = True)
logits, cache = net(out, cache = cache, early_exit_cache = small_cache, return_cache = True, start_from_early_exit_hiddens = True)

logits = logits[..., -(gamma + 1):, :]
logits = top_k(logits, thres = filter_thres)
Expand Down Expand Up @@ -253,6 +254,7 @@ def speculative_decoding_with_same_model(

next_seq_len = out.shape[-1]
cache = tuple(t[..., :next_seq_len, :] for t in cache)
small_cache = tuple(t[..., :next_seq_len, :] for t in small_cache)

# sample the additional token

Expand Down Expand Up @@ -352,9 +354,9 @@ def __init__(
heads = 8,
dim_head = 64,
ff_mult = 4,
weight_tie_layers = False,
ignore_index = -1,
early_exit_layer = None,
early_exit_extra_transformer_blocks = 0,
detach_early_exit_hiddens = False
):
super().__init__()
Expand All @@ -364,16 +366,11 @@ def __init__(

rotary_emb = RotaryEmbedding(dim = dim_head)

attn = None
ff = None

for _ in range(depth):

if not weight_tie_layers or not (exists(attn) and exists(ff)):
attn = CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb)
ff = FeedForward(dim = dim, mult = ff_mult)

self.layers.append(ModuleList([attn, ff]))
self.layers.append(ModuleList([
CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult)
]))

self.to_logits = nn.Sequential(
RMSNorm(dim),
Expand All @@ -383,8 +380,15 @@ def __init__(
self.detach_early_exit_hiddens = detach_early_exit_hiddens
self.early_exit_layer = early_exit_layer
self.to_early_exit_logits = None
self.early_exit_transformer_blocks = ModuleList([])

if exists(early_exit_layer):
for _ in range(early_exit_extra_transformer_blocks):
self.early_exit_transformer_blocks.append(ModuleList([
CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult)
]))

self.to_early_exit_logits = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, num_tokens, bias = False)
Expand All @@ -398,6 +402,7 @@ def forward(
return_loss = False,
return_cache = False,
cache = None,
early_exit_cache = None,
return_early_exit_only = False,
start_from_early_exit_hiddens = False
):
Expand All @@ -406,44 +411,40 @@ def forward(

x = self.token_emb(x)

# next cache
# setup cache

new_cached_kvs = []

# if cache passed in, just use the last token

cache_kvs = cache_embeds = None

if exists(cache) :
if exists(cache):
cache_kvs, cache_embeds = cache
if not start_from_early_exit_hiddens:
assert not self.training
num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2]
x = x[:, -num_tokens_keep:]

cache_kvs = default(cache_kvs, [])
iter_cache_kvs = iter(cache_kvs)

early_exit_hiddens = None

# handle if previous cached embedding layer from early exit layer passed in

layers = self.layers

if start_from_early_exit_hiddens:
assert not return_early_exit_only and exists(cache_embeds)
assert not return_early_exit_only and exists(early_exit_cache)
early_exit_layer_index = self.early_exit_layer

early_cache_kvs, cache_embeds = early_exit_cache

cache_embeds_len = cache_embeds.shape[-2]

assert cache_embeds_len <= x.shape[-2]

early_exit_layers, layers = layers[:early_exit_layer_index], layers[early_exit_layer_index:]
x = x[:, cache_embeds_len:]

iter_early_cache_kvs = iter(early_cache_kvs)

for ind, (attn, ff) in enumerate(early_exit_layers):
residual = x
attn_out, cached_kv = attn(x, cache = next(iter_cache_kvs, None))
attn_out, cached_kv = attn(x, cache = next(iter_early_cache_kvs, None))
x = residual + attn_out

new_cached_kvs.append(cached_kv)
Expand All @@ -452,6 +453,14 @@ def forward(

x = torch.cat((cache_embeds, x), dim = -2)

# if cache passed in, just use the last token

if exists(cache) :
num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2]
x = x[:, -num_tokens_keep:]

early_exit_hiddens = None

# main transformer body

for ind, (attn, ff) in enumerate(layers):
Expand All @@ -471,6 +480,15 @@ def forward(
if self.detach_early_exit_hiddens:
early_exit_hiddens = early_exit_hiddens.detach()

for early_exit_attn, early_exit_ff in self.early_exit_transformer_blocks:
residual = x
attn_out, cached_kv = early_exit_attn(x, cache = next(iter_cache_kvs, None))
x = residual + attn_out

new_cached_kvs.append(cached_kv)

x = early_exit_ff(x) + x

if return_early_exit_only:
break

Expand Down

0 comments on commit 2ac580e

Please sign in to comment.