diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 52eb463e15..362b9b2382 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -515,6 +515,18 @@ def remote_log(self, msg: str): ray.get(self.workers[0].remote_log_end.remote(handle)) def _sort_workers(self, driver_ip: str, workers: List[RayWorkerWrapper]): + """Sort workers.""" + # External bundle handling is only applicable when lmdeploy does NOT own + # the placement group. If lmdeploy owns the PG, we should continue to + # sort workers even if external bundle indices are specified. + if (not _envs.ray_external_pg_bundles) or self.ray_ctx.owned_pg: + return self._sort_workers_by_driver_then_worker_ip(driver_ip, workers) + else: + # do not sort when external bundle indices are specified and the + # placement group is externally managed + return workers + + def _sort_workers_by_driver_then_worker_ip(self, driver_ip: str, workers: List[RayWorkerWrapper]): """Sort workers by ip.""" worker_ips = ray.get([worker.get_node_ip.remote() for worker in workers]) @@ -566,21 +578,33 @@ def get_priority(ip): sorted_workers = [item[0] for item in sorted_worker_ip_map] return sorted_workers - def _valid_bundle_id(self, bundle_id: int): - """Check if a bundle is valid only when self.use_external_ray=True.""" - if (not self.ray_ctx.owned_pg and _envs.ray_external_pg_bundles - and bundle_id not in _envs.ray_external_pg_bundles): - return False - return True - def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict): """Init worker ray.""" device_str = get_device_str() bundle_indices = [] - for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if bundle.get(device_str, 0) and self._valid_bundle_id(bundle_id): - bundle_indices.append(bundle_id) + if not _envs.ray_external_pg_bundles: + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if bundle.get(device_str, 0): + bundle_indices.append(bundle_id) + else: + # use external specified bundle indices,keep the order as well + bundle_indices = _envs.ray_external_pg_bundles.copy() + # validate external bundle indices + num_bundles = len(placement_group.bundle_specs) + for bundle_id in bundle_indices: + if bundle_id < 0 or bundle_id >= num_bundles: + raise ValueError(f'External bundle index {bundle_id} is out of range. ' + f'Placement group has {num_bundles} bundles (valid indices: 0-{num_bundles - 1}).') + bundle = placement_group.bundle_specs[bundle_id] + if not bundle.get(device_str, 0): + raise ValueError( + f'External bundle index {bundle_id} does not have required resource: {device_str}. ' + f'Available resources in this bundle: {dict(bundle)}') attn_tp = self.dist_config.attn_tp + if len(bundle_indices) < attn_tp: + raise ValueError(f'Not enough bundle indices for attention tensor parallelism. ' + f'Required: {attn_tp}, Provided: {len(bundle_indices)} ' + f'(bundle_indices: {bundle_indices}).') bundle_indices = bundle_indices[:attn_tp] workers = list() diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f9c2919962..fffb0dacdf 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -384,6 +384,10 @@ def __init__( logger.warning(f'Failed to create GuidedManager for tokenizer {type(self.tokenizer)}: {e}') self.guided_decoding_manager = None + # update_params_ipc_buffer + self._update_params_ipc_tensor: torch.Tensor | None = None + self._update_params_ipc_event: torch.cuda.Event | None = None + # microbatch self.enable_microbatch = self.dist_config.enable_microbatch self.enable_microbatch_prefill_batchsize_threshold = \ @@ -1159,14 +1163,19 @@ def update_params(self, request: UpdateParamsRequest): """Update params.""" # modified from https://github.com/vllm-project/vllm/blob/v0.8.5/examples/offline_inference/rlhf_utils.py#L82 - def _construct(item): + def _construct(item, require_clone: bool = True): func, args = item args = list(args) args[6] = torch.cuda.current_device() # device id. - # clone() seems necessary otherwise the producer can not release the memory - return func(*args).clone() + ipc_tensor = func(*args) + return ipc_tensor.clone() if require_clone else ipc_tensor with self.all_context(): + # After deserialization, weights is a dict with following keys: + # - metadata: List[FlattenedTensorMetadata] + # - flattened_tensor: the flattened tensor for weights, optional + # - event_ipc_handle: the ipc handle of the event + # that used to sync stream across processes, optional serialized_data = request.serialized_named_tensors if isinstance(serialized_data, list): serialized_data = serialized_data[self.dist_ctx.tp_group.rank] @@ -1175,7 +1184,26 @@ def _construct(item): if request.load_format == 'flattened_bucket': metadata: List[FlattenedTensorMetadata] = weights['metadata'] if metadata: - flattened_tensor: torch.Tensor = _construct(weights['flattened_tensor']) + if 'flattened_tensor' in weights: + # Determine if clone is required + require_clone = weights.get('require_clone', True) + if 'event_ipc_handle' in weights and not hasattr(torch.cuda.Event, 'from_ipc_handle'): + # Force clone when IPC event is provided but cannot be used + require_clone = True + self._update_params_ipc_tensor = _construct(weights['flattened_tensor'], + require_clone=require_clone) + elif self._update_params_ipc_tensor is None: + raise ValueError( + 'flattened_tensor is not provided in weights and no cached ipc tensor is available. ' + 'Please provide flattened_tensor on the first update_params call.') + if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'): + self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle( + device=torch.cuda.current_device(), + handle=weights['event_ipc_handle'], + ) + flattened_tensor: torch.Tensor = self._update_params_ipc_tensor + if self._update_params_ipc_event is not None: + self._update_params_ipc_event.wait() bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) weights = bucket.reconstruct_tensors() else: @@ -1186,12 +1214,17 @@ def _construct(item): weights = ModelWeightLoader._rename_weights_iterator(weights, model) model.load_weights(weights) + if self._update_params_ipc_event is not None: + self._update_params_ipc_event.record() if request.finished: for _, mod in model.named_modules(): if not hasattr(mod, 'update_weights'): continue mod.update_weights() + torch.cuda.synchronize() + self._update_params_ipc_event = None + self._update_params_ipc_tensor = None torch.cuda.empty_cache() @@ -1206,6 +1239,9 @@ async def sleep(self, level: int = 1): device = 'cpu' if level == 1 else 'meta' self.patched_model.get_model().to(device=device, non_blocking=True) torch.cuda.synchronize() + # force clean _update_params_ipc tensor and event after all gpu jobs done + self._update_params_ipc_tensor = None + self._update_params_ipc_event = None torch.cuda.empty_cache() self.state.to_sleep.clear() diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 779289627d..73b1dcb3cd 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -462,9 +462,9 @@ def serialize_state_dict(state_dict: dict) -> str: from torch.multiprocessing.reductions import reduce_tensor # flattened_tensor - if 'metadata' in state_dict and 'flattened_tensor' in state_dict: + if 'metadata' in state_dict: data = state_dict - if isinstance(data['flattened_tensor'], torch.Tensor): + if 'flattened_tensor' in data and isinstance(data['flattened_tensor'], torch.Tensor): data['flattened_tensor'] = reduce_tensor(state_dict['flattened_tensor']) else: data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] @@ -519,6 +519,7 @@ def __init__( num_tensors = len(named_tensors) self.metadata = [None] * num_tensors self.flattened_tensor = [None] * num_tensors + flattened_tensor_list = [None] * num_tensors if num_tensors > 0: if num_tensors > 1: dtypes = [t.dtype for _, t in named_tensors] @@ -527,7 +528,7 @@ def __init__( current_idx = 0 for idx, (name, tensor) in enumerate(named_tensors): - self.flattened_tensor[idx] = tensor.flatten() + flattened_tensor_list[idx] = tensor.flatten() numel = tensor.numel() self.metadata[idx] = FlattenedTensorMetadata(name=name, shape=tensor.shape, @@ -536,8 +537,32 @@ def __init__( end_idx=current_idx + numel, numel=numel) current_idx += numel - - self.flattened_tensor = torch.cat(self.flattened_tensor, dim=0) + if flattened_tensor is None: + self.flattened_tensor = torch.cat(flattened_tensor_list, dim=0) + else: + # Validate user-provided preallocated buffer + if flattened_tensor.dim() != 1: + raise ValueError( + f'flattened_tensor must be a 1-D tensor, but got shape {tuple(flattened_tensor.shape)}') + if flattened_tensor.numel() < current_idx: + raise ValueError('Provided flattened tensor numel is smaller than required numel: ' + f'{flattened_tensor.numel()} < {current_idx}') + # Validate dtype and device compatibility with source tensors + reference_tensor = named_tensors[0][1] + if flattened_tensor.dtype != reference_tensor.dtype: + raise ValueError(f'flattened_tensor dtype {flattened_tensor.dtype} does not match source ' + f'tensors dtype {reference_tensor.dtype}') + if flattened_tensor.device != reference_tensor.device: + raise ValueError(f'flattened_tensor device {flattened_tensor.device} does not match source ' + f'tensors device {reference_tensor.device}') + if not flattened_tensor.is_contiguous(): + raise ValueError('flattened_tensor must be contiguous') + total_numel = sum(t.numel() for t in flattened_tensor_list) + if total_numel != current_idx: + raise ValueError('Mismatch between computed and expected flattened size: ' + f'{total_numel} != {current_idx}') + torch.cat(flattened_tensor_list, dim=0, out=flattened_tensor[:current_idx]) + self.flattened_tensor = flattened_tensor else: if flattened_tensor is None or metadata is None: raise ValueError('Must provide either named_tensors or both flattened_tensor and metadata') diff --git a/tests/test_lmdeploy/test_utils.py b/tests/test_lmdeploy/test_utils.py index e95cf2b4f8..060c6a5ea5 100644 --- a/tests/test_lmdeploy/test_utils.py +++ b/tests/test_lmdeploy/test_utils.py @@ -1,6 +1,78 @@ +import torch from transformers import AutoConfig -from lmdeploy.utils import _get_and_verify_max_len +from lmdeploy.utils import FlattenedTensorBucket, _get_and_verify_max_len + + +def test_flattened_tensor_bucket_preallocated_buffer(): + """Test FlattenedTensorBucket with preallocated buffer.""" + if not torch.cuda.is_available(): + print('CUDA not available, skipping test') + return + + # Create test tensors on CUDA + tensor1 = torch.randn(10, 10, dtype=torch.float32, device='cuda') # 100 elements + tensor2 = torch.randn(5, 20, dtype=torch.float32, device='cuda') # 100 elements + named_tensors = [('tensor1', tensor1), ('tensor2', tensor2)] + + # Test 1: Without preallocated buffer (original behavior) + bucket1 = FlattenedTensorBucket(named_tensors=named_tensors) + reconstructed = bucket1.reconstruct_tensors() + assert len(reconstructed) == 2 + # reconstruct_tensors returns List[Tuple[str, torch.Tensor]] + reconstructed_dict = dict(reconstructed) + assert torch.allclose(reconstructed_dict['tensor1'], tensor1) + assert torch.allclose(reconstructed_dict['tensor2'], tensor2) + + # Test 2: With valid preallocated buffer + preallocated = torch.empty(200, dtype=torch.float32, device='cuda') + bucket2 = FlattenedTensorBucket(named_tensors=named_tensors, flattened_tensor=preallocated) + assert bucket2.flattened_tensor is preallocated # Should use the same tensor + + # Test 3: With preallocated buffer larger than needed + preallocated_large = torch.empty(500, dtype=torch.float32, device='cuda') + bucket3 = FlattenedTensorBucket(named_tensors=named_tensors, flattened_tensor=preallocated_large) + assert bucket3.flattened_tensor is preallocated_large + + # Test 4: Error case - buffer too small + preallocated_small = torch.empty(50, dtype=torch.float32, device='cuda') # Only 50 elements, need 200 + try: + FlattenedTensorBucket(named_tensors=named_tensors, flattened_tensor=preallocated_small) + assert False, 'Should have raised ValueError' + except ValueError as e: + assert 'smaller than required numel' in str(e) + + # Test 5: Error case - wrong dtype + preallocated_wrong_dtype = torch.empty(200, dtype=torch.float64, device='cuda') + try: + FlattenedTensorBucket(named_tensors=named_tensors, flattened_tensor=preallocated_wrong_dtype) + assert False, 'Should have raised ValueError' + except ValueError as e: + assert 'dtype' in str(e) + + # Test 6: Error case - wrong device (CPU buffer for CUDA tensors) + preallocated_cpu = torch.empty(200, dtype=torch.float32, device='cpu') + try: + FlattenedTensorBucket(named_tensors=named_tensors, flattened_tensor=preallocated_cpu) + assert False, 'Should have raised ValueError' + except ValueError as e: + assert 'device' in str(e) + + # Test 7: Error case - non-contiguous tensor + preallocated_non_contig = torch.empty(400, dtype=torch.float32, device='cuda')[::2] # Strided view + try: + FlattenedTensorBucket(named_tensors=named_tensors, flattened_tensor=preallocated_non_contig) + assert False, 'Should have raised ValueError' + except ValueError as e: + assert 'contiguous' in str(e) + + # Test 8: Error case - not 1-D tensor + preallocated_2d = torch.empty(10, 20, dtype=torch.float32, device='cuda') + try: + FlattenedTensorBucket(named_tensors=named_tensors, flattened_tensor=preallocated_2d) + assert False, 'Should have raised ValueError' + except ValueError as e: + assert '1-D tensor' in str(e) def test_get_and_verify_max_len():