-
Notifications
You must be signed in to change notification settings - Fork 70
check bcast interfering with tma #5556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit de72855 Description
|
| Relevant files | |||||||
|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Broadcast validation logic
|
Test failures
-
(Medium, 267)
NVFuser TMA broadcast assertion failures across TmaWarpSpecializedTest suitesTest Name GB200 H100 Source CombinedSchedulerTest.ThunderLayerNormBackward ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_0_M_128_N_1024_WarpSpecializedOnTIDyRegisterSharing_NoneStageSlicePosition_None_CpAsyncBulkTensorTile ❌ ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_0_M_500_N_1024_WarpSpecializedOnTIDyRegisterSharing_NoneStageSlicePosition_None_CpAsyncBulkTensorTile ❌ ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_1_M_128_N_2048_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg1_M_500_N_1024_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_1024_N_1024_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_128_N_1024_PipelinedMBarrierForWAR_CpAsyncBulkTensorTile ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_128_N_2048_Pipelined_CpAsyncBulkTensorTile ❌ ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_500_N_2048_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile ❌ Link Hopper/TmaCircularBufferingTest.Matmul/stage_4_prefetch_0_M_1024_N_1024_PipelinedMBarrierForWAR_CpAsyncBulkTensorTile ❌ Link ... with 199 more test failures omitted. Check internal logs. -
(Medium, 2)
nvFuser internal TMA broadcast assertion failure in tests.python.test_schedule_opsTest Name GB200 H100 Source tests.python.test_schedule_ops.TestScheduleOps.test_var_mean_tma_user_schedule ❌ ❌ -
(Medium, 1)
NVFP4 grouped_mm numerical mismatch in test_with_id_model_indexerTest Name GB200 Source tests.python.direct.test_with_id_model_indexer.test_layout_op_and_cutlass_nvfp4_grouped_mm[out_dtype=torch.bfloat16-tokens_per_expert_neg_one=[115, 144, 8]-config=[1024, 128, 256]] ❌
|
!test |
|
!test |
e00a9ed to
de72855
Compare
|
!test |
Check broadcast interference with TMA loads
Case 1: Broadcast before TMA load (in gmem tensor logical domain)
Case 2: Broadcast after TMA load (downstream of TMA-loaded tensor)
validateTMAConsumerBroadcasts()Root Cause:
TMA auto-fills out-of-bounds accesses with zeros. When broadcast dimensions
are merged with non-broadcast and loaded with TMA, TMA
treats broadcast as a physical tile dimension and loads extra rows/columns
as zeros, breaking broadcast semantics (which should replicate values, not
fill zeros).
Testing:
See test
BroadcastDownstreamOfTMALoadfor detailed example showing thenumerical error when this restriction is violated.