-
Notifications
You must be signed in to change notification settings - Fork 664
[Fix][Feat] Fix worker sorting with external pg bundles & Support persistent buffer for update_params #4397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
25793f2
f8c505a
6bfd136
d53dcce
af3c2e1
2cadf8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -384,6 +384,10 @@ def __init__( | |
| logger.warning(f'Failed to create GuidedManager for tokenizer {type(self.tokenizer)}: {e}') | ||
| self.guided_decoding_manager = None | ||
|
|
||
| # update_params_ipc_buffer | ||
| self._update_params_ipc_tensor: torch.Tensor | None = None | ||
| self._update_params_ipc_event: torch.cuda.Event | None = None | ||
|
|
||
| # microbatch | ||
| self.enable_microbatch = self.dist_config.enable_microbatch | ||
| self.enable_microbatch_prefill_batchsize_threshold = \ | ||
|
|
@@ -1159,14 +1163,19 @@ def update_params(self, request: UpdateParamsRequest): | |
| """Update params.""" | ||
|
|
||
| # modified from https://github.com/vllm-project/vllm/blob/v0.8.5/examples/offline_inference/rlhf_utils.py#L82 | ||
| def _construct(item): | ||
| def _construct(item, require_clone: bool = True): | ||
| func, args = item | ||
| args = list(args) | ||
| args[6] = torch.cuda.current_device() # device id. | ||
| # clone() seems necessary otherwise the producer can not release the memory | ||
| return func(*args).clone() | ||
| ipc_tensor = func(*args) | ||
| return ipc_tensor.clone() if require_clone else ipc_tensor | ||
|
|
||
| with self.all_context(): | ||
| # After deserialization, weights is a dict with following keys: | ||
| # - metadata: List[FlattenedTensorMetadata] | ||
| # - flattened_tensor: the flattened tensor for weights, optional | ||
| # - event_ipc_handle: the ipc handle of the event | ||
| # that used to sync stream across processes, optional | ||
| serialized_data = request.serialized_named_tensors | ||
| if isinstance(serialized_data, list): | ||
| serialized_data = serialized_data[self.dist_ctx.tp_group.rank] | ||
|
|
@@ -1175,7 +1184,26 @@ def _construct(item): | |
| if request.load_format == 'flattened_bucket': | ||
| metadata: List[FlattenedTensorMetadata] = weights['metadata'] | ||
| if metadata: | ||
| flattened_tensor: torch.Tensor = _construct(weights['flattened_tensor']) | ||
| if 'flattened_tensor' in weights: | ||
| # Determine if clone is required | ||
| require_clone = weights.get('require_clone', True) | ||
| if 'event_ipc_handle' in weights and not hasattr(torch.cuda.Event, 'from_ipc_handle'): | ||
| # Force clone when IPC event is provided but cannot be used | ||
| require_clone = True | ||
| self._update_params_ipc_tensor = _construct(weights['flattened_tensor'], | ||
| require_clone=require_clone) | ||
| elif self._update_params_ipc_tensor is None: | ||
| raise ValueError( | ||
| 'flattened_tensor is not provided in weights and no cached ipc tensor is available. ' | ||
| 'Please provide flattened_tensor on the first update_params call.') | ||
| if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'): | ||
| self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle( | ||
| device=torch.cuda.current_device(), | ||
| handle=weights['event_ipc_handle'], | ||
| ) | ||
| flattened_tensor: torch.Tensor = self._update_params_ipc_tensor | ||
| if self._update_params_ipc_event is not None: | ||
| self._update_params_ipc_event.wait() | ||
|
Comment on lines
+1199
to
+1206
|
||
| bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) | ||
| weights = bucket.reconstruct_tensors() | ||
| else: | ||
|
|
@@ -1186,12 +1214,17 @@ def _construct(item): | |
|
|
||
| weights = ModelWeightLoader._rename_weights_iterator(weights, model) | ||
| model.load_weights(weights) | ||
| if self._update_params_ipc_event is not None: | ||
| self._update_params_ipc_event.record() | ||
|
|
||
| if request.finished: | ||
| for _, mod in model.named_modules(): | ||
| if not hasattr(mod, 'update_weights'): | ||
| continue | ||
| mod.update_weights() | ||
| torch.cuda.synchronize() | ||
| self._update_params_ipc_event = None | ||
| self._update_params_ipc_tensor = None | ||
|
|
||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
@@ -1206,6 +1239,9 @@ async def sleep(self, level: int = 1): | |
| device = 'cpu' if level == 1 else 'meta' | ||
| self.patched_model.get_model().to(device=device, non_blocking=True) | ||
| torch.cuda.synchronize() | ||
| # force clean _update_params_ipc tensor and event after all gpu jobs done | ||
| self._update_params_ipc_tensor = None | ||
| self._update_params_ipc_event = None | ||
| torch.cuda.empty_cache() | ||
| self.state.to_sleep.clear() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description mentions an env var
RAY_EXTERNAL_PG_BUNDLES, but the code readsLMDEPLOY_RAY_EXTERNAL_PG_BUNDLESvia_envs.ray_external_pg_bundles. Please align the PR description / documentation (or add an alias) to avoid users setting the wrong variable and not getting the intended behavior.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot I update PR description, please view this again.