@@ -183,9 +183,17 @@ def to_nvfp4(
183
183
if is_swizzled_scales :
184
184
M , K = data_hp .shape [0 ], data_hp .shape [1 ]
185
185
scale_shape = (M , K // block_size )
186
- blockwise_scales = to_blocked (
187
- blockwise_scales .view (scale_shape )
188
- ).flatten ()
186
+ # print(1, blockwise_scales.shape)
187
+ blockwise_scales = blockwise_scales .view (scale_shape )
188
+ # print(2, blockwise_scales.shape, blockwise_scales)
189
+ blockwise_scales = to_blocked (blockwise_scales )
190
+ # print(3, blockwise_scales.shape, blockwise_scales)
191
+
192
+ # match shape of data_hp
193
+ scale_M = ceil_div (data_hp .shape [0 ], 128 ) * 128
194
+ scale_K = ceil_div (data_hp .shape [1 ] // 16 , 4 ) * 4
195
+ blockwise_scales = blockwise_scales .view (scale_M , scale_K )
196
+ # print(4, blockwise_scales.shape, blockwise_scales)
189
197
190
198
return NVFP4Tensor (
191
199
data_lp ,
@@ -220,6 +228,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
220
228
data_unpacked = unpack_uint4 (data .contiguous ().view (torch .uint8 ))
221
229
data_f32 = f4_unpacked_to_f32 (data_unpacked )
222
230
231
+ # next: debug scale shape here
223
232
data_f32 = data_f32 .view (M , K // self ._block_size , self ._block_size )
224
233
scale_e4m3_reshaped = self .get_hp_scales ().view (M , K // self ._block_size , 1 )
225
234
data_scaled = data_f32 * scale_e4m3_reshaped .to (torch .float32 )
@@ -237,15 +246,17 @@ def get_hp_scales(self) -> torch.Tensor:
237
246
torch.Tensor: Scales of the NVFP4Tensor
238
247
"""
239
248
is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
249
+ print ("is_transposed" , is_transposed )
240
250
if is_transposed :
241
251
M , K = self .shape [1 ], self .shape [0 ]
252
+ scale_e4m3 = self ._scale_e4m3 .t ()
242
253
else :
243
254
M , K = self .shape [0 ], self .shape [1 ]
255
+ scale_e4m3 = self ._scale_e4m3
244
256
245
257
if self ._is_swizzled_scales :
246
- scale_e4m3 = from_blocked (self ._scale_e4m3 , M , K // self ._block_size )
247
- else :
248
- scale_e4m3 = self ._scale_e4m3
258
+ # import pdb; pdb.set_trace()
259
+ scale_e4m3 = from_blocked (scale_e4m3 , M , K // self ._block_size )
249
260
250
261
return (
251
262
scale_e4m3 .to (self ._orig_dtype )
@@ -366,6 +377,7 @@ def nvfp4_slice(func, types, args, kwargs):
366
377
raise ValueError ("Only support aten.slice with step=1" )
367
378
368
379
assert x .qdata .is_contiguous (), "Only support contiguous data for now"
380
+ assert x ._scale_e4m3 .is_contiguous (), "Only support contiguous scale for now"
369
381
370
382
M , K = x .shape [0 ], x .shape [1 ]
371
383
@@ -407,7 +419,7 @@ def nvfp4_slice(func, types, args, kwargs):
407
419
else None
408
420
)
409
421
410
- sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start_idx , end_idx , 1 )
422
+ sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start , end , 1 )
411
423
sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
412
424
413
425
elif dim == 1 :
@@ -452,20 +464,24 @@ def nvfp4_slice(func, types, args, kwargs):
452
464
# Full width - no slicing needed
453
465
sliced_scale = x ._scale_e4m3
454
466
else :
455
- # Extract specific column blocks from each row block
456
- # Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
457
- elements_per_row_block = n_col_blocks * elements_per_block
458
-
459
- # Build list of slices to extract
460
- slices_to_extract = []
461
- for row_block in range (n_row_blocks ):
462
- row_start = row_block * elements_per_row_block
463
- col_start = row_start + start_col_block * elements_per_block
464
- col_end = row_start + end_col_block * elements_per_block
465
- slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
466
-
467
- # Concatenate all the slices
468
- sliced_scale = torch .cat (slices_to_extract , dim = 0 )
467
+ if False :
468
+ # Extract specific column blocks from each row block
469
+ # Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
470
+ elements_per_row_block = n_col_blocks * elements_per_block
471
+
472
+ # Build list of slices to extract
473
+ slices_to_extract = []
474
+ for row_block in range (n_row_blocks ):
475
+ row_start = row_block * elements_per_row_block
476
+ col_start = row_start + start_col_block * elements_per_block
477
+ col_end = row_start + end_col_block * elements_per_block
478
+ slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
479
+
480
+ # Concatenate all the slices
481
+ sliced_scale = torch .cat (slices_to_extract , dim = 0 )
482
+ sliced_scale = aten .slice .Tensor (
483
+ x ._scale_e4m3 , dim , start_scale_col , end_scale_col , step
484
+ ).contiguous ()
469
485
470
486
# Slice the data tensor
471
487
packed_start = None if start is None else start // 2
@@ -537,7 +553,7 @@ def nvfp4_t(func, types, args, kwargs):
537
553
old = args [0 ]
538
554
new = NVFP4Tensor (
539
555
old .qdata .t (),
540
- old ._scale_e4m3 ,
556
+ old ._scale_e4m3 . t () ,
541
557
old ._block_size ,
542
558
old ._orig_dtype ,
543
559
old ._per_tensor_scale ,
@@ -577,6 +593,8 @@ def _addmm_nvfp4_dispatch(
577
593
"""
578
594
assert a .qdata .is_contiguous ()
579
595
assert b .qdata .t ().is_contiguous ()
596
+ assert a ._scale_e4m3 .is_contiguous ()
597
+ assert b ._scale_e4m3 .t ().is_contiguous ()
580
598
assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
581
599
assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
582
600
@@ -615,7 +633,7 @@ def _addmm_nvfp4_dispatch(
615
633
a .qdata .view (torch .float4_e2m1fn_x2 ),
616
634
b .qdata .view (torch .float4_e2m1fn_x2 ),
617
635
a_scale_blocked .view (torch .float8_e4m3fn ),
618
- b_scale_blocked .view (torch .float8_e4m3fn ),
636
+ b_scale_blocked .t (). view (torch .float8_e4m3fn ),
619
637
bias = None if should_add_bias_separately else bias ,
620
638
out_dtype = a ._orig_dtype ,
621
639
# scale_result=scale_result, # Not supported yet
0 commit comments