You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Rather than manually specifying attn_flash, why not allow the code to adaptively figure out if possible in the forward() method?
As far as i can tell you can use it when there is no fancy relative positional bias and not doing an ONNX export.
I feel like it would be a case of doing something like:
if not onnx_export and not t5_rel_bias and not ...:
return F.scaled_dot_product_attention(...)
else:
return some_vanilla_sdpa(...)
I don't know if there is way to introspect if we're doing an ONNX export. If not, we could simply add a parameter to the forward() method like onnx_exporting
The text was updated successfully, but these errors were encountered:
Or at least, it would be nice if you could specify attn_flash in forward() not in the constructor of the transformer. Then during training i can set to True, and during ONNX export (also during training), i can set to False. Alternatively, I can recreate a model and copy the weights during export.
So, the reason for this is: I would like to train with scaled_dot_product_attention() but when I'm doing my checkpointing, i export to ONNX, which doesn't support Flash. So I have to create a fresh model with Flash disable, port the weights, then do the ONNX export. If attn_flash were specified in the forward method, i wouldn't have to copy the model, i could just change the arguments.
Also, in the forward pass, you could check whether Flash is possible depending on:
sparse_topk == False
talking_heads == False
torch.__version__ >= 2.0.0
either attn_mask is not None or is_causal==True (not both)
If either of these is false, set flash = False.
Then basically you never have to worry about it and the model will always do the right thing without the user having to worry about it.
Rather than manually specifying
attn_flash
, why not allow the code to adaptively figure out if possible in theforward()
method?As far as i can tell you can use it when there is no fancy relative positional bias and not doing an ONNX export.
I feel like it would be a case of doing something like:
I don't know if there is way to introspect if we're doing an ONNX export. If not, we could simply add a parameter to the
forward()
method likeonnx_exporting
The text was updated successfully, but these errors were encountered: