Skip to content

Improve turbomind's prefix cache #3332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
750aaa8
add log
lvhan028 Mar 17, 2025
8886124
Merge branch 'main' into improve-tm-prefix-cache
lvhan028 Mar 18, 2025
7b4304a
refactor tm prefix caching
lvhan028 Mar 24, 2025
8be44f8
refactor tm prefix cache
lvhan028 Mar 25, 2025
dfdde01
Merge branch 'dev' into improve-tm-prefix-cache
lvhan028 Mar 25, 2025
fda1e25
fix linting
lvhan028 Mar 25, 2025
a4ffe41
fix linting
lvhan028 Mar 25, 2025
acf4092
combine Get&Create
lvhan028 Mar 27, 2025
a2352d1
update
lvhan028 Mar 27, 2025
1e940df
clear blocks
lvhan028 Mar 27, 2025
533941d
INFO log to DEBUG log
lvhan028 Mar 28, 2025
91d1412
refactor chat.py
lvhan028 Mar 28, 2025
ce08974
unlock the unmatched blocks when id is reused
lvhan028 Mar 28, 2025
3891782
merge main
lvhan028 Mar 31, 2025
9c3ebc8
remove start_flag and end_flag from tm csrc
lvhan028 Mar 31, 2025
d41683a
update output_logits
lvhan028 Apr 1, 2025
70399b4
update
lvhan028 Apr 1, 2025
1b99728
update
lvhan028 Apr 2, 2025
c5a2962
fix api_client
lvhan028 Apr 2, 2025
499b709
remove interactive chat API
lvhan028 Apr 3, 2025
617d317
fix build error on windows platform
lvhan028 Apr 3, 2025
50e56e2
fix chat
lvhan028 Apr 3, 2025
38ea2ae
update generate.ps1
lvhan028 Apr 3, 2025
e1489a5
fix clang-format error
lvhan028 Apr 3, 2025
9d1df28
fix clang-format error
lvhan028 Apr 3, 2025
e2a0c7a
fix vlm chat error
lvhan028 Apr 4, 2025
604b101
merge main
lvhan028 Apr 4, 2025
5e34425
fix get_logits
lvhan028 Apr 4, 2025
1cbdf5a
remove killing from tm csrc
lvhan028 Apr 4, 2025
afd531d
fix clang-format
lvhan028 Apr 6, 2025
3dc9ffa
update
lvhan028 Apr 7, 2025
14eb22a
enable_prefix_caching defaults to True
lvhan028 Apr 7, 2025
7e13a18
merge pt chat.py and tm chat.py
lvhan028 Apr 8, 2025
22cf302
remove pt chat.py and tm chat.py
lvhan028 Apr 8, 2025
8531df8
update
lvhan028 Apr 9, 2025
3ddec13
Merge branch 'default-prefix-cache' into improve-tm-prefix-cache
lvhan028 Apr 9, 2025
f3ef0d4
fix
lvhan028 Apr 9, 2025
87dfbb9
update
lvhan028 Apr 9, 2025
6fd0f56
merge main and resove the conflicts
lvhan028 May 12, 2025
61f2f0a
update
lvhan028 May 12, 2025
89b593f
Merge branch 'main' into improve-tm-prefix-cache
lvhan028 Jul 10, 2025
8785b5a
rollback autotset
lvhan028 Jul 10, 2025
a10b2d6
rollback api_server
lvhan028 Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def parse_args():
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group, default=2048)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
rope_scaling_factor_act = ArgumentHelper.rope_scaling_factor(pt_group)
dtype_act = ArgumentHelper.dtype(pt_group)

Expand Down Expand Up @@ -395,7 +395,7 @@ def main():
session_len=session_len,
rope_scaling_factor=args.rope_scaling_factor,
tp=args.tp,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
dtype=args.dtype,
)
elif args.backend == 'pytorch':
Expand All @@ -405,7 +405,7 @@ def main():
session_len=session_len,
tp=args.tp,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
dtype=args.dtype,
)
gen_config = GenerationConfig(top_k=args.top_k,
Expand Down
6 changes: 3 additions & 3 deletions benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def parse_args():
tp_act = ArgumentHelper.tp(pt_group)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)

# turbomind engine args
tb_group = parser.add_argument_group('TurboMind engine argument')
Expand Down Expand Up @@ -290,7 +290,7 @@ def main():
quant_policy=args.quant_policy,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
communicator=args.communicator,
)
elif args.backend == 'pytorch':
Expand All @@ -301,7 +301,7 @@ def main():
tp=args.tp,
thread_safe=False,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
)

