diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 1214e82a5aac..6332aa42ad64 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -119,7 +119,7 @@ def batch_token_ids(self) -> List[List[int]]: out.append(seq.input_token_id + seq.output_token_id) return out - def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int): + def streamingllm_update_batch(self): """ Update sequence_lengths and block_tables when it is necessary to swap out a block. """ @@ -132,9 +132,9 @@ def streamingllm_update_batch(self, start_token_size: int, generated_token_size: block_tables_list = self._block_tables.tolist() for batch_id in range(self.current_batch_size): # We assume that the start token occupies the entire first block. - if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1: + if self.cache_manager.check_block_full(self.block_tables_list[batch_id][-1]): need_update = True - sequence_lengths_list[batch_id] = start_token_size + generated_token_size + sequence_lengths_list[batch_id] = sequence_lengths_list[batch_id] - self.block_size block_id = block_tables_list[batch_id].pop(1) updated_block_ids.append(block_id) block_tables_list[batch_id].append(-1) @@ -181,46 +181,13 @@ def _make_compact(self) -> None: self._block_tables_helper.fill_(-1) self._current_batch_size = valid_num - def add_seq( - self, - seq: Sequence, - alloc_block_table: torch.Tensor = None, - alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None, - ) -> Union[torch.Tensor, None]: - """Add a single sequence to the batch. - User could opt to provide either a block table or a function to allocate block tables. - - Args: - seq (Sequence): The sequence to be added to the batch - alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence - alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence, - which is expected to reserve blocks and update status of kv-cache manager. - - Returns: - block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager. - None if the sequence cannot be added. - """ - block_table = None - # TODO might consider sorting by length - if self._current_batch_size < self.max_batch_size: - self._sequences_dict[seq.request_id] = seq - self._sequences_indexes[seq.request_id] = self._current_batch_size - self._sequence_lengths[self._current_batch_size] = seq.sentence_len - # NOTE the added seq still require block table allocation by kvcache manager - block_table = self._block_tables[self._current_batch_size - 1] - if alloc_block_table is not None: - # copy block ids from provided block tables - self._block_tables[self._current_batch_size - 1] = alloc_block_table - elif alloc_block_table_fn: - alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item()) - self._current_batch_size += 1 - return block_table - def add_seqs( self, seqs: List[Sequence], alloc_block_tables: torch.Tensor = None, alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None, + need_reused_block_table: bool = False, + streaningllm_prompt_len: int = 0, ) -> Union[torch.Tensor, None]: """Add a list of sequences to the batch. User could opt to provide either block tables or a function to allocate block tables. @@ -230,7 +197,8 @@ def add_seqs( alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences, which is expected to reserve blocks and update status of kv-cache manager. - + need_reused_block_table (bool): Whether to reuse cached block tables. + streaningllm_prompt_len (int): The length of sentences used for streamingLLM. Returns: block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager. None if the sequences cannot be added. @@ -243,15 +211,24 @@ def add_seqs( num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs)) block_tables = None if num_seqs_to_add > 0: + # NOTE block tables to be updated by kvcache manager + block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] for i, seq in enumerate(seqs[:num_seqs_to_add]): + if need_reused_block_table: + block_tables[i] = seq.block_table self._sequences_dict[seq.request_id] = seq self._sequences_indexes[seq.request_id] = self._current_batch_size + i # TODO external (rename): modify Sequence.sentence_len to seq_len - self._sequence_lengths[ - self._current_batch_size : self._current_batch_size + num_seqs_to_add - ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) - # NOTE block tables to be updated by kvcache manager - block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] + if need_reused_block_table: + self._sequence_lengths[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = torch.tensor( + [streaningllm_prompt_len[seq_id] for seq_id in range(num_seqs_to_add)], dtype=torch.int32 + ) + else: + self._sequence_lengths[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) if alloc_block_tables is not None: # copy block ids from provided block tables self._block_tables[ @@ -269,7 +246,10 @@ def add_seqs( return block_tables def pop_seq_update_batch( - self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None + self, + request_id: int, + free_block_table_fn: Callable[[torch.Tensor], None] = None, + update_cached_sequences_fn: Callable[[Sequence], None] = None, ) -> Tuple[Sequence, Union[torch.Tensor, None]]: """Pop a single sequence by id from the batch, and update the batch bucket status. @@ -277,12 +257,21 @@ def pop_seq_update_batch( request_id (int): The uid of the sequence free_block_table_fn (Callable): The function to free the block table of a sequence, if not provided, then we have to release the block table manually after calling this method - + update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict. + This function is used to update cached_sequences_dict. Returns: A tuple of: seq (Sequence): The target sequence and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks, none if the sequence is not found or free_block_table_fn is provided. """ + + # When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM. + # At this point, completed sentences will be stored in cached_sequences_dict and will not + # be released within the current function. + assert ( + free_block_table_fn is None or update_cached_sequences_fn is None + ), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously." + seq: Sequence = self._sequences_dict.get(request_id) block_table = None if seq is not None: @@ -305,7 +294,11 @@ def pop_seq_update_batch( if free_block_table_fn: free_block_table_fn(self._block_tables[seq_b_idx]) else: - block_table = self._block_tables[seq_b_idx].detach().clone() + if update_cached_sequences_fn: + # When enabling streamingllm, save previous inference sequences. + update_cached_sequences_fn(seq, self._block_tables[seq_b_idx]) + else: + block_table = self._block_tables[seq_b_idx].detach().clone() # replace block table of the target seq with that of the last seq in the batch self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx] self._block_tables[last_seq_b_idx].fill_(-1) @@ -313,7 +306,11 @@ def pop_seq_update_batch( if free_block_table_fn: free_block_table_fn(self._block_tables[0]) else: - block_table = self._block_tables[0].detach().clone() + if update_cached_sequences_fn: + # When enabling streamingllm, save previous inference sequences. + update_cached_sequences_fn(seq, self._block_tables[seq_b_idx]) + else: + block_table = self._block_tables[0].detach().clone() self._sequence_lengths[0].fill_(0) self._block_tables[0].fill_(-1) self._sequences_indexes.pop(request_id) @@ -376,17 +373,29 @@ def pop_n_seqs( return seqs, block_tables def pop_finished( - self, free_block_table_fn: Callable[[torch.Tensor], None] = None + self, + free_block_table_fn: Callable[[torch.Tensor], None] = None, + update_cached_sequences_fn: Callable[[Sequence], None] = None, ) -> Tuple[List[Sequence], List[torch.Tensor]]: """Pop finished sequences in the batch and a list of block tables of the finished sequences, if free_block_table_fn is not provided. Args: free_block_table_fn (Callable): The function to free the block table of a single sequence + update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict. + This function is used to update cached_sequences_dict. Returns: A tuple of: finished_seqs (List[Sequence]): The finished sequences, and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences. """ + + # When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM. + # At this point, completed sentences will be stored in cached_sequences_dict and will not + # be released within the current function. + assert ( + free_block_table_fn is None and update_cached_sequences_fn is None + ), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously." + finished_seqs = [] finished_block_tables = [] for seq in self._sequences_dict.values(): @@ -397,7 +406,7 @@ def pop_finished( # For now, the performance difference is not significant, so we use the frist method to pop seqs. # Precise evaluations to be done. for seq in finished_seqs: - _, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn) + _, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn, update_cached_sequences_fn) if block_table is not None: finished_block_tables.append(block_table) @@ -480,6 +489,14 @@ def merge(self, other: "BatchBucket") -> List[int]: return unmerged_ids + def has_reused_seqs(self): + """ """ + has_reused = False + for seq in self.seqs_li: + if seq.block_table: + has_reused = True + return has_reused + ########## The following methods are expected to be used in modeling ########### # For compatibility. diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 1515ca6d9eec..7990f65a9b36 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -54,6 +54,7 @@ class InputMetaData(RPC_PARAM): Args: block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None. sequence_lengths (torch.Tensor): A tensor containing sequence lengths. + current_prompt_lengths (torch.Tensor): A tensor containing current prompt lengths. fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None. batch_size (int, optional): The current batch size. Defaults to 64. is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding). @@ -70,6 +71,7 @@ class InputMetaData(RPC_PARAM): block_tables: torch.Tensor = None sequence_lengths: torch.Tensor = None + current_prompt_lengths: torch.Tensor = None fd_inter_tensor: FDIntermTensors = None batch_size: int = 64 # current_batch_size is_prompts: bool = False @@ -89,6 +91,7 @@ def to_rpc_param(self) -> Dict[str, any]: return { "block_tables": self.block_tables.tolist(), "sequence_lengths": self.sequence_lengths.tolist(), + "current_prompt_lengths": self.current_prompt_lengths.tolist(), "batch_size": self.batch_size, "is_prompts": self.is_prompts, "use_cuda_kernel": self.use_cuda_kernel, @@ -117,6 +120,9 @@ def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": sequence_lengths=torch.tensor( rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() ), + current_prompt_lengths=torch.tensor( + rpc_dict["current_prompt_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() + ), batch_size=rpc_dict["batch_size"], is_prompts=rpc_dict["is_prompts"], use_cuda_kernel=rpc_dict["use_cuda_kernel"], @@ -134,6 +140,7 @@ def __repr__(self) -> str: return ( f"InputMetaData(block_tables={self.block_tables}, " f"sequence_lengths={self.sequence_lengths}, " + f"current_prompt_lengths={self.current_prompt_lengths}, " f"fd_inter_tensor={self.fd_inter_tensor}, " f"batch_size={self.batch_size}, " f"is_prompts={self.is_prompts}, " @@ -280,14 +287,21 @@ def _verify_config(self) -> None: "{input_text}" in self.prompt_template ), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '" - assert ( - self.start_token_size <= self.block_size - ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." - assert ( - self.generated_token_size % self.block_size == 0 - ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." - # We assume that start_token_size occupies one block. - self.start_token_size = self.block_size + if self.enable_streamingllm: + assert ( + self.use_cuda_graph == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." + assert ( + self.max_input_len <= self.inference_config.generated_token_size + ), f"When enabling streamingLLM, max_input_len={self.max_input_len} must be less or equal than self.inference_config.generated_token_size={self.inference_config.generated_token_size}." + assert ( + self.start_token_size <= self.block_size + ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." + assert ( + self.generated_token_size % self.block_size == 0 + ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." + # We assume that start_token_size occupies one block. + self.start_token_size = self.block_size def to_generation_config(self, model_config) -> GenerationConfig: meta_config = { diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 8a65bfc3c481..fbd95829c102 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -89,6 +89,9 @@ def __init__( self.use_cuda_graph = self.inference_config.use_cuda_graph if self.use_cuda_graph: + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool = None # Set during graph capture. if verbose: @@ -197,6 +200,9 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P @torch.inference_mode() def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): assert self.use_cuda_graph, "please turn on the cuda graph" + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." if self.verbose: self.logger.info("Colossal AI CUDA Graph Capture begin") @@ -348,6 +354,11 @@ def enable_spec_dec( engine.clear_spec_dec() ``` """ + + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and Speculative Decoding simultaneously." + if drafter_model is None and self.drafter is None: raise ValueError("Drafter not initialized. Please provide a Drafter Model") if n_spec_tokens is not None: @@ -555,6 +566,9 @@ def generate( if self.use_spec_dec: assert self.drafter is not None, "Drafter Model is not initialized." + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and Speculative Decoding simultaneously." while self.request_handler.check_unfinished_seqs(): output_seqs_list += self.steps_spec_dec() else: @@ -596,6 +610,7 @@ def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str] def add_request( self, + user_ids: Union[List[int], int] = None, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, @@ -605,6 +620,7 @@ def add_request( Add requests. Args: + user_id (List[int], optional): The IDs of the input sequences' owner. request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. @@ -617,6 +633,9 @@ def add_request( block_size = self.inference_config.block_size + if user_ids is not None and not isinstance(user_ids, list): + user_ids = [user_ids] + if request_ids is not None and not isinstance(request_ids, list): request_ids = [request_ids] @@ -653,8 +672,11 @@ def add_request( ), f"The request_id type must be int, but got {type(request_ids[0])}" assert len(request_ids) == prompts_num request_id = request_ids[i] + user_id = user_id[i] else: request_id = next(self.counter) + # Default user_id to request_id if not provided. + user_id = request_id if prompts == None: prompt = None else: @@ -671,25 +693,47 @@ def add_request( self.inference_config.max_output_len >= max_new_tokens ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." - sequence = Sequence( - request_id, - prompt, - prompts_token_ids[i], - block_size, - None, - self.tokenizer.eos_token_id, - self.tokenizer.pad_token_id, - max_output_len=max_new_tokens, - ignore_eos=self.inference_config.ignore_eos, - ) + timestamp = time.time() + cached_sequences_dict = self.request_handler.cached_sequences_dict + + assert len( + prompts_token_ids[i] <= self.inference_config.max_input_len + ), f"The lengths of prompt must be less or equal than max_input_len={self.inference_config.max_input_len}, but got {len(prompts_token_ids[i])}." + + if self.inference_config.enable_streamingllm and user_id in cached_sequences_dict: + sequence = cached_sequences_dict[user_id] + sequence.reused( + request_id, + timestamp, + prompt, + prompts_token_ids[i], + ) + else: + sequence = Sequence( + user_id=user_id, + request_id=request_id, + prompt=prompt, + input_token_id=prompts_token_ids[i], + block_size=block_size, + sample_params=None, + block_table=None, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + max_output_len=max_new_tokens, + ignore_eos=self.inference_config.ignore_eos, + ) + self.request_handler.add_sequence(sequence) def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: input_ids = batch.get_1D_inputs() sequence_lengths = batch.get_sequence_lengths() + current_prompt_lengths = None if batch.is_prompts: - n_tokens = sequence_lengths.sum().item() + if self.inference_config.enable_streamingllm and batch.has_reused_seqs(): + current_prompt_lengths = batch.get_1D_inputs() + n_tokens = current_prompt_lengths.sum().item() else: n_tokens = batch.current_batch_size if batch.use_spec_dec: @@ -716,6 +760,7 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, input_meta_data = InputMetaData( block_tables=batch.get_block_table_tensor(), sequence_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, fd_inter_tensor=batch.fd_inter_tensor, batch_size=batch.current_batch_size, is_prompts=batch.is_prompts, @@ -750,19 +795,27 @@ def step(self) -> List[str]: input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) if input_meta_data.use_cuda_graph: + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." model_executable = self.graph_runners[input_meta_data.batch_size] else: model_executable = self.model # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + logits = model_executable( + input_token_ids, + output_tensor, + input_meta_data, + self.k_cache, + self.v_cache, + self.inference_config.enable_streamingllm and batch.has_reused_seqs(), + ) if self.inference_config.pad_input: logits = logits[:, -1, :] if self.inference_config.enable_streamingllm: - updated_block_ids = batch.streamingllm_update_batch( - self.inference_config.start_token_size, self.inference_config.generated_token_size - ) + updated_block_ids = batch.streamingllm_update_batch() self.request_handler.streamingllm_free_block_tables(updated_block_ids) next_tokens = search_tokens( diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 512eaea71c7b..14435822647c 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -127,6 +127,15 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo ) // inference_config.block_size head_dim = model_config.hidden_size // model_config.num_attention_heads + if self.inference_config.enable_streamingllm: + # user_id -> sequences + self.cached_sequences_dict = {} + # The current default cache size is the max_batch_size. + self.cache_size = inference_config.max_batch_size + else: + self.cached_sequences_dict = None + self.cache_size = 0 + fd_inter_tensor = FDIntermTensors() if fd_inter_tensor._tensors_initialized: @@ -230,10 +239,17 @@ def schedule(self): for seq in self.running_list.prefill[:num_seqs_to_add]: seq.mark_running() # allocate blocks for the prefill batch - self.prefill_bb.add_seqs( - self.running_list.prefill[:num_seqs_to_add], - alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, - ) + if self.inference_config.enable_streamingllm and self.prefill_bb.has_reused_seqs(): + self.streamingllm_prefill_alloc( + num_seqs_to_add, + self.cache_manager.allocate_context_from_block_tables, + self.cache_manager.allocate_context_from_non_empty_block_tables, + ) + else: + self.prefill_bb.add_seqs( + self.running_list.prefill[:num_seqs_to_add], + alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, + ) return self.prefill_bb @@ -242,7 +258,7 @@ def schedule(self): self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size ) if seqs_ids_to_recycle: - seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle) + seqs_to_recycle, _ = self.running_bb.pop_seqs(seqs_ids_to_recycle, self.cached_sequences_dict) for seq in seqs_to_recycle: seq.recycle() self.running_list.remove(seq) @@ -349,19 +365,78 @@ def update(self): # since we want to reuse the memory recorded on the block tables self.prefill_bb.clear(free_block_tables_fn=None) - finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table) + finished_seqs, _ = self.running_bb.pop_finished(cached_sequences_dict=self.update_cached_sequences) for seq in finished_seqs: self.running_list.remove(seq) self.done_list.extend(finished_seqs) return finished_seqs + def update_cached_sequences(self, seq: Sequence, block_table: torch.Tensor): + if seq.user_id not in self.cached_sequences_dict: + if self.cache_size == self.max_batch_size: + outdated_seq = min(self.cached_sequences_dict.values(), key=lambda x: x.timestamp) + self.cache_manager.free_block_table(outdated_seq.block_table) + del self.cached_sequences_dict[outdated_seq.user_id] + seq.block_table = block_table.clone() + self.cached_sequences_dict[seq.user_id] = seq + def streamingllm_free_block_tables(self, updated_block_ids: List[int]): """ Free the block that needs to be swapped out. """ self.cache_manager.streamingllm_free_block_tables(updated_block_ids) + def streamingllm_prefill_alloc( + self, + num_seqs_to_add: int = 0, + ): + current_seqs = self.running_list.prefill[:num_seqs_to_add] + + reused_seqs = [seq for seq in current_seqs if seq.block_table is None] + normal_seqs = [seq for seq in current_seqs if seq.block_table is not None] + + if normal_seqs: + self.prefill_bb.add_seqs( + normal_seqs, + alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, + ) + + if reused_seqs: + block_size = self.inference_config.block_size + start_block_num = self.inference_config.start_token_size // block_size + max_blocks_per_sequence = self.cache_manager.get_max_blocks_per_sequence() + streaningllm_prompt_lens = [] + for seq in reused_seqs: + block_table = seq.block_table + streaningllm_prompt_len = self.cache_manager.get_used_slots(block_table) + vaild_block_num = sum(1 for x in block_table if x > 0) + unused_block_num = max_blocks_per_sequence - vaild_block_num + assert ( + seq.input_len <= self.inference_config.generated_token_size + ), f"When enabling streamingLLM, the length of seq={seq} must be less or equal than self.inference_config.generated_token_size={self.inference_config.generated_token_size}, but got seq.input_len." + need_swap_blocks = (seq.input_len + block_size - 1) // block_size - unused_block_num + + if streaningllm_prompt_len + seq.input_len >= max_blocks_per_sequence: + streaningllm_prompt_len = max_blocks_per_sequence + else: + streaningllm_prompt_len = streaningllm_prompt_len + seq.input_len + + if need_swap_blocks > 0: + block_table[start_block_num : vaild_block_num - need_swap_blocks] = block_table[ + start_block_num + need_swap_blocks : vaild_block_num + ] + block_table[vaild_block_num - need_swap_blocks :] = [-1] * ( + max_blocks_per_sequence - vaild_block_num + need_swap_blocks + ) + + self.prefill_bb.add_seqs( + reused_seqs, + alloc_block_tables_fn=self.cache_manager.allocate_context_from_non_empty_block_tables, + need_reused_block_table=True, + streaningllm_prompt_lens=streaningllm_prompt_lens, + ) + class RPCRequestHandler(RequestHandler): """ diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 378eb2ff9151..97bba2296c9f 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -84,11 +84,12 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> N self.max_blocks_per_sequence = ( config.start_token_size + config.generated_token_size + self.block_size - 1 ) // self.block_size + 1 + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width * 2 else: self.max_blocks_per_sequence = ( self.max_input_length + self.max_output_length + self.block_size - 1 ) // self.block_size - self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Physical cache allocation if config.use_cuda_kernel: @@ -290,6 +291,82 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context block.add_ref() self._allocate_on_block(block, block.block_size) + def allocate_context_from_non_empty_block_tables( + self, block_tables: torch.Tensor, context_lengths: torch.Tensor + ) -> None: + """Allocate logical cache blocks for a batch of sequences during prefill stage from non-empty block_tables. + + Args: + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] + context_lengths (torch.Tensor): [bsz]] + """ + assert block_tables.dim() == 2 + assert block_tables.size(0) == context_lengths.size(0) + + blocks_required = context_lengths // self.block_size + num_blocks_required = torch.sum(blocks_required).item() + assert isinstance(num_blocks_required, int) + if num_blocks_required > self._available_blocks: + self.logger.error( + f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}." + ) + return + + block_start = torch.sum(block_tables > 0, dim=1) + + for i, block_id in enumerate(block_tables[block_start] - 1): + block: CacheBlock = self._cache_blocks[block_id] + self._allocate_on_block( + block, + block.block_size + if context_lengths[i] % block.block_size == 0 + else context_lengths[i].item() % block.block_size, + ) + + bsz = block_tables.size(0) + # Try contiguous allocation + torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:]) + torch.subtract( + self._block_states_cum[num_blocks_required:], + self._block_states_cum[:-num_blocks_required], + out=self._block_finder[num_blocks_required - 1 :], + ) + end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1) + if end_indexes.numel() > 0: + # contiguous cache exists + end_idx = end_indexes[0].item() + 1 # open interval + start_idx = end_idx - num_blocks_required # closed interval + alloc_block_ids = torch.arange(start_idx, end_idx) + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, block_start[i] : curr_required] = torch.arange( + start_idx, start_idx + curr_required, device=block_tables.device + ) + start_idx += curr_required + else: + # non-contiguous cache + available_block_ids = torch.nonzero(self._block_states > 0).view(-1) + alloc_block_ids = available_block_ids[:num_blocks_required] + alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device) + start_idx = 0 + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, block_start[i] : curr_required] = alloc_block_ids[start_idx, start_idx + curr_required] + start_idx += curr_required + + # Update cache blocks + self._block_states[alloc_block_ids] = 0 + self._available_blocks -= num_blocks_required + last_block_locs = torch.cumsum(blocks_required, dim=0) - 1 + last_block_locs = last_block_locs.to(device=alloc_block_ids.device) + + for block_id in alloc_block_ids: + if block_id in alloc_block_ids[last_block_locs]: + continue + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._allocate_on_block(block, block.block_size) + def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None: """Allocate the logical cache block for a single sequence during decoding stage, and updates the provided block table if a new cache block is needed. @@ -466,6 +543,12 @@ def streamingllm_free_block_tables(self, updated_block_ids: List[int]): self._available_blocks += 1 self._block_states[global_block_id] = 1 + def get_used_slots(self, block_table: torch.Tensor): + num_positive_elements = torch.sum(block_table > 0, dim=0).items() + last_block_id = block_table[num_positive_elements - 1] + block: CacheBlock = self._cache_blocks[last_block_id] + return (num_positive_elements - 1) * self.block_size + block.allocated_size + def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """Get the tensor corresponding to the cache block with the prompted id for a specific layer.""" return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx] @@ -518,6 +601,12 @@ def _init_device_caches( v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device)) return k_cache, v_cache + def check_block_full(self, global_block_id: int) -> bool: + if global_block_id < 0: + return False + block: CacheBlock = self._cache_blocks[global_block_id] + return block.available_space <= 0 + class RPCKVCacheManager(KVCacheManager): def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None: @@ -559,11 +648,12 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb self.max_blocks_per_sequence = ( config.start_token_size + config.generated_token_size + self.block_size - 1 ) // self.block_size + 1 + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width * 2 else: self.max_blocks_per_sequence = ( self.max_input_length + self.max_output_length + self.block_size - 1 ) // self.block_size - self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Logical cache blocks allocation self._available_blocks = self.num_blocks diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py index ea73f833242e..4aaed83d10e8 100644 --- a/colossalai/inference/logit_processors.py +++ b/colossalai/inference/logit_processors.py @@ -157,6 +157,7 @@ def apply_forced_eos_token_id( select_indexes = [] num_sequences = logits.shape[0] + # NOTE (yuehuayingxueluo): When streamingLLM is enabled, the following logic may cause errors. sequence_lengths = sequence_lengths[:num_sequences] max_lengths = max_lengths[:num_sequences] for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)): diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index f6f160eb7e96..4b952a6990e9 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -107,6 +107,9 @@ def llama_model_forward( use_cuda_kernel = False logger.warning("CUDA kernel is disabled for speculative-decoding.") + if inputmetadata.current_prompt_lengths: + current_prompt_lengths = inputmetadata.current_prompt_lengths + hidden_states = self.embed_tokens(input_tokens_ids) cu_seqlens = None @@ -133,9 +136,21 @@ def llama_model_forward( total_length = hidden_states.size(0) cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device) sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device) - inference_ops.get_cos_and_sin( - self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts - ) + if current_prompt_lengths: + inference_ops.prefix_get_cos_and_sin( + self._cos_cached, + self._sin_cached, + cos, + sin, + sequence_lengths, + current_prompt_lengths, + kv_seq_len, + inputmetadata.is_prompts, + ) + else: + inference_ops.get_cos_and_sin( + self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts + ) cos_sin = (cos, sin) else: cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) @@ -157,6 +172,7 @@ def llama_model_forward( is_verifier=inputmetadata.use_spec_dec, tokens_to_verify=tokens_to_verify, sequence_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, cos_sin=cos_sin, fd_inter_tensor=inputmetadata.fd_inter_tensor, kv_seq_len=kv_seq_len, @@ -186,6 +202,7 @@ def llama_decoder_layer_forward( k_cache: torch.Tensor, v_cache: torch.Tensor, sequence_lengths: torch.Tensor, + current_prompt_lengths: torch.Tensor, cos_sin: Tuple[torch.Tensor], fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, @@ -233,6 +250,7 @@ def llama_decoder_layer_forward( is_verifier=is_verifier, tokens_to_verify=tokens_to_verify, sequence_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, cos_sin=cos_sin, fd_inter_tensor=fd_inter_tensor, kv_seq_len=kv_seq_len, @@ -486,6 +504,7 @@ def forward( k_cache: torch.Tensor, v_cache: torch.Tensor, sequence_lengths: torch.Tensor, + current_prompt_lengths: torch.Tensor, cos_sin: Tuple[torch.Tensor], fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, @@ -534,26 +553,50 @@ def forward( block_size = k_cache.size(-2) if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) + if current_prompt_lengths: + if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: + # flash attn 2 currently only supports FP16/BF16. + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + inference_ops.context_kv_cache_memcpy( + key_states, + value_states, + k_cache, + v_cache, + sequence_lengths, + cu_seqlens, + block_tables, + kv_seq_len, + ) - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - ) - attn_output = attn_output.view(token_nums, -1) + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + ) + attn_output = attn_output.view(token_nums, -1) + else: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + use_new_kcache_layout=use_cuda_kernel, + ) else: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( @@ -563,6 +606,7 @@ def forward( k_cache=k_cache, v_cache=v_cache, context_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, block_tables=block_tables, block_size=block_size, output=output_tensor, diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 1a3094a27e2d..0666da4590c1 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import Any, List +import torch + from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -51,7 +53,9 @@ class Sequence: """Store information of input sequence. Args: + user_id (int): The ID of the input sequence's owner. request_id (int): The ID of input sequence. + timestamp: (float): It represents the last time seq was called. prompt (str): The prompt of input sequence. input_token_id (List[int]): The tokens ID of input sequence. block_size (int): The block size of input sequence. @@ -62,13 +66,17 @@ class Sequence: max_output_len (int): Maximum output length. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. output(str): The output of sequence + """ + user_id: int request_id: int + timestamp: float prompt: str input_token_id: List[int] block_size: int sample_params: Any # SampleParams needs to be imported later. + block_table: torch.Tensor eos_token_id: int pad_token_id: int max_output_len: int = 256 @@ -78,6 +86,7 @@ class Sequence: def __post_init__(self): self.output_token_id = [] + self.streaningllm_prompt_len = 0 self.status = RequestStatus.WAITING @property @@ -94,6 +103,13 @@ def input_len(self) -> int: """ return len(self.input_token_id) + @property + def streaningllm_prompt_len(self) -> int: + """ + Get the length of sentences used for streamingLLM. + """ + return self.streaningllm_prompt_len + @property def output_len(self) -> int: """ @@ -162,6 +178,14 @@ def recycle(self) -> None: is already done but it still in running list" self.status = RequestStatus.RECYCLED + def reused(self, request_id: int, timestamp: float, prompt: str, input_token_id: List[int]) -> None: + self.request_id = request_id + self.timestamp = timestamp + self.prompt = prompt + self.input_token_id = input_token_id + self.output_token_id = [] + self.status = RequestStatus.WAITING + def __repr__(self) -> str: return ( f"(request_id={self.request_id}, " diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index 9c69c4125d62..3f0258d06a17 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -363,6 +363,213 @@ def _fwd_context_paged_attention_kernel_v2( return +# Triton 2.1.0 +# TODO(yuanheng-zhao): This is a temporary dispatch to use the new layout for kcache +# merge `_fwd_context_paged_attention_kernel_v2` with `_fwd_context_paged_attention_kernel` later +# as the kcache layout has been supported in the whole triton flow. +@triton.jit +def _prefix_fwd_context_paged_attention_kernel_v1( + Q, + K, + V, + O, + KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim] + BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_kcb, # k cache stride(0) - num_blocks + stride_kch, # k cache stride(1) - num_kv_heads + stride_kcsplit_x, # k cache stride(2) - head_dim // x + stride_kcs, # k cache stride(3) - block_szie + stride_kcd, # k cache stride(4) - x + stride_vcb, # v cache stride(0) - num_blocks + stride_vch, # v cache stride(1) - num_kv_heads + stride_vcbs, # v cache stride(2) - block_size + stride_vcd, # v cache stride(3) - head_dim + stride_bts, + stride_btb, + context_lengths, + current_prompt_lengths, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, # k stride on the second last dimension + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + block_start_m = tl.program_id(2) # Br, max_input_len // Block_M + cur_kv_head_idx = cur_head_idx // KV_GROUPS + + # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same + tl.static_assert(BLOCK_M == BLOCK_N) + tl.static_assert(BLOCK_N == BLOCK_SIZE) + + # get the context sequence length from provided context lengths tensor + context_seq_len = tl.load(context_lengths + cur_seq_idx) + # get the current prompt length from provided context lengths tensor + current_prompt_len = tl.load(current_prompt_lengths + cur_seq_idx) + + if block_start_m * BLOCK_M >= context_seq_len: + return + + # NOTE when talking to fused QKV and a nopadding context attention, + # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum` + # could be considered as the start index of the context sequence. + # FIXME might want to explore better way to get the summation of prev seq lengths. + # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton. + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(current_prompt_lengths + i) + + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + + computed_blocks = (context_seq_len - current_prompt_len) // BLOCK_SIZE + computed_slots = (context_seq_len - current_prompt_len) % BLOCK_SIZE + + # block table for the context sequence + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq) + # Consider `block_start_m` as the logical block idx in the context block table, + # as we have BLOCK_M the same size as the block size. + cur_block_idx = block_start_m + cur_block_id = tl.load(block_table_ptr + cur_block_idx * stride_btb) + offsets_dmodel = tl.arange(0, HEAD_DIM) + block_range = tl.arange(0, BLOCK_SIZE) + + if block_start_m >= computed_blocks: + offset_kvcache = cur_block_id * stride_vcb + cur_kv_head_idx * stride_vch + + offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_n = tl.arange(0, BLOCK_N) + + if cur_head_idx % KV_GROUPS == 0: + # Copy k to corresponding cache block + if block_start_m == computed_blocks: + block_range = tl.arange(computed_slots, BLOCK_SIZE) + + X_range = tl.arange(0, KCACHE_X) + # unroll the loop aggressively + for split_x_group_id in tl.static_range(HEAD_DIM // KCACHE_X): + offsets_dmodel_x_partition = tl.arange(split_x_group_id * KCACHE_X, (split_x_group_id + 1) * KCACHE_X) + offsets_k = ( + K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt + ) + k = tl.load(offsets_k, mask=offsets_m[:, None] < current_prompt_len, other=0.0) + # HACK: KCache must be contiguous in order to apply the following offsets calculation + offsets_kcache = ( + KCache + + offset_kvcache + + split_x_group_id * BLOCK_SIZE * KCACHE_X + + block_range[:, None] * KCACHE_X + + X_range[None, :] + ) + tl.store(offsets_kcache, k, mask=block_range[:, None] < current_prompt_len - block_start_m * BLOCK_SIZE) + # Copy v to corresponding cache block + offsets_dmodel = tl.arange(0, HEAD_DIM) # offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + offsets_n + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_dmodel[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < current_prompt_len, other=0.0) + offsets_vcache = ( + VCache + offset_kvcache + block_range[None, :] * stride_vcbs + offsets_dmodel[:, None] * stride_vcd + ) + tl.store(offsets_vcache, v, mask=block_range[None, :] < current_prompt_len - block_start_m * BLOCK_SIZE) + + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(context_lengths + i) + + offset_kv_cache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch + + Q_block_ptr = tl.make_block_ptr( + base=Q + offset_q, + shape=(current_prompt_len, HEAD_DIM), + strides=(stride_qt, stride_qd), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + K_block_ptr = tl.make_block_ptr( + base=KCache + offset_kv_cache, + shape=(HEAD_DIM, current_prompt_len), + strides=(stride_kd, stride_kt), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=VCache + offset_kv_cache, + shape=(current_prompt_len, HEAD_DIM), + strides=(stride_vt, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + offset_q, + shape=(current_prompt_len, HEAD_DIM), + strides=(stride_ot, stride_od), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0)) + + for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N): + block_start_n = tl.multiple_of(block_start_n, BLOCK_N) + + k = tl.load(K_block_ptr, boundary_check=(0, 1)) + S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + S_ij += tl.dot(Q_i, k) + S_ij *= sm_scale + S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf")) + + m_ij = tl.max(S_ij, 1) # rowmax(Sij) + m_ij = tl.maximum(m_i, m_ij) # m_ij + S_ij -= m_ij[:, None] + p_ij_hat = tl.exp(S_ij) + scale = tl.exp(m_i - m_ij) + l_ij = scale * l_i + tl.sum(p_ij_hat, 1) + acc = acc * scale[:, None] + + v = tl.load(V_block_ptr, boundary_check=(1, 0)) + p_ij_hat = p_ij_hat.to(v.type.element_ty) + + acc += tl.dot(p_ij_hat, v) + l_i = l_ij + m_i = m_ij + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0)) + + return + + # Triton 2.1.0 @triton.jit def _alibi_fwd_context_paged_attention_kernel( @@ -553,11 +760,12 @@ def context_attention_unpadded( q: torch.Tensor, # [num_tokens, num_heads, head_dim] k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] - k_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] + k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] v_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] context_lengths: torch.Tensor, # [num_seqs] - block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], - block_size: int, + current_prompt_lengths: torch.Tensor = None, # [num_seqs] + block_tables: torch.Tensor = None, # [num_seqs, max_blocks_per_sequence], + block_size: int = 16, output: torch.Tensor = None, # [num_tokens, num_heads, head_dim] alibi_slopes: torch.Tensor = None, # [num_heads] max_seq_len: int = None, @@ -612,42 +820,86 @@ def context_attention_unpadded( ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready" x = k_cache_shape[4] # Intuition: 16 // dtype_size - _fwd_context_paged_attention_kernel_v2[grid]( - q, - k, - v, - output, - k_cache, - v_cache, - block_tables, - num_seqs, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - output.stride(0), - head_dim, - 1, - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride(3), - block_tables.stride(0), - block_tables.stride(1), - context_lengths, - sm_scale, - KV_GROUPS=num_kv_group, - BLOCK_SIZE=block_size, - HEAD_DIM=Lk, - KCACHE_X=x, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) + if current_prompt_lengths: + _prefix_fwd_context_paged_attention_kernel_v1( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride(4), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + current_prompt_lengths, + sm_scale, + KV_GROUPS=num_kv_group, + BLOCK_SIZE=block_size, + HEAD_DIM=Lk, + KCACHE_X=x, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + _fwd_context_paged_attention_kernel_v2[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + KV_GROUPS=num_kv_group, + BLOCK_SIZE=block_size, + HEAD_DIM=Lk, + KCACHE_X=x, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) return output if alibi_slopes is not None: diff --git a/examples/inference/llama/benchmark_llama.py b/examples/inference/llama/benchmark_llama.py index 2d24d87adfd1..8b6d471ade5f 100644 --- a/examples/inference/llama/benchmark_llama.py +++ b/examples/inference/llama/benchmark_llama.py @@ -142,6 +142,8 @@ def benchmark_inference(args): block_size=32, tp_size=args.tp_size, use_cuda_kernel=True, + enable_streamingllm=True, + generated_token_size=32, ) engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) elif args.mode == "vllm": diff --git a/examples/inference/llama/run_benchmark.sh b/examples/inference/llama/run_benchmark.sh index 1927159765ba..1f66c6b57f97 100755 --- a/examples/inference/llama/run_benchmark.sh +++ b/examples/inference/llama/run_benchmark.sh @@ -24,8 +24,8 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 # benchmark llama2-7b one single GPU -for input_len in 128 512 1024; do - for output_len in 128 256; do +for input_len in 128; do + for output_len in 256; do for bsz in 16 32 64; do python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${bsz}_${input_len}_${output_len}_${mode}_${GPU}.txt done