From 2ac580eb9bde327c98fc0206a52a060874894e7f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 20 Sep 2023 18:38:12 +0200 Subject: [PATCH] fix caching in speculative decoding with same model --- setup.py | 2 +- speculative_decoding/speculative_decoding.py | 66 +++++++++++++------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/setup.py b/setup.py index 484cff0..6d8f923 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.6', + version = '0.0.7', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding.py b/speculative_decoding/speculative_decoding.py index ff70c0b..2c716cf 100644 --- a/speculative_decoding/speculative_decoding.py +++ b/speculative_decoding/speculative_decoding.py @@ -191,6 +191,7 @@ def speculative_decoding_with_same_model( assert prompt.shape[0] == 1, 'batched spec decoding not supported yet' cache = None + small_cache = None num_steps = 0 total_accepted = 0 @@ -203,7 +204,7 @@ def speculative_decoding_with_same_model( q_sampled_out = [] for _ in range(gamma): - small_logits, cache = net(out, cache = cache, return_cache = True, return_early_exit_only = True) + small_logits, small_cache = net(out, cache = small_cache, return_cache = True, return_early_exit_only = True) small_logits = small_logits[:, -1] small_logits = top_k(small_logits, thres = filter_thres) @@ -219,7 +220,7 @@ def speculative_decoding_with_same_model( # verify with larger network - logits, cache = net(out, cache = cache, return_cache = True, start_from_early_exit_hiddens = True) + logits, cache = net(out, cache = cache, early_exit_cache = small_cache, return_cache = True, start_from_early_exit_hiddens = True) logits = logits[..., -(gamma + 1):, :] logits = top_k(logits, thres = filter_thres) @@ -253,6 +254,7 @@ def speculative_decoding_with_same_model( next_seq_len = out.shape[-1] cache = tuple(t[..., :next_seq_len, :] for t in cache) + small_cache = tuple(t[..., :next_seq_len, :] for t in small_cache) # sample the additional token @@ -352,9 +354,9 @@ def __init__( heads = 8, dim_head = 64, ff_mult = 4, - weight_tie_layers = False, ignore_index = -1, early_exit_layer = None, + early_exit_extra_transformer_blocks = 0, detach_early_exit_hiddens = False ): super().__init__() @@ -364,16 +366,11 @@ def __init__( rotary_emb = RotaryEmbedding(dim = dim_head) - attn = None - ff = None - for _ in range(depth): - - if not weight_tie_layers or not (exists(attn) and exists(ff)): - attn = CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb) - ff = FeedForward(dim = dim, mult = ff_mult) - - self.layers.append(ModuleList([attn, ff])) + self.layers.append(ModuleList([ + CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb), + FeedForward(dim = dim, mult = ff_mult) + ])) self.to_logits = nn.Sequential( RMSNorm(dim), @@ -383,8 +380,15 @@ def __init__( self.detach_early_exit_hiddens = detach_early_exit_hiddens self.early_exit_layer = early_exit_layer self.to_early_exit_logits = None + self.early_exit_transformer_blocks = ModuleList([]) if exists(early_exit_layer): + for _ in range(early_exit_extra_transformer_blocks): + self.early_exit_transformer_blocks.append(ModuleList([ + CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb), + FeedForward(dim = dim, mult = ff_mult) + ])) + self.to_early_exit_logits = nn.Sequential( RMSNorm(dim), nn.Linear(dim, num_tokens, bias = False) @@ -398,6 +402,7 @@ def forward( return_loss = False, return_cache = False, cache = None, + early_exit_cache = None, return_early_exit_only = False, start_from_early_exit_hiddens = False ): @@ -406,34 +411,28 @@ def forward( x = self.token_emb(x) - # next cache + # setup cache new_cached_kvs = [] - # if cache passed in, just use the last token - cache_kvs = cache_embeds = None - if exists(cache) : + if exists(cache): cache_kvs, cache_embeds = cache - if not start_from_early_exit_hiddens: - assert not self.training - num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2] - x = x[:, -num_tokens_keep:] cache_kvs = default(cache_kvs, []) iter_cache_kvs = iter(cache_kvs) - early_exit_hiddens = None - # handle if previous cached embedding layer from early exit layer passed in layers = self.layers if start_from_early_exit_hiddens: - assert not return_early_exit_only and exists(cache_embeds) + assert not return_early_exit_only and exists(early_exit_cache) early_exit_layer_index = self.early_exit_layer + early_cache_kvs, cache_embeds = early_exit_cache + cache_embeds_len = cache_embeds.shape[-2] assert cache_embeds_len <= x.shape[-2] @@ -441,9 +440,11 @@ def forward( early_exit_layers, layers = layers[:early_exit_layer_index], layers[early_exit_layer_index:] x = x[:, cache_embeds_len:] + iter_early_cache_kvs = iter(early_cache_kvs) + for ind, (attn, ff) in enumerate(early_exit_layers): residual = x - attn_out, cached_kv = attn(x, cache = next(iter_cache_kvs, None)) + attn_out, cached_kv = attn(x, cache = next(iter_early_cache_kvs, None)) x = residual + attn_out new_cached_kvs.append(cached_kv) @@ -452,6 +453,14 @@ def forward( x = torch.cat((cache_embeds, x), dim = -2) + # if cache passed in, just use the last token + + if exists(cache) : + num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2] + x = x[:, -num_tokens_keep:] + + early_exit_hiddens = None + # main transformer body for ind, (attn, ff) in enumerate(layers): @@ -471,6 +480,15 @@ def forward( if self.detach_early_exit_hiddens: early_exit_hiddens = early_exit_hiddens.detach() + for early_exit_attn, early_exit_ff in self.early_exit_transformer_blocks: + residual = x + attn_out, cached_kv = early_exit_attn(x, cache = next(iter_cache_kvs, None)) + x = residual + attn_out + + new_cached_kvs.append(cached_kv) + + x = early_exit_ff(x) + x + if return_early_exit_only: break