Skip to content

Commit

Permalink
Only run initial generate when compile=True
Browse files Browse the repository at this point in the history
  • Loading branch information
sheepymeh committed Mar 27, 2024
1 parent 01e3bc0 commit e6b529e
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions fam/llm/fast_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,27 +367,27 @@ def build_model(
dynamic=True,
)

encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device)
spk_emb = torch.randn((1, 256), device=device, dtype=precision)

device_sync(device=device) # MKG
t0 = time.perf_counter()
y = generate(
model,
encoded,
spk_emb,
max_new_tokens=200,
callback=lambda x: x,
temperature=torch.tensor(1.0, device=device, dtype=precision),
top_k=None,
top_p=torch.tensor(0.95, device=device, dtype=precision),
guidance_scale=torch.tensor(3.0, device=device, dtype=precision),
end_of_audio_token=9999, # don't end early for compilation stage.
)
encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device)
spk_emb = torch.randn((1, 256), device=device, dtype=precision)

device_sync(device=device) # MKG
t0 = time.perf_counter()
y = generate(
model,
encoded,
spk_emb,
max_new_tokens=200,
callback=lambda x: x,
temperature=torch.tensor(1.0, device=device, dtype=precision),
top_k=None,
top_p=torch.tensor(0.95, device=device, dtype=precision),
guidance_scale=torch.tensor(3.0, device=device, dtype=precision),
end_of_audio_token=9999, # don't end early for compilation stage.
)

device_sync(device=device) # MKG
device_sync(device=device) # MKG

print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")

return model, tokenizer, smodel, model_size

Expand Down

0 comments on commit e6b529e

Please sign in to comment.