diff --git a/fam/llm/fast_inference_utils.py b/fam/llm/fast_inference_utils.py index cbeb708..39ea506 100644 --- a/fam/llm/fast_inference_utils.py +++ b/fam/llm/fast_inference_utils.py @@ -367,27 +367,27 @@ def build_model( dynamic=True, ) - encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device) - spk_emb = torch.randn((1, 256), device=device, dtype=precision) - - device_sync(device=device) # MKG - t0 = time.perf_counter() - y = generate( - model, - encoded, - spk_emb, - max_new_tokens=200, - callback=lambda x: x, - temperature=torch.tensor(1.0, device=device, dtype=precision), - top_k=None, - top_p=torch.tensor(0.95, device=device, dtype=precision), - guidance_scale=torch.tensor(3.0, device=device, dtype=precision), - end_of_audio_token=9999, # don't end early for compilation stage. - ) + encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device) + spk_emb = torch.randn((1, 256), device=device, dtype=precision) + + device_sync(device=device) # MKG + t0 = time.perf_counter() + y = generate( + model, + encoded, + spk_emb, + max_new_tokens=200, + callback=lambda x: x, + temperature=torch.tensor(1.0, device=device, dtype=precision), + top_k=None, + top_p=torch.tensor(0.95, device=device, dtype=precision), + guidance_scale=torch.tensor(3.0, device=device, dtype=precision), + end_of_audio_token=9999, # don't end early for compilation stage. + ) - device_sync(device=device) # MKG + device_sync(device=device) # MKG - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") return model, tokenizer, smodel, model_size