You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
The following snippet should disable torch.compile, note the use of disable_compile as a kwarg. From the documentation, it should replace the corresponding value in generation_config:
importosos.environ["TORCH_LOGS"]="+dynamo"fromtransformersimportAutoTokenizer, AutoModelForCausalLMimporttorchmodel_id="google/gemma-2-2b-it"device="cuda:0"tokenizer=AutoTokenizer.from_pretrained(model_id)
model=AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
prompt="<start_of_turn>user\nWrite a poem about the Kraken.<end_of_turn>\n<start_of_turn>model\n"inputs=tokenizer.encode(prompt, return_tensors="pt").to(device)
outputs=model.generate(inputs, max_length=50, disable_compile=True)
text=tokenizer.decode(outputs[0])
But we can still see dynamo tracing calls when we run it.
The reason appears to be this line, which uses self.generation_config instead of generation_config.
Note that this behaviour will be fixed by #36519 when it's merged. Alternatively, we could fix this issue first if that PR takes long to be approved.
Expected behavior
As discussed above.
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.49.0Who can help?
@gante, @SunMarc, @ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The following snippet should disable
torch.compile
, note the use ofdisable_compile
as a kwarg. From the documentation, it should replace the corresponding value ingeneration_config
:But we can still see dynamo tracing calls when we run it.
The reason appears to be this line, which uses
self.generation_config
instead ofgeneration_config
.Note that this behaviour will be fixed by #36519 when it's merged. Alternatively, we could fix this issue first if that PR takes long to be approved.
Expected behavior
As discussed above.
The text was updated successfully, but these errors were encountered: