Skip to content

Add flash params to csrc #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2025
Merged

Add flash params to csrc #5

merged 2 commits into from
May 14, 2025

Conversation

LoserCheems
Copy link
Collaborator

This pull request introduces a new header file, flash.h, which defines the core structures and functions for a CUDA-based multi-head attention mechanism. The changes include the addition of parameter structures for forward and backward passes, utility constants, and function templates for executing the attention mechanism.

Core functionality for multi-head attention:

  • Definition of parameter structures:

    • Added QKV_params to encapsulate query, key, and value tensor pointers, strides, and head-related dimensions.
    • Added ZeroHold_params to manage zero-hold states, causal masks, and associated strides for attention mechanisms.
    • Introduced Flash_fwd_params and Flash_bwd_params to extend the above structures for forward and backward pass parameters, including dropout, scaling factors, and random state handling.
  • Function templates for execution:

    • Added templates run_mha_fwd_, run_mha_fwd_splitkv_dispatch, and run_mha_bwd_ for executing forward and backward multi-head attention operations with CUDA streams.
  • Namespace organization:

    • Encapsulated all additions within FLASH_NAMESPACE for modularity and clarity.

@LoserCheems LoserCheems merged commit 8890fe7 into main May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant