Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi round dialogue branch #5772

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 102 additions & 48 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(
fd_interm_tensor=None,
device=None,
dtype=torch.float16,
enable_streamingllm: bool = False,
start_token_size: int = 4,
generated_token_size: int = 512,
):
self.num_heads = num_heads
self.head_dim = head_dim
Expand All @@ -45,12 +48,19 @@ def __init__(
self._use_spec_dec = False
self._num_tokens_to_verify = None

self.enable_streamingllm = enable_streamingllm
self.start_token_size = start_token_size
self.generated_token_size = generated_token_size

self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
if enable_streamingllm:
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
else:
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
self._block_tables_helper = torch.full_like(self._block_tables, -1)

Expand Down Expand Up @@ -109,6 +119,33 @@ def batch_token_ids(self) -> List[List[int]]:
out.append(seq.input_token_id + seq.output_token_id)
return out

def streamingllm_update_batch(self):
"""
Update sequence_lengths and block_tables when it is necessary to swap out a block.
"""

updated_block_ids = []

if self.current_batch_size > 0:
need_update = False
sequence_lengths_list = self._sequence_lengths.tolist()
block_tables_list = self._block_tables.tolist()
for batch_id in range(self.current_batch_size):
# We assume that the start token occupies the entire first block.
if self.cache_manager.check_block_full(self.block_tables_list[batch_id][-1]):
need_update = True
sequence_lengths_list[batch_id] = sequence_lengths_list[batch_id] - self.block_size
block_id = block_tables_list[batch_id].pop(1)
updated_block_ids.append(block_id)
block_tables_list[batch_id].append(-1)
if need_update:
self._sequence_lengths = torch.tensor(
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
)
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)

return updated_block_ids

def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
Expand Down Expand Up @@ -144,46 +181,13 @@ def _make_compact(self) -> None:
self._block_tables_helper.fill_(-1)
self._current_batch_size = valid_num

def add_seq(
self,
seq: Sequence,
alloc_block_table: torch.Tensor = None,
alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None,
) -> Union[torch.Tensor, None]:
"""Add a single sequence to the batch.
User could opt to provide either a block table or a function to allocate block tables.

Args:
seq (Sequence): The sequence to be added to the batch
alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence,
which is expected to reserve blocks and update status of kv-cache manager.

Returns:
block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager.
None if the sequence cannot be added.
"""
block_table = None
# TODO might consider sorting by length
if self._current_batch_size < self.max_batch_size:
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size
self._sequence_lengths[self._current_batch_size] = seq.sentence_len
# NOTE the added seq still require block table allocation by kvcache manager
block_table = self._block_tables[self._current_batch_size - 1]
if alloc_block_table is not None:
# copy block ids from provided block tables
self._block_tables[self._current_batch_size - 1] = alloc_block_table
elif alloc_block_table_fn:
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item())
self._current_batch_size += 1
return block_table

def add_seqs(
self,
seqs: List[Sequence],
alloc_block_tables: torch.Tensor = None,
alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None,
need_reused_block_table: bool = False,
streaningllm_prompt_len: int = 0,
) -> Union[torch.Tensor, None]:
"""Add a list of sequences to the batch.
User could opt to provide either block tables or a function to allocate block tables.
Expand All @@ -193,7 +197,8 @@ def add_seqs(
alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences,
which is expected to reserve blocks and update status of kv-cache manager.

need_reused_block_table (bool): Whether to reuse cached block tables.
streaningllm_prompt_len (int): The length of sentences used for streamingLLM.
Returns:
block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager.
None if the sequences cannot be added.
Expand All @@ -206,15 +211,24 @@ def add_seqs(
num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs))
block_tables = None
if num_seqs_to_add > 0:
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
for i, seq in enumerate(seqs[:num_seqs_to_add]):
if need_reused_block_table:
block_tables[i] = seq.block_table
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if need_reused_block_table:
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor(
[streaningllm_prompt_len[seq_id] for seq_id in range(num_seqs_to_add)], dtype=torch.int32
)
else:
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
if alloc_block_tables is not None:
# copy block ids from provided block tables
self._block_tables[
Expand All @@ -232,20 +246,32 @@ def add_seqs(
return block_tables

def pop_seq_update_batch(
self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
self,
request_id: int,
free_block_table_fn: Callable[[torch.Tensor], None] = None,
update_cached_sequences_fn: Callable[[Sequence], None] = None,
) -> Tuple[Sequence, Union[torch.Tensor, None]]:
"""Pop a single sequence by id from the batch, and update the batch bucket status.

