Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 18, 2023
1 parent 082371c commit 24e90c3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
3 changes: 1 addition & 2 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def speculative_decoding(
@torch.no_grad()
def speculative_decoding_with_same_model(
net: Module,
small_net: Module,
prompt: Tensor,
seq_len: int,
gamma: int = 5,
Expand Down Expand Up @@ -204,7 +203,7 @@ def speculative_decoding_with_same_model(
q_sampled_out = []

for _ in range(gamma):
small_logits, cache = small_net(out, cache = cache, return_cache = True)
small_logits, cache = net(out, cache = cache, return_cache = True, return_early_exit_only = True)
small_logits = small_logits[:, -1]

small_logits = top_k(small_logits, thres = filter_thres)
Expand Down
3 changes: 1 addition & 2 deletions train_early_exit.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ def __len__(self):

sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

small_model = partial(model, return_early_exit_only = True)
(spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_same_model)(model,small_model, prompt, GENERATE_LENGTH, GAMMA)
(spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_same_model)(model, prompt, GENERATE_LENGTH, GAMMA)

base_decode_output = decode_tokens(sampled[0])
spec_decode_output = decode_tokens(spec_decode_sampled[0])
Expand Down

0 comments on commit 24e90c3

Please sign in to comment.