engine = Engine(args.model_path, engine_config, csv=args.csv)
Expand Down
18 changes: 8 additions & 10 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def __init__(self, model_path: str, engine_config: Union[PytorchEngineConfig, Tu
self.tm_model = tm_model
self.pbar = None

async def _inference(self, req_queue: Queue, session_id: int, temperature: float, top_p: float, top_k: int,
stream_output: bool, skip_tokenize: bool, skip_detokenize: bool, concurrency: int):
async def _inference(self, req_queue: Queue, temperature: float, top_p: float, top_k: int, stream_output: bool,
skip_tokenize: bool, skip_detokenize: bool):
model_inst = self.tm_model.create_instance()
sess: Session = None
for prompt, _, output_seqlen, cancel_after, sess in iter(req_queue.get_nowait, None):
Expand All @@ -166,7 +166,7 @@ async def _inference(self, req_queue: Queue, session_id: int, temperature: float
prev_len = 0
token_ids = input_ids.copy()

generator = model_inst.async_stream_infer(session_id,
generator = model_inst.async_stream_infer(sess.id,
input_ids=input_ids,
gen_config=GenerationConfig(max_new_tokens=output_seqlen,
temperature=temperature,
Expand All @@ -193,10 +193,9 @@ async def _inference(self, req_queue: Queue, session_id: int, temperature: float

# for pytorch engine to restart a session
if self.backend == 'pytorch':
await model_inst.async_end(session_id)
await model_inst.async_end(sess.id)

self.pbar.update(1)
session_id += concurrency

def process_request(self, requests, profiler: Profiler, concurrency, temperature, top_p, top_k, stream_output,
skip_tokenize, skip_detokenize, cancel_rate):
Expand All @@ -219,8 +218,7 @@ def process_request(self, requests, profiler: Profiler, concurrency, temperature
# start threads
tasks = []
for i in range(concurrency):
task = self._inference(req_queue, i, temperature, top_p, top_k, stream_output, skip_tokenize,
skip_detokenize, concurrency)
task = self._inference(req_queue, temperature, top_p, top_k, stream_output, skip_tokenize, skip_detokenize)
tasks.append(task)

async def _gather_tasks(tasks):
Expand Down Expand Up @@ -311,7 +309,7 @@ def parse_args():
tp_act = ArgumentHelper.tp(pt_group)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
quant_policy_act = ArgumentHelper.quant_policy(pt_group, default=0)
dtype_act = ArgumentHelper.dtype(pt_group)

Expand Down Expand Up @@ -348,7 +346,7 @@ def main():
quant_policy=args.quant_policy,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
dtype=args.dtype,
communicator=args.communicator,
)
Expand All @@ -359,7 +357,7 @@ def main():
max_batch_size=args.concurrency,
tp=args.tp,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
quant_policy=args.quant_policy,
dtype=args.dtype,
distributed_executor_backend=args.distributed_executor_backend,
Expand Down
9 changes: 5 additions & 4 deletions lmdeploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,18 @@ def pipeline(model_path: str,
model_path = get_model(model_path, download_dir, revision)

task, pipeline_class = get_task(model_path)
if task == 'vlm':
if backend_config and backend_config.enable_prefix_caching:
backend_config.enable_prefix_caching = False
logger.warning('VLM does not support prefix caching.')

if type(backend_config) is not PytorchEngineConfig:
# set auto backend mode
backend_config = autoget_backend_config(model_path, backend_config)
backend = 'pytorch' if type(backend_config) is PytorchEngineConfig else 'turbomind'
logger.info(f'Using {backend} engine')

if task == 'vlm':
if backend_config and backend_config.enable_prefix_caching:
backend_config.enable_prefix_caching = False
logger.warning('VLM does not support prefix caching.')

return pipeline_class(model_path,
backend=backend,
backend_config=backend_config,
Expand Down
95 changes: 95 additions & 0 deletions lmdeploy/cli/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
import fire

from lmdeploy import ChatTemplateConfig, GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline
from lmdeploy.archs import autoget_backend


def input_prompt():
"""Input a prompt in the consolo interface."""
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))


def build_pipe(model_path, backend, **kwargs):
# set enable_prefix_cache
disable_prefix_cache = kwargs.pop('disable_prefix_cache', False)
kwargs.update(enable_prefix_caching=not disable_prefix_cache)
# set engine config
engine_config = None
if backend == 'turbomind':
engine_config = TurbomindEngineConfig()
for key, value in kwargs.items():
if hasattr(TurbomindEngineConfig, key):
setattr(engine_config, key, value)
else:
engine_config = PytorchEngineConfig()
for key, value in kwargs.items():
if hasattr(PytorchEngineConfig, key):
setattr(engine_config, key, value)
if kwargs.get('adapters', None):
from .utils import get_lora_adapters
adapters = get_lora_adapters(kwargs['adapters'])
engine_config.adapters = adapters
# set chat template config
chat_template = kwargs.get('chat_template', None)
chat_template_config = None
if chat_template:
chat_template_config = ChatTemplateConfig(model_name=chat_template)

pipe = pipeline(model_path,
backend_config=engine_config,
chat_template_config=chat_template_config,
log_level='ERROR',
**kwargs)
return pipe


def build_gen_config(**kwargs):
gen_config = GenerationConfig(max_new_tokens=1024, top_k=40, top_p=0.8, temperature=0.8, repetition_penalty=1.0)
for key, value in kwargs.items():
if hasattr(GenerationConfig, key):
setattr(gen_config, key, value)
return gen_config


def main(model_path, backend, **kwargs):
if backend != 'pytorch':
# set auto backend mode
backend = autoget_backend(model_path)

pipe = build_pipe(model_path, backend, **kwargs)
gen_config = build_gen_config(**kwargs)

quit = False
while True:
with pipe.session(gen_config) as sess:
while True:
try:
prompt = input_prompt()
except KeyboardInterrupt:
quit = True
break
if prompt == 'end':
sess.close()
break
if prompt == 'exit':
quit = True
break
resps = sess(prompt)
try:
for resp in resps:
print(resp.text, end='', flush=True)
sess.messages.append(dict(role='assistant', content=resp.text))
except KeyboardInterrupt:
sess.stop()
finally:
print('\ncancelling the conversation')
if quit:
print('exiting...')
break


if __name__ == '__main__':
fire.Fire(main)
40 changes: 5 additions & 35 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

from ..version import __version__
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args


class CLI(object):
Expand Down Expand Up @@ -104,7 +104,7 @@ def add_parser_chat():
tp_act = ArgumentHelper.tp(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)

# turbomind args
Expand Down Expand Up @@ -218,39 +218,9 @@ def get_gpu_topo():
@staticmethod
def chat(args):
"""Chat with pytorch or turbomind engine."""
from lmdeploy.archs import autoget_backend

chat_template_config = get_chat_template(args.chat_template)

backend = args.backend
if backend != 'pytorch':
# set auto backend mode
backend = autoget_backend(args.model_path)

if backend == 'pytorch':
from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.chat import run_chat

adapters = get_lora_adapters(args.adapters)
engine_config = PytorchEngineConfig(dtype=args.dtype,
tp=args.tp,
session_len=args.session_len,
cache_max_entry_count=args.cache_max_entry_count,
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
eager_mode=args.eager_mode,
quant_policy=args.quant_policy)
run_chat(args.model_path, engine_config, chat_template_config=chat_template_config)
else:
from lmdeploy.turbomind.chat import main as run_chat
kwargs = convert_args(args)
kwargs.pop('chat_template')
kwargs.pop('backend')
kwargs.pop('device')
kwargs.pop('eager_mode')
kwargs['chat_template_config'] = chat_template_config
run_chat(**kwargs)
from .chat import main
kwargs = convert_args(args)
main(**kwargs)

@staticmethod
def add_parsers():
Expand Down
12 changes: 6 additions & 6 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def add_parser_gradio():
max_batch_size_act = ArgumentHelper.max_batch_size(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)
model_format_act = ArgumentHelper.model_format(pt_group)
# turbomind args
Expand Down Expand Up @@ -161,7 +161,7 @@ def add_parser_api_server():
max_batch_size_act = ArgumentHelper.max_batch_size(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)
model_format = ArgumentHelper.model_format(pt_group)
Expand Down Expand Up @@ -274,7 +274,7 @@ def gradio(args):
cache_max_entry_count=args.cache_max_entry_count,
block_size=args.cache_block_seq_len,
session_len=args.session_len,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
Expand All @@ -290,7 +290,7 @@ def gradio(args):
rope_scaling_factor=args.rope_scaling_factor,
cache_max_entry_count=args.cache_max_entry_count,
cache_block_seq_len=args.cache_block_seq_len,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
max_prefill_token_num=args.max_prefill_token_num,
communicator=args.communicator)
chat_template_config = get_chat_template(args.chat_template)
Expand Down Expand Up @@ -327,7 +327,7 @@ def api_server(args):
block_size=args.cache_block_seq_len,
session_len=args.session_len,
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
Expand All @@ -349,7 +349,7 @@ def api_server(args):
rope_scaling_factor=args.rope_scaling_factor,
cache_max_entry_count=args.cache_max_entry_count,
cache_block_seq_len=args.cache_block_seq_len,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
max_prefill_token_num=args.max_prefill_token_num,
communicator=args.communicator)
chat_template_config = get_chat_template(args.chat_template)
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,13 @@ def cache_block_seq_len(parser):
'be ignored')

@staticmethod
def enable_prefix_caching(parser):
def disable_prefix_caching(parser):
"""Add argument enable_prefix_caching to parser."""

return parser.add_argument('--enable-prefix-caching',
return parser.add_argument('--disable-prefix-caching',
action='store_true',
default=False,
help='Enable cache and match prefix')
help='Disable prefix caching')

@staticmethod
def num_tokens_per_iter(parser):
Expand Down
Loading
Loading