Skip to content

Commit

Permalink
necessary setup for batched speculative decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 22, 2023
1 parent a011329 commit 689d459
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 11 deletions.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.0.8',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand All @@ -19,7 +19,6 @@
install_requires=[
'beartype',
'einops>=0.6.1',
'rotary-embedding-torch>=0.3.0',
'torch>=1.12',
],
classifiers=[
Expand Down
66 changes: 57 additions & 9 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,35 @@ def top_k(logits, thres = 0.9):
probs.scatter_(-1, ind, val)
return probs

# rotary embeddings

class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)

def forward(self, seq_len, offset = None):
t = torch.arange(seq_len, device = self.inv_freq.device).type_as(self.inv_freq)
t = rearrange(t, 'n -> 1 n')

if exists(offset):
t = t + offset[..., None]

freqs = torch.einsum('b n , d -> b n d', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
return freqs

def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(pos, t):
seq_len = t.shape[-2]
pos = rearrange(pos, 'b n d -> b 1 n d')
pos = pos[..., -seq_len:, :]
return t * pos.cos() + rotate_half(t) * pos.sin()

# different decoding strategies

@torch.no_grad()
Expand Down Expand Up @@ -282,7 +311,6 @@ def __init__(
self,
dim,
*,
rotary_emb: RotaryEmbedding,
dim_head = 64,
heads = 8,
):
Expand All @@ -292,15 +320,16 @@ def __init__(
dim_inner = dim_head * heads

self.norm = RMSNorm(dim)
self.rotary_emb = rotary_emb

self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim, bias = False)

def forward(
self,
x,
cache = None
cache = None,
context_mask = None,
rotary_emb = None
):
h, device = self.heads, x.device

Expand All @@ -315,7 +344,9 @@ def forward(

cached_kv = torch.stack((k, v))

q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
if exists(rotary_emb):
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

Expand All @@ -324,6 +355,10 @@ def forward(

sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

if exists(context_mask):
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
Expand Down Expand Up @@ -364,11 +399,11 @@ def __init__(

self.layers = ModuleList([])

rotary_emb = RotaryEmbedding(dim = dim_head)
self.rotary_emb = RotaryEmbedding(dim = dim_head)

for _ in range(depth):
self.layers.append(ModuleList([
CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))

Expand Down Expand Up @@ -401,6 +436,7 @@ def forward(
x,
return_loss = False,
return_cache = False,
seq_start_pos_offset = None,
cache = None,
early_exit_cache = None,
return_early_exit_only = False,
Expand All @@ -411,6 +447,18 @@ def forward(

x = self.token_emb(x)

# handle seq start pos offset

self_attn_kv_mask = None
if exists(seq_start_pos_offset):
batch, seq_len = x.shape[:2]
seq_range = torch.arange(seq_len, device = x.device, dtype = torch.long)
self_attn_kv_mask = seq_range >= seq_start_pos_offset[..., None]

# relative positional encoding

rotary_emb = self.rotary_emb(x.shape[-2], offset = seq_start_pos_offset)

# setup cache

new_cached_kvs = []
Expand Down Expand Up @@ -444,7 +492,7 @@ def forward(

for ind, (attn, ff) in enumerate(early_exit_layers):
residual = x
attn_out, cached_kv = attn(x, cache = next(iter_early_cache_kvs, None))
attn_out, cached_kv = attn(x, context_mask = self_attn_kv_mask, rotary_emb = rotary_emb, cache = next(iter_early_cache_kvs, None))
x = residual + attn_out

new_cached_kvs.append(cached_kv)
Expand All @@ -467,7 +515,7 @@ def forward(
layer = ind + 1

residual = x
attn_out, cached_kv = attn(x, cache = next(iter_cache_kvs, None))
attn_out, cached_kv = attn(x, rotary_emb = rotary_emb, cache = next(iter_cache_kvs, None))
x = residual + attn_out

new_cached_kvs.append(cached_kv)
Expand All @@ -482,7 +530,7 @@ def forward(

for early_exit_attn, early_exit_ff in self.early_exit_transformer_blocks:
residual = x
attn_out, cached_kv = early_exit_attn(x, cache = next(iter_cache_kvs, None))
attn_out, cached_kv = early_exit_attn(x, rotary_emb = rotary_emb, cache = next(iter_cache_kvs, None))
x = residual + attn_out

new_cached_kvs.append(cached_kv)
Expand Down

0 comments on commit 689d459

Please sign in to comment.