Skip to content

Commit

Permalink
early exit worked, add ability to detach, in case training with auxil…
Browse files Browse the repository at this point in the history
…iary early exit loss hurts the main model
  • Loading branch information
lucidrains committed Sep 18, 2023
1 parent d249bcc commit 174680b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
7 changes: 6 additions & 1 deletion speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def __init__(
ff_mult = 4,
weight_tie_layers = False,
ignore_index = -1,
early_exit_layer = None
early_exit_layer = None,
detach_early_exit_hiddens = False
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
Expand All @@ -282,6 +283,7 @@ def __init__(
nn.Linear(dim, num_tokens, bias = False)
)

self.detach_early_exit_hiddens = detach_early_exit_hiddens
self.early_exit_layer = early_exit_layer
self.to_early_exit_logits = None

Expand Down Expand Up @@ -336,6 +338,9 @@ def forward(
if layer == self.early_exit_layer:
early_exit_hiddens = x

if self.detach_early_exit_hiddens:
early_exit_hiddens = early_exit_hiddens.detach()

if return_early_exit_only:
break

Expand Down
3 changes: 2 additions & 1 deletion train_early_exit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5
EARLY_EXIT_LOSS_WEIGHT = 1.

DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -113,7 +114,7 @@ def __len__(self):

loss, small_loss = model(data, return_loss = True)

((loss + small_loss) / GRAD_ACCUM_EVERY).backward()
((loss + small_loss * EARLY_EXIT_LOSS_WEIGHT) / GRAD_ACCUM_EVERY).backward()

print(f"training loss: {loss.item():.3f}")
print(f"training small loss: {small_loss.item():.3f}")
Expand Down

0 comments on commit 174680b

Please sign in to comment.