@@ -1389,29 +1389,44 @@ def init_kv_cache(
1389
1389
]
1390
1390
1391
1391
def cuda_graph_warmup (self , bs : int , max_s : int , max_bt : int ):
1392
- input_ids = torch .zeros (bs , dtype = torch .int64 , device = self .device )
1393
- position_ids = torch .zeros (bs , dtype = torch .int32 , device = self .device )
1394
- slots = torch .arange (bs , dtype = torch .int64 , device = self .device )
1392
+ max_bs = max (self .cuda_graphs .keys ()) if self .cuda_graphs else None
1395
1393
input_lengths = [max_s ] * bs
1396
1394
cache_lengths = [0 ] * bs
1397
- input_lengths_tensor = (
1398
- torch .ones (bs , dtype = torch .int32 , device = self .device ) * max_s
1399
- )
1400
- cache_lengths_tensor = torch .zeros (bs , dtype = torch .int32 , device = self .device )
1401
- block_tables = torch .arange (
1402
- max_bt , dtype = torch .int32 , device = self .device
1403
- ).repeat (bs )
1404
- block_tables = block_tables .reshape ((bs , max_bt ))
1395
+ if max_bs is None :
1396
+ input_ids = torch .zeros (bs , dtype = torch .int64 , device = self .device )
1397
+ position_ids = torch .zeros (bs , dtype = torch .int32 , device = self .device )
1398
+ slots = torch .arange (bs , dtype = torch .int64 , device = self .device )
1399
+ input_lengths_tensor = (
1400
+ torch .ones (bs , dtype = torch .int32 , device = self .device ) * max_s
1401
+ )
1402
+ cache_lengths_tensor = torch .zeros (
1403
+ bs , dtype = torch .int32 , device = self .device
1404
+ )
1405
+ block_tables = torch .arange (
1406
+ max_bt , dtype = torch .int32 , device = self .device
1407
+ ).repeat (bs )
1408
+ block_tables = block_tables .reshape ((bs , max_bt ))
1409
+ if ATTENTION == "flashinfer" :
1410
+ block_tables = block_tables_to_ragged (
1411
+ block_tables = block_tables ,
1412
+ input_lengths = input_lengths ,
1413
+ cache_lengths = cache_lengths ,
1414
+ input_lengths_tensor = input_lengths_tensor ,
1415
+ cache_lengths_tensor = cache_lengths_tensor ,
1416
+ max_current_length = max_s ,
1417
+ )
1418
+ else :
1419
+ input_ids = self .cuda_graphs [max_bs ]["input_ids" ][:bs ]
1420
+ position_ids = self .cuda_graphs [max_bs ]["position_ids" ][:bs ]
1421
+ if ATTENTION == "flashinfer" :
1422
+ block_tables = self .cuda_graphs [max_bs ]["block_tables" ][: bs * max_bt ]
1423
+ else :
1424
+ block_tables = self .cuda_graphs [max_bs ]["block_tables" ][:bs ]
1425
+ slots = self .cuda_graphs [max_bs ]["slots" ][:bs ]
1426
+ input_lengths_tensor = self .cuda_graphs [max_bs ]["input_lengths" ][:bs ]
1427
+ cache_lengths_tensor = self .cuda_graphs [max_bs ]["cache_lengths" ][:bs ]
1405
1428
1406
1429
if ATTENTION == "flashinfer" :
1407
- block_tables = block_tables_to_ragged (
1408
- block_tables = block_tables ,
1409
- input_lengths = input_lengths ,
1410
- cache_lengths = cache_lengths ,
1411
- input_lengths_tensor = input_lengths_tensor ,
1412
- cache_lengths_tensor = cache_lengths_tensor ,
1413
- max_current_length = max_s ,
1414
- )
1415
1430
from text_generation_server .layers .attention .flashinfer import (
1416
1431
create_decode_state_cuda_graphs ,
1417
1432
)
0 commit comments