From c35cf984900e7aa9c1f4a40abde473e0928a2d12 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 08:08:38 +0000 Subject: [PATCH 01/19] one remote agent per remote rank Signed-off-by: nicklucche --- vllm/config.py | 4 + .../kv_connector/v1/nixl_connector.py | 136 +++++++++++------- 2 files changed, 89 insertions(+), 51 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 54eb5e0ef0e..273cc333d00 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3442,6 +3442,10 @@ class KVTransferConfig(BaseModel): # any extra config that the connector may need kv_connector_extra_config: dict[str, Any] = {} + kv_producers_tensor_parallel_size: Optional[int] = None + kv_consumers_tensor_parallel_size: Optional[int] = None + + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8a15955bd4f..67323e877d6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -14,7 +14,7 @@ from typing_extensions import Optional from vllm import envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( @@ -98,7 +98,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(str(self.engine_id)) + self.connector_worker = NixlConnectorWorker(str(self.engine_id), vllm_config.kv_transfer_config) ############################################################ # Scheduler Side Methods @@ -214,7 +214,7 @@ def build_connector_meta( class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, engine_id: str): + def __init__(self, engine_id: str, kv_config: KVTransferConfig): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") @@ -223,8 +223,8 @@ def __init__(self, engine_id: str): # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) - # Map of engine_id -> agent_name. - self._remote_agents: dict[str, str] = {} + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) # Metadata. self.engine_id = engine_id @@ -236,16 +236,16 @@ def __init__(self, engine_id: str): self.kv_caches: dict[str, torch.Tensor] = {} # Map of engine_id -> kv_caches_base_addr - self.kv_caches_base_addr: dict[str, list[int]] = {} + self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = {} # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 # nixl_prepped_dlist_handle (int). - self.src_xfer_side_handle: int = 0 + self.src_xfer_side_handle: int = -1 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, int] = {} + self.dst_xfer_side_handles: dict[str, dict[int, int]] = defaultdict(dict) # Map of engine_id -> num_blocks. self.dst_num_blocks: dict[str, int] = {} @@ -267,6 +267,28 @@ def __init__(self, engine_id: str): # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None + # TODO tp multiplier only works when N/M is a multiple of the other. We can + # refactor later if needs be. + if kv_config.kv_producers_tensor_parallel_size is None or kv_config.kv_consumers_tensor_parallel_size is None: + # Ignore unless both are set + kv_config.kv_producers_tensor_parallel_size = self.world_size + kv_config.kv_consumers_tensor_parallel_size = self.world_size + else: + assert kv_config.kv_producers_tensor_parallel_size >= 0 and \ + kv_config.kv_consumers_tensor_parallel_size >= 0 + assert (kv_config.kv_producers_tensor_parallel_size % + kv_config.kv_consumers_tensor_parallel_size == 0) or \ + (kv_config.kv_consumers_tensor_parallel_size % + kv_config.kv_producers_tensor_parallel_size ==0) + + # Used to skip extra NIXL handshakes + self.is_homogenous_tp = kv_config.kv_producers_tensor_parallel_size == kv_config.kv_consumers_tensor_parallel_size + if self.world_size == kv_config.kv_producers_tensor_parallel_size: + # TODO better name + self._other_world_size = kv_config.kv_consumers_tensor_parallel_size + else: + self._other_world_size = kv_config.kv_producers_tensor_parallel_size + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, rank: int): @@ -278,6 +300,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # move this into the scheduler rather than worker, since # each rank needs the metadata of all other ranks (whereas # in this setup, each rank only gets one other rank's meta. + # TODO iterate over all ranks to handshake with M. Can we get M from config? encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) @@ -309,24 +332,33 @@ def _nixl_handshake(self, host: str, port: int): # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - path = f"tcp://{host}:{port + self.rank}" - logger.debug("Querying metadata on path: %s", path) - with zmq_ctx(zmq.REQ, path) as sock: - # Send query for the request. - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - - # Register Remote agent. - self.add_remote_agent(metadata) - setup_agent_time = time.perf_counter() - - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + # TODO Can we have a fixed number of remote ranks we handshake with? + # Eg with tp multiplier rank0<==>rrank0,1 | rank1<==>rrank2,3 ... + # iterate over all ranks to handshake with M. + for rank_j in range(self._other_world_size): + rank_j = self.rank if self.is_homogenous_tp else rank_j + path = f"tcp://{host}:{port + rank_j}" + logger.debug("Querying metadata on path: %s", path) + with zmq_ctx(zmq.REQ, path) as sock: + # Send query for the request. + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata, rank_j) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + + # Cut it short for homogenous tp and only record one remote rank + if self.is_homogenous_tp: + break def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -375,7 +407,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): region_len = self.num_blocks * self.block_len caches_data.append((base_addr, region_len, self.rank, "")) kv_caches_base_addr.append(base_addr) - self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + # Own kv_caches will only be indexed by self.rank. Remote kv caches will contain info for all workers. + self.kv_caches_base_addr[self.engine_id][self.rank] = kv_caches_base_addr self.num_regions = len(caches_data) descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") @@ -401,49 +434,50 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0): engine_id = nixl_agent_meta.engine_id if engine_id in self._remote_agents: return - self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( + self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr - - # Create src descs and xfer side handles. - blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) - - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) + engine_id][remote_rank] = nixl_agent_meta.kv_caches_base_addr + + # Create src descs and xfer side handles. Local block descr only contains own rank. + if self.src_xfer_side_handle < 0: + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id][self.rank]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) # Create dst descs and xfer side handles. self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] - for base_addr in self.kv_caches_base_addr[engine_id]: + for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * self.block_len # (addr, len, device id) blocks_data.append( (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, self.rank) + logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), engine_id, remote_rank, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id], descs) + engine_id][remote_rank] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """ From ec8481746c7cda3d464277b89f90c052af1b7f77 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 08:54:42 +0000 Subject: [PATCH 02/19] tp_size in metadata and handshake with rank0 first Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 79 +++++++++++-------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 67323e877d6..22db9096cef 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -52,6 +52,7 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int + tp_size: int @dataclass @@ -232,6 +233,7 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() + # Remote tracking ds only contain one entry for own tp group: engine_id-self.rank # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} @@ -248,7 +250,7 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): self.dst_xfer_side_handles: dict[str, dict[int, int]] = defaultdict(dict) # Map of engine_id -> num_blocks. - self.dst_num_blocks: dict[str, int] = {} + self.dst_num_blocks: dict[str, dict[int, int]] = defaultdict(dict) self._registered_descs: list[Any] = [] # In progress transfers. @@ -283,11 +285,13 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # Used to skip extra NIXL handshakes self.is_homogenous_tp = kv_config.kv_producers_tensor_parallel_size == kv_config.kv_consumers_tensor_parallel_size - if self.world_size == kv_config.kv_producers_tensor_parallel_size: - # TODO better name - self._other_world_size = kv_config.kv_consumers_tensor_parallel_size - else: - self._other_world_size = kv_config.kv_producers_tensor_parallel_size + # if self.world_size == kv_config.kv_producers_tensor_parallel_size: + # TODO we cant know this because it is spawned in a separate process with some other --tp value so we have to discover. + # Every instance may have a different number of tp workers. + # we can relax this right now and assume all instances are passed world size info from cli + # TODO what we can do is have current rank respond with own world size. Then client rank will broadcast world_size of other + self._tp_size = {self.engine_id: self.world_size} + # kv end to local ranks. Or we can have only rank0 of both producer/consumer synch up on said value. Auto-scaling needs a refresh. @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, @@ -333,32 +337,41 @@ def _nixl_handshake(self, host: str, port: int): # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. # TODO Can we have a fixed number of remote ranks we handshake with? - # Eg with tp multiplier rank0<==>rrank0,1 | rank1<==>rrank2,3 ... - # iterate over all ranks to handshake with M. - for rank_j in range(self._other_world_size): - rank_j = self.rank if self.is_homogenous_tp else rank_j + # Ow we could have rank0 send all metadata in a batch. + + def handshake(sock, rank: int)->NixlAgentMetadata: + # Send query for the request. + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata, rank) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return metadata + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = f"tcp://{host}:{port}" + logger.debug("Querying master rank metadata on path: %s", path) + with zmq_ctx(zmq.REQ, path) as sock: + metadata = handshake(sock, 0) + + # TODO should we skip this if remote world_size == world_size (homogeneous)? + # Iterate over all other remote ranks to handshake with. + for rank_j in range(1, self._tp_size[metadata.tp_size]): path = f"tcp://{host}:{port + rank_j}" logger.debug("Querying metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: - # Send query for the request. - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - - # Register Remote agent. - self.add_remote_agent(metadata, rank_j) - setup_agent_time = time.perf_counter() - - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) - - # Cut it short for homogenous tp and only record one remote rank - if self.is_homogenous_tp: - break + metadata = handshake(sock, rank_j) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -388,7 +401,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) - self.dst_num_blocks[self.engine_id] = self.num_blocks + self.dst_num_blocks[self.engine_id][self.rank] = self.num_blocks self.kv_caches = kv_caches kv_caches_base_addr = [] caches_data = [] @@ -424,6 +437,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, + tp_size=self.world_size ) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( @@ -439,6 +453,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= if engine_id in self._remote_agents: return + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + self._tp_size[engine_id] = nixl_agent_meta.tp_size self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) self.kv_caches_base_addr[ @@ -462,7 +479,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= "NIXL_INIT_AGENT", descs) # Create dst descs and xfer side handles. - self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + self.dst_num_blocks[engine_id][remote_rank] = nixl_agent_meta.num_blocks blocks_data = [] for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: for block_id in range(nixl_agent_meta.num_blocks): From 60ab1975a9c3a7d0419fb5cb7034b41402e310c3 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 09:06:54 +0000 Subject: [PATCH 03/19] todos Signed-off-by: nicklucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 22db9096cef..fac1306cd6e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -195,6 +195,7 @@ def build_connector_meta( self, scheduler_output: SchedulerOutput, ) -> KVConnectorMetadata: + # TODO ignored scheduler for now, I think this ds could be moved at some point meta = NixlConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. @@ -462,6 +463,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= engine_id][remote_rank] = nixl_agent_meta.kv_caches_base_addr # Create src descs and xfer side handles. Local block descr only contains own rank. + # TODO we could pull this out if it has nothing to do with remote if self.src_xfer_side_handle < 0: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id][self.rank]: @@ -479,6 +481,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= "NIXL_INIT_AGENT", descs) # Create dst descs and xfer side handles. + # TODO likely dont need 'remote_rank' indexing, as ALL tp workers have same num blocks right? self.dst_num_blocks[engine_id][remote_rank] = nixl_agent_meta.num_blocks blocks_data = [] for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: @@ -630,6 +633,9 @@ def _read_blocks( dst_engine_id: str, request_id: str, ): + # TODO right now I am missing the remote rank input: where should I read these blocks from? + # should I map remote_block_ids=>remote_rank? + # NOTE(rob): this takes ~2s. We need to get this off the hotpath. if dst_engine_id not in self._remote_agents: self._nixl_handshake(remote_host, remote_port) From 792bacd657528a80fef0c9dd4db909848b13c15d Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 12:50:39 +0000 Subject: [PATCH 04/19] dst_num_blocks is engine_id only Signed-off-by: nicklucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fac1306cd6e..02739ad0bfb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -250,8 +250,9 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, dict[int, int]] = defaultdict(dict) - # Map of engine_id -> num_blocks. - self.dst_num_blocks: dict[str, dict[int, int]] = defaultdict(dict) + # Map of engine_id -> num_blocks. Remote TP ranks will have the same + # amount of blocks. + self.dst_num_blocks: dict[str, int] = dict() self._registered_descs: list[Any] = [] # In progress transfers. @@ -402,7 +403,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) - self.dst_num_blocks[self.engine_id][self.rank] = self.num_blocks + self.dst_num_blocks[self.engine_id] = self.num_blocks self.kv_caches = kv_caches kv_caches_base_addr = [] caches_data = [] @@ -482,7 +483,10 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= # Create dst descs and xfer side handles. # TODO likely dont need 'remote_rank' indexing, as ALL tp workers have same num blocks right? - self.dst_num_blocks[engine_id][remote_rank] = nixl_agent_meta.num_blocks + if engine_id in self.dst_num_blocks[engine_id]: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: for block_id in range(nixl_agent_meta.num_blocks): @@ -491,7 +495,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= blocks_data.append( (base_addr + block_offset, self.block_len, self.rank)) logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", - len(blocks_data), engine_id, remote_rank, self.rank) + len(blocks_data), engine_id, remote_rank, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") From d2ea8c8380c7bedc3d402545580d886d0d50a247 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 15:35:20 +0000 Subject: [PATCH 05/19] fixes Signed-off-by: nicklucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 02739ad0bfb..79f729caa5c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -195,7 +195,6 @@ def build_connector_meta( self, scheduler_output: SchedulerOutput, ) -> KVConnectorMetadata: - # TODO ignored scheduler for now, I think this ds could be moved at some point meta = NixlConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. @@ -239,7 +238,7 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): self.kv_caches: dict[str, torch.Tensor] = {} # Map of engine_id -> kv_caches_base_addr - self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = {} + self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = defaultdict(dict) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -367,7 +366,7 @@ def handshake(sock, rank: int)->NixlAgentMetadata: # TODO should we skip this if remote world_size == world_size (homogeneous)? # Iterate over all other remote ranks to handshake with. - for rank_j in range(1, self._tp_size[metadata.tp_size]): + for rank_j in range(1, metadata.tp_size): path = f"tcp://{host}:{port + rank_j}" logger.debug("Querying metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: @@ -437,7 +436,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), - kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.rank], num_blocks=self.num_blocks, tp_size=self.world_size ) @@ -483,7 +482,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= # Create dst descs and xfer side handles. # TODO likely dont need 'remote_rank' indexing, as ALL tp workers have same num blocks right? - if engine_id in self.dst_num_blocks[engine_id]: + if engine_id in self.dst_num_blocks: assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks From f17092fab0b3989deba01ccf3f3ab682fb812e41 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 16:10:02 +0000 Subject: [PATCH 06/19] block_len is tp dependent Signed-off-by: nicklucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 79f729caa5c..144e3f83a17 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -52,6 +52,7 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int + block_len: int # tp_size dependent tp_size: int @@ -438,6 +439,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.rank], num_blocks=self.num_blocks, + block_len=self.block_len, tp_size=self.world_size ) ready_event = threading.Event() @@ -489,10 +491,10 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= blocks_data = [] for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * self.block_len + block_offset = block_id * nixl_agent_meta.block_len # (addr, len, device id) blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) + (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) # TODO remote rank? logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), engine_id, remote_rank, self.rank) @@ -637,7 +639,8 @@ def _read_blocks( request_id: str, ): # TODO right now I am missing the remote rank input: where should I read these blocks from? - # should I map remote_block_ids=>remote_rank? + # should I map remote_block_ids=>remote_rank? dst_engine is actually unique per rank! + print("READ_BLOCKS", remote_block_ids, dst_engine_id) # NOTE(rob): this takes ~2s. We need to get this off the hotpath. if dst_engine_id not in self._remote_agents: From ddf4c8e37f1f77db47f30e0b4a7a3f56789a583e Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 17:55:01 +0000 Subject: [PATCH 07/19] wip Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 144e3f83a17..0f26e71811d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -52,7 +52,6 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int - block_len: int # tp_size dependent tp_size: int @@ -238,7 +237,8 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr + # Map of engine_id -> kv_caches_base_addr. For TP case, the structure + # is still flat. self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = defaultdict(dict) # Number of NIXL regions. Currently one region per cache @@ -246,7 +246,7 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): self.num_regions = 0 # nixl_prepped_dlist_handle (int). - self.src_xfer_side_handle: int = -1 + self.src_xfer_side_handle: dict[int, int] = -1 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, dict[int, int]] = defaultdict(dict) @@ -439,7 +439,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.rank], num_blocks=self.num_blocks, - block_len=self.block_len, tp_size=self.world_size ) ready_event = threading.Event() @@ -465,36 +464,44 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= engine_id][remote_rank] = nixl_agent_meta.kv_caches_base_addr # Create src descs and xfer side handles. Local block descr only contains own rank. - # TODO we could pull this out if it has nothing to do with remote - if self.src_xfer_side_handle < 0: + tp_multiplier = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert tp_multiplier > 0, "Decode TP cannot be smaller than prefill TP" + # Different P instances may have different number of tp workers. + # Prepare descriptors based on expected size of xfers. + dst_block_len = self.block_len // tp_multiplier # TP kv_heads splitting + if tp_multiplier not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id][self.rank]: for block_id in range(self.num_blocks): block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) + for i in range(tp_multiplier): + tp_multiplier_offset = i * dst_block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + self.src_xfer_side_handle[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) - # Create dst descs and xfer side handles. - # TODO likely dont need 'remote_rank' indexing, as ALL tp workers have same num blocks right? + # Create dst descs and xfer side handles. TP workers have same #blocks if engine_id in self.dst_num_blocks: assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] + # TODO map only the split belonging to current rank + # With heterogenous TP, prepare the descriptors for the P KV chunk + # belonging to current D rank. Eg. 1P 2D => P KV:[KV_0 | KV_1]. for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len + block_offset = block_id * dst_block_len # (addr, len, device id) blocks_data.append( - (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) # TODO remote rank? + (base_addr + block_offset, dst_block_len, self.rank)) # TODO remote rank? logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), engine_id, remote_rank, self.rank) @@ -638,10 +645,6 @@ def _read_blocks( dst_engine_id: str, request_id: str, ): - # TODO right now I am missing the remote rank input: where should I read these blocks from? - # should I map remote_block_ids=>remote_rank? dst_engine is actually unique per rank! - print("READ_BLOCKS", remote_block_ids, dst_engine_id) - # NOTE(rob): this takes ~2s. We need to get this off the hotpath. if dst_engine_id not in self._remote_agents: self._nixl_handshake(remote_host, remote_port) From 00392cef9ecd482c3e33832cbc48f3a0db67214d Mon Sep 17 00:00:00 2001 From: nicklucche Date: Wed, 7 May 2025 09:45:06 +0000 Subject: [PATCH 08/19] refactor remote kv_cache splitting and ditch tp_multiplier Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 71 +++++++++++-------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0f26e71811d..a8ab0b0392e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -246,7 +246,7 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): self.num_regions = 0 # nixl_prepped_dlist_handle (int). - self.src_xfer_side_handle: dict[int, int] = -1 + self.src_xfer_side_handle: int = -1 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, dict[int, int]] = defaultdict(dict) @@ -387,6 +387,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # MLA case. self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] + # TODO does this include tp dependent size? block_shape = first_kv_cache.shape[-block_rank:] else: # [2 (k and v), num_blocks, ...] @@ -397,6 +398,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc self.block_len = kv_elem_size * math.prod(block_shape) + print(f"\n\n{self.block_len=}\n\n") logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, first_kv_cache.shape) @@ -463,28 +465,25 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self.kv_caches_base_addr[ engine_id][remote_rank] = nixl_agent_meta.kv_caches_base_addr - # Create src descs and xfer side handles. Local block descr only contains own rank. + # Create src descs and xfer side handles. D workers adding remote P's. tp_multiplier = self._tp_size[self.engine_id] // self._tp_size[engine_id] assert tp_multiplier > 0, "Decode TP cannot be smaller than prefill TP" - # Different P instances may have different number of tp workers. - # Prepare descriptors based on expected size of xfers. - dst_block_len = self.block_len // tp_multiplier # TP kv_heads splitting - if tp_multiplier not in self.src_xfer_side_handle: + # TODO may not be needed if block_len is already accounting for tp + # dst_block_len = self.block_len // tp_multiplier + if self.src_xfer_side_handle < 0: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id][self.rank]: for block_id in range(self.num_blocks): block_offset = block_id * self.block_len - for i in range(tp_multiplier): - tp_multiplier_offset = i * dst_block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist( + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) # Create dst descs and xfer side handles. TP workers have same #blocks @@ -494,22 +493,31 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] # TODO map only the split belonging to current rank - # With heterogenous TP, prepare the descriptors for the P KV chunk - # belonging to current D rank. Eg. 1P 2D => P KV:[KV_0 | KV_1]. - for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * dst_block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, dst_block_len, self.rank)) # TODO remote rank? - logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", - len(blocks_data), engine_id, remote_rank, self.rank) - - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id][remote_rank] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][remote_rank], descs) + # With heterogenous TP, prepare the descriptors by splitting the P KV + # cache into chunks of D worker's size (D>P). + # Eg. 1P 2D => P KV:[KV_0 | KV_1]. + p_remote_rank = self.rank % nixl_agent_meta.tp_size + # Only register the remote's descriptor if current rank pulls from it + if p_remote_rank == remote_rank: + # TODO in case sizes aren't exactly divisible, we may want to swap + # self.block_len with meta.block_len // tp_multiplier + # (eg when dividing by 3) and handle final block. src_xfer too. + rank_offset = self.rank // nixl_agent_meta.tp_size * self.block_len + for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: + base_addr += rank_offset + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), engine_id, remote_rank, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id][remote_rank] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -661,10 +669,15 @@ def _read_blocks( assert len(local_block_ids) > 0 assert len(local_block_ids) == len(remote_block_ids) + # TODO you have to make N xfers with heter tp + # With homogeneous TP, each TP worker loads KV from corresponding rank. + # With heterogenous TP, assuming D>P, the D tp workers will have to + # issue xfers to part of the P `p_remote_rank` kv caches. + p_remote_rank = self.rank % self._tp_size[dst_engine_id] # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle - remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][p_remote_rank] # Get descs ids. remote_block_descs_ids = self._get_block_descs_ids( From eb0bdd2cb1dbfea941e579021ef4a46114ea60bc Mon Sep 17 00:00:00 2001 From: nicklucche Date: Wed, 7 May 2025 12:52:55 +0000 Subject: [PATCH 09/19] 2-handshake model with vertical kv cache split Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 99 ++++++++----------- 1 file changed, 42 insertions(+), 57 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a8ab0b0392e..19cd0f749ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -237,9 +237,9 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr. For TP case, the structure - # is still flat. - self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = defaultdict(dict) + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[str, list[int]] = dict() # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -248,10 +248,10 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # nixl_prepped_dlist_handle (int). self.src_xfer_side_handle: int = -1 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, dict[int, int]] = defaultdict(dict) + self.dst_xfer_side_handles: dict[str, int] = dict() # Map of engine_id -> num_blocks. Remote TP ranks will have the same - # amount of blocks. + # number of blocks. self.dst_num_blocks: dict[str, int] = dict() self._registered_descs: list[Any] = [] @@ -270,30 +270,8 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None - - # TODO tp multiplier only works when N/M is a multiple of the other. We can - # refactor later if needs be. - if kv_config.kv_producers_tensor_parallel_size is None or kv_config.kv_consumers_tensor_parallel_size is None: - # Ignore unless both are set - kv_config.kv_producers_tensor_parallel_size = self.world_size - kv_config.kv_consumers_tensor_parallel_size = self.world_size - else: - assert kv_config.kv_producers_tensor_parallel_size >= 0 and \ - kv_config.kv_consumers_tensor_parallel_size >= 0 - assert (kv_config.kv_producers_tensor_parallel_size % - kv_config.kv_consumers_tensor_parallel_size == 0) or \ - (kv_config.kv_consumers_tensor_parallel_size % - kv_config.kv_producers_tensor_parallel_size ==0) - # Used to skip extra NIXL handshakes - self.is_homogenous_tp = kv_config.kv_producers_tensor_parallel_size == kv_config.kv_consumers_tensor_parallel_size - # if self.world_size == kv_config.kv_producers_tensor_parallel_size: - # TODO we cant know this because it is spawned in a separate process with some other --tp value so we have to discover. - # Every instance may have a different number of tp workers. - # we can relax this right now and assume all instances are passed world size info from cli - # TODO what we can do is have current rank respond with own world size. Then client rank will broadcast world_size of other self._tp_size = {self.engine_id: self.world_size} - # kv end to local ranks. Or we can have only rank0 of both producer/consumer synch up on said value. Auto-scaling needs a refresh. @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, @@ -366,12 +344,15 @@ def handshake(sock, rank: int)->NixlAgentMetadata: metadata = handshake(sock, 0) # TODO should we skip this if remote world_size == world_size (homogeneous)? - # Iterate over all other remote ranks to handshake with. - for rank_j in range(1, metadata.tp_size): - path = f"tcp://{host}:{port + rank_j}" - logger.debug("Querying metadata on path: %s", path) + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + p_remote_rank = self.rank % metadata.tp_size + if p_remote_rank > 0: + path = f"tcp://{host}:{port + p_remote_rank}" + logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) with zmq_ctx(zmq.REQ, path) as sock: - metadata = handshake(sock, rank_j) + metadata = handshake(sock, p_remote_rank) @@ -424,9 +405,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): region_len = self.num_blocks * self.block_len caches_data.append((base_addr, region_len, self.rank, "")) kv_caches_base_addr.append(base_addr) - # Own kv_caches will only be indexed by self.rank. Remote kv caches will contain info for all workers. - self.kv_caches_base_addr[self.engine_id][self.rank] = kv_caches_base_addr + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) + print("************************BLOCKS SETUP") + print(f"Number of blocks {len(kv_caches_base_addr)=}\n") + print(f"{self.num_blocks=}, {self.block_len=}, {self.num_regions=}\n") descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) @@ -439,7 +422,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), - kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.rank], + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, tp_size=self.world_size ) @@ -454,7 +437,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0): engine_id = nixl_agent_meta.engine_id - if engine_id in self._remote_agents: + # TODO re-evaluate refreshing for scaling/recovery + if engine_id in self._remote_agents and remote_rank in self._remote_agents[engine_id]: return if engine_id in self._tp_size: @@ -462,17 +446,12 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self._tp_size[engine_id] = nixl_agent_meta.tp_size self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) - self.kv_caches_base_addr[ - engine_id][remote_rank] = nixl_agent_meta.kv_caches_base_addr - - # Create src descs and xfer side handles. D workers adding remote P's. - tp_multiplier = self._tp_size[self.engine_id] // self._tp_size[engine_id] - assert tp_multiplier > 0, "Decode TP cannot be smaller than prefill TP" - # TODO may not be needed if block_len is already accounting for tp - # dst_block_len = self.block_len // tp_multiplier + + # Create src descs and xfer side handles. + # TODO we could pull this out of remote_agent, only depends on self if self.src_xfer_side_handle < 0: blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id][self.rank]: + for base_addr in self.kv_caches_base_addr[self.engine_id]: for block_id in range(self.num_blocks): block_offset = block_id * self.block_len # (addr, len, device id) @@ -492,18 +471,27 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] - # TODO map only the split belonging to current rank # With heterogenous TP, prepare the descriptors by splitting the P KV # cache into chunks of D worker's size (D>P). - # Eg. 1P 2D => P KV:[KV_0 | KV_1]. + # Eg. 1P 2D => P0 KV:[KV_0 | KV_1] (contiguous view). p_remote_rank = self.rank % nixl_agent_meta.tp_size # Only register the remote's descriptor if current rank pulls from it if p_remote_rank == remote_rank: + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + # TODO enforce tp sizes are exact multiples + d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" + # TODO in case sizes aren't exactly divisible, we may want to swap - # self.block_len with meta.block_len // tp_multiplier + # self.block_len with meta.block_len // d_workers_per_p_worker # (eg when dividing by 3) and handle final block. src_xfer too. + # assert nixl_agent_meta.block_len % self.block_len == 0 + + # Split the kv memory inside a nixl region to guarantee each local + # rank is pulling the kv cache of all layers of a remote worker. rank_offset = self.rank // nixl_agent_meta.tp_size * self.block_len - for base_addr in self.kv_caches_base_addr[engine_id][remote_rank]: + print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}, {len(self.kv_caches_base_addr[engine_id])}\n\n") + for base_addr in self.kv_caches_base_addr[engine_id]: base_addr += rank_offset for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * self.block_len @@ -515,8 +503,7 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id][remote_rank] = self.nixl_wrapper.prep_xfer_dlist( + self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: @@ -669,15 +656,13 @@ def _read_blocks( assert len(local_block_ids) > 0 assert len(local_block_ids) == len(remote_block_ids) - # TODO you have to make N xfers with heter tp - # With homogeneous TP, each TP worker loads KV from corresponding rank. - # With heterogenous TP, assuming D>P, the D tp workers will have to - # issue xfers to part of the P `p_remote_rank` kv caches. - p_remote_rank = self.rank % self._tp_size[dst_engine_id] + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogenous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle - remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][p_remote_rank] + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # Get descs ids. remote_block_descs_ids = self._get_block_descs_ids( From 44db4642d8621340deba623b564a456b486a26a7 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Thu, 8 May 2025 14:15:28 +0000 Subject: [PATCH 10/19] still broken Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 71 ++++++++++++------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 19cd0f749ba..f9ea84364cd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -53,6 +53,7 @@ class NixlAgentMetadata( kv_caches_base_addr: list[int] num_blocks: int tp_size: int + block_len: int @dataclass @@ -245,8 +246,9 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 - # nixl_prepped_dlist_handle (int). - self.src_xfer_side_handle: int = -1 + # nixl_prepped_dlist_handle. Different dst TP sizes require preparing + # xfer layout differently. + self.src_xfer_side_handle: int = dict() # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, int] = dict() @@ -297,6 +299,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # NOTE(rob): we need each rank to have a unique port. This # hack to keeps us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. + # TODO get rank port util port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank path = f"tcp://{host}:{port}" logger.debug("Starting listening on path: %s", path) @@ -316,8 +319,6 @@ def _nixl_handshake(self, host: str, port: int): # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - # TODO Can we have a fixed number of remote ranks we handshake with? - # Ow we could have rank0 send all metadata in a batch. def handshake(sock, rank: int)->NixlAgentMetadata: # Send query for the request. @@ -424,7 +425,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - tp_size=self.world_size + tp_size=self.world_size, + block_len=self.block_len ) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( @@ -436,6 +438,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event.wait() def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0): + # FIXME one other approach I tried is loading half of every remote block instead of half the blocks. Doesnt seem to make much difference engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if engine_id in self._remote_agents and remote_rank in self._remote_agents[engine_id]: @@ -446,41 +449,49 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self._tp_size[engine_id] = nixl_agent_meta.tp_size self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) + + # TODO enforce tp sizes are exact multiples + d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" + dst_num_blocks_per_local_rank = nixl_agent_meta.num_blocks // d_workers_per_p_worker # Create src descs and xfer side handles. # TODO we could pull this out of remote_agent, only depends on self - if self.src_xfer_side_handle < 0: + if d_workers_per_p_worker not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + for block_id in range(dst_num_blocks_per_local_rank): + block_offset = block_id * nixl_agent_meta.block_len # (addr, len, device id) + # use the block size of the dst/P node to make sure regions match blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) + (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle[d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) # Create dst descs and xfer side handles. TP workers have same #blocks - if engine_id in self.dst_num_blocks: - assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + # if engine_id in self.dst_num_blocks: + # assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + + # self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + # When D_TP>P_TP, P blocks are split between D workers. Hence we may + # record a fraction of the total num_blocks in P. + self.dst_num_blocks[engine_id] = dst_num_blocks_per_local_rank - self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] # With heterogenous TP, prepare the descriptors by splitting the P KV # cache into chunks of D worker's size (D>P). - # Eg. 1P 2D => P0 KV:[KV_0 | KV_1] (contiguous view). + # Eg. PTP1 DTP2 => P0 KV:[KV_0 | KV_1] (contiguous view). p_remote_rank = self.rank % nixl_agent_meta.tp_size # Only register the remote's descriptor if current rank pulls from it if p_remote_rank == remote_rank: self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - # TODO enforce tp sizes are exact multiples - d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] - assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" # TODO in case sizes aren't exactly divisible, we may want to swap # self.block_len with meta.block_len // d_workers_per_p_worker @@ -489,15 +500,25 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= # Split the kv memory inside a nixl region to guarantee each local # rank is pulling the kv cache of all layers of a remote worker. - rank_offset = self.rank // nixl_agent_meta.tp_size * self.block_len - print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}, {len(self.kv_caches_base_addr[engine_id])}\n\n") - for base_addr in self.kv_caches_base_addr[engine_id]: + # TODO what if the region_len of P and D don't match in size due to some TP overhead??Also this would assume the mem utilized is the same.. + assert d_workers_per_p_worker == 2 + rank_offset = self.rank // nixl_agent_meta.tp_size * nixl_agent_meta.block_len * dst_num_blocks_per_local_rank + print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}/ Remote region_len {nixl_agent_meta.num_blocks*nixl_agent_meta.block_len}\n\n") + print(f"{nixl_agent_meta.num_blocks=}, {dst_num_blocks_per_local_rank=}") + # DECODE TP2 || self.num_blocks=33769, self.block_len=16384, self.num_regions=56 + # PREFILL TP1 || self.num_blocks=17371, self.block_len=32768, self.num_regions=56 + # FIXME assume num_blocks and block_len are actually divisible and all is nice. This needs to be enforced (eg diff mem usage might break) + for base_addr in nixl_agent_meta.kv_caches_base_addr: base_addr += rank_offset - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * self.block_len + # for block_id in range(self.num_blocks): + for block_id in range(dst_num_blocks_per_local_rank): + # block_offset = block_id * self.block_len + block_offset = block_id * nixl_agent_meta.block_len # (addr, len, device id) blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) + (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) + # blocks_data.append( + # (base_addr + block_offset, self.block_len, self.rank)) logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), engine_id, remote_rank, self.rank) @@ -641,6 +662,7 @@ def _read_blocks( request_id: str, ): # NOTE(rob): this takes ~2s. We need to get this off the hotpath. + # TODO check remote_rank in here too? if dst_engine_id not in self._remote_agents: self._nixl_handshake(remote_host, remote_port) @@ -661,7 +683,8 @@ def _read_blocks( # workers will issue xfers to parts of the P worker remote kv caches. # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle + d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handle[d_workers_per_p_worker] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # Get descs ids. From 52d232592730b4b7ba72779c2ad390872fc3aabe Mon Sep 17 00:00:00 2001 From: nicklucche Date: Fri, 9 May 2025 12:56:38 +0000 Subject: [PATCH 11/19] minor Signed-off-by: nicklucche --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f9ea84364cd..040f5c22ca1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -456,7 +456,6 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= dst_num_blocks_per_local_rank = nixl_agent_meta.num_blocks // d_workers_per_p_worker # Create src descs and xfer side handles. - # TODO we could pull this out of remote_agent, only depends on self if d_workers_per_p_worker not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: @@ -501,7 +500,6 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= # Split the kv memory inside a nixl region to guarantee each local # rank is pulling the kv cache of all layers of a remote worker. # TODO what if the region_len of P and D don't match in size due to some TP overhead??Also this would assume the mem utilized is the same.. - assert d_workers_per_p_worker == 2 rank_offset = self.rank // nixl_agent_meta.tp_size * nixl_agent_meta.block_len * dst_num_blocks_per_local_rank print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}/ Remote region_len {nixl_agent_meta.num_blocks*nixl_agent_meta.block_len}\n\n") print(f"{nixl_agent_meta.num_blocks=}, {dst_num_blocks_per_local_rank=}") From 8080346df190b3efcba45dc7efe7fe34b3108d65 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Fri, 9 May 2025 14:48:25 +0000 Subject: [PATCH 12/19] revert config changes Signed-off-by: nicklucche --- vllm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 273cc333d00..4333bedeea2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3442,9 +3442,6 @@ class KVTransferConfig(BaseModel): # any extra config that the connector may need kv_connector_extra_config: dict[str, Any] = {} - kv_producers_tensor_parallel_size: Optional[int] = None - kv_consumers_tensor_parallel_size: Optional[int] = None - def compute_hash(self) -> str: """ From f216e03fed5974924528c2ed95028f77963c2006 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 09:54:07 +0000 Subject: [PATCH 13/19] split kv_cache along head dim Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 93 +++++++++---------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 040f5c22ca1..afe68873b71 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -286,7 +286,6 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # move this into the scheduler rather than worker, since # each rank needs the metadata of all other ranks (whereas # in this setup, each rank only gets one other rank's meta. - # TODO iterate over all ranks to handshake with M. Can we get M from config? encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) @@ -369,18 +368,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # MLA case. self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] - # TODO does this include tp dependent size? block_shape = first_kv_cache.shape[-block_rank:] + # TODO handle heterogenous TP for MLA else: - # [2 (k and v), num_blocks, ...] + # [2 (k and v), num_blocks, block_size, kv_heads, head_dim] self.num_blocks = first_kv_cache.shape[1] block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] + self.block_size, self.n_kv_heads, self.head_dim = block_shape + # head size in btyes. + self.kv_dim = kv_elem_size * self.n_kv_heads * self.head_dim # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc + # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) - print(f"\n\n{self.block_len=}\n\n") logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, first_kv_cache.shape) @@ -401,6 +403,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches cache_list = [cache_or_caches] if use_mla else cache_or_caches + # TODO I think current mem layout is fine but double check for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len @@ -416,7 +419,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") - self._registered_descs.append(descs) # After KV Caches registered, listen for new connections. @@ -438,7 +440,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event.wait() def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0): - # FIXME one other approach I tried is loading half of every remote block instead of half the blocks. Doesnt seem to make much difference engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if engine_id in self._remote_agents and remote_rank in self._remote_agents[engine_id]: @@ -453,18 +454,21 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= # TODO enforce tp sizes are exact multiples d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" - dst_num_blocks_per_local_rank = nixl_agent_meta.num_blocks // d_workers_per_p_worker # Create src descs and xfer side handles. if d_workers_per_p_worker not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(dst_num_blocks_per_local_rank): - block_offset = block_id * nixl_agent_meta.block_len - # (addr, len, device id) - # use the block size of the dst/P node to make sure regions match - blocks_data.append( - (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) + # self.num_blocks > nixl_agent_meta.num_blocks + for block_id in range(nixl_agent_meta.num_blocks): + # block_offset = block_id * nixl_agent_meta.block_len + block_offset = block_id * self.block_len + for b in range(self.block_size): + head_offset = b * self.kv_dim + addr = base_addr + block_offset + head_offset + # (addr, len, device id) + # use the block size of the dst/P node to make sure regions match + blocks_data.append((addr, self.kv_dim, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank) @@ -474,49 +478,34 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self.src_xfer_side_handle[d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) - # Create dst descs and xfer side handles. TP workers have same #blocks - # if engine_id in self.dst_num_blocks: - # assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - # self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks - # When D_TP>P_TP, P blocks are split between D workers. Hence we may - # record a fraction of the total num_blocks in P. - self.dst_num_blocks[engine_id] = dst_num_blocks_per_local_rank + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] # With heterogenous TP, prepare the descriptors by splitting the P KV - # cache into chunks of D worker's size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[KV_0 | KV_1] (contiguous view). + # cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. p_remote_rank = self.rank % nixl_agent_meta.tp_size - # Only register the remote's descriptor if current rank pulls from it + # Only register the remote's descriptors if current rank pulls from it. if p_remote_rank == remote_rank: self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - - # TODO in case sizes aren't exactly divisible, we may want to swap - # self.block_len with meta.block_len // d_workers_per_p_worker - # (eg when dividing by 3) and handle final block. src_xfer too. - # assert nixl_agent_meta.block_len % self.block_len == 0 - - # Split the kv memory inside a nixl region to guarantee each local - # rank is pulling the kv cache of all layers of a remote worker. - # TODO what if the region_len of P and D don't match in size due to some TP overhead??Also this would assume the mem utilized is the same.. - rank_offset = self.rank // nixl_agent_meta.tp_size * nixl_agent_meta.block_len * dst_num_blocks_per_local_rank - print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}/ Remote region_len {nixl_agent_meta.num_blocks*nixl_agent_meta.block_len}\n\n") - print(f"{nixl_agent_meta.num_blocks=}, {dst_num_blocks_per_local_rank=}") - # DECODE TP2 || self.num_blocks=33769, self.block_len=16384, self.num_regions=56 - # PREFILL TP1 || self.num_blocks=17371, self.block_len=32768, self.num_regions=56 - # FIXME assume num_blocks and block_len are actually divisible and all is nice. This needs to be enforced (eg diff mem usage might break) + rank_offset = self.rank // nixl_agent_meta.tp_size * self.kv_dim + # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: - base_addr += rank_offset - # for block_id in range(self.num_blocks): - for block_id in range(dst_num_blocks_per_local_rank): - # block_offset = block_id * self.block_len + for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) - # blocks_data.append( - # (base_addr + block_offset, self.block_len, self.rank)) + # TODO assume same block_size. Should enforce it on handshake. + for b in range(self.block_size): + # block_offset = block_id * self.block_len + # Remote kv_dim=local kv_dim * d_workers_per_p_worker + head_offset = b * self.kv_dim * d_workers_per_p_worker + addr = base_addr + block_offset + head_offset + # (addr, len, device id) + blocks_data.append( + (addr + rank_offset, self.kv_dim, remote_rank)) logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), engine_id, remote_rank, self.rank) @@ -685,12 +674,17 @@ def _read_blocks( local_xfer_side_handle = self.src_xfer_side_handle[d_workers_per_p_worker] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + print(f"BLOCK IDS {local_block_ids=}, {remote_block_ids=}") + # Get descs ids. remote_block_descs_ids = self._get_block_descs_ids( dst_engine_id, remote_block_ids) local_block_descs_ids = self._get_block_descs_ids( self.engine_id, local_block_ids) assert len(local_block_descs_ids) == len(remote_block_descs_ids) + + print(f"LOCAL {len(local_block_descs_ids)=}, {local_xfer_side_handle=}") + print(f"REMOTE {len(remote_block_descs_ids)=}, {remote_xfer_side_handle=}") # Prepare transfer with Nixl. handle = self.nixl_wrapper.make_prepped_xfer( @@ -718,9 +712,12 @@ def _get_block_descs_ids(self, engine_id: str, # Compute the desc ids for each block. descs_ids: list[int] = [] + # TODO branch out here here to save on number of descr for homogen tp for reg_id in region_ids: for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) + # descs_ids.append(reg_id * num_blocks + block_id) + for kv_block in range(self.block_size): + descs_ids.append(reg_id * num_blocks * self.block_size + block_id * self.block_size + kv_block) return descs_ids From 72a4c141f66cc1784adc7b009e4e692aefabee63 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 10:52:58 +0000 Subject: [PATCH 14/19] fix descr indexing Signed-off-by: nicklucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index afe68873b71..be71afc95b7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -459,9 +459,12 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= if d_workers_per_p_worker not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: - # self.num_blocks > nixl_agent_meta.num_blocks - for block_id in range(nixl_agent_meta.num_blocks): - # block_offset = block_id * nixl_agent_meta.block_len + # NOTE Here more blocks that what are transfered are used + # as self.num_blocks >= nixl_agent_meta.num_blocks. We could + # create fewer, but then _get_block_descs_ids needs to select + # nixl_agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): block_offset = block_id * self.block_len for b in range(self.block_size): head_offset = b * self.kv_dim From 522f647b59f09df4c306c71f9878fb79b28f532b Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 18:07:05 +0000 Subject: [PATCH 15/19] clean up Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 49 ++++++++----------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index be71afc95b7..2ff8570ab58 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -14,7 +14,7 @@ from typing_extensions import Optional from vllm import envs -from vllm.config import VllmConfig, KVTransferConfig +from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( @@ -100,7 +100,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(str(self.engine_id), vllm_config.kv_transfer_config) + self.connector_worker = NixlConnectorWorker(str(self.engine_id)) ############################################################ # Scheduler Side Methods @@ -216,7 +216,7 @@ def build_connector_meta( class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, engine_id: str, kv_config: KVTransferConfig): + def __init__(self, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") @@ -343,8 +343,6 @@ def handshake(sock, rank: int)->NixlAgentMetadata: with zmq_ctx(zmq.REQ, path) as sock: metadata = handshake(sock, 0) - # TODO should we skip this if remote world_size == world_size (homogeneous)? - # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. p_remote_rank = self.rank % metadata.tp_size @@ -369,15 +367,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] block_shape = first_kv_cache.shape[-block_rank:] - # TODO handle heterogenous TP for MLA + self.block_size, kv_latent_dim = block_shape + self.kv_dim = kv_elem_size * kv_latent_dim else: # [2 (k and v), num_blocks, block_size, kv_heads, head_dim] self.num_blocks = first_kv_cache.shape[1] block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - self.block_size, self.n_kv_heads, self.head_dim = block_shape + self.block_size, n_kv_heads, head_dim = block_shape # head size in btyes. - self.kv_dim = kv_elem_size * self.n_kv_heads * self.head_dim + self.kv_dim = kv_elem_size * n_kv_heads * head_dim # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc @@ -403,7 +402,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches cache_list = [cache_or_caches] if use_mla else cache_or_caches - # TODO I think current mem layout is fine but double check for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len @@ -411,9 +409,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) - print("************************BLOCKS SETUP") - print(f"Number of blocks {len(kv_caches_base_addr)=}\n") - print(f"{self.num_blocks=}, {self.block_len=}, {self.num_regions=}\n") descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) @@ -451,18 +446,22 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) - # TODO enforce tp sizes are exact multiples d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" + # TODO we should also check hidden_dim and kv precision, they must match + remote_block_size = nixl_agent_meta.block_len / (self.kv_dim*d_workers_per_p_worker) + assert self.block_size == remote_block_size, "Remote P worker with " + "different block size is not supported" + # Create src descs and xfer side handles. if d_workers_per_p_worker not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: - # NOTE Here more blocks that what are transfered are used - # as self.num_blocks >= nixl_agent_meta.num_blocks. We could - # create fewer, but then _get_block_descs_ids needs to select - # nixl_agent_meta.num_blocks instead of self.num_blocks for + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for # local descr, and that makes handling regular flow less clean. for block_id in range(self.num_blocks): block_offset = block_id * self.block_len @@ -470,7 +469,6 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= head_offset = b * self.kv_dim addr = base_addr + block_offset + head_offset # (addr, len, device id) - # use the block size of the dst/P node to make sure regions match blocks_data.append((addr, self.kv_dim, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank) @@ -488,8 +486,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] - # With heterogenous TP, prepare the descriptors by splitting the P KV - # cache along kv_head dim, of D worker's kv_head size (D>P). + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogenous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. p_remote_rank = self.rank % nixl_agent_meta.tp_size # Only register the remote's descriptors if current rank pulls from it. @@ -677,8 +676,6 @@ def _read_blocks( local_xfer_side_handle = self.src_xfer_side_handle[d_workers_per_p_worker] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] - print(f"BLOCK IDS {local_block_ids=}, {remote_block_ids=}") - # Get descs ids. remote_block_descs_ids = self._get_block_descs_ids( dst_engine_id, remote_block_ids) @@ -686,9 +683,6 @@ def _read_blocks( self.engine_id, local_block_ids) assert len(local_block_descs_ids) == len(remote_block_descs_ids) - print(f"LOCAL {len(local_block_descs_ids)=}, {local_xfer_side_handle=}") - print(f"REMOTE {len(remote_block_descs_ids)=}, {remote_xfer_side_handle=}") - # Prepare transfer with Nixl. handle = self.nixl_wrapper.make_prepped_xfer( "READ", @@ -715,12 +709,11 @@ def _get_block_descs_ids(self, engine_id: str, # Compute the desc ids for each block. descs_ids: list[int] = [] - # TODO branch out here here to save on number of descr for homogen tp for reg_id in region_ids: for block_id in block_ids: - # descs_ids.append(reg_id * num_blocks + block_id) for kv_block in range(self.block_size): - descs_ids.append(reg_id * num_blocks * self.block_size + block_id * self.block_size + kv_block) + descs_ids.append(reg_id * num_blocks * self.block_size + + block_id * self.block_size + kv_block) return descs_ids From e4e47497a9a3da55bfb38f39183f7f2ce49a0598 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 18:18:21 +0000 Subject: [PATCH 16/19] format Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 114 ++++++++++-------- 1 file changed, 62 insertions(+), 52 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2ff8570ab58..66b80961a15 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -234,12 +234,11 @@ def __init__(self, engine_id: str): self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() - # Remote tracking ds only contain one entry for own tp group: engine_id-self.rank # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr. For TP case, each local - # rank will still only pull from a single remote TP worker. + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. self.kv_caches_base_addr: dict[str, list[int]] = dict() # Number of NIXL regions. Currently one region per cache @@ -247,7 +246,7 @@ def __init__(self, engine_id: str): self.num_regions = 0 # nixl_prepped_dlist_handle. Different dst TP sizes require preparing - # xfer layout differently. + # xfer layout differently. self.src_xfer_side_handle: int = dict() # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, int] = dict() @@ -272,7 +271,7 @@ def __init__(self, engine_id: str): # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None - + self._tp_size = {self.engine_id: self.world_size} @staticmethod @@ -315,11 +314,12 @@ def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(sock, rank: int)->NixlAgentMetadata: + def handshake(sock, rank: int) -> NixlAgentMetadata: # Send query for the request. sock.send(GET_META_MSG) metadata_bytes = sock.recv() @@ -332,27 +332,26 @@ def handshake(sock, rank: int)->NixlAgentMetadata: setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) + got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + setup_agent_time - got_metadata_time) return metadata - # Handshake with remote agent-rank0 first to get the tp_size of remote + # Handshake with remote agent-rank0 first to get the tp_size of remote path = f"tcp://{host}:{port}" logger.debug("Querying master rank metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: - metadata = handshake(sock, 0) - - # Handshake only with the other TP remote the current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. + metadata = handshake(sock, 0) + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. p_remote_rank = self.rank % metadata.tp_size if p_remote_rank > 0: path = f"tcp://{host}:{port + p_remote_rank}" - logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) with zmq_ctx(zmq.REQ, path) as sock: metadata = handshake(sock, p_remote_rank) - - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -375,7 +374,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_size, n_kv_heads, head_dim = block_shape - # head size in btyes. + # head size in bytes. self.kv_dim = kv_elem_size * n_kv_heads * head_dim # TODO(tms): self.block_len needs to be per-layer for sliding window, @@ -423,8 +422,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, tp_size=self.world_size, - block_len=self.block_len - ) + block_len=self.block_len) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -434,23 +432,30 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0): + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_rank: int = 0): engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery - if engine_id in self._remote_agents and remote_rank in self._remote_agents[engine_id]: + if (engine_id in self._remote_agents and \ + remote_rank in self._remote_agents[engine_id]): return if engine_id in self._tp_size: assert self._tp_size[engine_id] == nixl_agent_meta.tp_size self._tp_size[engine_id] = nixl_agent_meta.tp_size - self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) - - d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] - assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" + self._remote_agents[engine_id][ + remote_rank] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // self._tp_size[engine_id] + assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than" + " prefill TP" # TODO we should also check hidden_dim and kv precision, they must match - remote_block_size = nixl_agent_meta.block_len / (self.kv_dim*d_workers_per_p_worker) + remote_block_size = nixl_agent_meta.block_len / ( + self.kv_dim * d_workers_per_p_worker) assert self.block_size == remote_block_size, "Remote P worker with " "different block size is not supported" @@ -458,26 +463,27 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= if d_workers_per_p_worker not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: - # NOTE With heter-TP, more blocks are prepared than what are - # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We - # could create fewer, but then _get_block_descs_ids needs to - # select agent_meta.num_blocks instead of self.num_blocks for - # local descr, and that makes handling regular flow less clean. + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + block_offset = block_id * self.block_len for b in range(self.block_size): head_offset = b * self.kv_dim addr = base_addr + block_offset + head_offset # (addr, len, device id) blocks_data.append((addr, self.kv_dim, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) + len(blocks_data), self.engine_id, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") # NIXL_INIT_AGENT to be used for preparations of local descs. - self.src_xfer_side_handle[d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) + self.src_xfer_side_handle[ + d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -486,34 +492,36 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int= self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] - # With homogeneous TP, D pulls the whole kv cache from corresponding - # rank. With heterogenous TP, prepare the descriptors by splitting the + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. p_remote_rank = self.rank % nixl_agent_meta.tp_size # Only register the remote's descriptors if current rank pulls from it. if p_remote_rank == remote_rank: - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr rank_offset = self.rank // nixl_agent_meta.tp_size * self.kv_dim # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_len - # TODO assume same block_size. Should enforce it on handshake. for b in range(self.block_size): - # block_offset = block_id * self.block_len # Remote kv_dim=local kv_dim * d_workers_per_p_worker head_offset = b * self.kv_dim * d_workers_per_p_worker addr = base_addr + block_offset + head_offset # (addr, len, device id) blocks_data.append( - (addr + rank_offset, self.kv_dim, remote_rank)) - logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", - len(blocks_data), engine_id, remote_rank, self.rank) + (addr + rank_offset, self.kv_dim, remote_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " \ + "local rank %s", + len(blocks_data), engine_id, remote_rank, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: @@ -667,13 +675,15 @@ def _read_blocks( assert len(local_block_ids) > 0 assert len(local_block_ids) == len(remote_block_ids) - # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from - # corresponding rank. With heterogenous TP, fixing D>P, the D tp + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. # Get side handles. - d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] - local_xfer_side_handle = self.src_xfer_side_handle[d_workers_per_p_worker] + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handle[ + d_workers_per_p_worker] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # Get descs ids. @@ -682,7 +692,7 @@ def _read_blocks( local_block_descs_ids = self._get_block_descs_ids( self.engine_id, local_block_ids) assert len(local_block_descs_ids) == len(remote_block_descs_ids) - + # Prepare transfer with Nixl. handle = self.nixl_wrapper.make_prepped_xfer( "READ", @@ -712,7 +722,7 @@ def _get_block_descs_ids(self, engine_id: str, for reg_id in region_ids: for block_id in block_ids: for kv_block in range(self.block_size): - descs_ids.append(reg_id * num_blocks * self.block_size + + descs_ids.append(reg_id * num_blocks * self.block_size + block_id * self.block_size + kv_block) return descs_ids From d2ce96abc753aac30957de8afdcf5f45447baf9d Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 18:31:02 +0000 Subject: [PATCH 17/19] format Signed-off-by: nicklucche --- vllm/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 4333bedeea2..54eb5e0ef0e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3442,7 +3442,6 @@ class KVTransferConfig(BaseModel): # any extra config that the connector may need kv_connector_extra_config: dict[str, Any] = {} - def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, From ca0e15f6d1e17cb1c60487f694cf5d71b8d9c014 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 18:45:14 +0000 Subject: [PATCH 18/19] type Signed-off-by: nicklucche --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 66b80961a15..39256156cf9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -247,7 +247,7 @@ def __init__(self, engine_id: str): # nixl_prepped_dlist_handle. Different dst TP sizes require preparing # xfer layout differently. - self.src_xfer_side_handle: int = dict() + self.src_xfer_side_handle: dict[int, int] = dict() # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, int] = dict() From 6868c9a71ef6c47c30b74769075a5a2f33d614ee Mon Sep 17 00:00:00 2001 From: nicklucche Date: Mon, 12 May 2025 08:47:26 +0000 Subject: [PATCH 19/19] change remote worker selection indexing; test ptp2-dtp4 Signed-off-by: nicklucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 39256156cf9..49ed590d0f8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -345,7 +345,9 @@ def handshake(sock, rank: int) -> NixlAgentMetadata: # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - p_remote_rank = self.rank % metadata.tp_size + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // metadata.tp_size + p_remote_rank = self.rank // d_workers_per_p_worker if p_remote_rank > 0: path = f"tcp://{host}:{port + p_remote_rank}" logger.debug("Querying metadata on path: %s at remote rank %s", @@ -496,18 +498,18 @@ def add_remote_agent(self, # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - p_remote_rank = self.rank % nixl_agent_meta.tp_size + p_remote_rank = self.rank // d_workers_per_p_worker # Only register the remote's descriptors if current rank pulls from it. if p_remote_rank == remote_rank: self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.rank // nixl_agent_meta.tp_size * self.kv_dim + rank_offset = self.rank % d_workers_per_p_worker * self.kv_dim # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_len for b in range(self.block_size): - # Remote kv_dim=local kv_dim * d_workers_per_p_worker + # Remote kv_dim = local kv_dim * d_workers_per_p_worker head_offset = b * self.kv_dim * d_workers_per_p_worker addr = base_addr + block_offset + head_offset # (addr, len, device id)