Skip to content

Commit

Permalink
verify streamingllm config
Browse files Browse the repository at this point in the history
  • Loading branch information
isky-cd committed May 23, 2024
1 parent c3f4edb commit d63e068
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,6 @@ class InferenceConfig(RPC_PARAM):

def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
assert (
self.start_token_size <= self.block_size
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
# We assume that start_token_size occupies one block.
self.start_token_size = self.block_size
self._verify_config()

def _verify_config(self) -> None:
Expand Down Expand Up @@ -285,6 +280,15 @@ def _verify_config(self) -> None:
"{input_text}" in self.prompt_template
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"

assert (
self.start_token_size <= self.block_size
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
assert (
self.generated_token_size % self.block_size == 0
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
# We assume that start_token_size occupies one block.
self.start_token_size = self.block_size

def to_generation_config(self, model_config) -> GenerationConfig:
meta_config = {
"max_length": self.max_input_len + self.max_output_len,
Expand Down

0 comments on commit d63e068

Please sign in to comment.