Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] GPU object store support in Ray Core #51173

Open
richardliaw opened this issue Mar 7, 2025 · 2 comments
Open

[RFC] GPU object store support in Ray Core #51173

richardliaw opened this issue Mar 7, 2025 · 2 comments
Labels
core Issues that should be addressed in Ray Core enhancement Request for new feature and/or capability RFC RFC issues

Comments

@richardliaw
Copy link
Contributor

richardliaw commented Mar 7, 2025

GPU Support in Ray Core

Authors: @stephanie-wang @edoakes

TLDR: We discuss a design for GPU objects (specifically torch.Tensors) in the Ray Core API.

Requirements

The goal of this API proposal is to add support for GPU “objects” and direct GPU-GPU communication in the Ray Core API.

Goals:

  • [P0] Performance overhead: <1ms latency overhead to launch collectives. Overlapping compute and communication.
    • Note that this means that this API will be primarily suitable for cases such as inter-node KV-cache transfer or weight-syncing, where the size of the data transfer can amortize the system overhead. It is not suitable for latency-sensitive cases such as tensor-parallel inference.
  • Features: Extend current Ray API with
    • [P0] P2p GPU communication through NCCL and/or CUDA IPC
    • [P1] Collective GPU communication through NCCL
    • [P1] CPU-GPU data movement, via pinned memory
  • [P0] Basic memory management:
    • Garbage collection
    • No caching or spilling (for now). GPU memory will be GCed as early as possible and a user error will be thrown if buffered memory exceeds a configurable threshold
  • [P0] Interoperability
    • User can continue to use SPMD-style collectives in actor definitions
    • Works with all other Ray Core APIs
    • Eventually, other methods of transport such as RDMA or non-NVIDIA GPU
    • Works with PyTorch APIs such as torch.compile

(Current) limitations:

  • Actors only - in the future, we could additionally support task-actor and/or task-task communication
  • GPU data limited to torch.Tensors
  • Only the process that creates the actors can specify GPU object transfers between the actors. If the actor handle is passed to another worker, the new caller will not be able to call methods that return or take in a GPU object reference, i.e. GPUObjectRefs cannot be borrowed.
    • Also implies that all actors that use GPU-GPU communication need to be created by the same process.
    • The process that creates the actors could be a task, but for simplicity, we’ll call it the “driver” for the rest of this doc.
    • This simplifies failure handling: if the driver crashes unexpectedly, its child actors will also exit, so we don’t need to worry about dangling references.
  • User needs to be aware of GPU execution to write correct and efficient code. For example:
    • Add type hints to all Ray tasks that return GPU tensors.
    • If torch.Tensors returned by a Ray task are modified by the user after the task returns without proper stream synchronization, the behavior is undefined.
    • Deadlock prevention is not guaranteed if user-defined collectives create a cycle or if the program invokes user-defined collectives through a different actor handle

Background

This doc is motivated by recent evidence that Compiled Graphs may have limited applicability to current applications. Its main use cases are online/offline inference (which is currently bottlenecked on vLLM development) and distributed training (which will take a while to develop). Meanwhile, we have other applications such as RLHF that we would like to support. These applications can use the Compiled Graphs API, but it requires significant developer effort and they are structured in such a way that the added performance gain is negligible. See Ray Compiled Graphs Q1 2025 update for more information.

Therefore, our goal is to introduce an “interpreted” version of the Ray Compiled Graphs API that enables direct GPU-GPU movement of torch.Tensors between Ray actors. This has been a common user request for almost as long as Ray has been around. This is to support the single-controller dataflow model for orchestrating GPU devices, in contrast to the current options with Ray:

  • Orchestrating p2p and collectives inside actor definitions, using NCCL, torch.distributed or ray.util.collective
  • Fully adopting the new Ray Compiled Graphs API - this option can provide better performance but restricts the user to static dataflow graphs and currently does not play well with non-Compiled Graphs Ray code

Ultimately, the goal is for this API to be consistent with the existing Ray Compiled Graphs API. Ideally, it should be simple to change between these APIs.

Other useful docs:

APIs

See [PUBLIC] RFC: GPU objects in Ray Core API.

Proposed Design

The actor’s creator will be responsible for coordinating the transfers between actors. For simplicity, we will call this creator process the “driver”, although it may not be the driver of the overall Ray job. The driver will order all transfers between actors to ensure that collective operations are scheduled on actors in a consistent order, to avoid deadlock.

Each actor will locally store the tensors that they are sending/receiving in Python. We will extend each Ray actor with the following Python state:

  • communicators: Dict[CommID, Communicator]: A map of (NCCL) communicators that the actor is a participant in
  • tensor_store: Dict[Tuple[ObjectID, tensor_index], Tuple[torch.Tensor, int]]: A map from ObjectRef to the torch.Tensor and its current reference count. Tensor_index is used for objects that may contain multiple torch.Tensors
    • The reference count should be > 0 IFF (the driver still has the corresponding GPUObjectRef in scope OR there is a pending communication op that uses the tensor OR there is a pending task that takes this tensor as an argument)

Collective group initialization and destruction

Collective group initialization and destruction is accomplished by having the driver send a creation/destruction task to each actor. For example, if the user code looks like this:

# Setup.
A, B, C = [Actor.options(num_gpus=1).remote(i) for i in range(3)]
# Each actor is assigned rank according to its order in the list.
group : NcclGroup = ray.util.collectives.init_group([A, B, C])
# Wait until group is ready, same as in placement groups.
ray.get(group.ready())

Then, during init_group, the driver will launch a pre-defined task to each actor that:

  1. Creates a NCCL communicator, using ray.util.collective
  2. Stores the handle in self.communicators

Example: GPU-GPU communication via NCCL

Suppose we have example code like this that sends a torch.Tensor from actor A to actor B:

@ray.remote(num_gpus=1)
class Actor:
  @ray.method(
    tensor_transport=auto”,
    tensor_shape=torch.Size([N]),
  )
  def foo():
    return torch.randn(N, device="cuda")
  def bar(t: torch.Tensor):
    ...


A, B = Actor.remote(), Actor.remote()
group : NcclGroup = ray.util.collectives.init_group([A, B])
x : GPUObjectRef = A.foo.remote()
y = B.bar.remote(x)
del x
ray.get(y)

In this case, the steps on the driver are:

  1. A.foo.remote():
    1. Driver sends ExecuteTask RPC to A to dispatch A.foo.remote() task.
  2. B.bar.remote().
    1. Driver sends BeginSend RPC to A to begin sending the tensor with ID (x.id, 0) to B.
    2. Driver sends BeginRecv RPC to B to begin receiving a tensor of size N, and to store the result in B.tensor_store[(x.id, 0)]
    3. Driver sends ExecuteTask RPC to B to dispatch B.bar.remote() task. Note that due to Ray’s task execution order, this will get ordered after B’s receive task.
  3. Del x:
    1. Driver sends DecrementRefCount RPC to A to decrement the tensor’s ref count.
  4. ray.get(y)
    1. The usual Ray Core protocol.

On actor A:

  1. A receives A.foo.remote() ExecuteTask:
    1. A executes foo()
    2. A serializes foo()’s return value, and extracts any GPU tensors. The GPU tensors are replaced with a tensor placeholder.
    3. A asserts that the tensor has size N.
    4. A stores the tensor in self.tensor_store, with initial ref count=1, for the driver’s reference.
  2. A receives BeginSend RPC:
    1. A begins sending the tensor to actor B, using the correct NCCL communicator and B’s rank.
    2. A increments the tensor’s ref count, to indicate that there is a pending send.
  3. A receives DecrementRefCount RPC
    1. A decrements the tensor’s ref count. If the ref count == 0, delete.
  4. Upon completion of the send to B:
    1. A decrements the tensor’s ref count. If the ref count == 0, delete.

On actor B:

  1. B receives BeginRecv RPC:
    1. B begins receiving the tensor from actor A, using the correct NCCL communicator and A’s rank.
    2. B initializes the tensor ref count to 1, indicating that there is a pending task that requires this tensor as an argument.
  2. B receives ExecuteTask RPC:
    1. NOTE: At this point, the tensor should already have been received.
    2. B deserializes the task’s arguments, replacing any tensor placeholders with the tensor from self.tensor_store.
    3. Decrement ref count for any found tensor placeholders. If the ref count == 0, delete.

The flow of messages looks something like this. Steps that have the same number can proceed concurrently:

Image

The protocol for collective communication is similar. The only difference is that the driver must dispatch to all actors in the group, and we would use a BeginCollective RPC instead of BeginSend/BeginRecv.

WARNING: Ensuring data consistency

