Skip to content

Commit

Permalink
for testing first idea, before escalating to prophet net idea, with a…
Browse files Browse the repository at this point in the history
…n early exit layer to logits as the small model using the same model
  • Loading branch information
lucidrains committed Sep 18, 2023
1 parent 5129e4c commit d249bcc
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 5 deletions.
48 changes: 43 additions & 5 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def base_decoding(

return out[..., prompt_seq_len:]

# speculative decoding functions

def safe_div(num, den, eps = 1e-10):
return num / max(den, eps)

Expand Down Expand Up @@ -254,7 +256,8 @@ def __init__(
dim_head = 64,
ff_mult = 4,
weight_tie_layers = False,
ignore_index = -1
ignore_index = -1,
early_exit_layer = None
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
Expand All @@ -279,14 +282,24 @@ def __init__(
nn.Linear(dim, num_tokens, bias = False)
)

self.early_exit_layer = early_exit_layer
self.to_early_exit_logits = None

if exists(early_exit_layer):
self.to_early_exit_logits = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, num_tokens, bias = False)
)

self.ignore_index = ignore_index

def forward(
self,
x,
return_loss = False,
return_cache = False,
cache = None
cache = None,
return_early_exit_only = False
):
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
Expand All @@ -307,7 +320,11 @@ def forward(
cache = default(cache, [])
iter_cache = iter(cache)

for attn, ff in self.layers:
early_exit_hiddens = None

for ind, (attn, ff) in enumerate(self.layers):
layer = ind + 1

residual = x
attn_out, cached_kv = attn(x, cache = next(iter_cache, None))
x = residual + attn_out
Expand All @@ -316,18 +333,39 @@ def forward(

x = ff(x) + x

if layer == self.early_exit_layer:
early_exit_hiddens = x

if return_early_exit_only:
break

new_cached_kvs = torch.stack(new_cached_kvs)

logits = self.to_logits(x)
to_logits = self.to_logits if not return_early_exit_only else self.to_early_exit_logits

logits = to_logits(x)

if not return_loss:
if not return_cache:
return logits

return logits, new_cached_kvs

return F.cross_entropy(
loss = F.cross_entropy(
rearrange(logits, 'b n c -> b c n'),
labels,
ignore_index = self.ignore_index
)

if not exists(self.to_early_exit_logits):
return loss

early_exit_logits = self.to_early_exit_logits(early_exit_hiddens)

early_exit_loss = F.cross_entropy(
rearrange(early_exit_logits, 'b n c -> b c n'),
labels,
ignore_index = self.ignore_index
)

return loss, early_exit_loss
157 changes: 157 additions & 0 deletions train_early_exit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import gzip
import random
import tqdm
import numpy as np
import time
from functools import wraps, partial

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.cuda import synchronize, Event
from torch.utils.data import DataLoader, Dataset

timer = partial(Event, enable_timing = True)

from speculative_decoding import (
Decoder,
base_decoding,
speculative_decoding
)

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5

DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

# helpers

def cycle(loader):
while True:
for data in loader:
yield data

def decode_token(token):
return str(chr(max(32, token)))

def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))

def benchmark(fn):
@wraps(fn)
def inner(*args, **kwargs):
start_event = timer()
end_event = timer()
start_event.record()

out = fn(*args, **kwargs)

end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
return out, elapsed_time_ms
return inner

# instantiate transformer

device = torch.device(DEVICE_STR)

model = Decoder(
num_tokens = 256,
dim = 512,
depth = 10,
early_exit_layer = 2 # use the same model as the small approximate model, worry about caching layer hiddens later
).to(device)

# prepare enwik8 data

with gzip.open("./data/enwik8.gz") as file:
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
np_train, np_valid = np.split(data, [int(90e6)])
data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len

def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.to(device)

def __len__(self):
return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# optimizer

optim = Adam(model.parameters(), lr = LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
model.train()

for _ in range(GRAD_ACCUM_EVERY):
data = next(train_loader)

loss, small_loss = model(data, return_loss = True)

((loss + small_loss) / GRAD_ACCUM_EVERY).backward()

print(f"training loss: {loss.item():.3f}")
print(f"training small loss: {small_loss.item():.3f}")

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

optim.step()
optim.zero_grad()

if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
valid_data = next(val_loader)

loss, small_loss = model(valid_data, return_loss = True)
print(f"validation loss: {loss.item():.3f}")
print(f"validation small loss: {small_loss.item():.3f}")

if i % GENERATE_EVERY == 0:
model.eval()

inp = random.choice(val_dataset)[:PRIME_LENGTH]
prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100))

prompt = inp[None, ...]

sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

small_model = partial(model, return_early_exit_only = True)
(spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding)(model,small_model, prompt, GENERATE_LENGTH, GAMMA)

base_decode_output = decode_tokens(sampled[0])
spec_decode_output = decode_tokens(spec_decode_sampled[0])

print("\nbase decoding:\n\n", base_decode_output, "\n")
print("\nspec decoding:\n\n", spec_decode_output, "\n")

print(f'base decoding in: {base_decode_elapsed:.3f}ms\n')
print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n')
print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')

0 comments on commit d249bcc

Please sign in to comment.