Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: don't specify attn_flash. Select when possible #173

Open
pfeatherstone opened this issue Aug 9, 2023 · 2 comments
Open

Comments

@pfeatherstone
Copy link
Contributor

Rather than manually specifying attn_flash, why not allow the code to adaptively figure out if possible in the forward() method?
As far as i can tell you can use it when there is no fancy relative positional bias and not doing an ONNX export.
I feel like it would be a case of doing something like:

if not onnx_export and not t5_rel_bias and not ...:
     return F.scaled_dot_product_attention(...)
else:
    return some_vanilla_sdpa(...)

I don't know if there is way to introspect if we're doing an ONNX export. If not, we could simply add a parameter to the forward() method like onnx_exporting

@pfeatherstone
Copy link
Contributor Author

pfeatherstone commented Aug 9, 2023

Or at least, it would be nice if you could specify attn_flash in forward() not in the constructor of the transformer. Then during training i can set to True, and during ONNX export (also during training), i can set to False. Alternatively, I can recreate a model and copy the weights during export.

@pfeatherstone
Copy link
Contributor Author

pfeatherstone commented Aug 23, 2023

So, the reason for this is: I would like to train with scaled_dot_product_attention() but when I'm doing my checkpointing, i export to ONNX, which doesn't support Flash. So I have to create a fresh model with Flash disable, port the weights, then do the ONNX export. If attn_flash were specified in the forward method, i wouldn't have to copy the model, i could just change the arguments.
Also, in the forward pass, you could check whether Flash is possible depending on:

  • sparse_topk == False
  • talking_heads == False
  • torch.__version__ >= 2.0.0
  • either attn_mask is not None or is_causal==True (not both)
  • isinstance(rel_pos, RelativePositionBias) == False
  • isinstance(rel_pos, DynamicPositionBias) == False
  • not residual_attn and not cross_residual_attn
  • somehow detect if doing onnx export

If either of these is false, set flash = False.
Then basically you never have to worry about it and the model will always do the right thing without the user having to worry about it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant