Skip to content

[WIP]: vl prefix caching #3389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions lmdeploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@ def pipeline(model_path: str,
if backend_config is not None else None
model_path = get_model(model_path, download_dir, revision)

task, pipeline_class = get_task(model_path)
if task == 'vlm':
if backend_config and backend_config.enable_prefix_caching:
backend_config.enable_prefix_caching = False
logger.warning('VLM does not support prefix caching.')
_, pipeline_class = get_task(model_path)

if type(backend_config) is not PytorchEngineConfig:
# set auto backend mode
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
"""
self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst)

def copy_to(self, src_to_dst: Dict[int, int], cache_type: str = 'gpu') -> None:
"""Copy cache.

Args:
src_to_dst (Dict[int, int]): Map between src and dst.
cache_type (str): cache type 'cpu', 'gpu'
"""
target_cache = self.full_gpu_cache if cache_type == 'gpu' else self.full_cpu_cache
self._swap(target_cache, target_cache, src_to_dst)

@classmethod
def get_cache_block_size(cls,
block_size: int,
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ def __make_dummy_inputs():
inputs=ModelInputs.make_dummy(1, is_decoding=not prefill),
swap_in_map=dict(),
swap_out_map=dict(),
copy_map=dict(),
loop_count=num_loops,
is_dummy=True,
sync_long_context=False,
Expand Down Expand Up @@ -823,6 +824,7 @@ def __make_dummy_inputs():
running = scheduler_output.running
swap_in_map = scheduler_output.swap_in_map
swap_out_map = scheduler_output.swap_out_map
copy_map = scheduler_output.copy_map

if self.should_execute_dummy_batch and len(running) == 0:
return __make_dummy_inputs()
Expand All @@ -844,6 +846,7 @@ def __make_dummy_inputs():
inputs=inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map,
loop_count=num_loops,
all_ids=all_ids,
guided_input_ids=guided_input_ids,
Expand Down
29 changes: 19 additions & 10 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def msg_with_rank(rank: int, msg: str):
return f'rank[{rank}] - {msg}'


def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict, copy_map: dict):
"""perform cache swapping."""
issued_cache_op = False
if len(swap_in_map) > 0:
Expand All @@ -37,7 +37,9 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: d
if len(swap_out_map) > 0:
cache_engine.swap_out(swap_out_map)
issued_cache_op = True

if len(copy_map) > 0:
cache_engine.copy_to(copy_map)
issued_cache_op = True
if issued_cache_op:
cache_engine.events.wait()

Expand Down Expand Up @@ -135,7 +137,7 @@ def all_context(self):
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
raise NotImplementedError('NotImplemented.')

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
"""model forward.

Args:
Expand Down Expand Up @@ -200,6 +202,7 @@ async def _async_model_forward(
inputs: ModelInputs,
swap_in_map: Dict,
swap_out_map: Dict,
copy_map: Dict,
return_logits: bool,
sync_long_context: bool,
):
Expand Down Expand Up @@ -241,12 +244,15 @@ def get_output(self):

async def __forward(inputs):
"""forward."""
nonlocal swap_done, swap_in_map, swap_out_map
nonlocal swap_done, swap_in_map, swap_out_map, copy_map
if swap_done:
return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict())
return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict(), copy_map=dict())
else:
swap_done = True
return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
return await self.async_forward(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map)

async def __long_context_single_forward(new_inputs, max_seqlen: int):
"""one large sequence."""
Expand Down Expand Up @@ -334,6 +340,7 @@ async def _async_step_background(
inputs: ModelInputs,
swap_in_map: Dict,
swap_out_map: Dict,
copy_map: Dict,
loop_count: int,
all_ids: torch.Tensor = None,
guided_input_ids: torch.Tensor = None,
Expand Down Expand Up @@ -422,6 +429,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map,
return_logits=return_logits,
sync_long_context=sync_long_context,
)
Expand Down Expand Up @@ -469,6 +477,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
if is_decoding and idx < loop_count - 1:
swap_in_map = dict()
swap_out_map = dict()
copy_map = dict()
inputs.model_metas = model_metas
__update_inputs(next_token_ids)

Expand Down Expand Up @@ -639,8 +648,8 @@ def build_cache_engine(self):

self.cache_engine = CacheEngine(self.cache_config, self.model_config, world_size=tp)

