Skip to content

[Dtype] Low-precision Blackwell Datatype Support #18027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Kathryn-cat
Copy link
Contributor

@Kathryn-cat Kathryn-cat commented May 31, 2025

This PR focuses on supporting FP4/FP8 data types introduced in Blackwell architectures (sm_100).

TVM nd array stores subbyte data types in compact format, thus two FP4 would be stored in 1 byte. The size calculator for array allocator is modified accordingly.


Subtype arithmetic

The type __nv_fp4_e2m1 from <cuda_fp4.h> is a tag type and does not support pointer arithmetic. Accordingly, the compiler does not support index operations on an array declared with __nv_fp4_e2m1 directly. If any index operations like arr[0] + arr[1] is desired, user should declare the array as vector type like __nv_fp4x2_e2m1.

For example, suppose user creates an array A of type __nv_fp4_e2m1 with values

[-1 2 0.5 -6 -6 -2 2 3 4 1 -3 4 -2 2...]

extern "C" __global__ void __launch_bounds__(32) add_kernel(__nv_fp4_e2m1* __restrict__ A, __nv_fp4_e2m1* __restrict__ C) {
  C[((int)threadIdx.x)] = (__nv_fp4_e2m1)(((half)A[((int)threadIdx.x)]) + ((half)B[((int)threadIdx.x)]));
}

Printing out values of A[0], A[1], ... will show

A[0]: 2.000000
A[1]: -6.000000
A[2]: -2.000000
A[3]: 3.000000

This is because __nv_fp4_e2m1 is only a tag type. When it advances pointer, it advance by 1-byte at a time, yielding the upper 4 bits in the packed memory buffer. As a result, we should avoid directly doing indexing on __nv_fp4_e2m1 for arithmetic operations.

If user passes in __nv_fp4_e2m1 nd array and perform indexing, we can convert it to __nv_fp4x2_e2m1 and recalculate the indices if possible, but this requires more careful handling in the lowering process.

Thus, the original corresponding test case in test_target_codegen_cuda_fp4.py is removed.

@Kathryn-cat
Copy link
Contributor Author

Kathryn-cat commented May 31, 2025

Thanks @DerrickYLJ for coauthoring it! Would you like to add yourself as coauthor in the PR?

@tqchen
Copy link
Member

tqchen commented Jun 2, 2025

cc @Hzfengsy

@Hzfengsy
Copy link
Member

Hzfengsy commented Jun 2, 2025

Thanks for reminding me, I will take a close look tomorrow. also cc @LeiWang1999

Co-authored-by: DerrickYLJ <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants