Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ideas to improve the efficiency #31

Open
mavenlin opened this issue Feb 13, 2025 · 0 comments
Open

Ideas to improve the efficiency #31

mavenlin opened this issue Feb 13, 2025 · 0 comments

Comments

@mavenlin
Copy link
Member

mavenlin commented Feb 13, 2025

Some thoughts for improvements

oat/oat/learners/base.py

Lines 582 to 608 in 4540740

for name, param in model.named_parameters():
count += 1 # empty_cache at last param
# Fire all vllm engines for broadcast
if self.strategy.is_rank_0():
shape = (
param.shape
if self.strategy.args.zero_stage != 3
else param.ds_shape
)
futs = [
actor.futures.update_weight(
name,
dtype=torch_type_codec(param.dtype),
shape=shape,
empty_cache=count == num_params,
)
for actor in self.actors
]
# For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
with deepspeed.zero.GatheredParameters(
[param], enabled=self.strategy.args.zero_stage == 3
):
if self.strategy.is_rank_0():
dist.broadcast(param.data, 0, group=self._model_update_group)
_ = [fut.result() for fut in futs]

  • Instead of the fut.result() for each param, would save the dispatch latency if we call update_weight and broadcast for every single param, and then wait on all the futs. My understanding is that they will be dispatched as a series of nccl calls, and will respect the order they are dispatched.
  • It may be possible to broadcast different params from different learners, so that the communication bandwidth is maximally used. But with some caveats, 1. we may need different communication groups; 2. we need some coordination mechanism to make sure the broadcast / update_weight pairs are in the right order. Is it possible to add all the actors to the deepspeed communication group so that they get parameter updates? But without having them participate in the training.

oat/oat/learners/base.py

Lines 575 to 579 in 4540740

while True:
time.sleep(0.1)
actors_busy = [actor.is_generating() for actor in self.actors]
if not any(actors_busy):
break

  • Can we avoid this polling? An idea is to create a nccl broadcast independent of the vllm update_weight call. The actors receive weight updates in another thread and cache them. Then in the step function of actor, we check this cache and update the weights when they are available. In this way we maximize the communication computation overlap.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant