@@ -198,11 +198,11 @@ def scaler_index_map(b, i, layer_ref, *_):
198
198
ks_bp = (None , 1 , bk )
199
199
200
200
in_specs = [
201
- pl .BlockSpec (q_index_map , q_bp ),
202
- pl .BlockSpec (kv_index_map , kv_bp ),
203
- pl .BlockSpec (kv_index_map , kv_bp ),
204
- pl .BlockSpec (scaler_index_map , ks_bp ),
205
- pl .BlockSpec (scaler_index_map , ks_bp ),
201
+ pl .BlockSpec (index_map = q_index_map , block_shape = q_bp ),
202
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ),
203
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ),
204
+ pl .BlockSpec (index_map = scaler_index_map , block_shape = ks_bp ),
205
+ pl .BlockSpec (index_map = scaler_index_map , block_shape = ks_bp ),
206
206
]
207
207
inputs = (
208
208
start ,
@@ -229,9 +229,15 @@ def scaler_index_map(b, i, layer_ref, *_):
229
229
num_scalar_prefetch = 5 ,
230
230
in_specs = in_specs ,
231
231
out_specs = [
232
- pl .BlockSpec (q_index_map , (None , time , head_dim )),
233
- pl .BlockSpec (q_index_map , (None , time , head_dim )),
234
- pl .BlockSpec (q_index_map , (None , time , head_dim )),
232
+ pl .BlockSpec (
233
+ index_map = q_index_map , block_shape = (None , time , head_dim )
234
+ ),
235
+ pl .BlockSpec (
236
+ index_map = q_index_map , block_shape = (None , time , head_dim )
237
+ ),
238
+ pl .BlockSpec (
239
+ index_map = q_index_map , block_shape = (None , time , head_dim )
240
+ ),
235
241
],
236
242
grid = (batch_size , seq_len // bk ),
237
243
),
@@ -397,11 +403,14 @@ def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_):
397
403
ks_bp = (None , 1 , bk )
398
404
399
405
in_specs = [
400
- pl .BlockSpec (lambda b , i , * _ : (b , 0 , 0 ), (None , time , head_dim )), # q
401
- pl .BlockSpec (kv_index_map , kv_bp ), # k
402
- pl .BlockSpec (kv_index_map , kv_bp ), # v
403
- pl .BlockSpec (kv_scale_index_map , ks_bp ), # k_scaler
404
- pl .BlockSpec (kv_scale_index_map , ks_bp ), # v_scaler
406
+ pl .BlockSpec (
407
+ index_map = lambda b , i , * _ : (b , 0 , 0 ),
408
+ block_shape = (None , time , head_dim ),
409
+ ), # q
410
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ), # k
411
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ), # v
412
+ pl .BlockSpec (index_map = kv_scale_index_map , block_shape = ks_bp ), # k_scaler
413
+ pl .BlockSpec (index_map = kv_scale_index_map , block_shape = ks_bp ), # v_scaler
405
414
]
406
415
407
416
inputs = (
@@ -430,9 +439,18 @@ def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_):
430
439
num_scalar_prefetch = 6 ,
431
440
in_specs = in_specs ,
432
441
out_specs = [
433
- pl .BlockSpec (lambda b , * _ : (b , 0 , 0 ), (None , time , head_dim )),
434
- pl .BlockSpec (lambda b , * _ : (b , 0 , 0 ), (None , time , head_dim )),
435
- pl .BlockSpec (lambda b , * _ : (b , 0 , 0 ), (None , time , head_dim )),
442
+ pl .BlockSpec (
443
+ index_map = lambda b , * _ : (b , 0 , 0 ),
444
+ block_shape = (None , time , head_dim ),
445
+ ),
446
+ pl .BlockSpec (
447
+ index_map = lambda b , * _ : (b , 0 , 0 ),
448
+ block_shape = (None , time , head_dim ),
449
+ ),
450
+ pl .BlockSpec (
451
+ index_map = lambda b , * _ : (b , 0 , 0 ),
452
+ block_shape = (None , time , head_dim ),
453
+ ),
436
454
],
437
455
grid = (batch_size , seq_len // bk ),
438
456
),
0 commit comments