@@ -198,11 +198,11 @@ def scaler_index_map(b, i, layer_ref, *_):
198
198
ks_bp = (None , 1 , bk )
199
199
200
200
in_specs = [
201
- pl .BlockSpec (q_index_map , q_bp ),
202
- pl .BlockSpec (kv_index_map , kv_bp ),
203
- pl .BlockSpec (kv_index_map , kv_bp ),
204
- pl .BlockSpec (scaler_index_map , ks_bp ),
205
- pl .BlockSpec (scaler_index_map , ks_bp ),
201
+ pl .BlockSpec (index_map = q_index_map , block_shape = q_bp ),
202
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ),
203
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ),
204
+ pl .BlockSpec (index_map = scaler_index_map , block_shape = ks_bp ),
205
+ pl .BlockSpec (index_map = scaler_index_map , block_shape = ks_bp ),
206
206
]
207
207
inputs = (
208
208
start ,
@@ -229,9 +229,9 @@ def scaler_index_map(b, i, layer_ref, *_):
229
229
num_scalar_prefetch = 5 ,
230
230
in_specs = in_specs ,
231
231
out_specs = [
232
- pl .BlockSpec (q_index_map , (None , time , head_dim )),
233
- pl .BlockSpec (q_index_map , (None , time , head_dim )),
234
- pl .BlockSpec (q_index_map , (None , time , head_dim )),
232
+ pl .BlockSpec (index_map = q_index_map , block_shape = (None , time , head_dim )),
233
+ pl .BlockSpec (index_map = q_index_map , block_shape = (None , time , head_dim )),
234
+ pl .BlockSpec (index_map = q_index_map , block_shape = (None , time , head_dim )),
235
235
],
236
236
grid = (batch_size , seq_len // bk ),
237
237
),
@@ -397,11 +397,11 @@ def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_):
397
397
ks_bp = (None , 1 , bk )
398
398
399
399
in_specs = [
400
- pl .BlockSpec (lambda b , i , * _ : (b , 0 , 0 ), (None , time , head_dim )), # q
401
- pl .BlockSpec (kv_index_map , kv_bp ), # k
402
- pl .BlockSpec (kv_index_map , kv_bp ), # v
403
- pl .BlockSpec (kv_scale_index_map , ks_bp ), # k_scaler
404
- pl .BlockSpec (kv_scale_index_map , ks_bp ), # v_scaler
400
+ pl .BlockSpec (index_map = lambda b , i , * _ : (b , 0 , 0 ), block_shape = (None , time , head_dim )), # q
401
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ), # k
402
+ pl .BlockSpec (index_map = kv_index_map , block_shape = kv_bp ), # v
403
+ pl .BlockSpec (index_map = kv_scale_index_map , block_shape = ks_bp ), # k_scaler
404
+ pl .BlockSpec (index_map = kv_scale_index_map , block_shape = ks_bp ), # v_scaler
405
405
]
406
406
407
407
inputs = (
@@ -430,9 +430,9 @@ def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_):
430
430
num_scalar_prefetch = 6 ,
431
431
in_specs = in_specs ,
432
432
out_specs = [
433
- pl .BlockSpec (lambda b , * _ : (b , 0 , 0 ), (None , time , head_dim )),
434
- pl .BlockSpec (lambda b , * _ : (b , 0 , 0 ), (None , time , head_dim )),
435
- pl .BlockSpec (lambda b , * _ : (b , 0 , 0 ), (None , time , head_dim )),
433
+ pl .BlockSpec (index_map = lambda b , * _ : (b , 0 , 0 ), block_shape = (None , time , head_dim )),
434
+ pl .BlockSpec (index_map = lambda b , * _ : (b , 0 , 0 ), block_shape = (None , time , head_dim )),
435
+ pl .BlockSpec (index_map = lambda b , * _ : (b , 0 , 0 ), block_shape = (None , time , head_dim )),
436
436
],
437
437
grid = (batch_size , seq_len // bk ),
438
438
),
0 commit comments