Skip to content

Commit

Permalink
Provide default max model length (#1224)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Sep 28, 2023
1 parent 6f88f76 commit f936657
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
18 changes: 11 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

def get_max_model_len(self) -> int:
return self.max_model_len

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
Expand Down Expand Up @@ -378,10 +375,17 @@ def _get_and_verify_max_len(
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if derived_max_model_len == float("inf"):
raise ValueError(
"The model's config.json must contain one of the following keys "
"to determine the original maximum length of the model: "
f"{possible_keys}")
if max_model_len is not None:
# If max_model_len is specified, we use it.
return max_model_len

default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is "
f"{default_max_len}.")
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def create_engine_configs(
self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.get_max_model_len())
model_config.max_model_len)
return model_config, cache_config, parallel_config, scheduler_config


Expand Down
1 change: 1 addition & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
f"revision={model_config.revision}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len()
max_model_len = engine_model_config.max_model_len

# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,
Expand Down

0 comments on commit f936657

Please sign in to comment.