Args:
request_id (int): The uid of the sequence
free_block_table_fn (Callable): The function to free the block table of a sequence,
if not provided, then we have to release the block table manually after calling this method

update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict.
This function is used to update cached_sequences_dict.
Returns:
A tuple of: seq (Sequence): The target sequence
and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks,
none if the sequence is not found or free_block_table_fn is provided.
"""

# When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM.
# At this point, completed sentences will be stored in cached_sequences_dict and will not
# be released within the current function.
assert (
free_block_table_fn is None or update_cached_sequences_fn is None
), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously."

seq: Sequence = self._sequences_dict.get(request_id)
block_table = None
if seq is not None:
Expand All @@ -268,15 +294,23 @@ def pop_seq_update_batch(
if free_block_table_fn:
free_block_table_fn(self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[seq_b_idx].detach().clone()
if update_cached_sequences_fn:
# When enabling streamingllm, save previous inference sequences.
update_cached_sequences_fn(seq, self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[seq_b_idx].detach().clone()
# replace block table of the target seq with that of the last seq in the batch
self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx]
self._block_tables[last_seq_b_idx].fill_(-1)
else:
if free_block_table_fn:
free_block_table_fn(self._block_tables[0])
else:
block_table = self._block_tables[0].detach().clone()
if update_cached_sequences_fn:
# When enabling streamingllm, save previous inference sequences.
update_cached_sequences_fn(seq, self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[0].detach().clone()
self._sequence_lengths[0].fill_(0)
self._block_tables[0].fill_(-1)
self._sequences_indexes.pop(request_id)
Expand Down Expand Up @@ -339,17 +373,29 @@ def pop_n_seqs(
return seqs, block_tables

def pop_finished(
self, free_block_table_fn: Callable[[torch.Tensor], None] = None
self,
free_block_table_fn: Callable[[torch.Tensor], None] = None,
update_cached_sequences_fn: Callable[[Sequence], None] = None,
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Pop finished sequences in the batch and a list of block tables of the finished sequences,
if free_block_table_fn is not provided.

Args:
free_block_table_fn (Callable): The function to free the block table of a single sequence
update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict.
This function is used to update cached_sequences_dict.
Returns:
A tuple of: finished_seqs (List[Sequence]): The finished sequences,
and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences.
"""

# When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM.
# At this point, completed sentences will be stored in cached_sequences_dict and will not
# be released within the current function.
assert (
free_block_table_fn is None and update_cached_sequences_fn is None
), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously."

finished_seqs = []
finished_block_tables = []
for seq in self._sequences_dict.values():
Expand All @@ -360,7 +406,7 @@ def pop_finished(
# For now, the performance difference is not significant, so we use the frist method to pop seqs.
# Precise evaluations to be done.
for seq in finished_seqs:
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn)
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn, update_cached_sequences_fn)
if block_table is not None:
finished_block_tables.append(block_table)

Expand Down Expand Up @@ -443,6 +489,14 @@ def merge(self, other: "BatchBucket") -> List[int]:

return unmerged_ids

def has_reused_seqs(self):
""" """
has_reused = False
for seq in self.seqs_li:
if seq.block_table:
has_reused = True
return has_reused

########## The following methods are expected to be used in modeling ###########

# For compatibility.
Expand Down
Loading