@@ -52,6 +52,8 @@ def per_token_cast_to_fp8(x: torch.Tensor):
52
52
53
53
54
54
def cast_fp8_to_fp32 (x_fp8 : torch .Tensor , x_scales : torch .Tensor ):
55
+ # TODO(shifangx): remove print after debugging
56
+ print (f"in cast_fp8_to_fp32, x_fp8.shape: { x_fp8 .shape } , x_scales.shape: { x_scales .shape } " )
55
57
if x_fp8 .numel () == 0 :
56
58
return x_fp8 .to (torch .bfloat16 )
57
59
if x_scales .dtype == torch .int :
@@ -91,11 +93,21 @@ def int32_to_8floats_lookup(tensor: torch.Tensor, table: torch.Tensor) -> torch.
91
93
92
94
93
95
def cast_nvfp4_to_fp32 (x_nvfp4 : torch .Tensor , x_scales : torch .Tensor , x_sf_scale : float , use_ue8m0_for_nvfp4_sf : bool = False ):
96
+ # TODO(shifangx): remove print after debugging
97
+ print (f"in cast_nvfp4_to_fp32, x_nvfp4.shape: { x_nvfp4 .shape } , x_scales.shape: { x_scales .shape } " )
94
98
NVFP4_TABLE = torch .tensor ([0 , 0.5 , 1 , 1.5 , 2 , 3 , 4 , 6 , 0 , - 0.5 , - 1.0 , - 1.5 , - 2 , - 3 , - 4 , - 6 ], dtype = torch .float32 , device = 'cuda' )
95
99
if use_ue8m0_for_nvfp4_sf :
96
100
x_scales = x_scales .view (dtype = torch .int8 ).to (torch .int ) << 23
97
101
x_scales = x_scales .view (dtype = torch .float )
98
102
else :
103
+ # shape of x_scales: (32, 4, rm, 4, rk, l)
104
+ dim_0 , dim_1 , dim_2 , dim_3 , dim_4 , dim_5 = x_scales .shape
105
+ assert dim_0 == 32 and dim_1 == 4 and dim_3 == 4 , "x_scales must be in the shape of (32, 4, rm, 4, rk, l)"
106
+ rm = dim_2
107
+ rk = dim_4
108
+ l = dim_5
109
+ x_scales = x_scales .view (dtype = torch .float8_e4m3fn ).permute ({{5 , 2 , 0 , 1 , 4 , 3 }}); # shape of x_scales: (l, rm, 32, 4, rk, 4)
110
+ x_scales = x_scales .reshape ({l , rm * 32 * 4 , rk * 4 }); # shape of x_scales: (l, m, k)
99
111
x_scales = x_scales .view (dtype = torch .float8_e4m3fn ).to (torch .float32 )
100
112
x_sf_scale = x_sf_scale .view (* x_sf_scale .shape , 1 )
101
113
x_scales = x_scales * (1 / x_sf_scale )
0 commit comments