Skip to content

Commit

Permalink
[Bench] Json mode bench (#2552)
Browse files Browse the repository at this point in the history
* [Bench] Json mode bench

This PR refactors mlc bench to enable json mode in dataset.

* upd

* fix lint
  • Loading branch information
cyx-6 committed Jun 12, 2024
1 parent dcece51 commit 07c92b0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 26 deletions.
60 changes: 44 additions & 16 deletions python/mlc_llm/bench/prompts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""MLC LLM bench prompts generator"""

import json
import random
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand All @@ -18,6 +20,7 @@ class PromptsGenerator: # pylint: disable=too-few-public-methods
def __init__(
self,
prompts_path: Optional[str] = None,
json_prompts_path: Optional[str] = None,
tokenizer: Optional[Any] = None,
seed: Optional[int] = 11111,
) -> None:
Expand All @@ -32,6 +35,11 @@ def __init__(
or a .jsonl file where each line is a JSON object formatted as
{"prompt": "prompt text", "prompt_tokens": 10}.
json_prompts_path : Optional[str]
The path to the file containing the source json prompts. This file a
.jsonl file where each line is a JSON object formatted as
{"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.
tokenizer : Optional[Any]
The tokenizer object to use for tokenizing the prompts.
Expand Down Expand Up @@ -66,6 +74,22 @@ def __init__(
prompt_line = file.readline()
prompt_tokens = self._count_tokens(prompt_line)
self.prompts.append({"prompt": prompt_line, "prompt_tokens": prompt_tokens})
if json_prompts_path:
self.json_prompts = defaultdict(list)
with open(json_prompts_path, "r", encoding="utf-8") as file:
for line in file:
json_line = json.loads(line)
assert (
"messages" in json_line
), "The messages field is required in the JSONL file."
assert (
"response_format" in json_line
), "The response_format field is required in the JSONL file."
self.json_prompts[json.dumps(json_line["response_format"]["schema"])].append(
json_line["messages"]
)
else:
self.json_prompts = None

def _count_tokens(self, text: str) -> int:
"""Get the number of tokens.
Expand All @@ -82,40 +106,44 @@ def _count_tokens(self, text: str) -> int:
"""
return len(self.tokenizer.encode(text))

def generate_prompt(self, tokens_mean: int, tokens_stddev: Optional[int] = 0) -> str:
def generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""
Generates a prompt that closely matches the desired token count.
Generates a prompt based on the params, e.g. prompt_tokens, response_format.
Parameters
----------
token_mean : int
params : Dict[str, Any]
The desired mean number of tokens in the prompt.
token_stddev : Optional[int]
The desired standard deviation of tokens in the prompt.
Returns
-------
out: str
A prompt string with the specified number of tokens.
override_params: Dict[str, Any]
The params to override the original request, e.g. messages, response_format.
"""
if "response_format" in params:
response_format = params["response_format"]
if response_format.get("type") == "json_object":
if response_format.get("schema") in self.json_prompts:
assert len(self.json_prompts[response_format["schema"]]) > 0
return {"messages": random.choice(self.json_prompts[response_format["schema"]])}
schema, prompts = random.choice(list(self.json_prompts.items()))
response_format["schema"] = schema
return {"messages": random.choice(prompts), "response_format": response_format}
tokens_mean = params.get("prompt_tokens", 128)
assert tokens_mean > 0, "The mean number of tokens must be greater than 0."
out_prompt_tokens = (
int(random.gauss(tokens_mean, tokens_stddev)) if tokens_stddev else tokens_mean
)
if out_prompt_tokens <= 0:
out_prompt_tokens = tokens_mean
remaining_prompt_tokens = out_prompt_tokens
remaining_prompt_tokens = tokens_mean
result_prompt = ""
override_params = None
while remaining_prompt_tokens > 0:
prompt_dict = random.choice(self.prompts)
cur_prompt_tokens = prompt_dict["prompt_tokens"]
cur_prompt = prompt_dict["prompt"]
if override_params is None:
override_params = prompt_dict["override_params"]
if remaining_prompt_tokens - cur_prompt_tokens < 0:
result_prompt += cur_prompt[:remaining_prompt_tokens]
remaining_prompt_tokens = 0
break
result_prompt += cur_prompt
remaining_prompt_tokens -= cur_prompt_tokens
self._count_tokens(result_prompt)
return result_prompt
return {"messages": [{"role": "system", "content": result_prompt}]}
19 changes: 9 additions & 10 deletions python/mlc_llm/bench/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MLC LLM Bench Request"""

import json
import os
import time
Expand Down Expand Up @@ -45,6 +46,8 @@ class OpenAIRequestSender: # pylint: disable=too-many-instance-attributes
The client to use for sending requests.
include_server_metrics : Optional[bool]
Specifies if server metrics should be included, default is False.
prompt_generator : Optional[PromptsGenerator]
The prompt generator for missing messages fields.
Attributes
----------
Expand All @@ -60,6 +63,7 @@ def __init__( # pylint: disable=too-many-arguments
timeout: Optional[float] = None,
client: Optional[Any] = None,
include_server_metrics: Optional[bool] = False,
prompt_generator: Optional[PromptsGenerator] = None,
) -> None:
import aiohttp # pylint: disable=import-outside-toplevel,import-error
from transformers import ( # pylint: disable=import-outside-toplevel,import-error
Expand All @@ -69,7 +73,7 @@ def __init__( # pylint: disable=too-many-arguments
self.stream = stream
self.timeout = timeout
self.tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
self.prompt_generator = PromptsGenerator()
self.prompt_generator = PromptsGenerator() if prompt_generator is None else prompt_generator
self.request_records: List[RequestRecords] = []
self.client = client if client else aiohttp.ClientSession()
self.include_server_metrics = include_server_metrics
Expand All @@ -88,15 +92,10 @@ async def __call__( # pylint: disable=too-many-locals, too-many-branches, too-m
self, params: Dict[str, Any] = None
) -> None:
if "messages" not in params:
prompt_tokens = 128
if "prompt_tokens" in params:
prompt_tokens = params["prompt_tokens"]
else:
logger.warning("A random prompt with %d tokens will be generated.", prompt_tokens)
prompt = self.prompt_generator.generate_prompt(prompt_tokens)
params["messages"] = [{"role": "system", "content": prompt}]
else:
prompt = params["messages"][-1]["content"]
override_params = self.prompt_generator.generate_prompt(params)
assert "messages" in override_params, "override params must contain messages field"
params.update(override_params)
prompt = params["messages"][-1]["content"]
chat_params = self._get_chat_completion_params(params)
if "stream" not in chat_params:
chat_params["stream"] = self.stream
Expand Down

0 comments on commit 07c92b0

Please sign in to comment.