Skip to content

Commit

Permalink
add prefix flash attn
Browse files Browse the repository at this point in the history
  • Loading branch information
isky-cd committed Jun 3, 2024
1 parent d63e068 commit beddf23
Show file tree
Hide file tree
Showing 11 changed files with 717 additions and 145 deletions.
117 changes: 67 additions & 50 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def batch_token_ids(self) -> List[List[int]]:
out.append(seq.input_token_id + seq.output_token_id)
return out

def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
def streamingllm_update_batch(self):
"""
Update sequence_lengths and block_tables when it is necessary to swap out a block.
"""
Expand All @@ -132,9 +132,9 @@ def streamingllm_update_batch(self, start_token_size: int, generated_token_size:
block_tables_list = self._block_tables.tolist()
for batch_id in range(self.current_batch_size):
# We assume that the start token occupies the entire first block.
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
if self.cache_manager.check_block_full(self.block_tables_list[batch_id][-1]):
need_update = True
sequence_lengths_list[batch_id] = start_token_size + generated_token_size
sequence_lengths_list[batch_id] = sequence_lengths_list[batch_id] - self.block_size
block_id = block_tables_list[batch_id].pop(1)
updated_block_ids.append(block_id)
block_tables_list[batch_id].append(-1)
Expand Down Expand Up @@ -181,46 +181,13 @@ def _make_compact(self) -> None:
self._block_tables_helper.fill_(-1)
self._current_batch_size = valid_num

def add_seq(
self,
seq: Sequence,
alloc_block_table: torch.Tensor = None,
alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None,
) -> Union[torch.Tensor, None]:
"""Add a single sequence to the batch.
User could opt to provide either a block table or a function to allocate block tables.
Args:
seq (Sequence): The sequence to be added to the batch
alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence,
which is expected to reserve blocks and update status of kv-cache manager.
Returns:
block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager.
None if the sequence cannot be added.
"""
block_table = None
# TODO might consider sorting by length
if self._current_batch_size < self.max_batch_size:
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size
self._sequence_lengths[self._current_batch_size] = seq.sentence_len
# NOTE the added seq still require block table allocation by kvcache manager
block_table = self._block_tables[self._current_batch_size - 1]
if alloc_block_table is not None:
# copy block ids from provided block tables
self._block_tables[self._current_batch_size - 1] = alloc_block_table
elif alloc_block_table_fn:
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item())
self._current_batch_size += 1
return block_table

def add_seqs(
self,
seqs: List[Sequence],
alloc_block_tables: torch.Tensor = None,
alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None,
need_reused_block_table: bool = False,
streaningllm_prompt_len: int = 0,
) -> Union[torch.Tensor, None]:
"""Add a list of sequences to the batch.
User could opt to provide either block tables or a function to allocate block tables.
Expand All @@ -230,7 +197,8 @@ def add_seqs(
alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences,
which is expected to reserve blocks and update status of kv-cache manager.
need_reused_block_table (bool): Whether to reuse cached block tables.
streaningllm_prompt_len (int): The length of sentences used for streamingLLM.
Returns:
block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager.
None if the sequences cannot be added.
Expand All @@ -243,15 +211,24 @@ def add_seqs(
num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs))
block_tables = None
if num_seqs_to_add > 0:
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
for i, seq in enumerate(seqs[:num_seqs_to_add]):
if need_reused_block_table:
block_tables[i] = seq.block_table
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if need_reused_block_table:
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor(
[streaningllm_prompt_len[seq_id] for seq_id in range(num_seqs_to_add)], dtype=torch.int32
)
else:
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
if alloc_block_tables is not None:
# copy block ids from provided block tables
self._block_tables[
Expand All @@ -269,20 +246,32 @@ def add_seqs(
return block_tables

def pop_seq_update_batch(
self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
self,
request_id: int,
free_block_table_fn: Callable[[torch.Tensor], None] = None,
update_cached_sequences_fn: Callable[[Sequence], None] = None,
) -> Tuple[Sequence, Union[torch.Tensor, None]]:
"""Pop a single sequence by id from the batch, and update the batch bucket status.
Args:
request_id (int): The uid of the sequence
free_block_table_fn (Callable): The function to free the block table of a sequence,
if not provided, then we have to release the block table manually after calling this method
update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict.
This function is used to update cached_sequences_dict.
Returns:
A tuple of: seq (Sequence): The target sequence
and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks,
none if the sequence is not found or free_block_table_fn is provided.
"""

# When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM.
# At this point, completed sentences will be stored in cached_sequences_dict and will not
# be released within the current function.
assert (
free_block_table_fn is None or update_cached_sequences_fn is None
), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously."

seq: Sequence = self._sequences_dict.get(request_id)
block_table = None
if seq is not None:
Expand All @@ -305,15 +294,23 @@ def pop_seq_update_batch(
if free_block_table_fn:
free_block_table_fn(self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[seq_b_idx].detach().clone()
if update_cached_sequences_fn:
# When enabling streamingllm, save previous inference sequences.
update_cached_sequences_fn(seq, self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[seq_b_idx].detach().clone()
# replace block table of the target seq with that of the last seq in the batch
self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx]
self._block_tables[last_seq_b_idx].fill_(-1)
else:
if free_block_table_fn:
free_block_table_fn(self._block_tables[0])
else:
block_table = self._block_tables[0].detach().clone()
if update_cached_sequences_fn:
# When enabling streamingllm, save previous inference sequences.
update_cached_sequences_fn(seq, self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[0].detach().clone()
self._sequence_lengths[0].fill_(0)
self._block_tables[0].fill_(-1)
self._sequences_indexes.pop(request_id)
Expand Down Expand Up @@ -376,17 +373,29 @@ def pop_n_seqs(
return seqs, block_tables

def pop_finished(
self, free_block_table_fn: Callable[[torch.Tensor], None] = None
self,
free_block_table_fn: Callable[[torch.Tensor], None] = None,
update_cached_sequences_fn: Callable[[Sequence], None] = None,
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Pop finished sequences in the batch and a list of block tables of the finished sequences,
if free_block_table_fn is not provided.
Args:
free_block_table_fn (Callable): The function to free the block table of a single sequence
update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict.
This function is used to update cached_sequences_dict.
Returns:
A tuple of: finished_seqs (List[Sequence]): The finished sequences,
and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences.
"""