One caveat of this approach is that the user may still have a pointer to the tensor while it’s in the tensor_store and pending transfers or collectives to other nodes. This can lead to data inconsistency if the user modifies the tensor while or before it is sent to other actors.

Detecting whether the user has a pointer is also hard to detect. Tracking Python references is not sufficient because different torch.Tensors could share the same physical data, etc.

Therefore, the user needs to be careful when sending tensors. Ideally, we should expose an API to allow the user to synchronize with any ongoing sends/collectives, so that they know when it’s safe to write the data. This kind of synchronization would only be possible for actors with concurrency enabled, because otherwise synchronization could hang the actor.

One possibility is to provide a future-like syntax, keyed by torch.Tensor. For example:

@ray.remote(num_gpus=1)
class Actor:
  @ray.method(tensor_transport=group)
  def foo(self):
    self.tensor = torch.randn(N, device="cuda")
    return self.tensor

  def modify_tensor(self):
    # Wait until any ongoing communication ops involving self.tensor have finished.
    # self._tensor_store = {...: copy(self.tensor)}
    ray.wait(self.tensor)
    self.tensor += 1

This program could hang if the GPUObjectRef corresponding to `self.tensor` never goes out of scope at the driver. One way to fix this is to allow copy-on-write: copy self.tensor back into the actor’s tensor storage after a timeout, allowing the user to use the original copy.

WARNING: Deadlock prevention

TODO

Dynamic tensor shapes

If the tensor shape is not known, then the driver needs to wait until A has finished and extracted all GPU tensors before sending to B. This looks something like this:

Image

If there are multiple tensors in the value, the user can specify them using a “key” into the value. For example, if the returned value is a TensorDict, then the user would use the key strings to distinguish different tensors. Also, the tensor shape(s) can be specified on-the-fly, per method invocation instead of per method definition. For example, the following code specifies the shapes of two different tensors that are nested inside one Python object:

x : GPUObjectRef = A.foo.options(tensor_shape={
  “layer1”: (N, M),
  “layer2”: (M, O),
}).remote()

Memory management

The protocol must hook into Ray Core’s reference counting protocol (C++). In particular, if the driver’s GPUObjectRef goes out of scope, then we should send DecrementRefCount RPCs to the actor(s) that stored the original copy of this object. We can find these actors by storing weak refs to these actors’ handles inside the GPUObjectRef.

We should support configuration of each actor’s maximum allowed GPU memory for its self.tensor_store. If the actor tries to place a GPU object in its store and it would exceed the store’s capacity, then the actor should throw an OutOfMemoryError. This error should get propagated to all downstream tasks.

In the future, we can consider more advanced memory management such as:

  • Waiting for current memory to be consumed
  • Offloading to CPU memory and/or disk

The same tensor may be passed as a task argument multiple times to the same actor. If the tensor must be received from a different actor, then we have two options:

  1. Driver asks receiving actor if it still has the copy, then decides whether it needs to trigger another BeginSend/Recv. This requires the driver to remember all actors that may have a copy of a tensor, not just the one that originated the copy.
  2. Driver always triggers another BeginSend/Recv.

We will favor the second option initially since this is simpler, but less efficient if significant data needs to be transferred.

Overlapping compute and communication

This is a critical performance feature in most distributed GPU applications. To support this, we can use a similar design as Ray Compiled Graphs: [PUBLIC] Design: Overlapping GPU communication in aDAG. The main difference would be that we cannot rearrange the order of tasks before execution; instead the driver will guarantee a consistent global order by submitting operations one at a time.

To avoid blocking on the CPU, we may need to use Python or potentially C++ multithreading to handle the BeginSend/BeginRecv/BeginCollective RPCs. Also, we may need to rate-limit the pending communication ops to avoid memory buildup.

Other communication transports

Intra-actor: Skipping serialization/communication

If a GPUObjectRef is passed back to a task on the same actor that created the data, then we can avoid serialization. This optimization is already done in Ray Compiled Graphs but has not been possible in Ray Core because we always serialize the data into the object store.

@ray.remote(num_gpus=1)
class Actor:
  @ray.method(tensor_shape=torch.Size([N]))
  def foo():
    return torch.randn(N, device="cuda")

A = Actor.remote()
x : GPUObjectRef = A.foo.remote()
# One option is to avoid serializing x and pass it directly to y.
y : ObjectRef = A.bar.remote(x)

Intra-node: CUDA memcpy and IPC

TODO

CPU-GPU

TODO