def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)
output = model_forward(
self.patched_model,
inputs,
Expand All @@ -649,15 +658,15 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
)
return output

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
"""model forward.

Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)
await asyncio.sleep(0)
return output

Expand Down
86 changes: 76 additions & 10 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,67 @@ def __init__(self, multimodals: MultiModalInputs):
if multimodals is None:
multimodals = dict()
self.multimodals = multimodals
self._init_mm_ranges()

def _init_mm_ranges(self):
"""init mm ranges and sort it."""
mm_ranges = []
for _, modal_datas in self.multimodals.items():
for modal_data in modal_datas:
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
mm_ranges.append(data)
mm_ranges.sort(key=lambda x: x[1])
self._mm_ranges = mm_ranges

@property
def mm_ranges(self):
"""mm_ranges."""
return self._mm_ranges

def get_datas(self, start=0, end=-1):
"""get multimodals from prompts position [start, end)."""
outs = dict()
test_range = range(start, end)
for modal_type, modal_datas in self.multimodals.items():
data = []
for modal_data in modal_datas:
if (modal_data.start not in test_range and modal_data.end not in test_range):
continue
data.append(modal_data)
if modal_data.start < end and modal_data.end > start:
data.append(modal_data)
if len(data) > 0:
outs[modal_type] = data
return outs

def get_step(self, step: int) -> int:
"""get step that before a whole image."""
real_step = step
for start, end, _ in self._mm_ranges:
if start <= real_step < end:
real_step = start
return real_step

def has_data(self, start: int, end: int) -> bool:
"""whether has multimodal data in [start, end)"""
return any([s < end and e > start for s, e, _ in self._mm_ranges])

def get_hash_values(self, start: int, end: int):
"""get multimodals hash values that from [start, end)"""
mm_hash_values = []
multimodal_ends = []

for mm_start, mm_end, hash_value in self._mm_ranges:
# the mm range intersect with the target range
if mm_start < end and mm_end > start:
mm_hash_values.append(hash_value)
# the mm end in the target range
if start < mm_end <= end:
cur_data = (tuple(mm_hash_values), mm_end)
multimodal_ends.append(cur_data)

if len(mm_hash_values) == 0:
mm_hash_values = None
else:
mm_hash_values = tuple(mm_hash_values)
return mm_hash_values, multimodal_ends

def add_inputs(self, input_mms: MultiModalInputs):
"""add new inputs."""
for modal_type, vals in input_mms.items():
Expand All @@ -394,9 +440,17 @@ def add_inputs(self, input_mms: MultiModalInputs):
else:
self.multimodals[modal_type] = vals

def empty(self):
# update mm_ranges
for modal_data in vals:
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
self._mm_ranges.append(data)

# sort mm_ranges
self._mm_ranges.sort(key=lambda x: x[1])

def empty(self) -> bool:
if len(self.multimodals) == 0:
return 0
return True

return all(len(vals) == 0 for vals in self.multimodals)

Expand Down Expand Up @@ -582,7 +636,7 @@ def update_token_ids(self,

# update multimodals
if multimodals is not None:
multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_all_ids)
multimodals = HistoryMultiModals.update_multimodals(multimodals, self._num_history_ids)
self.history_multimodals.add_inputs(multimodals)

# cross
Expand Down Expand Up @@ -610,11 +664,11 @@ def set_step(self, step: int):
"""set step."""
num_all_ids = self.num_all_ids
# update step for vlm
if len(self.history_embeddings) > 0:
new_step, self._num_history_images, self._num_images = \
self.history_embeddings.get_step(step)
if self.history_multimodals is not None:
new_step = self.history_multimodals.get_step(step)
assert 0 <= new_step <= step
step = new_step

self._num_history_ids = step
self._num_token_ids = num_all_ids - step
self.num_ignored_history = min(step, self.num_ignored_history)
Expand All @@ -625,3 +679,15 @@ def set_step(self, step: int):
if self.history_multimodals is not None:
self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids)
self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids)

def __repr__(self):
return (f'SchedulerSequence(seq_id={self.seq_id}, session_id={self.session_id}, '
f'status={self.status}, arrive_time={self.arrive_time}, '
f'return_logits={self.return_logits}, sampling_param={self.sampling_param}, '
f'num_history_tokens={self.history_len}, num_all_tokens={self.num_all_ids}, '
f'num_new_tokens={self.num_new_tokens}, all_token_ids={self.all_ids}, '
f'mm_ranges={self.history_multimodals.mm_ranges}, '
f'num_gpu_blocks={self.num_blocks}, gpu_blocks={self.logical_blocks.get_real_blocks()}, '
f'last_shared_node={getattr(self.logical_blocks, "last_shared_node", None)})')

__str__ = __repr__
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,13 +852,14 @@ def preprocess_input(self,
offset = input_mm['offset']
num_pad = input_mm['image_tokens']
image_token_id = input_mm['image_token_id']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,13 +889,14 @@ def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwarg
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()

self.ctx_mgr = ctx_mgr

# ----------- vision encoder ------------
Expand Down Expand Up @@ -144,7 +145,7 @@ def __init__(self,
# ----------- language model ------------
language_config = config.language_config
self.language = DeepseekV2ForCausalLM(config=language_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

self.config = language_config
# ----------- input processor ------------
self.input_processor = DeepSeekVLV2InputProcessor(config, dtype)

Expand Down Expand Up @@ -434,6 +435,7 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
images_spatial_crop = input_mm.get('images_spatial_crop', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()
Expand All @@ -443,6 +445,7 @@ def preprocess_input(self,
end=offset + num_pad,
meta=dict(
image_token_id=image_token_id,
hash_value=hash_value,
images_spatial_crop=images_spatial_crop,
))

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,14 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
Loading
Loading