diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8a15955bd4f..49ed590d0f8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -52,6 +52,8 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int + tp_size: int + block_len: int @dataclass @@ -223,8 +225,8 @@ def __init__(self, engine_id: str): # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) - # Map of engine_id -> agent_name. - self._remote_agents: dict[str, str] = {} + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) # Metadata. self.engine_id = engine_id @@ -235,20 +237,23 @@ def __init__(self, engine_id: str): # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr - self.kv_caches_base_addr: dict[str, list[int]] = {} + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[str, list[int]] = dict() # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 - # nixl_prepped_dlist_handle (int). - self.src_xfer_side_handle: int = 0 + # nixl_prepped_dlist_handle. Different dst TP sizes require preparing + # xfer layout differently. + self.src_xfer_side_handle: dict[int, int] = dict() # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, int] = {} + self.dst_xfer_side_handles: dict[str, int] = dict() - # Map of engine_id -> num_blocks. - self.dst_num_blocks: dict[str, int] = {} + # Map of engine_id -> num_blocks. Remote TP ranks will have the same + # number of blocks. + self.dst_num_blocks: dict[str, int] = dict() self._registered_descs: list[Any] = [] # In progress transfers. @@ -267,6 +272,8 @@ def __init__(self, engine_id: str): # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None + self._tp_size = {self.engine_id: self.world_size} + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, rank: int): @@ -290,6 +297,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # NOTE(rob): we need each rank to have a unique port. This # hack to keeps us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. + # TODO get rank port util port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank path = f"tcp://{host}:{port}" logger.debug("Starting listening on path: %s", path) @@ -306,12 +314,12 @@ def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - path = f"tcp://{host}:{port + self.rank}" - logger.debug("Querying metadata on path: %s", path) - with zmq_ctx(zmq.REQ, path) as sock: + + def handshake(sock, rank: int) -> NixlAgentMetadata: # Send query for the request. sock.send(GET_META_MSG) metadata_bytes = sock.recv() @@ -320,13 +328,32 @@ def _nixl_handshake(self, host: str, port: int): got_metadata_time = time.perf_counter() # Register Remote agent. - self.add_remote_agent(metadata) + self.add_remote_agent(metadata, rank) setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time) + return metadata + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = f"tcp://{host}:{port}" + logger.debug("Querying master rank metadata on path: %s", path) + with zmq_ctx(zmq.REQ, path) as sock: + metadata = handshake(sock, 0) + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // metadata.tp_size + p_remote_rank = self.rank // d_workers_per_p_worker + if p_remote_rank > 0: + path = f"tcp://{host}:{port + p_remote_rank}" + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) + with zmq_ctx(zmq.REQ, path) as sock: + metadata = handshake(sock, p_remote_rank) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -341,14 +368,20 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] block_shape = first_kv_cache.shape[-block_rank:] + self.block_size, kv_latent_dim = block_shape + self.kv_dim = kv_elem_size * kv_latent_dim else: - # [2 (k and v), num_blocks, ...] + # [2 (k and v), num_blocks, block_size, kv_heads, head_dim] self.num_blocks = first_kv_cache.shape[1] block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] + self.block_size, n_kv_heads, head_dim = block_shape + # head size in bytes. + self.kv_dim = kv_elem_size * n_kv_heads * head_dim # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc + # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, @@ -382,7 +415,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") - self._registered_descs.append(descs) # After KV Caches registered, listen for new connections. @@ -391,7 +423,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - ) + tp_size=self.world_size, + block_len=self.block_len) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -401,49 +434,97 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_rank: int = 0): engine_id = nixl_agent_meta.engine_id - if engine_id in self._remote_agents: + # TODO re-evaluate refreshing for scaling/recovery + if (engine_id in self._remote_agents and \ + remote_rank in self._remote_agents[engine_id]): return - self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + self._tp_size[engine_id] = nixl_agent_meta.tp_size + self._remote_agents[engine_id][ + remote_rank] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // self._tp_size[engine_id] + assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than" + " prefill TP" + + # TODO we should also check hidden_dim and kv precision, they must match + remote_block_size = nixl_agent_meta.block_len / ( + self.kv_dim * d_workers_per_p_worker) + assert self.block_size == remote_block_size, "Remote P worker with " + "different block size is not supported" # Create src descs and xfer side handles. - blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) - - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) - - # Create dst descs and xfer side handles. + if d_workers_per_p_worker not in self.src_xfer_side_handle: + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for b in range(self.block_size): + head_offset = b * self.kv_dim + addr = base_addr + block_offset + head_offset + # (addr, len, device id) + blocks_data.append((addr, self.kv_dim, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle[ + d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + blocks_data = [] - for base_addr in self.kv_caches_base_addr[engine_id]: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, self.rank) - - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id], descs) + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + p_remote_rank = self.rank // d_workers_per_p_worker + # Only register the remote's descriptors if current rank pulls from it. + if p_remote_rank == remote_rank: + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.rank % d_workers_per_p_worker * self.kv_dim + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + for b in range(self.block_size): + # Remote kv_dim = local kv_dim * d_workers_per_p_worker + head_offset = b * self.kv_dim * d_workers_per_p_worker + addr = base_addr + block_offset + head_offset + # (addr, len, device id) + blocks_data.append( + (addr + rank_offset, self.kv_dim, remote_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " \ + "local rank %s", + len(blocks_data), engine_id, remote_rank, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -580,6 +661,7 @@ def _read_blocks( request_id: str, ): # NOTE(rob): this takes ~2s. We need to get this off the hotpath. + # TODO check remote_rank in here too? if dst_engine_id not in self._remote_agents: self._nixl_handshake(remote_host, remote_port) @@ -595,9 +677,15 @@ def _read_blocks( assert len(local_block_ids) > 0 assert len(local_block_ids) == len(remote_block_ids) + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handle[ + d_workers_per_p_worker] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # Get descs ids. @@ -635,7 +723,9 @@ def _get_block_descs_ids(self, engine_id: str, descs_ids: list[int] = [] for reg_id in region_ids: for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) + for kv_block in range(self.block_size): + descs_ids.append(reg_id * num_blocks * self.block_size + + block_id * self.block_size + kv_block) return descs_ids