-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[COMMON/JAX] Support sliding window on THD format #1327
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
/te-ci jax L1 |
/te-ci |
Signed-off-by: Reese Wang <[email protected]>
@@ -46,6 +46,30 @@ class AttnMaskType(Enum): | |||
CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK | |||
PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK | |||
|
|||
def is_causal(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice pythonic helpers!
// sliding window | ||
((cudnn_runtime_version < 90200 && window_size_left == -1 && | ||
(window_size_right == -1 || window_size_right == 0)) || | ||
(cudnn_runtime_version >= 90200 && | ||
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || | ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && | ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || | ||
((!is_ragged && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || | ||
(is_ragged && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you factor the self-attn bottom_right combination into this condition as well?
(!is_ragged && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv))) || (is_ragged && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than that one comment, LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to the code LGTM. The CI builds failed, need to check and re-run.
Description
Support THD format with sliding window
Type of change
Changes
is_ragged
andBSHD||SBHD
with!is_ragged
innvte_get_fused_attn_backend
AttnMaskType
Checklist: