Skip to content

[POC] force kv cache recomputation after each weight update #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: simulate_conventional_rl
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions pipelinerl/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.usage.usage_lib import UsageContext
from vllm.worker.multi_step_worker import MultiStepWorker
from vllm.worker.multi_step_model_runner import MultiStepModelRunner
from vllm.core.scheduler import Scheduler


import torch.distributed as dist
Expand All @@ -47,6 +48,22 @@
handler.setFormatter(formatter)
logger.addHandler(handler)

old_schedule_method = Scheduler.schedule
def new_schedule_method(self, *args, **kwargs):
result = old_schedule_method(self, *args, **kwargs)
if getattr(self, "_force_recompute_kv_cache", True):
logger.info(f"Clear the force recompute flag")
self._force_recompute_kv_cache = False
return result
Scheduler.schedule = new_schedule_method

old_can_append_slots = Scheduler._can_append_slots
def new_can_append_slots(self, *args, **kwargs):
if getattr(self, "_force_recompute_kv_cache", True):
logger.info(f"Return False from can_append_slots because of force recompute")
return False
return old_can_append_slots(self, *args, **kwargs)
Scheduler._can_append_slots = new_can_append_slots


def make_worker_class(multi_step: bool):
Expand Down Expand Up @@ -209,17 +226,6 @@ def signal_handler(*_) -> None:
if not args.disable_weight_updates:
weight_update_manager.input_process_groups()

# weight_update_stream = SingleStreamSpec(exp_path=args.exp_root_dir, topic="weight_update_request")
# async def weight_update_receiver():
# async with AsyncStreamReader(weight_update_stream) as reader:
# async for line in reader.read():
# message = TypeAdapter(TrainerMessage).validate_python(line)
# if isinstance(message, WeightUpdateRequest):
# await weight_update_manager.receive_weight_update(message)
# if not args.disable_weight_updates:
# logger.info(f"Create weight update background task")
# asyncio.create_task(weight_update_receiver())

# Run HTTP server
sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr)
Expand All @@ -228,6 +234,8 @@ def signal_handler(*_) -> None:
@app.post("/receive_weight_update")
async def _receive_weight_update(request: WeightUpdateRequest):
await weight_update_manager.receive_weight_update(request)
for scheduler in engine.engine.scheduler:
scheduler._force_recompute_kv_cache = True
return {"status": "ok"}

model_config = await engine.get_model_config()
Expand Down