diff --git a/setup.py b/setup.py index 51285b9..63f73f6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.3', + version = '0.0.4', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding.py b/speculative_decoding/speculative_decoding.py index f7d1c1d..4059a3c 100644 --- a/speculative_decoding/speculative_decoding.py +++ b/speculative_decoding/speculative_decoding.py @@ -257,7 +257,8 @@ def __init__( ff_mult = 4, weight_tie_layers = False, ignore_index = -1, - early_exit_layer = None + early_exit_layer = None, + detach_early_exit_hiddens = False ): super().__init__() self.token_emb = nn.Embedding(num_tokens, dim) @@ -282,6 +283,7 @@ def __init__( nn.Linear(dim, num_tokens, bias = False) ) + self.detach_early_exit_hiddens = detach_early_exit_hiddens self.early_exit_layer = early_exit_layer self.to_early_exit_logits = None @@ -336,6 +338,9 @@ def forward( if layer == self.early_exit_layer: early_exit_hiddens = x + if self.detach_early_exit_hiddens: + early_exit_hiddens = early_exit_hiddens.detach() + if return_early_exit_only: break diff --git a/train_early_exit.py b/train_early_exit.py index 4eb3a14..b373e71 100644 --- a/train_early_exit.py +++ b/train_early_exit.py @@ -31,6 +31,7 @@ GENERATE_LENGTH = 512 SEQ_LEN = 512 GAMMA = 5 +EARLY_EXIT_LOSS_WEIGHT = 1. DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -113,7 +114,7 @@ def __len__(self): loss, small_loss = model(data, return_loss = True) - ((loss + small_loss) / GRAD_ACCUM_EVERY).backward() + ((loss + small_loss * EARLY_EXIT_LOSS_WEIGHT) / GRAD_ACCUM_EVERY).backward() print(f"training loss: {loss.item():.3f}") print(f"training small loss: {small_loss.item():.3f}")