Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt custom allreduce for tensorrt llm #2511

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.6"]
"xgrammar>=0.1.6", "sgl-kernel"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"]

# HIP (Heterogeneous-computing Interface for Portability) for AMD
Expand Down
47 changes: 12 additions & 35 deletions python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import contextlib
import functools
import importlib
Expand All @@ -14,7 +14,7 @@

if not is_hpu():
try:
import custom_ar
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)

Expand Down Expand Up @@ -50,46 +50,23 @@ def wrapper(*args, **kwargs):

# custom ar
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
rank_id: int,
world_size: int,
buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return torch.ops._C_vllm_ar.init_custom_ar(
ipc_tensors, rank_data, rank, full_nvlink
return sgl_kernel.ops.init_custom_reduce(
rank_id, world_size, buffers, barrier_in, barrier_out
)


def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)


def dispose(fa: int) -> None:
torch.ops._C_vllm_ar.dispose(fa)


def meta_size() -> int:
return torch.ops._C_vllm_ar.meta_size()


def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
sgl_kernel.ops.custom_dispose(fa)


# temporary fix for https://github.com/vllm-project/vllm/issues/5456
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from sglang.srt.utils import cuda_device_count_stateless, is_cuda

try:
ops.meta_size()
import sgl_kernel

custom_ar = True
except Exception:
# For AMD GPUs and CPUs
custom_ar = False

logger = logging.getLogger(__name__)


_P = ParamSpec("_P")
_R = TypeVar("_R")

Expand All @@ -47,7 +47,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:


@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
Expand Down Expand Up @@ -196,32 +196,33 @@ def __init__(
)
return

self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
ops.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = full_nvlink

# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1000 * 1000, 8 * 1000 * 1000]

# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (24 + 2) * 8

self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)

self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
rank,
world_size,
self.buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
self.disabled = False

@staticmethod
def create_shared_buffer(
Expand Down Expand Up @@ -258,36 +259,11 @@ def free_shared_buffer(

@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()

def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)

def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
Expand All @@ -300,28 +276,19 @@ def should_custom_ar(self, inp: torch.Tensor):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False

def all_reduce(
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
):
"""Performs an out-of-place all reduce.
if self.world_size == 2:
return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[0]
)

If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
if self.full_nvlink:
return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[1]
)
return out

return False

def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
Expand All @@ -330,23 +297,25 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
output = torch.empty_like(input)
ops.all_reduce(self._ptr, input, output)
return output
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
output = torch.empty_like(input)
ops.all_reduce(self._ptr, input, output)
return output

def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs)
self.free_shared_buffer(self.barrier_out_ptrs)
self._ptr = 0

def __del__(self):
self.close()
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"sampling/penaltylib",
"test_abort.py",
"test_chunked_prefill.py",
"test_custom_allreduce.py",
"test_double_sparsity.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
Expand Down
Loading
Loading