# When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM.
# At this point, completed sentences will be stored in cached_sequences_dict and will not
# be released within the current function.
assert (
free_block_table_fn is None and update_cached_sequences_fn is None
), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously."

finished_seqs = []
finished_block_tables = []
for seq in self._sequences_dict.values():
Expand All @@ -397,7 +406,7 @@ def pop_finished(
# For now, the performance difference is not significant, so we use the frist method to pop seqs.
# Precise evaluations to be done.
for seq in finished_seqs:
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn)
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn, update_cached_sequences_fn)
if block_table is not None:
finished_block_tables.append(block_table)

Expand Down Expand Up @@ -480,6 +489,14 @@ def merge(self, other: "BatchBucket") -> List[int]:

return unmerged_ids

def has_reused_seqs(self):
""" """
has_reused = False
for seq in self.seqs_li:
if seq.block_table:
has_reused = True
return has_reused

########## The following methods are expected to be used in modeling ###########

# For compatibility.
Expand Down
30 changes: 22 additions & 8 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class InputMetaData(RPC_PARAM):
Args:
block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None.
sequence_lengths (torch.Tensor): A tensor containing sequence lengths.
current_prompt_lengths (torch.Tensor): A tensor containing current prompt lengths.
fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None.
batch_size (int, optional): The current batch size. Defaults to 64.
is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding).
Expand All @@ -70,6 +71,7 @@ class InputMetaData(RPC_PARAM):

block_tables: torch.Tensor = None
sequence_lengths: torch.Tensor = None
current_prompt_lengths: torch.Tensor = None
fd_inter_tensor: FDIntermTensors = None
batch_size: int = 64 # current_batch_size
is_prompts: bool = False
Expand All @@ -89,6 +91,7 @@ def to_rpc_param(self) -> Dict[str, any]:
return {
"block_tables": self.block_tables.tolist(),
"sequence_lengths": self.sequence_lengths.tolist(),
"current_prompt_lengths": self.current_prompt_lengths.tolist(),
"batch_size": self.batch_size,
"is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel,
Expand Down Expand Up @@ -117,6 +120,9 @@ def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
sequence_lengths=torch.tensor(
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
),
current_prompt_lengths=torch.tensor(
rpc_dict["current_prompt_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
),
batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"],
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
Expand All @@ -134,6 +140,7 @@ def __repr__(self) -> str:
return (
f"InputMetaData(block_tables={self.block_tables}, "
f"sequence_lengths={self.sequence_lengths}, "
f"current_prompt_lengths={self.current_prompt_lengths}, "
f"fd_inter_tensor={self.fd_inter_tensor}, "
f"batch_size={self.batch_size}, "
f"is_prompts={self.is_prompts}, "
Expand Down Expand Up @@ -280,14 +287,21 @@ def _verify_config(self) -> None:
"{input_text}" in self.prompt_template
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"

assert (
self.start_token_size <= self.block_size
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
assert (
self.generated_token_size % self.block_size == 0
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
# We assume that start_token_size occupies one block.
self.start_token_size = self.block_size
if self.enable_streamingllm:
assert (
self.use_cuda_graph == False
), "We currently do not support using streamingLLM and CUDA graph simultaneously."
assert (
self.max_input_len <= self.inference_config.generated_token_size
), f"When enabling streamingLLM, max_input_len={self.max_input_len} must be less or equal than self.inference_config.generated_token_size={self.inference_config.generated_token_size}."
assert (
self.start_token_size <= self.block_size
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
assert (
self.generated_token_size % self.block_size == 0
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
# We assume that start_token_size occupies one block.
self.start_token_size = self.block_size

def to_generation_config(self, model_config) -> GenerationConfig:
meta_config = {
Expand Down
Loading

0 comments on commit beddf23

Please sign in to comment.