Driver-specific communication

  • Can/should the driver have access to GPUs?
  • ray.put(), ray.get() results from actor

Dynamic/autoscaling actor groups: RDMA / NVSHMEM / etc

NCCL only supports static actor groups. If the membership of the group needs to be changed, e.g., for autoscaling or upon failure, then the NCCL group needs to be recreated. NCCL group creation is also quite slow. This is a known issue for NCCL users in autoscaling environments.

Initially, we plan to use NCCL because it is the industry standard for NVIDIA GPU communication. However, in the future, we can consider adding support for dynamic actor groups. This includes two different features:

  1. Actors could be dynamically added or removed from a NCCL group. Actor failures could be handled smoothly by removing that actor from any NCCL groups it participated in.
  2. Peer-to-peer communications between actors would not require specifying a NCCL group beforehand.

A simple version of Feature 1 is to simply re-create a NCCL group upon actor addition or deletion. If it happens relatively infrequently, the performance overhead is okay.

Feature 2 is more challenging. NCCL group (re)creation is more likely to be a bottleneck when there is an elastic group of actors and many possible actor-actor pairs. Options:

  1. A high-level library like UCX. This requires benchmarking to determine overheads.
  2. Build our own transport over low-level primitives like RDMA or NVSHMEM. This will bring up some new complexity around how to:
    1. Set up the connection. Typically this will require some kind of initialization call on both actors to map a shared region of memory.
    2. Schedule the transfer. These are lower-level APIs compared to NCCL and we would likely want to perform chunking ourselves.
    3. Tear down the connection. We may want to cache the connection, but also need to be aware of possible failures, out-of-memory conditions, etc.

Deadlock prevention

TODO

Implementation

Ray Compiled Graphs

Our top priority is to support vLLM. The secondary priority is to support distributed training, the development of which is primarily happening at UW.

To support these applications, we must harden the following features, some of which are shared with this current design proposal:

  • Inter-node communication performance
  • [shared] Compute/communication overlap
  • [shared] Collectives support
  • Usability - remove execution timeouts

Specifically, this also means that we will de-prioritize the following features that were originally planned. This timeline also matches the current proposal better, in that it gives us more time to develop the current proposal before deciding how the two APIs can be flexibly used together.

  • DAG flexibility
    • DAG concurrency - ability to execute multiple DAGs on the same actor
    • Ability to interoperate with normal Ray Core API - allow passing inputs and outputs to non-compiled tasks through Ray object store
    • Ability to run any compiled DAG code in non-compiled mode. Useful for development and debugging.

Project Timeline

Target applications:

  • veRL
    • Data transfer
    • weight syncing
  • Riot SEED RL algorithm (?)
  • Ray Data (?)
    • GPU-GPU actor communication
    • Possibly, CPU-GPU

Initial prototype:

  • P2p example works. Includes:
    • Creation of one collective group
    • GC of torch.Tensors
    • Tensor shape is known
  • + Tensor shape is unknown
  • + Object contains multiple tensors
  • + What level of actor-driver support do we need?
  • veRL prototype

Checkpoint: veRL prototype complete, Ray Data GPU-GPU prototype possible?

Remaining features:

  • Correctness:
    • Ability to synchronize torch.Tensors
  • Features
    • GPUObjectRef is sent to multiple different actor tasks
    • Collectives API supported
    • CPU-GPU support
  • Performance
    • Intra-actor: Skipping serialization/communication
    • Intra-node: CUDA IPC
    • Overlapping compute and communication
@richardliaw richardliaw added core Issues that should be addressed in Ray Core RFC RFC issues enhancement Request for new feature and/or capability labels Mar 7, 2025
@richardliaw richardliaw pinned this issue Mar 7, 2025
@Catch-Bull
Copy link
Contributor

In Example: GPU-GPU communication via NCCL, does the tensor represented by GPUObjectRef x have different IDs in Actor A and Actor B? If they are the same, how is the consistency of the two tensors with the same ID in A and B ensured?

@Catch-Bull
Copy link
Contributor

Catch-Bull commented Mar 10, 2025

In Example: GPU-GPU communication via NCCL, the Python statement del x triggers the rpc DecrementRefCount to decrement the reference count. How should we handle the case where the single-controller is a detached actor? Should actor A decrement the reference count when it restarts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Issues that should be addressed in Ray Core enhancement Request for new feature and/or capability RFC RFC issues
Projects
None yet
Development

No branches or pull requests

2 participants