Skip to content

Commit

Permalink
another tiny step
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 18, 2023
1 parent bf903f2 commit dea6376
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from rotary_embedding_torch import RotaryEmbedding
from beartype import beartype
from beartype.typing import Optional

from collections import namedtuple

Expand Down Expand Up @@ -307,7 +308,8 @@ def forward(
return_loss = False,
return_cache = False,
cache = None,
return_early_exit_only = False
return_early_exit_only = False,
start_from_early_exit_hiddens: Optional[Tensor] = None
):
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
Expand All @@ -333,7 +335,17 @@ def forward(

early_exit_hiddens = None

for ind, (attn, ff) in enumerate(self.layers):
# handle if previous cached embedding layer from early exit layer passed in

layers = self.layers
if start_from_early_exit_hiddens:
assert not return_early_exit_only and exists(self.early_exit_layer)
layers = layers[self.early_exit_layer - 1:]
x = start_from_early_exit_hiddens

# main transformer body

for ind, (attn, ff) in enumerate(layers):
layer = ind + 1

residual = x
Expand Down

0 comments on commit dea6376

Please sign in to comment.