diff --git a/README.md b/README.md index 8ae4450..693fca0 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,9 @@ Also have a few ideas of my own that I will try and share in this repository, if - [x] make batched spec decoding work with early exit strategy - [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings) + - [x] complete prophet net with hierarchical transformer training + - [ ] complete the spec decoding algorithm using trained prophet net transformer + - [ ] dedicate a morning to microoptimizations ## Citations diff --git a/setup.py b/setup.py index 3b92aa2..a3ada28 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.12', + version = '0.0.14', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding_with_prophet.py b/speculative_decoding/speculative_decoding_with_prophet.py new file mode 100644 index 0000000..3e093f4 --- /dev/null +++ b/speculative_decoding/speculative_decoding_with_prophet.py @@ -0,0 +1,513 @@ +import math + +import torch +from torch.nn import Module, ModuleList +from torch import nn, einsum, Tensor +import torch.nn.functional as F + +from rotary_embedding_torch import RotaryEmbedding +from beartype import beartype + +from collections import namedtuple + +from einops import rearrange + +# constants + +Cache = namedtuple('Cache', ['cached_kvs', 'embeds']) + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# sampling helpers + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + +def gumbel_sample(t, temperature = 1., dim = -1): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) + +def top_k(logits, thres = 0.9): + k = math.ceil((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(-1, ind, val) + return probs + +# rotary embeddings + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, seq_len): + t = torch.arange(seq_len, device = self.inv_freq.device).type_as(self.inv_freq) + freqs = einsum('i, j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + return freqs + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(pos, t): + seq_len = t.shape[-2] + pos = pos[-seq_len:, :] + return t * pos.cos() + rotate_half(t) * pos.sin() + +# different decoding strategies + +@torch.no_grad() +def base_decoding( + net: Module, + prompt: Tensor, + seq_len: int, + temperature = 1., + filter_thres = 0.9, +): + prompt_seq_len, out = prompt.shape[-1], prompt.clone() + sample_num_times = max(0, seq_len - prompt_seq_len) + + cache = None + + for _ in range(sample_num_times): + logits, cache = net(out, cache = cache, return_cache = True) + logits = logits[:, -1] + + logits = top_k(logits, thres = filter_thres) + sample = gumbel_sample(logits, temperature = temperature, dim = -1) + + out = torch.cat((out, sample[..., None]), dim = -1) + + return out[..., prompt_seq_len:] + +# speculative decoding functions + +def safe_div(num, den, eps = 1e-10): + return num / max(den, eps) + +def find_first_true_index(bool_tensor, dim = -1): + return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim) + +@torch.no_grad() +def speculative_decoding_with_prophet_model( + net: Module, + prompt: Tensor, + seq_len: int, + gamma: int = 5, + temperature = 1., + filter_thres = 0.9, + lenience = 1., + pad_id = 0 +): + """ + eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192 + """ + + raise NotImplementedError + + batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device + sample_num_times = max(0, seq_len - prompt_seq_len) + + cache = None + small_cache = None + + num_steps = 0 + total_accepted = 0 + + batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None] + seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long) + + while (seq_lens < seq_len).any(): + + # predict with smaller network + + all_small_logits = [] + q_sampled_out = [] + + for _ in range(gamma): + small_logits, small_cache = net( + out, + cache = small_cache, + return_cache = True, + return_early_exit_only = True, + seq_start_pos = out.shape[-1] - seq_lens + ) + + small_logits = small_logits[:, -1] + + small_logits = top_k(small_logits, thres = filter_thres) + all_small_logits.append(small_logits) + + sample = gumbel_sample(small_logits, temperature = temperature, dim = -1) + out = torch.cat((out, sample[..., None]), dim = -1) + seq_lens += 1 + + q_sampled_out.append(rearrange(sample, 'b -> b 1 1')) + + q_sampled_out = torch.cat(q_sampled_out, dim = -2) + small_logits = torch.stack(all_small_logits, dim = -2) + + # verify with larger network + + logits, cache = net( + out, + cache = cache, + early_exit_cache = small_cache, + return_cache = True, + start_from_early_exit_hiddens = True, + seq_start_pos = out.shape[-1] - seq_lens + ) + + logits = logits[..., -(gamma + 1):, :] + logits = top_k(logits, thres = filter_thres) + + # prob and prob of small model (p(x) and q(x) in algorithm 1) + + prob = safe_div(logits, temperature).softmax(dim = -1) + small_prob = safe_div(small_logits, temperature).softmax(dim = -1) + + p, prob_next = prob[:, :-1], prob[:, -1] + + p = p.gather(-1, q_sampled_out) + q = small_prob.gather(-1, q_sampled_out) * lenience + + p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)] + + r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1) + + accepted = find_first_true_index(r > (p / q)) + + total_accepted += accepted.float().mean() + num_steps += 1 + + num_rejected = gamma - accepted + has_rejected = num_rejected > 0 + + accepted = rearrange(accepted, 'b -> b 1') + accepted.clamp_(max = gamma - 1) + + adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) + adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) + adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') + + prob_next = torch.where( + rearrange(has_rejected, '... -> ... 1'), + adjusted_prob, + prob_next + ) + + # do a bunch of slicing and align everything to the right, including kv caches + + max_num_rejected = num_rejected.amax() + seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long) + seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None] + + seq_lens -= num_rejected + max_seq_len = seq_lens.amax() + + if batch > 1: + out = F.pad(out, (0, max_num_rejected), value = pad_id) + out = out[batch_range, seq_offset_indices] + + cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache) + small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache) + + cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache) + small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache) + + cache = tuple(t[batch_range, seq_offset_indices] for t in cache) + small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache) + + cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache) + small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache) + + if out.shape[-1] > max_seq_len: + left_index = out.shape[-1] - max_seq_len + out = out[:, left_index:] + cache = tuple(t[..., left_index:, :] for t in cache) + small_cache = tuple(t[..., left_index:, :] for t in small_cache) + + # sample the additional token, one of the tricks in the paper to better bound the worst case + + next_token = torch.multinomial(prob_next, 1) + + out = torch.cat((out, next_token), dim = -1) + seq_lens += 1 + + # now left align + + num_pad_left = out.shape[-1] - seq_lens + max_pad_left = num_pad_left.amax() + out = F.pad(out, (0, max_pad_left), value = pad_id) + + seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long) + out = out[batch_range, seq_len_range + num_pad_left[..., None]] + + return out[..., prompt_seq_len:], total_accepted / num_steps + +# norm + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim = -1) * self.scale * self.gamma + +# attention and feedforward + +class CausalAttention(Module): + def __init__( + self, + dim, + *, + dim_head = 64, + heads = 8, + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + dim_inner = dim_head * heads + + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) + self.to_out = nn.Linear(dim_inner, dim, bias = False) + + def forward( + self, + x, + cache = None, + context_mask = None, + rotary_emb = None + ): + h, device = self.heads, x.device + + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) + + if exists(cache): + ck, cv = cache.unbind(dim = 1) + k = torch.cat((ck, k), dim = -2) + v = torch.cat((cv, v), dim = -2) + + cached_kv = torch.stack((k, v), dim = 1) + + if exists(rotary_emb): + q = apply_rotary_pos_emb(rotary_emb, q) + k = apply_rotary_pos_emb(rotary_emb, k) + + sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + i, j = sim.shape[-2:] + causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + if exists(context_mask): + context_mask = rearrange(context_mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max) + + attn = sim.softmax(dim = -1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + + return out, cached_kv + +def FeedForward(dim, mult = 4): + dim_inner = dim * mult + return nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Linear(dim_inner, dim) + ) + +# main class + +class Decoder(Module): + def __init__( + self, + *, + num_tokens, + dim, + depth, + heads = 8, + dim_head = 64, + ff_mult = 4, + ignore_index = -1 + ): + super().__init__() + self.dim = dim + self.token_emb = nn.Embedding(num_tokens, dim) + + self.layers = ModuleList([]) + + self.rotary_emb = RotaryEmbedding(dim = dim_head) + + for _ in range(depth): + self.layers.append(ModuleList([ + CausalAttention(dim = dim, dim_head = dim_head, heads = heads), + FeedForward(dim = dim, mult = ff_mult) + ])) + + self.to_logits = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, num_tokens, bias = False) + ) + + self.ignore_index = ignore_index + + def forward( + self, + x, + start_tokens = None, + return_loss = False, + return_cache = False, + seq_start_pos = None, + cache = None + ): + has_start_tokens = exists(start_tokens) + + if return_loss: + label_start_index = (1 if not has_start_tokens else 0) + x, labels = x[:, :-1], x[:, label_start_index:] + + x = self.token_emb(x) + + if exists(start_tokens): + x = torch.cat((start_tokens, x), dim = 1) + + # handle seq start pos offset + + self_attn_kv_mask = None + if exists(seq_start_pos): + batch, seq_len = x.shape[:2] + seq_range = torch.arange(seq_len, device = x.device, dtype = torch.long) + self_attn_kv_mask = seq_range >= seq_start_pos[..., None] + + # relative positional encoding + + rotary_emb = self.rotary_emb(x.shape[-2]) + + # setup cache + + new_cached_kvs = [] + + cache_kvs = cache_embeds = None + + if exists(cache): + cache_kvs, cache_embeds = cache + + if exists(cache_kvs): + iter_cache_kvs = iter(cache_kvs.unbind(dim = 1)) + else: + iter_cache_kvs = iter([]) + + # if cache passed in, just use the last token + + if exists(cache): + num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2] + x = x[:, -num_tokens_keep:] + + # main transformer body + + for ind, (attn, ff) in enumerate(self.layers): + layer = ind + 1 + + residual = x + attn_out, cached_kv = attn(x, rotary_emb = rotary_emb, cache = next(iter_cache_kvs, None)) + x = residual + attn_out + + new_cached_kvs.append(cached_kv) + + new_cached_kvs = torch.stack(new_cached_kvs, dim = 1) + + logits = self.to_logits(x) + + if not return_loss: + if not return_cache: + return logits + + return logits, Cache(new_cached_kvs, x) + + loss = F.cross_entropy( + rearrange(logits, 'b n c -> b c n'), + labels, + ignore_index = self.ignore_index + ) + + return loss, Cache(new_cached_kvs, x) + +class ModelWithProphetWrapper(Module): + def __init__( + self, + model: Decoder, + prophet: Decoder, + prophet_train_length = 8, # should be greater than spec decoding gamma, as main model cache embedding is one step behind + detach_model_embed_for_prophet = False + ): + super().__init__() + self.model = model + self.prophet = prophet + + model_prophet_same_dim = model.dim == prophet.dim + self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Sequential(RMSNorm(model.dim), nn.Linear(model.dim, prophet.dim, bias = False)) + + self.prophet_train_length = prophet_train_length + self.detach_model_embed_for_prophet = detach_model_embed_for_prophet + + def forward(self, x): + batch, seq_len, device = *x.shape, x.device + prophet_seq_len = self.prophet_train_length + assert seq_len >= prophet_seq_len + + total_loss = 0. + + main_loss, (cached_kvs, embeds) = self.model(x, return_loss = True) + + total_loss = total_loss + main_loss + + if self.detach_model_embed_for_prophet: + embeds = embeds.detach() + + prophet_start_tokens = self.to_prophet_start_token(embeds) + + batch_arange = torch.arange(batch, device = device, dtype = torch.long) + prophet_seq_arange = torch.arange(prophet_seq_len, device = device, dtype = torch.long) + + num_seq_train_prophet = seq_len - prophet_seq_len + offsets = torch.arange(num_seq_train_prophet, device = device, dtype = torch.long) + + prophet_input = x[ + batch_arange[..., None, None], + prophet_seq_arange + offsets[..., None] + ] + + prophet_input = rearrange(prophet_input, '... n -> (...) n') + + prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n d -> (b n) 1 d') + + prophet_loss, _ = self.prophet(prophet_input, start_tokens = prophet_start_tokens, return_loss = True) + + total_loss = total_loss + prophet_loss + + return total_loss, (main_loss, prophet_loss) diff --git a/train_prophet.py b/train_prophet.py new file mode 100644 index 0000000..9313a97 --- /dev/null +++ b/train_prophet.py @@ -0,0 +1,159 @@ +import gzip +import random +import tqdm +import numpy as np +import time +from functools import wraps, partial + +import torch +from torch.optim import Adam +from torch.nn import functional as F +from torch.cuda import synchronize, Event +from torch.utils.data import DataLoader, Dataset + +timer = partial(Event, enable_timing = True) + +from speculative_decoding.speculative_decoding_with_prophet import ( + Decoder, + ModelWithProphetWrapper, + base_decoding, + speculative_decoding_with_prophet_model +) + +# constants + +NUM_BATCHES = int(1e5) +BATCH_SIZE = 4 +GRAD_ACCUM_EVERY = 4 +LEARNING_RATE = 1e-4 +PRIME_LENGTH = 128 +GENERATE_EVERY = 10 +GENERATE_LENGTH = 512 +SEQ_LEN = 512 +GAMMA = 5 +EARLY_EXIT_LOSS_WEIGHT = 1. + +DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu' + +# helpers + +def cycle(loader): + while True: + for data in loader: + yield data + +def decode_token(token): + return str(chr(max(32, token))) + +def decode_tokens(tokens): + return "".join(list(map(decode_token, tokens))) + +def benchmark(fn): + @wraps(fn) + def inner(*args, **kwargs): + start_event = timer() + end_event = timer() + start_event.record() + + out = fn(*args, **kwargs) + + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + return out, elapsed_time_ms + return inner + +# instantiate transformer + +device = torch.device(DEVICE_STR) + +model = Decoder( + num_tokens = 256, + dim = 512, + depth = 10 +) + +prophet = Decoder( + num_tokens = 256, + dim = 128, + depth = 4 +) + +model_and_prophet = ModelWithProphetWrapper( + model, + prophet, + prophet_train_length = GAMMA + 2, + detach_model_embed_for_prophet = False # train end to end, shouldn't hurt (although benefits is dubious) given ProphetNet paper - of course, trying to get to the bottom of the benefits in spec decoding setting here +).to(device) + +# prepare enwik8 data + +with gzip.open("./data/enwik8.gz") as file: + data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() + np_train, np_valid = np.split(data, [int(90e6)]) + data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid) + +class TextSamplerDataset(Dataset): + def __init__(self, data, seq_len): + super().__init__() + self.data = data + self.seq_len = seq_len + + def __getitem__(self, index): + rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) + full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() + return full_seq.to(device) + + def __len__(self): + return self.data.size(0) // self.seq_len + +train_dataset = TextSamplerDataset(data_train, SEQ_LEN) +val_dataset = TextSamplerDataset(data_val, SEQ_LEN) +train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) + +# optimizer + +optim = Adam(model.parameters(), lr = LEARNING_RATE) + +# training + +for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"): + model.train() + + for _ in range(GRAD_ACCUM_EVERY): + data = next(train_loader) + + total_loss, (loss, prophet_loss) = model_and_prophet(data) + + (total_loss / GRAD_ACCUM_EVERY).backward() + + print(f"training loss: {loss.item():.3f}") + print(f"training prophet loss: {prophet_loss.item():.3f}") + + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + + optim.step() + optim.zero_grad() + + if i % GENERATE_EVERY == 0: + model.eval() + + inp = random.choice(val_dataset)[:PRIME_LENGTH] + prime = decode_tokens(inp) + print(f"%s \n\n %s", (prime, "*" * 100)) + + prompt = inp[None, ...] + + sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH) + + (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_prophet_model)(model, prompt, GENERATE_LENGTH, GAMMA) + + base_decode_output = decode_tokens(sampled[0]) + spec_decode_output = decode_tokens(spec_decode_sampled[0]) + + print("\nbase decoding:\n\n", base_decode_output, "\n") + print("\nspec decoding:\n\n", spec_decode_output, "\n") + + print(f'base decoding in: {base_decode_elapsed:.3f}ms\n') + print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n') + print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')