Skip to content

Commit 550a5ef

Browse files
committed
Saving some VRAM.
- 8B on 4xL4 attention=flashdecoding . Before 4.28GB left, After 4.32GB left, so 400MB saved. - Effect not as visible on attention=flashinfer and n_shard=1. I suspect it's linked to the torch allocator.
1 parent d471805 commit 550a5ef

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,29 +1389,44 @@ def init_kv_cache(
13891389
]
13901390

13911391
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
1392-
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
1393-
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1394-
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
1392+
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
13951393
input_lengths = [max_s] * bs
13961394
cache_lengths = [0] * bs
1397-
input_lengths_tensor = (
1398-
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
1399-
)
1400-
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
1401-
block_tables = torch.arange(
1402-
max_bt, dtype=torch.int32, device=self.device
1403-
).repeat(bs)
1404-
block_tables = block_tables.reshape((bs, max_bt))
1395+
if max_bs is None:
1396+
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
1397+
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1398+
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
1399+
input_lengths_tensor = (
1400+
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
1401+
)
1402+
cache_lengths_tensor = torch.zeros(
1403+
bs, dtype=torch.int32, device=self.device
1404+
)
1405+
block_tables = torch.arange(
1406+
max_bt, dtype=torch.int32, device=self.device
1407+
).repeat(bs)
1408+
block_tables = block_tables.reshape((bs, max_bt))
1409+
if ATTENTION == "flashinfer":
1410+
block_tables = block_tables_to_ragged(
1411+
block_tables=block_tables,
1412+
input_lengths=input_lengths,
1413+
cache_lengths=cache_lengths,
1414+
input_lengths_tensor=input_lengths_tensor,
1415+
cache_lengths_tensor=cache_lengths_tensor,
1416+
max_current_length=max_s,
1417+
)
1418+
else:
1419+
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
1420+
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
1421+
if ATTENTION == "flashinfer":
1422+
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
1423+
else:
1424+
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
1425+
slots = self.cuda_graphs[max_bs]["slots"][:bs]
1426+
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
1427+
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
14051428

14061429
if ATTENTION == "flashinfer":
1407-
block_tables = block_tables_to_ragged(
1408-
block_tables=block_tables,
1409-
input_lengths=input_lengths,
1410-
cache_lengths=cache_lengths,
1411-
input_lengths_tensor=input_lengths_tensor,
1412-
cache_lengths_tensor=cache_lengths_tensor,
1413-
max_current_length=max_s,
1414-
)
14151430
from text_generation_server.layers.attention.flashinfer import (
14161431
create_decode_state_cuda_graphs,
14171432
)

0 commit comments

Comments
 (0)