diff --git a/iris/drivers/__init__.py b/iris/drivers/__init__.py index 0520e18ab..a46501665 100644 --- a/iris/drivers/__init__.py +++ b/iris/drivers/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. """ -Shared driver package types for fabric backends. +Shared driver package types for memory backends. """ from __future__ import annotations @@ -10,14 +10,14 @@ from dataclasses import dataclass from typing import Optional -from iris.drivers.base import BaseFabricDriver +from iris.drivers.base import BaseDriver __all__ = ["DriverStack"] @dataclass class DriverStack: - """Fabric drivers available for a rank.""" + """Driver available for a rank.""" vendor: str - fabric: Optional[BaseFabricDriver] + driver: Optional[BaseDriver] diff --git a/iris/drivers/base.py b/iris/drivers/base.py index 74c747824..4221d19e0 100644 --- a/iris/drivers/base.py +++ b/iris/drivers/base.py @@ -2,21 +2,21 @@ # Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. """ -Abstract base classes, shared dataclasses, and exceptions for fabric drivers. +Abstract base classes, shared dataclasses, and exceptions for memory drivers. """ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from iris.host.distributed.topology import InterconnectLevel __all__ = [ "PeerMapping", "LocalAllocation", - "BaseFabricDriver", + "BaseDriver", "DriverError", "DriverNotSupported", ] @@ -40,6 +40,7 @@ class LocalAllocation: va: int size: int handle: Any + _va_owned: bool = True class DriverError(RuntimeError): @@ -50,23 +51,39 @@ class DriverNotSupported(DriverError): """The current hardware or software stack does not support this driver.""" -class BaseFabricDriver(ABC): - """Cross-node fabric memory sharing (for example NVSwitch or xGMI).""" +class BaseDriver(ABC): + """Generic base class for local and fabric memory drivers.""" @abstractmethod def initialize(self, device_ordinal: int) -> None: """Prepare the driver for a specific local GPU.""" @abstractmethod - def allocate_exportable(self, size: int) -> LocalAllocation: - """Allocate memory that can be shared through the fabric transport.""" + def allocate_exportable( + self, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> LocalAllocation: + """Allocate exportable memory, optionally mapping it at a caller-reserved VA.""" @abstractmethod def export_handle(self, allocation: LocalAllocation) -> bytes: """Export a transport-specific handle for a local allocation.""" @abstractmethod - def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping: + def import_and_map( + self, + peer_rank: int, + handle_bytes: bytes, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> PeerMapping: """Import a peer handle and map it into the local virtual address space.""" @abstractmethod @@ -76,3 +93,27 @@ def cleanup_import(self, mapping: PeerMapping) -> None: @abstractmethod def cleanup_local(self, allocation: LocalAllocation) -> None: """Release a locally-exported allocation.""" + + @abstractmethod + def get_minimum_granularity(self) -> int: + """Minimum allocation granularity in bytes for this driver+device.""" + + @abstractmethod + def reserve_va(self, size: int, alignment: int = 0) -> int: + """Reserve a virtual address range without backing physical memory.""" + + @abstractmethod + def free_va(self, va: int, size: int) -> None: + """Free a VA range previously returned by reserve_va.""" + + def get_address_range(self, ptr: int) -> tuple[int, int]: + """Return the base VA and size of the allocation containing ptr.""" + raise DriverNotSupported( + f"{type(self).__name__} does not support get_address_range" + ) + + def export_pointer_handle(self, ptr: int, size: int) -> bytes: + """Export a peer handle for an arbitrary device pointer.""" + raise DriverNotSupported( + f"{type(self).__name__} does not support export_pointer_handle" + ) diff --git a/iris/drivers/fabric/amd.py b/iris/drivers/fabric/amd.py index 80fd275b9..a7af84497 100644 --- a/iris/drivers/fabric/amd.py +++ b/iris/drivers/fabric/amd.py @@ -7,26 +7,49 @@ from __future__ import annotations -from iris.drivers.base import BaseFabricDriver, DriverNotSupported, LocalAllocation, PeerMapping +from typing import Optional + +from iris.drivers.base import ( + BaseDriver, + DriverNotSupported, + LocalAllocation, + PeerMapping, +) __all__ = ["AmdFabricDriver"] _NOT_IMPLEMENTED_MESSAGE = "AMD fabric driver not yet implemented" -class AmdFabricDriver(BaseFabricDriver): +class AmdFabricDriver(BaseDriver): """AMD fabric driver placeholder.""" def initialize(self, device_ordinal: int) -> None: raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) - def allocate_exportable(self, size: int) -> LocalAllocation: + def allocate_exportable( + self, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> LocalAllocation: raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) def export_handle(self, allocation: LocalAllocation) -> bytes: raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) - def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping: + def import_and_map( + self, + peer_rank: int, + handle_bytes: bytes, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> PeerMapping: raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) def cleanup_import(self, mapping: PeerMapping) -> None: @@ -34,3 +57,12 @@ def cleanup_import(self, mapping: PeerMapping) -> None: def cleanup_local(self, allocation: LocalAllocation) -> None: raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def get_minimum_granularity(self) -> int: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def reserve_va(self, size: int, alignment: int = 0) -> int: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def free_va(self, va: int, size: int) -> None: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) diff --git a/iris/drivers/fabric/nvidia.py b/iris/drivers/fabric/nvidia.py index b6ae1901d..994135f38 100644 --- a/iris/drivers/fabric/nvidia.py +++ b/iris/drivers/fabric/nvidia.py @@ -9,12 +9,13 @@ import ctypes import logging +from collections.abc import Callable from typing import Any, Optional import torch from iris.drivers.base import ( - BaseFabricDriver, + BaseDriver, DriverError, DriverNotSupported, LocalAllocation, @@ -68,7 +69,10 @@ def _cuda_try(err: int, op_name: str = "CUDA operation") -> None: error_name = str(err) if _cuda_driver is not None and hasattr(_cuda_driver, "cuGetErrorName"): ptr = ctypes.c_char_p() - if _cuda_driver.cuGetErrorName(err, ctypes.byref(ptr)) == CUDA_SUCCESS and ptr.value: + if ( + _cuda_driver.cuGetErrorName(err, ctypes.byref(ptr)) == CUDA_SUCCESS + and ptr.value + ): error_name = ptr.value.decode("utf-8") message = f"{op_name} failed with {error_name} ({err})" if err == CUDA_ERROR_NOT_SUPPORTED: @@ -94,12 +98,18 @@ def _normalize_fabric_handle_bytes(raw_handle: Any) -> bytes: data = bytes(raw_handle) except Exception: try: - data = ctypes.string_at(ctypes.addressof(raw_handle), FABRIC_HANDLE_BYTES) + data = ctypes.string_at( + ctypes.addressof(raw_handle), FABRIC_HANDLE_BYTES + ) except Exception as exc: - raise CudaFabricError("Unable to convert fabric handle object to bytes") from exc + raise CudaFabricError( + "Unable to convert fabric handle object to bytes" + ) from exc if len(data) != FABRIC_HANDLE_BYTES: - raise CudaFabricError(f"Fabric handle serialization expected {FABRIC_HANDLE_BYTES} bytes, got {len(data)}") + raise CudaFabricError( + f"Fabric handle serialization expected {FABRIC_HANDLE_BYTES} bytes, got {len(data)}" + ) return data @@ -113,14 +123,16 @@ def _get_required_cuda_symbol(name: str) -> Any: return symbol -def _run_cleanup_steps(*steps) -> None: +def _run_cleanup_steps(*steps: tuple[str, Callable[[], None]]) -> None: first_error = None - for step in steps: + for name, step in steps: try: step() except Exception as exc: if first_error is None: first_error = exc + else: + logger.warning("Secondary cleanup step %s failed: %s", name, exc) if first_error is not None: raise first_error @@ -162,7 +174,9 @@ def _configure_cuda_signatures() -> None: cu_device_get = _get_required_cuda_symbol("cuDeviceGet") cu_device_primary_ctx_retain = _get_required_cuda_symbol("cuDevicePrimaryCtxRetain") cu_ctx_set_current = _get_required_cuda_symbol("cuCtxSetCurrent") - cu_mem_get_allocation_granularity = _get_required_cuda_symbol("cuMemGetAllocationGranularity") + cu_mem_get_allocation_granularity = _get_required_cuda_symbol( + "cuMemGetAllocationGranularity" + ) cu_mem_address_reserve = _get_required_cuda_symbol("cuMemAddressReserve") cu_mem_address_free = _get_required_cuda_symbol("cuMemAddressFree") cu_mem_create = _get_required_cuda_symbol("cuMemCreate") @@ -170,8 +184,12 @@ def _configure_cuda_signatures() -> None: cu_mem_map = _get_required_cuda_symbol("cuMemMap") cu_mem_unmap = _get_required_cuda_symbol("cuMemUnmap") cu_mem_set_access = _get_required_cuda_symbol("cuMemSetAccess") - cu_mem_export_to_shareable_handle = _get_required_cuda_symbol("cuMemExportToShareableHandle") - cu_mem_import_from_shareable_handle = _get_required_cuda_symbol("cuMemImportFromShareableHandle") + cu_mem_export_to_shareable_handle = _get_required_cuda_symbol( + "cuMemExportToShareableHandle" + ) + cu_mem_import_from_shareable_handle = _get_required_cuda_symbol( + "cuMemImportFromShareableHandle" + ) cu_init.argtypes = [ctypes.c_uint] cu_init.restype = ctypes.c_int @@ -262,7 +280,7 @@ def _configure_cuda_signatures() -> None: cu_get_error_name.restype = ctypes.c_int -class NvidiaFabricDriver(BaseFabricDriver): +class NvidiaFabricDriver(BaseDriver): """ NVIDIA CUDA VMM fabric driver. @@ -305,7 +323,10 @@ def _mem_set_access(self, va: int, size: int) -> None: desc.location.type = _CU_MEM_LOCATION_TYPE_DEVICE desc.location.id = self._device_ordinal desc.flags = _CU_MEM_ACCESS_FLAGS_PROT_READWRITE - _cuda_try(_cuda_driver.cuMemSetAccess(va, size, ctypes.byref(desc), 1), "cuMemSetAccess") + _cuda_try( + _cuda_driver.cuMemSetAccess(va, size, ctypes.byref(desc), 1), + "cuMemSetAccess", + ) def initialize(self, device_ordinal: int) -> None: if _cuda_driver is None: @@ -314,9 +335,14 @@ def initialize(self, device_ordinal: int) -> None: _configure_cuda_signatures() _cuda_try(_cuda_driver.cuInit(0), "cuInit") dev = ctypes.c_int() - _cuda_try(_cuda_driver.cuDeviceGet(ctypes.byref(dev), device_ordinal), "cuDeviceGet") + _cuda_try( + _cuda_driver.cuDeviceGet(ctypes.byref(dev), device_ordinal), "cuDeviceGet" + ) ctx = ctypes.c_void_p() - _cuda_try(_cuda_driver.cuDevicePrimaryCtxRetain(ctypes.byref(ctx), dev.value), "cuDevicePrimaryCtxRetain") + _cuda_try( + _cuda_driver.cuDevicePrimaryCtxRetain(ctypes.byref(ctx), dev.value), + "cuDevicePrimaryCtxRetain", + ) _cuda_try(_cuda_driver.cuCtxSetCurrent(ctx), "cuCtxSetCurrent") self._device_ordinal = device_ordinal self._granularity = None @@ -325,35 +351,67 @@ def initialize(self, device_ordinal: int) -> None: def _check_initialized(self) -> None: if not self._initialized: - raise CudaFabricError("NvidiaFabricDriver not initialized — call initialize() first") + raise CudaFabricError( + "NvidiaFabricDriver not initialized — call initialize() first" + ) - def allocate_exportable(self, size: int) -> LocalAllocation: + def allocate_exportable( + self, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> LocalAllocation: self._check_initialized() + if (access_va is None) != (access_size is None): + raise CudaFabricError("access_va and access_size must be provided together") props = self._make_alloc_props() granularity = self._get_granularity() alloc_size = _round_up(size, granularity) - va = ctypes.c_uint64() + reserved_va = va is None + mapped_va = int(va) if va is not None else 0 handle = ctypes.c_uint64() mapped = False try: + if reserved_va: + reserved = ctypes.c_uint64() + _cuda_try( + _cuda_driver.cuMemAddressReserve( + ctypes.byref(reserved), alloc_size, granularity, 0, 0 + ), + "cuMemAddressReserve", + ) + mapped_va = int(reserved.value) _cuda_try( - _cuda_driver.cuMemAddressReserve(ctypes.byref(va), alloc_size, granularity, 0, 0), - "cuMemAddressReserve", + _cuda_driver.cuMemCreate( + ctypes.byref(handle), alloc_size, ctypes.byref(props), 0 + ), + "cuMemCreate", ) _cuda_try( - _cuda_driver.cuMemCreate(ctypes.byref(handle), alloc_size, ctypes.byref(props), 0), - "cuMemCreate", + _cuda_driver.cuMemMap(mapped_va, alloc_size, 0, handle.value, 0), + "cuMemMap", ) - _cuda_try(_cuda_driver.cuMemMap(va.value, alloc_size, 0, handle.value, 0), "cuMemMap") mapped = True - self._mem_set_access(int(va.value), alloc_size) - return LocalAllocation(va=int(va.value), size=alloc_size, handle=int(handle.value)) + self._mem_set_access( + int(access_va) if access_va is not None else mapped_va, + int(access_size) if access_size is not None else alloc_size, + ) + return LocalAllocation( + va=mapped_va, + size=alloc_size, + handle=int(handle.value), + _va_owned=reserved_va, + ) except Exception: if mapped: try: - _cuda_try(_cuda_driver.cuMemUnmap(va.value, alloc_size), "cuMemUnmap") + _cuda_try( + _cuda_driver.cuMemUnmap(mapped_va, alloc_size), "cuMemUnmap" + ) except Exception: pass if handle.value: @@ -361,9 +419,12 @@ def allocate_exportable(self, size: int) -> LocalAllocation: _cuda_try(_cuda_driver.cuMemRelease(handle.value), "cuMemRelease") except Exception: pass - if va.value: + if reserved_va and mapped_va: try: - _cuda_try(_cuda_driver.cuMemAddressFree(va.value, alloc_size), "cuMemAddressFree") + _cuda_try( + _cuda_driver.cuMemAddressFree(mapped_va, alloc_size), + "cuMemAddressFree", + ) except Exception: pass raise @@ -396,59 +457,160 @@ def _import_handle(self, handle_bytes: bytes) -> int: ) return int(imported.value) - def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping: + def import_and_map( + self, + peer_rank: int, + handle_bytes: bytes, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> PeerMapping: self._check_initialized() + if (access_va is None) != (access_size is None): + raise CudaFabricError("access_va and access_size must be provided together") imported_handle = self._import_handle(handle_bytes) granularity = self._get_granularity() - va = ctypes.c_uint64() - + va_owned = va is None + mapped_va = int(va) if va is not None else 0 mapped = False try: + if va_owned: + reserved = ctypes.c_uint64() + _cuda_try( + _cuda_driver.cuMemAddressReserve( + ctypes.byref(reserved), size, granularity, 0, 0 + ), + "cuMemAddressReserve", + ) + mapped_va = int(reserved.value) _cuda_try( - _cuda_driver.cuMemAddressReserve(ctypes.byref(va), size, granularity, 0, 0), - "cuMemAddressReserve", + _cuda_driver.cuMemMap(mapped_va, size, 0, imported_handle, 0), + "cuMemMap", ) - _cuda_try(_cuda_driver.cuMemMap(va.value, size, 0, imported_handle, 0), "cuMemMap") mapped = True - self._mem_set_access(int(va.value), size) + self._mem_set_access( + int(access_va) if access_va is not None else mapped_va, + int(access_size) if access_size is not None else size, + ) except Exception: if mapped: try: - _cuda_try(_cuda_driver.cuMemUnmap(va.value, size), "cuMemUnmap") + _cuda_try(_cuda_driver.cuMemUnmap(mapped_va, size), "cuMemUnmap") except Exception: pass try: _cuda_try(_cuda_driver.cuMemRelease(imported_handle), "cuMemRelease") except Exception: pass - if va.value: + if va_owned and mapped_va: try: - _cuda_try(_cuda_driver.cuMemAddressFree(va.value, size), "cuMemAddressFree") + _cuda_try( + _cuda_driver.cuMemAddressFree(mapped_va, size), + "cuMemAddressFree", + ) except Exception: pass raise + tag = "driver_va" if va_owned else "caller_va" return PeerMapping( peer_rank=peer_rank, transport=InterconnectLevel.INTRA_RACK_FABRIC, - remote_va=int(va.value), + remote_va=mapped_va, size=size, - _driver_handle=imported_handle, + _driver_handle=(tag, imported_handle), ) def cleanup_import(self, mapping: PeerMapping) -> None: self._check_initialized() - _run_cleanup_steps( - lambda: _cuda_try(_cuda_driver.cuMemUnmap(mapping.remote_va, mapping.size), "cuMemUnmap"), - lambda: _cuda_try(_cuda_driver.cuMemRelease(mapping._driver_handle), "cuMemRelease"), - lambda: _cuda_try(_cuda_driver.cuMemAddressFree(mapping.remote_va, mapping.size), "cuMemAddressFree"), - ) + if ( + isinstance(mapping._driver_handle, tuple) + and len(mapping._driver_handle) == 2 + ): + tag, imported_handle = mapping._driver_handle + else: + tag = "driver_va" + imported_handle = mapping._driver_handle + + steps: list[tuple[str, Callable[[], None]]] = [ + ( + "cuMemUnmap", + lambda: _cuda_try( + _cuda_driver.cuMemUnmap(mapping.remote_va, mapping.size), + "cuMemUnmap", + ), + ), + ( + "cuMemRelease", + lambda: _cuda_try( + _cuda_driver.cuMemRelease(imported_handle), + "cuMemRelease", + ), + ), + ] + if tag == "driver_va": + steps.append( + ( + "cuMemAddressFree", + lambda: _cuda_try( + _cuda_driver.cuMemAddressFree(mapping.remote_va, mapping.size), + "cuMemAddressFree", + ), + ) + ) + _run_cleanup_steps(*steps) def cleanup_local(self, allocation: LocalAllocation) -> None: self._check_initialized() - _run_cleanup_steps( - lambda: _cuda_try(_cuda_driver.cuMemUnmap(allocation.va, allocation.size), "cuMemUnmap"), - lambda: _cuda_try(_cuda_driver.cuMemRelease(allocation.handle), "cuMemRelease"), - lambda: _cuda_try(_cuda_driver.cuMemAddressFree(allocation.va, allocation.size), "cuMemAddressFree"), + steps = [ + ( + "cuMemUnmap", + lambda: _cuda_try( + _cuda_driver.cuMemUnmap(allocation.va, allocation.size), + "cuMemUnmap", + ), + ), + ( + "cuMemRelease", + lambda: _cuda_try( + _cuda_driver.cuMemRelease(allocation.handle), + "cuMemRelease", + ), + ), + ] + if allocation._va_owned: + steps.append( + ( + "cuMemAddressFree", + lambda: _cuda_try( + _cuda_driver.cuMemAddressFree(allocation.va, allocation.size), + "cuMemAddressFree", + ), + ) + ) + _run_cleanup_steps(*steps) + + def get_minimum_granularity(self) -> int: + self._check_initialized() + return self._get_granularity() + + def reserve_va(self, size: int, alignment: int = 0) -> int: + self._check_initialized() + if alignment == 0: + alignment = self._get_granularity() + + reserved = ctypes.c_uint64() + _cuda_try( + _cuda_driver.cuMemAddressReserve( + ctypes.byref(reserved), size, alignment, 0, 0 + ), + "cuMemAddressReserve", ) + return int(reserved.value) + + def free_va(self, va: int, size: int) -> None: + self._check_initialized() + _cuda_try(_cuda_driver.cuMemAddressFree(va, size), "cuMemAddressFree") diff --git a/iris/drivers/factory.py b/iris/drivers/factory.py new file mode 100644 index 000000000..fc6110c4e --- /dev/null +++ b/iris/drivers/factory.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +"""Driver factory: vendor + interconnect -> BaseDriver.""" + +from __future__ import annotations + +from iris.drivers.base import BaseDriver, DriverNotSupported +from iris.drivers.fabric.amd import AmdFabricDriver +from iris.drivers.fabric.nvidia import NvidiaFabricDriver +from iris.drivers.local.amd import LocalHipDriver +from iris.drivers.local.nvidia import LocalCudaDriver +from iris.host.distributed.topology import InterconnectLevel + +__all__ = ["DriverFactory"] + + +class DriverFactory: + """Stateless factory for memory drivers.""" + + @staticmethod + def create_driver(vendor: str, interconnect: InterconnectLevel) -> BaseDriver: + v = vendor.strip().lower() + if v == "nvidia": + if interconnect == InterconnectLevel.INTRA_RACK_FABRIC: + return NvidiaFabricDriver() + if interconnect == InterconnectLevel.INTRA_NODE: + return LocalCudaDriver() + elif v == "amd": + if interconnect == InterconnectLevel.INTRA_RACK_FABRIC: + return AmdFabricDriver() + if interconnect == InterconnectLevel.INTRA_NODE: + return LocalHipDriver() + raise DriverNotSupported( + f"No driver for vendor={vendor!r}, interconnect={interconnect!r}" + ) diff --git a/iris/drivers/local/__init__.py b/iris/drivers/local/__init__.py new file mode 100644 index 000000000..d4b83ba8c --- /dev/null +++ b/iris/drivers/local/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +"""Local (intra-node) memory drivers.""" + +from iris.drivers.local.amd import LocalHipDriver +from iris.drivers.local.nvidia import LocalCudaDriver + +__all__ = ["LocalHipDriver", "LocalCudaDriver"] diff --git a/iris/drivers/local/amd.py b/iris/drivers/local/amd.py new file mode 100644 index 000000000..93ee93731 --- /dev/null +++ b/iris/drivers/local/amd.py @@ -0,0 +1,785 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +"""AMD HIP local memory driver.""" + +from __future__ import annotations + +import ctypes +import logging +import os +import struct +from collections.abc import Callable +from typing import Any, Optional + +from iris.drivers.base import ( + BaseDriver, + DriverError, + DriverNotSupported, + LocalAllocation, + PeerMapping, +) +from iris.host.distributed.topology import InterconnectLevel + +logger = logging.getLogger("iris.drivers.local.amd") + +__all__ = [ + "LocalHipError", + "LocalHipNotSupported", + "LocalHipDriver", +] + +_hip = None +try: + _hip = ctypes.cdll.LoadLibrary("libamdhip64.so") +except OSError: + pass + +HIP_SUCCESS = 0 +HIP_ERROR_NOT_SUPPORTED = 801 + +hipMemAllocationTypePinned = 0x1 +hipMemHandleTypePosixFileDescriptor = 0x1 +hipMemLocationTypeDevice = 0x1 +hipMemAllocationGranularityRecommended = 0x1 +hipMemAccessFlagsProtReadWrite = 0x3 +hipExternalMemoryHandleTypeOpaqueFd = 1 + +hipMemGenericAllocationHandle_t = ctypes.c_void_p +hipExternalMemory_t = ctypes.c_void_p + +_AMD_HANDLE_FMT = "=iQQ" +_AMD_HANDLE_BYTES = struct.calcsize(_AMD_HANDLE_FMT) + + +class LocalHipError(DriverError): + """HIP local-memory operation failed.""" + + +class LocalHipNotSupported(DriverNotSupported): + """The local HIP stack does not support this driver.""" + + +class hipMemLocation(ctypes.Structure): + """Structure describing a HIP memory location.""" + + _fields_ = [ + ("type", ctypes.c_int), + ("id", ctypes.c_int), + ] + + +class hipMemAllocationProp(ctypes.Structure): + """Properties for a HIP VMem allocation.""" + + class _allocFlags(ctypes.Structure): + _fields_ = [ + ("smc", ctypes.c_ubyte), + ("l2", ctypes.c_ubyte), + ] + + _fields_ = [ + ("type", ctypes.c_int), + ("requestedHandleType", ctypes.c_int), + ("location", hipMemLocation), + ("win32Handle", ctypes.c_void_p), + ("allocFlags", _allocFlags), + ] + + +class hipMemAccessDesc(ctypes.Structure): + """Access descriptor for a HIP VMem mapping.""" + + _fields_ = [ + ("location", hipMemLocation), + ("flags", ctypes.c_int), + ] + + +class hipExternalMemoryHandleDesc(ctypes.Structure): + """Descriptor for importing HIP external memory from a DMA-BUF FD.""" + + class HandleUnion(ctypes.Union): + _fields_ = [ + ("fd", ctypes.c_int), + ("win32", ctypes.c_void_p * 2), + ] + + _fields_ = [ + ("type", ctypes.c_int), + ("_pad", ctypes.c_int), + ("handle", HandleUnion), + ("size", ctypes.c_ulonglong), + ("flags", ctypes.c_uint), + ("_pad2", ctypes.c_uint), + ("reserved", ctypes.c_uint * 16), + ] + + +class hipExternalMemoryBufferDesc(ctypes.Structure): + """Descriptor for mapping an imported HIP external-memory buffer.""" + + _fields_ = [ + ("offset", ctypes.c_ulonglong), + ("size", ctypes.c_ulonglong), + ("flags", ctypes.c_uint), + ("reserved", ctypes.c_uint * 16), + ] + + +def _get_required_hip_symbol(name: str) -> Any: + if _hip is None: + raise LocalHipNotSupported("libamdhip64.so not found") + + symbol = getattr(_hip, name, None) + if symbol is None: + raise LocalHipNotSupported(f"HIP runtime missing required symbol: {name}") + return symbol + + +def _configure_signatures() -> None: + """Configure ctypes signatures for all HIP functions used by this driver.""" + if _hip is None: + return + + hip_set_device = _get_required_hip_symbol("hipSetDevice") + hip_mem_get_allocation_granularity = _get_required_hip_symbol( + "hipMemGetAllocationGranularity" + ) + hip_mem_create = _get_required_hip_symbol("hipMemCreate") + hip_mem_address_reserve = _get_required_hip_symbol("hipMemAddressReserve") + hip_mem_map = _get_required_hip_symbol("hipMemMap") + hip_mem_set_access = _get_required_hip_symbol("hipMemSetAccess") + hip_mem_unmap = _get_required_hip_symbol("hipMemUnmap") + hip_mem_release = _get_required_hip_symbol("hipMemRelease") + hip_mem_address_free = _get_required_hip_symbol("hipMemAddressFree") + hip_mem_get_address_range = _get_required_hip_symbol("hipMemGetAddressRange") + hip_mem_get_handle_for_address_range = _get_required_hip_symbol( + "hipMemGetHandleForAddressRange" + ) + hip_mem_import_from_shareable_handle = _get_required_hip_symbol( + "hipMemImportFromShareableHandle" + ) + hip_import_external_memory = _get_required_hip_symbol("hipImportExternalMemory") + hip_external_memory_get_mapped_buffer = _get_required_hip_symbol( + "hipExternalMemoryGetMappedBuffer" + ) + hip_destroy_external_memory = _get_required_hip_symbol("hipDestroyExternalMemory") + hip_get_error_string = _get_required_hip_symbol("hipGetErrorString") + + hip_set_device.argtypes = [ctypes.c_int] + hip_set_device.restype = ctypes.c_int + + hip_mem_get_allocation_granularity.argtypes = [ + ctypes.POINTER(ctypes.c_size_t), + ctypes.POINTER(hipMemAllocationProp), + ctypes.c_int, + ] + hip_mem_get_allocation_granularity.restype = ctypes.c_int + + hip_mem_create.argtypes = [ + ctypes.POINTER(hipMemGenericAllocationHandle_t), + ctypes.c_size_t, + ctypes.POINTER(hipMemAllocationProp), + ctypes.c_ulonglong, + ] + hip_mem_create.restype = ctypes.c_int + + hip_mem_address_reserve.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_void_p, + ctypes.c_ulonglong, + ] + hip_mem_address_reserve.restype = ctypes.c_int + + hip_mem_map.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_size_t, + hipMemGenericAllocationHandle_t, + ctypes.c_ulonglong, + ] + hip_mem_map.restype = ctypes.c_int + + hip_mem_set_access.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.POINTER(hipMemAccessDesc), + ctypes.c_size_t, + ] + hip_mem_set_access.restype = ctypes.c_int + + hip_mem_unmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t] + hip_mem_unmap.restype = ctypes.c_int + + hip_mem_release.argtypes = [hipMemGenericAllocationHandle_t] + hip_mem_release.restype = ctypes.c_int + + hip_mem_address_free.argtypes = [ctypes.c_void_p, ctypes.c_size_t] + hip_mem_address_free.restype = ctypes.c_int + + hip_mem_get_address_range.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_size_t), + ctypes.c_void_p, + ] + hip_mem_get_address_range.restype = ctypes.c_int + + hip_mem_get_handle_for_address_range.argtypes = [ + ctypes.POINTER(ctypes.c_int), + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_ulonglong, + ] + hip_mem_get_handle_for_address_range.restype = ctypes.c_int + + hip_mem_import_from_shareable_handle.argtypes = [ + ctypes.POINTER(hipMemGenericAllocationHandle_t), + ctypes.c_void_p, + ctypes.c_int, + ] + hip_mem_import_from_shareable_handle.restype = ctypes.c_int + + hip_import_external_memory.argtypes = [ + ctypes.POINTER(hipExternalMemory_t), + ctypes.POINTER(hipExternalMemoryHandleDesc), + ] + hip_import_external_memory.restype = ctypes.c_int + + hip_external_memory_get_mapped_buffer.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + hipExternalMemory_t, + ctypes.POINTER(hipExternalMemoryBufferDesc), + ] + hip_external_memory_get_mapped_buffer.restype = ctypes.c_int + + hip_destroy_external_memory.argtypes = [hipExternalMemory_t] + hip_destroy_external_memory.restype = ctypes.c_int + + hip_get_error_string.argtypes = [ctypes.c_int] + hip_get_error_string.restype = ctypes.c_char_p + + +def _hip_try(err: int, op_name: str = "HIP operation") -> None: + """Check a HIP runtime return code and raise a driver exception on error.""" + if err == HIP_SUCCESS: + return + + error_string = str(err) + if _hip is not None and hasattr(_hip, "hipGetErrorString"): + decoded = _hip.hipGetErrorString(ctypes.c_int(err)) + if decoded: + error_string = decoded.decode("utf-8", errors="replace") + + message = f"{op_name} failed with HIP error code {err}: {error_string}" + if err == HIP_ERROR_NOT_SUPPORTED: + raise LocalHipNotSupported(message) + raise LocalHipError(message) + + +def _round_up(value: int, granularity: int) -> int: + if granularity <= 0: + raise ValueError(f"granularity must be > 0, got {granularity}") + return ((value + granularity - 1) // granularity) * granularity + + +def _run_cleanup_steps(*steps: tuple[str, Callable[[], None]]) -> None: + first_error = None + for name, step in steps: + try: + step() + except Exception as exc: + if first_error is None: + first_error = exc + else: + logger.warning("Secondary cleanup step %s failed: %s", name, exc) + if first_error is not None: + raise first_error + + +def _cleanup_after_failure(*steps: tuple[str, Callable[[], None]]) -> None: + for name, step in steps: + try: + step() + except Exception as exc: + logger.warning( + "Cleanup step %s failed after earlier failure: %s", name, exc + ) + + +class LocalHipDriver(BaseDriver): + """ + AMD HIP VMem local driver using DMA-BUF handles for peer import/export. + + hipSetDevice is per-thread; use each driver instance from the thread that + called initialize(). + """ + + def __init__(self) -> None: + self._device_ordinal: int = 0 + self._granularity: Optional[int] = None + self._initialized: bool = False + + def _check_initialized(self) -> None: + if not self._initialized: + raise LocalHipError( + "LocalHipDriver not initialized - call initialize() first" + ) + + def _make_alloc_props(self) -> hipMemAllocationProp: + props = hipMemAllocationProp() + props.type = hipMemAllocationTypePinned + props.requestedHandleType = hipMemHandleTypePosixFileDescriptor + props.location.type = hipMemLocationTypeDevice + props.location.id = self._device_ordinal + props.win32Handle = None + return props + + def _get_granularity(self) -> int: + if self._granularity is not None: + return self._granularity + + props = self._make_alloc_props() + granularity = ctypes.c_size_t() + _hip_try( + _hip.hipMemGetAllocationGranularity( + ctypes.byref(granularity), + ctypes.byref(props), + hipMemAllocationGranularityRecommended, + ), + "hipMemGetAllocationGranularity", + ) + self._granularity = int(granularity.value) + return self._granularity + + def _mem_set_access(self, va: int, size: int) -> None: + desc = hipMemAccessDesc() + desc.location.type = hipMemLocationTypeDevice + desc.location.id = self._device_ordinal + desc.flags = hipMemAccessFlagsProtReadWrite + _hip_try( + _hip.hipMemSetAccess(ctypes.c_void_p(va), size, ctypes.byref(desc), 1), + "hipMemSetAccess", + ) + + def initialize(self, device_ordinal: int) -> None: + """Prepare the HIP runtime and bind this driver instance to one GPU.""" + if _hip is None: + raise LocalHipNotSupported("libamdhip64.so not found") + + _configure_signatures() + _hip_try(_hip.hipSetDevice(device_ordinal), "hipSetDevice") + self._device_ordinal = device_ordinal + self._granularity = None + self._initialized = True + logger.info("LocalHipDriver initialized (device %d)", device_ordinal) + + def allocate_exportable( + self, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> LocalAllocation: + """ + Allocate HIP VMem exportable as a DMA-BUF. + + If va is supplied, the caller must already own a sufficiently large, + granularity-aligned VA range containing [va, va + size). + """ + self._check_initialized() + if (access_va is None) != (access_size is None): + raise LocalHipError("access_va and access_size must be provided together") + props = self._make_alloc_props() + granularity = self._get_granularity() + alloc_size = _round_up(size, granularity) + + reserved_va = va is None + mapped_va = int(va) if va is not None else 0 + handle = hipMemGenericAllocationHandle_t() + mapped = False + + try: + if reserved_va: + reserved = ctypes.c_void_p() + _hip_try( + _hip.hipMemAddressReserve( + ctypes.byref(reserved), alloc_size, granularity, None, 0 + ), + "hipMemAddressReserve", + ) + mapped_va = int(reserved.value) + + _hip_try( + _hip.hipMemCreate( + ctypes.byref(handle), alloc_size, ctypes.byref(props), 0 + ), + "hipMemCreate", + ) + _hip_try( + _hip.hipMemMap(ctypes.c_void_p(mapped_va), alloc_size, 0, handle, 0), + "hipMemMap", + ) + mapped = True + self._mem_set_access( + int(access_va) if access_va is not None else mapped_va, + int(access_size) if access_size is not None else alloc_size, + ) + return LocalAllocation( + va=mapped_va, + size=alloc_size, + handle=int(handle.value), + _va_owned=reserved_va, + ) + except Exception: + steps: list[tuple[str, Callable[[], None]]] = [] + if mapped: + steps.append( + ( + "hipMemUnmap", + lambda: _hip_try( + _hip.hipMemUnmap(ctypes.c_void_p(mapped_va), alloc_size), + "hipMemUnmap", + ), + ) + ) + if handle.value: + steps.append( + ( + "hipMemRelease", + lambda: _hip_try( + _hip.hipMemRelease( + hipMemGenericAllocationHandle_t(handle.value) + ), + "hipMemRelease", + ), + ) + ) + if reserved_va and mapped_va: + steps.append( + ( + "hipMemAddressFree", + lambda: _hip_try( + _hip.hipMemAddressFree( + ctypes.c_void_p(mapped_va), alloc_size + ), + "hipMemAddressFree", + ), + ) + ) + _cleanup_after_failure(*steps) + raise + + def _export_range(self, va: int, size: int) -> bytes: + base_ptr = ctypes.c_void_p() + base_size = ctypes.c_size_t() + allocation_ptr = ctypes.c_void_p(int(va)) + _hip_try( + _hip.hipMemGetAddressRange( + ctypes.byref(base_ptr), ctypes.byref(base_size), allocation_ptr + ), + "hipMemGetAddressRange", + ) + if base_ptr.value is None: + raise LocalHipError("hipMemGetAddressRange returned a null base pointer") + + fd = ctypes.c_int(-1) + _hip_try( + _hip.hipMemGetHandleForAddressRange( + ctypes.byref(fd), allocation_ptr, size, 1, 0 + ), + "hipMemGetHandleForAddressRange", + ) + fd_value = int(fd.value) + + try: + base_va = int(base_ptr.value) + base_size_value = int(base_size.value) + offset = int(va) - base_va + if offset < 0 or offset + size > base_size_value: + raise LocalHipError( + f"Allocation range va={va} size={size} exceeds base range va={base_va} size={base_size_value}" + ) + + return struct.pack(_AMD_HANDLE_FMT, fd_value, offset, base_size_value) + except Exception: + try: + os.close(fd_value) + except OSError: + pass + raise + + def export_handle(self, allocation: LocalAllocation) -> bytes: + """Export a 20-byte DMA-BUF descriptor for a local HIP allocation.""" + self._check_initialized() + return self._export_range(allocation.va, allocation.size) + + def import_and_map( + self, + peer_rank: int, + handle_bytes: bytes, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> PeerMapping: + """Import a DMA-BUF descriptor and map it into local GPU address space.""" + self._check_initialized() + if (access_va is None) != (access_size is None): + raise LocalHipError("access_va and access_size must be provided together") + if len(handle_bytes) != _AMD_HANDLE_BYTES: + raise LocalHipError( + f"AMD local handle must be {_AMD_HANDLE_BYTES} bytes, got {len(handle_bytes)}" + ) + + fd, offset, base_size = struct.unpack(_AMD_HANDLE_FMT, handle_bytes) + if size > base_size - offset: + raise LocalHipError( + f"Requested map size {size} exceeds imported base range {base_size} at offset {offset}" + ) + + if va is not None: + mapped_va = int(va) + imported_handle = hipMemGenericAllocationHandle_t() + mapped = False + fd_open = True + try: + _hip_try( + _hip.hipMemImportFromShareableHandle( + ctypes.byref(imported_handle), + ctypes.c_void_p(fd), + hipMemHandleTypePosixFileDescriptor, + ), + "hipMemImportFromShareableHandle", + ) + os.close(fd) + fd_open = False + + _hip_try( + _hip.hipMemMap( + ctypes.c_void_p(mapped_va), size, offset, imported_handle, 0 + ), + "hipMemMap", + ) + mapped = True + self._mem_set_access( + int(access_va) if access_va is not None else mapped_va, + int(access_size) if access_size is not None else size, + ) + return PeerMapping( + peer_rank=peer_rank, + transport=InterconnectLevel.INTRA_NODE, + remote_va=mapped_va, + size=size, + _driver_handle=("vmm", int(imported_handle.value)), + ) + except Exception: + steps: list[tuple[str, Callable[[], None]]] = [] + if mapped: + steps.append( + ( + "hipMemUnmap", + lambda: _hip_try( + _hip.hipMemUnmap(ctypes.c_void_p(mapped_va), size), + "hipMemUnmap", + ), + ) + ) + if imported_handle.value: + steps.append( + ( + "hipMemRelease", + lambda: _hip_try( + _hip.hipMemRelease(imported_handle), + "hipMemRelease", + ), + ) + ) + if fd_open: + steps.append(("os.close", lambda: os.close(fd))) + _cleanup_after_failure(*steps) + raise + + mem_handle_desc = hipExternalMemoryHandleDesc() + mem_handle_desc.type = hipExternalMemoryHandleTypeOpaqueFd + mem_handle_desc.handle.fd = fd + mem_handle_desc.size = base_size + mem_handle_desc.flags = 0 + + ext_mem = hipExternalMemory_t() + try: + # ROCm 7.1+ external memory import is preferred over + # hipMemImportFromShareableHandle to avoid the ROCm 7.0 MemObjMap + # segfault path for imported memory objects. + err = _hip.hipImportExternalMemory( + ctypes.byref(ext_mem), ctypes.byref(mem_handle_desc) + ) + if err != HIP_SUCCESS: + try: + os.close(fd) + except OSError: + pass + _hip_try(err, "hipImportExternalMemory") + + buffer_desc = hipExternalMemoryBufferDesc() + buffer_desc.offset = 0 + buffer_desc.size = base_size + buffer_desc.flags = 0 + + mapped_base = ctypes.c_void_p() + _hip_try( + _hip.hipExternalMemoryGetMappedBuffer( + ctypes.byref(mapped_base), ext_mem, ctypes.byref(buffer_desc) + ), + "hipExternalMemoryGetMappedBuffer", + ) + if mapped_base.value is None: + raise LocalHipError( + "hipExternalMemoryGetMappedBuffer returned a null pointer" + ) + + remote_va = int(mapped_base.value) + int(offset) + return PeerMapping( + peer_rank=peer_rank, + transport=InterconnectLevel.INTRA_NODE, + remote_va=remote_va, + size=size, + _driver_handle=(ext_mem, base_size), + ) + except Exception: + if ext_mem.value: + _cleanup_after_failure( + ( + "hipDestroyExternalMemory", + lambda: _hip_try( + _hip.hipDestroyExternalMemory(ext_mem), + "hipDestroyExternalMemory", + ), + ) + ) + raise + + def cleanup_import(self, mapping: PeerMapping) -> None: + """Release an imported HIP external-memory mapping.""" + self._check_initialized() + if ( + isinstance(mapping._driver_handle, tuple) + and len(mapping._driver_handle) == 2 + and mapping._driver_handle[0] == "vmm" + ): + imported_handle = hipMemGenericAllocationHandle_t(mapping._driver_handle[1]) + _run_cleanup_steps( + ( + "hipMemUnmap", + lambda: _hip_try( + _hip.hipMemUnmap( + ctypes.c_void_p(mapping.remote_va), mapping.size + ), + "hipMemUnmap", + ), + ), + ( + "hipMemRelease", + lambda: _hip_try( + _hip.hipMemRelease(imported_handle), + "hipMemRelease", + ), + ), + ) + return + + ext_mem, _base_size = mapping._driver_handle + try: + _hip_try(_hip.hipDestroyExternalMemory(ext_mem), "hipDestroyExternalMemory") + except Exception: + logger.warning( + "hipDestroyExternalMemory failed during import cleanup", exc_info=True + ) + raise + + def cleanup_local(self, allocation: LocalAllocation) -> None: + """Unmap, release, and free a local HIP VMem allocation.""" + self._check_initialized() + steps = [ + ( + "hipMemUnmap", + lambda: _hip_try( + _hip.hipMemUnmap(ctypes.c_void_p(allocation.va), allocation.size), + "hipMemUnmap", + ), + ), + ( + "hipMemRelease", + lambda: _hip_try( + _hip.hipMemRelease( + hipMemGenericAllocationHandle_t(int(allocation.handle)) + ), + "hipMemRelease", + ), + ), + ] + if allocation._va_owned: + steps.append( + ( + "hipMemAddressFree", + lambda: _hip_try( + _hip.hipMemAddressFree( + ctypes.c_void_p(allocation.va), allocation.size + ), + "hipMemAddressFree", + ), + ) + ) + _run_cleanup_steps(*steps) + + def get_minimum_granularity(self) -> int: + """Return the HIP VMem allocation granularity for this device.""" + self._check_initialized() + return self._get_granularity() + + def reserve_va(self, size: int, alignment: int = 0) -> int: + """Reserve a HIP virtual address range without backing memory.""" + self._check_initialized() + if alignment == 0: + alignment = self._get_granularity() + + reserved = ctypes.c_void_p() + _hip_try( + _hip.hipMemAddressReserve(ctypes.byref(reserved), size, alignment, None, 0), + "hipMemAddressReserve", + ) + if reserved.value is None: + raise LocalHipError("hipMemAddressReserve returned a null VA") + return int(reserved.value) + + def free_va(self, va: int, size: int) -> None: + """Free a HIP VA range previously returned by reserve_va.""" + self._check_initialized() + _hip_try(_hip.hipMemAddressFree(ctypes.c_void_p(va), size), "hipMemAddressFree") + + def get_address_range(self, ptr: int) -> tuple[int, int]: + """Return the base allocation range containing a HIP device pointer.""" + self._check_initialized() + base_ptr = ctypes.c_void_p() + base_size = ctypes.c_size_t() + _hip_try( + _hip.hipMemGetAddressRange( + ctypes.byref(base_ptr), + ctypes.byref(base_size), + ctypes.c_void_p(int(ptr)), + ), + "hipMemGetAddressRange", + ) + if base_ptr.value is None: + raise LocalHipError("hipMemGetAddressRange returned a null base pointer") + return int(base_ptr.value), int(base_size.value) + + def export_pointer_handle(self, ptr: int, size: int) -> bytes: + """Export a DMA-BUF descriptor for an arbitrary HIP device pointer range.""" + self._check_initialized() + return self._export_range(int(ptr), int(size)) diff --git a/iris/drivers/local/nvidia.py b/iris/drivers/local/nvidia.py new file mode 100644 index 000000000..d64ae0613 --- /dev/null +++ b/iris/drivers/local/nvidia.py @@ -0,0 +1,659 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +"""NVIDIA CUDA driver-API local memory driver.""" + +from __future__ import annotations + +import ctypes +import logging +import os +import struct +from collections.abc import Callable +from typing import Any, Optional + +from iris.drivers.base import ( + BaseDriver, + DriverError, + DriverNotSupported, + LocalAllocation, + PeerMapping, +) +from iris.host.distributed.topology import InterconnectLevel + +logger = logging.getLogger("iris.drivers.local.nvidia") + +__all__ = [ + "LocalCudaError", + "LocalCudaNotSupported", + "LocalCudaDriver", +] + +_cuda_driver = None +try: + _cuda_driver = ctypes.CDLL("libcuda.so.1") +except OSError: + try: + _cuda_driver = ctypes.CDLL("libcuda.so") + except OSError: + pass + +CUDA_SUCCESS = 0 +CUDA_ERROR_NOT_SUPPORTED = 801 + +_CUDA_HANDLE_FMT = "=i" +_CUDA_HANDLE_BYTES = struct.calcsize(_CUDA_HANDLE_FMT) + +_CU_MEM_ALLOCATION_TYPE_PINNED = 1 +_CU_MEM_LOCATION_TYPE_DEVICE = 1 +_CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = 0x1 +_CU_MEM_ALLOC_GRANULARITY_MINIMUM = 0 +_CU_MEM_ACCESS_FLAGS_PROT_READWRITE = 0x3 + + +class LocalCudaError(DriverError): + """CUDA local VMM operation failed.""" + + +class LocalCudaNotSupported(DriverNotSupported): + """The local CUDA driver stack does not support this driver.""" + + +class _MemLocation(ctypes.Structure): + _fields_ = [("type", ctypes.c_int), ("id", ctypes.c_int)] + + +class _MemAllocationFlags(ctypes.Structure): + _fields_ = [ + ("compressionType", ctypes.c_ubyte), + ("gpuDirectRDMACapable", ctypes.c_ubyte), + ("usage", ctypes.c_ushort), + ("reserved", ctypes.c_ubyte * 4), + ] + + +class _MemAllocationProp(ctypes.Structure): + _fields_ = [ + ("type", ctypes.c_int), + ("requestedHandleTypes", ctypes.c_int), + ("location", _MemLocation), + ("win32HandleMetaData", ctypes.c_void_p), + ("allocFlags", _MemAllocationFlags), + ] + + +class _MemAccessDesc(ctypes.Structure): + _fields_ = [("location", _MemLocation), ("flags", ctypes.c_ulonglong)] + + +def _get_required_cuda_symbol(name: str) -> Any: + if _cuda_driver is None: + raise LocalCudaNotSupported("CUDA driver library (libcuda.so) not found") + + symbol = getattr(_cuda_driver, name, None) + if symbol is None: + raise LocalCudaNotSupported(f"CUDA driver missing required VMM symbol: {name}") + return symbol + + +def _configure_signatures() -> None: + """Configure ctypes signatures for all CUDA driver functions used here.""" + if _cuda_driver is None: + return + + cu_init = _get_required_cuda_symbol("cuInit") + cu_device_get = _get_required_cuda_symbol("cuDeviceGet") + cu_device_primary_ctx_retain = _get_required_cuda_symbol("cuDevicePrimaryCtxRetain") + cu_ctx_set_current = _get_required_cuda_symbol("cuCtxSetCurrent") + cu_mem_get_allocation_granularity = _get_required_cuda_symbol( + "cuMemGetAllocationGranularity" + ) + cu_mem_address_reserve = _get_required_cuda_symbol("cuMemAddressReserve") + cu_mem_address_free = _get_required_cuda_symbol("cuMemAddressFree") + cu_mem_create = _get_required_cuda_symbol("cuMemCreate") + cu_mem_release = _get_required_cuda_symbol("cuMemRelease") + cu_mem_map = _get_required_cuda_symbol("cuMemMap") + cu_mem_unmap = _get_required_cuda_symbol("cuMemUnmap") + cu_mem_set_access = _get_required_cuda_symbol("cuMemSetAccess") + cu_mem_export_to_shareable_handle = _get_required_cuda_symbol( + "cuMemExportToShareableHandle" + ) + cu_mem_import_from_shareable_handle = _get_required_cuda_symbol( + "cuMemImportFromShareableHandle" + ) + + cu_init.argtypes = [ctypes.c_uint] + cu_init.restype = ctypes.c_int + + cu_device_get.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.c_int] + cu_device_get.restype = ctypes.c_int + + cu_device_primary_ctx_retain.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_int, + ] + cu_device_primary_ctx_retain.restype = ctypes.c_int + + cu_ctx_set_current.argtypes = [ctypes.c_void_p] + cu_ctx_set_current.restype = ctypes.c_int + + cu_mem_get_allocation_granularity.argtypes = [ + ctypes.POINTER(ctypes.c_size_t), + ctypes.POINTER(_MemAllocationProp), + ctypes.c_int, + ] + cu_mem_get_allocation_granularity.restype = ctypes.c_int + + cu_mem_address_reserve.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_uint64, + ctypes.c_ulonglong, + ] + cu_mem_address_reserve.restype = ctypes.c_int + + cu_mem_address_free.argtypes = [ctypes.c_uint64, ctypes.c_size_t] + cu_mem_address_free.restype = ctypes.c_int + + cu_mem_create.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.c_size_t, + ctypes.POINTER(_MemAllocationProp), + ctypes.c_ulonglong, + ] + cu_mem_create.restype = ctypes.c_int + + cu_mem_release.argtypes = [ctypes.c_uint64] + cu_mem_release.restype = ctypes.c_int + + cu_mem_map.argtypes = [ + ctypes.c_uint64, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_uint64, + ctypes.c_ulonglong, + ] + cu_mem_map.restype = ctypes.c_int + + cu_mem_unmap.argtypes = [ctypes.c_uint64, ctypes.c_size_t] + cu_mem_unmap.restype = ctypes.c_int + + cu_mem_set_access.argtypes = [ + ctypes.c_uint64, + ctypes.c_size_t, + ctypes.POINTER(_MemAccessDesc), + ctypes.c_size_t, + ] + cu_mem_set_access.restype = ctypes.c_int + + cu_mem_export_to_shareable_handle.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint64, + ctypes.c_int, + ctypes.c_ulonglong, + ] + cu_mem_export_to_shareable_handle.restype = ctypes.c_int + + cu_mem_import_from_shareable_handle.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.c_void_p, + ctypes.c_int, + ] + cu_mem_import_from_shareable_handle.restype = ctypes.c_int + + cu_get_error_name = getattr(_cuda_driver, "cuGetErrorName", None) + if cu_get_error_name is not None: + cu_get_error_name.argtypes = [ + ctypes.c_int, + ctypes.POINTER(ctypes.c_char_p), + ] + cu_get_error_name.restype = ctypes.c_int + + +def _cuda_try(err: int, op_name: str = "CUDA operation") -> None: + """Check a CUDA driver return code and raise a driver exception on error.""" + if err == CUDA_SUCCESS: + return + + error_name = str(err) + if _cuda_driver is not None and hasattr(_cuda_driver, "cuGetErrorName"): + ptr = ctypes.c_char_p() + if ( + _cuda_driver.cuGetErrorName(err, ctypes.byref(ptr)) == CUDA_SUCCESS + and ptr.value + ): + error_name = ptr.value.decode("utf-8") + + message = f"{op_name} failed with {error_name} ({err})" + if err == CUDA_ERROR_NOT_SUPPORTED: + raise LocalCudaNotSupported(message) + raise LocalCudaError(message) + + +def _round_up(value: int, granularity: int) -> int: + if granularity <= 0: + raise ValueError(f"granularity must be > 0, got {granularity}") + return ((value + granularity - 1) // granularity) * granularity + + +def _normalize_handle_bytes(raw_handle: bytes) -> bytes: + if isinstance(raw_handle, memoryview): + data = raw_handle.tobytes() + elif isinstance(raw_handle, (bytes, bytearray)): + data = bytes(raw_handle) + else: + try: + data = bytes(raw_handle) + except Exception as exc: + raise LocalCudaError( + "Unable to convert POSIX-FD handle object to bytes" + ) from exc + + if len(data) != _CUDA_HANDLE_BYTES: + raise LocalCudaError( + f"CUDA POSIX-FD handle must be {_CUDA_HANDLE_BYTES} bytes, got {len(data)}" + ) + return data + + +def _run_cleanup_steps(*steps: tuple[str, Callable[[], None]]) -> None: + first_error = None + for name, step in steps: + try: + step() + except Exception as exc: + if first_error is None: + first_error = exc + else: + logger.warning("Secondary cleanup step %s failed: %s", name, exc) + if first_error is not None: + raise first_error + + +def _cleanup_after_failure(*steps: tuple[str, Callable[[], None]]) -> None: + for name, step in steps: + try: + step() + except Exception as exc: + logger.warning( + "Cleanup step %s failed after earlier failure: %s", name, exc + ) + + +class LocalCudaDriver(BaseDriver): + """ + NVIDIA CUDA driver-API VMM local driver. + + This driver uses libcuda.so, not the CUDA runtime API. Exported handles are + POSIX file descriptors encoded as bytes; the caller is responsible for + delivering the FD across processes, for example with SCM_RIGHTS. POSIX-FD + handles require source and destination processes to share the same OS + namespace, typically on the same machine. + """ + + def __init__(self) -> None: + self._device_ordinal: int = 0 + self._granularity: Optional[int] = None + self._initialized: bool = False + + def _check_initialized(self) -> None: + if not self._initialized: + raise LocalCudaError( + "LocalCudaDriver not initialized - call initialize() first" + ) + + def _make_alloc_props(self) -> _MemAllocationProp: + props = _MemAllocationProp() + props.type = _CU_MEM_ALLOCATION_TYPE_PINNED + props.requestedHandleTypes = _CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + props.location.type = _CU_MEM_LOCATION_TYPE_DEVICE + props.location.id = self._device_ordinal + props.win32HandleMetaData = None + return props + + def _get_granularity(self) -> int: + if self._granularity is not None: + return self._granularity + + props = self._make_alloc_props() + granularity = ctypes.c_size_t() + _cuda_try( + _cuda_driver.cuMemGetAllocationGranularity( + ctypes.byref(granularity), + ctypes.byref(props), + _CU_MEM_ALLOC_GRANULARITY_MINIMUM, + ), + "cuMemGetAllocationGranularity", + ) + self._granularity = int(granularity.value) + return self._granularity + + def _mem_set_access(self, va: int, size: int) -> None: + desc = _MemAccessDesc() + desc.location.type = _CU_MEM_LOCATION_TYPE_DEVICE + desc.location.id = self._device_ordinal + desc.flags = _CU_MEM_ACCESS_FLAGS_PROT_READWRITE + _cuda_try( + _cuda_driver.cuMemSetAccess(va, size, ctypes.byref(desc), 1), + "cuMemSetAccess", + ) + + def initialize(self, device_ordinal: int) -> None: + """Prepare the CUDA driver context and bind this instance to one GPU.""" + if _cuda_driver is None: + raise LocalCudaNotSupported("CUDA driver library (libcuda.so) not found") + + _configure_signatures() + _cuda_try(_cuda_driver.cuInit(0), "cuInit") + dev = ctypes.c_int() + _cuda_try( + _cuda_driver.cuDeviceGet(ctypes.byref(dev), device_ordinal), "cuDeviceGet" + ) + ctx = ctypes.c_void_p() + _cuda_try( + _cuda_driver.cuDevicePrimaryCtxRetain(ctypes.byref(ctx), dev.value), + "cuDevicePrimaryCtxRetain", + ) + _cuda_try(_cuda_driver.cuCtxSetCurrent(ctx), "cuCtxSetCurrent") + self._device_ordinal = device_ordinal + self._granularity = None + self._initialized = True + logger.info("LocalCudaDriver initialized (device %d)", device_ordinal) + + def allocate_exportable( + self, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> LocalAllocation: + """ + Allocate CUDA VMM memory exportable as a POSIX FD. + + If va is supplied, the caller must already own a sufficiently large, + granularity-aligned VA range containing [va, va + size). + """ + self._check_initialized() + if (access_va is None) != (access_size is None): + raise LocalCudaError("access_va and access_size must be provided together") + props = self._make_alloc_props() + granularity = self._get_granularity() + alloc_size = _round_up(size, granularity) + + reserved_va = va is None + mapped_va = int(va) if va is not None else 0 + handle = ctypes.c_uint64() + mapped = False + + try: + if reserved_va: + reserved = ctypes.c_uint64() + _cuda_try( + _cuda_driver.cuMemAddressReserve( + ctypes.byref(reserved), alloc_size, granularity, 0, 0 + ), + "cuMemAddressReserve", + ) + mapped_va = int(reserved.value) + _cuda_try( + _cuda_driver.cuMemCreate( + ctypes.byref(handle), alloc_size, ctypes.byref(props), 0 + ), + "cuMemCreate", + ) + _cuda_try( + _cuda_driver.cuMemMap(mapped_va, alloc_size, 0, handle.value, 0), + "cuMemMap", + ) + mapped = True + self._mem_set_access( + int(access_va) if access_va is not None else mapped_va, + int(access_size) if access_size is not None else alloc_size, + ) + return LocalAllocation( + va=mapped_va, + size=alloc_size, + handle=int(handle.value), + _va_owned=reserved_va, + ) + except Exception: + steps: list[tuple[str, Callable[[], None]]] = [] + if mapped: + steps.append( + ( + "cuMemUnmap", + lambda: _cuda_try( + _cuda_driver.cuMemUnmap(mapped_va, alloc_size), "cuMemUnmap" + ), + ) + ) + if handle.value: + steps.append( + ( + "cuMemRelease", + lambda: _cuda_try( + _cuda_driver.cuMemRelease(handle.value), "cuMemRelease" + ), + ) + ) + if reserved_va and mapped_va: + steps.append( + ( + "cuMemAddressFree", + lambda: _cuda_try( + _cuda_driver.cuMemAddressFree(mapped_va, alloc_size), + "cuMemAddressFree", + ), + ) + ) + _cleanup_after_failure(*steps) + raise + + def export_handle(self, allocation: LocalAllocation) -> bytes: + """Export a 4-byte native-endian POSIX-FD descriptor for a local allocation.""" + self._check_initialized() + fd = ctypes.c_int(-1) + _cuda_try( + _cuda_driver.cuMemExportToShareableHandle( + ctypes.byref(fd), + int(allocation.handle), + _CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + 0, + ), + "cuMemExportToShareableHandle", + ) + return struct.pack(_CUDA_HANDLE_FMT, int(fd.value)) + + def _import_handle(self, handle_bytes: bytes) -> int: + handle_bytes = _normalize_handle_bytes(handle_bytes) + fd_value = struct.unpack(_CUDA_HANDLE_FMT, handle_bytes)[0] + imported = ctypes.c_uint64() + err = _cuda_driver.cuMemImportFromShareableHandle( + ctypes.byref(imported), + ctypes.c_void_p(fd_value), + _CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + ) + if err != CUDA_SUCCESS: + try: + os.close(fd_value) + except OSError: + pass + _cuda_try(err, "cuMemImportFromShareableHandle") + os.close(fd_value) + return int(imported.value) + + def import_and_map( + self, + peer_rank: int, + handle_bytes: bytes, + size: int, + va: Optional[int] = None, + *, + access_va: Optional[int] = None, + access_size: Optional[int] = None, + ) -> PeerMapping: + """Import a POSIX-FD handle and map it into local CUDA VMM VA space.""" + self._check_initialized() + if (access_va is None) != (access_size is None): + raise LocalCudaError("access_va and access_size must be provided together") + imported_handle = self._import_handle(handle_bytes) + + granularity = self._get_granularity() + va_owned = va is None + mapped_va = int(va) if va is not None else 0 + mapped = False + try: + if va_owned: + reserved = ctypes.c_uint64() + _cuda_try( + _cuda_driver.cuMemAddressReserve( + ctypes.byref(reserved), size, granularity, 0, 0 + ), + "cuMemAddressReserve", + ) + mapped_va = int(reserved.value) + _cuda_try( + _cuda_driver.cuMemMap(mapped_va, size, 0, imported_handle, 0), + "cuMemMap", + ) + mapped = True + self._mem_set_access( + int(access_va) if access_va is not None else mapped_va, + int(access_size) if access_size is not None else size, + ) + except Exception: + steps: list[tuple[str, Callable[[], None]]] = [] + if mapped: + steps.append( + ( + "cuMemUnmap", + lambda: _cuda_try( + _cuda_driver.cuMemUnmap(mapped_va, size), "cuMemUnmap" + ), + ) + ) + steps.append( + ( + "cuMemRelease", + lambda: _cuda_try( + _cuda_driver.cuMemRelease(imported_handle), "cuMemRelease" + ), + ) + ) + if va_owned and mapped_va: + steps.append( + ( + "cuMemAddressFree", + lambda: _cuda_try( + _cuda_driver.cuMemAddressFree(mapped_va, size), + "cuMemAddressFree", + ), + ) + ) + _cleanup_after_failure(*steps) + raise + + tag = "driver_va" if va_owned else "caller_va" + return PeerMapping( + peer_rank=peer_rank, + transport=InterconnectLevel.INTRA_NODE, + remote_va=mapped_va, + size=size, + _driver_handle=(tag, imported_handle), + ) + + def cleanup_import(self, mapping: PeerMapping) -> None: + """Unmap, release, and free an imported CUDA VMM mapping.""" + self._check_initialized() + if ( + isinstance(mapping._driver_handle, tuple) + and len(mapping._driver_handle) == 2 + ): + tag, imported_handle = mapping._driver_handle + else: + tag = "driver_va" + imported_handle = mapping._driver_handle + + steps: list[tuple[str, Callable[[], None]]] = [ + ( + "cuMemUnmap", + lambda: _cuda_try( + _cuda_driver.cuMemUnmap(mapping.remote_va, mapping.size), + "cuMemUnmap", + ), + ), + ( + "cuMemRelease", + lambda: _cuda_try( + _cuda_driver.cuMemRelease(imported_handle), "cuMemRelease" + ), + ), + ] + if tag == "driver_va": + steps.append( + ( + "cuMemAddressFree", + lambda: _cuda_try( + _cuda_driver.cuMemAddressFree(mapping.remote_va, mapping.size), + "cuMemAddressFree", + ), + ) + ) + _run_cleanup_steps(*steps) + + def cleanup_local(self, allocation: LocalAllocation) -> None: + """Unmap, release, and conditionally free a local CUDA VMM allocation.""" + self._check_initialized() + steps = [ + ( + "cuMemUnmap", + lambda: _cuda_try( + _cuda_driver.cuMemUnmap(allocation.va, allocation.size), + "cuMemUnmap", + ), + ), + ( + "cuMemRelease", + lambda: _cuda_try( + _cuda_driver.cuMemRelease(allocation.handle), "cuMemRelease" + ), + ), + ] + if allocation._va_owned: + steps.append( + ( + "cuMemAddressFree", + lambda: _cuda_try( + _cuda_driver.cuMemAddressFree(allocation.va, allocation.size), + "cuMemAddressFree", + ), + ) + ) + _run_cleanup_steps(*steps) + + def get_minimum_granularity(self) -> int: + """Return the CUDA VMM allocation granularity for this device.""" + self._check_initialized() + return self._get_granularity() + + def reserve_va(self, size: int, alignment: int = 0) -> int: + """Reserve a CUDA virtual address range without backing memory.""" + self._check_initialized() + if alignment == 0: + alignment = self._get_granularity() + + reserved = ctypes.c_uint64() + _cuda_try( + _cuda_driver.cuMemAddressReserve( + ctypes.byref(reserved), size, alignment, 0, 0 + ), + "cuMemAddressReserve", + ) + return int(reserved.value) + + def free_va(self, va: int, size: int) -> None: + """Free a CUDA VA range previously returned by reserve_va.""" + self._check_initialized() + _cuda_try(_cuda_driver.cuMemAddressFree(va, size), "cuMemAddressFree") diff --git a/iris/host/distributed/topology.py b/iris/host/distributed/topology.py index c1123aafe..cbf20ce22 100644 --- a/iris/host/distributed/topology.py +++ b/iris/host/distributed/topology.py @@ -1122,20 +1122,10 @@ def __init__(self, iris_ctx=None): if num_gpus <= 0: raise RuntimeError("TopologyDiscovery requires at least one GPU") - # Use LOCAL_RANK (set by torchrun/SLURM) for per-node GPU assignment. - # This is more robust than global_rank % num_gpus, which breaks when - # ranks aren't distributed in a way that aligns with device_count - # (e.g., 2 nodes with 8 GPUs each but only 4 ranks per node). - # The % num_gpus clamp handles isolation (LOCAL_RANK=3, device_count=1). - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - self.gpu_id = local_rank % num_gpus - # MUST set device BEFORE init_process_group — NCCL needs a CUDA - # device assigned to this process, otherwise all ranks fight over - # GPU 0 and init either fails or produces world_size=1. - torch.cuda.set_device(self.gpu_id) if dist.is_initialized(): self.rank = dist.get_rank() self.world_size = dist.get_world_size() + self.gpu_id = torch.cuda.current_device() else: raise RuntimeError("TopologyDiscovery requires an initialized distributed process group.") diff --git a/iris/host/memory/allocators/__init__.py b/iris/host/memory/allocators/__init__.py index 916d2a7a7..d25f6fb7a 100644 --- a/iris/host/memory/allocators/__init__.py +++ b/iris/host/memory/allocators/__init__.py @@ -5,3 +5,4 @@ from .base import BaseAllocator # noqa: F401 from .torch_allocator import TorchAllocator # noqa: F401 from .vmem_allocator import VMemAllocator # noqa: F401 +from .vmem_chunked_allocator import VMemChunkedAllocator # noqa: F401 diff --git a/iris/host/memory/allocators/vmem_chunked_allocator.py b/iris/host/memory/allocators/vmem_chunked_allocator.py new file mode 100644 index 000000000..68915a397 --- /dev/null +++ b/iris/host/memory/allocators/vmem_chunked_allocator.py @@ -0,0 +1,700 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Chunked VMem allocator with power-of-two free lists and GC-based deallocation. + +Design: +- Reserve large VA range up front (cheap, just address space) +- Map physical memory in large chunks (e.g. 256 MiB) +- Driver applies access once per chunk (not per allocation) +- Sub-allocate from chunks with bump pointer +- Power-of-two free lists for O(1) alloc/free reuse +- GC via weakref finalizers on tensor.untyped_storage() +- Free/reuse is pure bookkeeping (no driver calls, no physical remap) + +Cost model: +- Init: ~170us per chunk (create + map + set_access for 1 device) +- Per-allocation: 0us (bump or free-list pop, no driver calls) +- Per-free: 0us (push to free list, no driver calls) +- Chunk growth: ~170us (rare, every chunk_size bytes) +""" + +import math +import logging +import weakref +from collections import defaultdict, deque +from dataclasses import dataclass +from threading import Lock +from typing import List, Optional + +import torch + +from .base import BaseAllocator +from iris.drivers.base import LocalAllocation, PeerMapping +from iris.drivers.factory import DriverFactory +from iris.host.distributed.topology import ( + InterconnectLevel, + TopologyMap, + _detect_vendor, +) + +logger = logging.getLogger("iris.host.memory.allocators.vmem_chunked_allocator") + + +# Module-level CUDAArrayInterface to avoid repeated class creation +class _CUDAArrayInterface: + __slots__ = ("ptr", "nbytes") + + def __init__(self, ptr, nbytes): + self.ptr = ptr + self.nbytes = nbytes + + @property + def __cuda_array_interface__(self): + return { + "shape": (self.nbytes,), + "typestr": "|u1", + "data": (self.ptr, False), + "version": 3, + } + + +# Cached element sizes to avoid torch.tensor([], dtype=...).element_size() overhead +_DTYPE_ELEMENT_SIZE = {} + + +def _element_size(dtype): + if dtype not in _DTYPE_ELEMENT_SIZE: + _DTYPE_ELEMENT_SIZE[dtype] = torch.tensor([], dtype=dtype).element_size() + return _DTYPE_ELEMENT_SIZE[dtype] + + +def _next_power_of_two(n): + """Return the smallest power of two >= n.""" + if n <= 0: + return 1 + return 1 << (n - 1).bit_length() + + +def _is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +@dataclass +class _SharedRegion: + """Exported heap region tracked for peer refresh.""" + + va: int + size: int + allocation: Optional[LocalAllocation] = None + + +class VMemChunkedAllocator(BaseAllocator): + """ + Chunked VMem allocator with power-of-two free lists. + + Physical memory is allocated in large chunks that are mapped once and + never remapped. Free/reuse is pure bookkeeping -- no driver calls, no + physical remap, no peer coordination. + + Driver tier selection: this allocator picks ONE driver per rank based on + the topology. If any peer is in the same fabric domain on a different + host, the fabric driver is used for ALL peers including local ones. + Mixed-tier jobs work correctly but local peers may incur a small overhead + vs a pure intra-node driver. A future per-peer driver selection would + require a higher-level coordinator. The chosen tier is exposed via the + `transport_tier` property and the underlying driver via the `driver` + attribute; both are part of the public contract for orchestration layers. + + Lifetime: this allocator's GC finalizers, attached to every allocated + tensor's storage, hold a strong reference to the allocator itself. As a + consequence, del allocator does NOT destroy the allocator while any tensor + allocated by it is still live. To release resources deterministically, call + allocator.close() explicitly, or use the allocator as a context manager: + + with VMemChunkedAllocator(...) as alloc: + tensor = alloc.allocate(...) + ... + # alloc.close() called automatically here, even if tensor is still alive. + + The tensor must not be used after close(). The __del__ method is a + best-effort backup; do not rely on it for release ordering. + + Args: + heap_size: Initial heap size in bytes (best effort; will grow if exceeded) + device_id: GPU device ID + cur_rank: Current process rank + num_ranks: Total number of ranks + chunk_size: Size of each physical chunk in bytes (default 256 MiB) + va_size: Total VA reservation size (default 64 GiB) + topology: Optional cluster topology used to select local vs fabric driver + """ + + # Default chunk size: 256 MiB + DEFAULT_CHUNK_SIZE = 256 * 1024 * 1024 + # Default VA size: 0 means auto-size (128 GiB) + DEFAULT_VA_SIZE = 0 + + def __init__( + self, + heap_size: int, + device_id: int, + cur_rank: int, + num_ranks: int, + chunk_size: int = DEFAULT_CHUNK_SIZE, + va_size: int = DEFAULT_VA_SIZE, + *, + topology: Optional[TopologyMap] = None, + ) -> None: + super().__init__(heap_size, device_id, cur_rank, num_ranks) + self.device = torch.device(f"cuda:{device_id}") + self.lock = Lock() + self._closed = False + self.driver = None + + # Collections initialized first so close() is always safe to call, + # even if a later __init__ step raises. + self.bump = 0 + self.free_lists = defaultdict(list) + self.alloc_sizes = {} + self._pending_free = deque() + self.chunks: List[LocalAllocation] = [] + self._shared_regions: List[_SharedRegion] = [] + self._peer_mappings: List[PeerMapping] = [] + self._imported_heap_mappings: List[PeerMapping] = [] + self.mapped_extent = 0 + self.base_va = 0 + self.granularity = 0 + + vendor = _detect_vendor() + if vendor == "unknown": + raise RuntimeError( + "VMemChunkedAllocator: could not detect GPU vendor; no compatible driver available" + ) + + interconnect = InterconnectLevel.INTRA_NODE + if topology is not None: + own_info = topology.gpu_info.get(cur_rank) + if own_info is None: + logger.warning( + "Rank %d not found in topology.gpu_info; defaulting to INTRA_NODE driver. " + "This may indicate a topology/rank-assignment mismatch.", + cur_rank, + ) + else: + own_domain = own_info.fabric_info.domain_key + if own_domain: + for peer_rank, peer_info in topology.gpu_info.items(): + if peer_rank == cur_rank: + continue + if ( + peer_info.hostname != own_info.hostname + and peer_info.fabric_info.domain_key == own_domain + ): + interconnect = InterconnectLevel.INTRA_RACK_FABRIC + logger.info( + "Rank %d using fabric driver: peer %d on host %s shares fabric domain %s", + cur_rank, + peer_rank, + peer_info.hostname, + own_domain, + ) + break + + self.driver = DriverFactory.create_driver(vendor, interconnect) + self.driver.initialize(device_id) + self._interconnect = interconnect + logger.info( + "VMemChunkedAllocator initialized: vendor=%s, interconnect=%s, device=%d, rank=%d/%d", + vendor, + interconnect.name, + device_id, + cur_rank, + num_ranks, + ) + self.granularity = self.driver.get_minimum_granularity() + if not _is_power_of_two(self.granularity): + raise RuntimeError( + f"VMemChunkedAllocator: driver granularity {self.granularity} " + f"is not a power of two; bitmask alignment math will not work. " + f"This indicates a driver bug." + ) + + # Chunk configuration -- cap at heap_size to avoid overshooting VA + effective_chunk = min(chunk_size, max(heap_size, self.granularity)) + self.chunk_size = max(effective_chunk, self.granularity) + # granularity is a power of two (asserted above), so the bitmask trick + # (x + a - 1) & ~(a - 1) is safe for this alignment. + self.chunk_size = (self.chunk_size + self.granularity - 1) & ~( + self.granularity - 1 + ) + if not _is_power_of_two(self.chunk_size): + raise RuntimeError( + f"VMemChunkedAllocator: computed chunk_size {self.chunk_size} " + f"is not a power of two. Pass a power-of-two chunk_size to the " + f"constructor (default {self.DEFAULT_CHUNK_SIZE} is safe)." + ) + + # VA reservation -- just address space, no physical memory cost. + # Default: 128 GiB (plenty of room for growth + imports). + if va_size == 0: + va_size = 128 * 1024 * 1024 * 1024 # 128 GiB + # va_size = max(va_size, heap_size * 4) -- gives 4x headroom for growth + # and imports. For very large heaps (>100 GiB) this can exceed sensible + # VA budgets; if reserve_va fails, the caller should pass an explicit + # va_size. Not capped here because the right cap depends on workload. + self.va_size = max(va_size, heap_size * 4) + # chunk_size is a power of two (asserted above), so bitmask alignment is safe. + self.va_size = (self.va_size + self.chunk_size - 1) & ~(self.chunk_size - 1) + self.base_va = self.driver.reserve_va(self.va_size, self.granularity) + + self.min_alignment = max(self.granularity, 1024) + + # Pre-allocate initial chunks to cover heap_size + n_initial_chunks = max(1, math.ceil(heap_size / self.chunk_size)) + for _ in range(n_initial_chunks): + self._grow_chunk() + + def _grow_chunk(self): + """Map a new physical chunk into the VA range.""" + if self.mapped_extent + self.chunk_size > self.va_size: + raise RuntimeError( + f"VMemChunkedAllocator: VA space exhausted. " + f"mapped_extent={self.mapped_extent}, " + f"chunk_size={self.chunk_size}, va_size={self.va_size}" + ) + + target_va = self.base_va + self.mapped_extent + alloc_kwargs = {} + if self.driver.__class__.__name__ == "LocalHipDriver": + alloc_kwargs = { + "access_va": self.base_va, + "access_size": self.mapped_extent + self.chunk_size, + } + allocation = self.driver.allocate_exportable( + self.chunk_size, va=target_va, **alloc_kwargs + ) + self.chunks.append(allocation) + self._shared_regions.append( + _SharedRegion(va=allocation.va, size=allocation.size, allocation=allocation) + ) + self.mapped_extent += self.chunk_size + + def _process_pending_frees(self): + """Process pending frees from GC finalizers. Call with lock held.""" + while self._pending_free: + offset, size_class = self._pending_free.popleft() + self.free_lists[size_class].append(offset) + + def _free_callback(self, offset, size_class): + """Called by weakref finalizer when a tensor's storage is GC'd. + + NOTE: this runs without holding self.lock. That is intentional -- GC + finalizers can fire from any thread, and acquiring a lock from a + finalizer risks deadlock with the thread the GC interrupted. Safety + relies on: + - deque.append being atomic under the GIL + - stale appends after close() being benign because the deque is cleared + and never read again. + """ + if self._closed: + return + self._pending_free.append((offset, size_class)) + + def get_base_address(self) -> int: + return self.base_va + + def get_minimum_allocation_size(self) -> int: + return self.granularity + + def get_device(self) -> torch.device: + return self.device + + @property + def transport_tier(self) -> InterconnectLevel: + """Return the interconnect tier this allocator's driver operates over. + + Used by orchestration layers to choose between local FD-based and + fabric-handle peer setup. Stable for the allocator lifetime. + """ + return self._interconnect + + def allocate( + self, num_elements: int, dtype: torch.dtype, alignment: int = 1024 + ) -> torch.Tensor: + with self.lock: + if num_elements == 0: + return torch.empty(0, dtype=dtype, device=self.device) + + self._process_pending_frees() + + elem_size = _element_size(dtype) + size_bytes = num_elements * elem_size + # Minimum allocation is one granule + size_bytes = max(size_bytes, self.granularity) + # Round to next power of two for free-list bucketing + size_class = _next_power_of_two(size_bytes) + # Ensure alignment to granularity + aligned_size = max(size_class, self.min_alignment) + aligned_size = (aligned_size + self.granularity - 1) & ~( + self.granularity - 1 + ) + + # Try free list first + if self.free_lists[aligned_size]: + offset = self.free_lists[aligned_size].pop() + else: + # Bump allocate + # Align the bump pointer + aligned_bump = (self.bump + self.min_alignment - 1) & ~( + self.min_alignment - 1 + ) + + needed = aligned_bump + aligned_size - self.mapped_extent + if needed > 0: + n_new_chunks = math.ceil(needed / self.chunk_size) + if ( + self.mapped_extent + n_new_chunks * self.chunk_size + > self.va_size + ): + raise RuntimeError( + f"VMemChunkedAllocator: requested allocation needs " + f"{n_new_chunks} chunks ({n_new_chunks * self.chunk_size} bytes), " + f"but only {self.va_size - self.mapped_extent} bytes of VA remain." + ) + for _ in range(n_new_chunks): + self._grow_chunk() + + offset = aligned_bump + self.bump = aligned_bump + aligned_size + + # Track for free + self.alloc_sizes[offset] = aligned_size + + # Create tensor via CUDAArrayInterface + va = self.base_va + offset + interface_size = (aligned_size // elem_size) * elem_size + iface = _CUDAArrayInterface(va, interface_size) + tensor_bytes = torch.as_tensor(iface, device=self.device) + full = tensor_bytes.view(dtype) + tensor = full.narrow(0, 0, num_elements) + + # Attach GC weak ref for automatic free + weakref.finalize( + tensor.untyped_storage(), + self._free_callback, + offset, + aligned_size, + ) + + return tensor + + def owns_tensor(self, tensor: torch.Tensor) -> bool: + if not tensor.is_cuda: + return False + if tensor.numel() == 0: + return True + ptr = tensor.data_ptr() + if self.base_va <= ptr < self.base_va + self.va_size: + return True + for mapping in self._peer_mappings: + if mapping.remote_va <= ptr < mapping.remote_va + mapping.size: + return True + return False + + def get_allocation_chunks(self): + """ + Get list of exported heap regions for peer sharing. + + This includes both ordinary chunk-backed heap regions and imported + external tensors that have been permanently mapped into the heap for + RMA-safe symmetric addressing. + + Each call invokes driver export once per returned region, which on AMD + allocates a fresh DMA-BUF file descriptor per chunk. The caller OWNS + those FDs and must close them after delivering them to peers (e.g. via + SCM_RIGHTS or pidfd_getfd). Calling this method multiple times will + leak FDs unless the caller closes the previous batch. + + This method is intended to be called once per allocator lifetime during + peer setup. If you need to re-export after a chunk grow, cache the + previous handle_bytes externally and only call this for new chunks. + + Returns: + List of (chunk_index, offset, size, handle_bytes) tuples, + where handle_bytes is the serialized peer handle for the chunk. + """ + return self.get_allocation_chunks_since(0) + + def get_allocation_chunks_since(self, start_index: int): + """ + Get exported heap regions at index >= start_index, with their handles. + + Used by orchestration layers that have already imported the first + `start_index` regions and only need handles for newly-added ones. + Each call invokes driver.export_handle once per returned chunk; for + drivers that allocate kernel resources (FDs) at export time, this + avoids re-exporting chunks the caller has already processed. + + Args: + start_index: chunks at indices [0, start_index) are skipped. + + Returns: + List of (chunk_index, offset, size, handle_bytes) tuples for + chunks at index >= start_index. Indices are absolute (matching + get_allocation_chunks's output), not relative to start_index. + + Raises: + ValueError: if start_index is negative or > current chunk count. + """ + if start_index < 0: + raise ValueError(f"start_index must be non-negative, got {start_index}") + if start_index > len(self._shared_regions): + raise ValueError( + f"start_index {start_index} exceeds exported region count {len(self._shared_regions)}" + ) + + result = [] + for i in range(start_index, len(self._shared_regions)): + region = self._shared_regions[i] + offset = region.va - self.base_va + if region.allocation is not None: + handle_bytes = self.driver.export_handle(region.allocation) + else: + handle_bytes = self.driver.export_pointer_handle(region.va, region.size) + result.append((i, offset, region.size, handle_bytes)) + return result + + def get_num_chunks(self): + """Return the number of exported heap regions.""" + return len(self._shared_regions) + + def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: + """Import an external tensor into the symmetric heap (zero-copy). + + The imported tensor shares physical memory with the original -- writes + to one are visible in the other. On the chunked allocator path, the + imported allocation is mapped into the heap's VA layout and retained + until allocator.close() so peer translation remains valid for RMA. + + Raises: + DriverNotSupported: This operation requires DMA-BUF support and is + currently AMD-only. On NVIDIA, the local driver does not + implement export_pointer_handle for arbitrary device pointers, + and this method will raise. + RuntimeError: If the input tensor is not on a CUDA/HIP device or + is not contiguous. + """ + with self.lock: + if not external_tensor.is_cuda: + raise RuntimeError("Can only import CUDA/HIP tensors") + if not external_tensor.is_contiguous(): + raise RuntimeError( + "Only contiguous tensors can be imported; call .contiguous() before as_symmetric()" + ) + + external_ptr = external_tensor.data_ptr() + tensor_size = external_tensor.numel() * external_tensor.element_size() + alloc_base, alloc_size = self.driver.get_address_range(external_ptr) + offset_in_alloc = external_ptr - alloc_base + aligned_alloc_size = (alloc_size + self.granularity - 1) & ~( + self.granularity - 1 + ) + + target_offset = (self.mapped_extent + self.granularity - 1) & ~( + self.granularity - 1 + ) + if target_offset + aligned_alloc_size > self.va_size: + raise RuntimeError( + f"VMemChunkedAllocator: imported tensor needs {aligned_alloc_size} bytes " + f"at offset {target_offset}, but only " + f"{self.va_size - target_offset} bytes of VA remain." + ) + + target_base_va = self.base_va + target_offset + handle_bytes = self.driver.export_pointer_handle(alloc_base, alloc_size) + import_kwargs = {} + if self.driver.__class__.__name__ == "LocalHipDriver": + import_kwargs = { + "access_va": self.base_va, + "access_size": target_offset + aligned_alloc_size, + } + mapping = self.driver.import_and_map( + self.cur_rank, + handle_bytes, + aligned_alloc_size, + va=target_base_va, + **import_kwargs, + ) + self._imported_heap_mappings.append(mapping) + self._shared_regions.append( + _SharedRegion( + va=target_base_va, size=aligned_alloc_size, allocation=None + ) + ) + self.mapped_extent = target_offset + aligned_alloc_size + self.bump = max(self.bump, self.mapped_extent) + + tensor_va = target_base_va + offset_in_alloc + iface = _CUDAArrayInterface(tensor_va, tensor_size) + tensor_bytes = torch.as_tensor(iface, device=self.device) + return tensor_bytes.view(external_tensor.dtype).reshape( + external_tensor.shape + ) + + def _import_release_callback(self, mapping: PeerMapping) -> None: + """Called when an imported tensor's storage is GC'd. + + Unlike _free_callback, this finalizer DOES take self.lock. The lock, + combined with using self._peer_mappings.remove(mapping) as the gate, + is what makes cleanup race-free against close() and release_peer_chunk: + only one code path can successfully remove a given mapping from the + list, and that code path owns the cleanup_import call. + + The self._closed check before the lock is a fast-path optimization only + -- it is NOT a correctness gate. On weakly-ordered architectures the + finalizer thread may not see _closed=True until the lock acquire forces + a memory barrier. That's fine: when it does acquire the lock, it will + find the mapping already removed by close() and return via the + ValueError path. + + Acquiring a lock from a finalizer normally risks deadlock, but here: + - imported tensors are not allocated frequently (a few per process) + - the lock is held only for one driver call + one list mutation + - the hot allocate path uses a different finalizer (_free_callback) + which does NOT take the lock, so allocate can never block this + finalizer indirectly. + """ + if self._closed: + return + with self.lock: + try: + self._peer_mappings.remove(mapping) + except ValueError: + return + try: + self.driver.cleanup_import(mapping) + except Exception as exc: + logger.warning("cleanup_import failed in finalizer: %s", exc) + + def import_peer_chunk(self, peer_rank: int, handle_bytes: bytes, size: int) -> int: + """ + Import a serialized chunk handle from a peer rank. + + Returns the local virtual address where the chunk was mapped. + The caller is responsible for calling release_peer_chunk when done. + """ + with self.lock: + mapping = self.driver.import_and_map(peer_rank, handle_bytes, size, va=None) + self._peer_mappings.append(mapping) + return mapping.remote_va + + def release_peer_chunk(self, remote_va: int) -> None: + """Release a peer chunk previously imported via import_peer_chunk. + + Idempotent -- calling this on a remote_va that's already been released + (or that was never imported via import_peer_chunk) is a no-op and logs + at DEBUG. Do NOT use this for tensors returned by import_external_tensor; + those remain mapped until allocator.close() so their heap offsets stay + valid for peer RMA. + """ + with self.lock: + for i, mapping in enumerate(self._peer_mappings): + if mapping.remote_va == remote_va: + try: + self.driver.cleanup_import(mapping) + finally: + self._peer_mappings.pop(i) + return + logger.debug( + "release_peer_chunk: no mapping at va=0x%x (already released or never imported via import_peer_chunk)", + remote_va, + ) + + def get_stats(self): + """Return allocator statistics.""" + total_free = sum(len(v) for v in self.free_lists.values()) + free_bytes = sum( + size_class * len(offsets) for size_class, offsets in self.free_lists.items() + ) + return { + "num_chunks": len(self.chunks), + "mapped_bytes": self.mapped_extent, + "bump": self.bump, + "num_active_allocs": len(self.alloc_sizes), + "num_free_blocks": total_free, + "free_bytes": free_bytes, + "va_size": self.va_size, + "chunk_size": self.chunk_size, + "granularity": self.granularity, + } + + def close(self): + """Release all VMem resources.""" + if self._closed: + return + self._closed = True + + if self.driver is None: + return + + with self.lock: + # Clear allocator bookkeeping before releasing mappings/chunks. + self._pending_free.clear() + self.free_lists.clear() + self.alloc_sizes.clear() + + try: + torch.cuda.synchronize(self.device) + except Exception: + pass + + # Release imported peer mappings + for mapping in self._peer_mappings: + try: + self.driver.cleanup_import(mapping) + except Exception: + pass + self._peer_mappings.clear() + + # Release imported symmetric mappings that were inserted into the + # heap VA layout via import_external_tensor. + for mapping in self._imported_heap_mappings: + try: + self.driver.cleanup_import(mapping) + except Exception: + pass + self._imported_heap_mappings.clear() + self._shared_regions.clear() + + # Release locally-mapped chunks. Each chunk has _va_owned=False, + # so cleanup_local will unmap and release the physical handle + # but will NOT free VA + for alloc in self.chunks: + try: + self.driver.cleanup_local(alloc) + except Exception: + pass + self.chunks.clear() + + # Free the master VA reservation last. + if self.base_va: + try: + self.driver.free_va(self.base_va, self.va_size) + except Exception: + pass + self.base_va = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + return False + + def __del__(self): + try: + self.close() + except Exception: + pass diff --git a/iris/host/memory/symmetric_heap.py b/iris/host/memory/symmetric_heap.py index 955eeb462..bd4964679 100644 --- a/iris/host/memory/symmetric_heap.py +++ b/iris/host/memory/symmetric_heap.py @@ -9,17 +9,71 @@ """ import logging +import os +import struct import numpy as np import torch -import os from iris.host.logging.logging import _log_rank, logger -from iris.host.memory.allocators import TorchAllocator, VMemAllocator +from iris.host.memory.allocators import TorchAllocator, VMemAllocator, VMemChunkedAllocator +from iris.drivers.base import PeerMapping from iris.host.distributed.fd_passing import setup_fd_infrastructure from iris.host.distributed.helpers import distributed_allgather from iris.host.platform.utils import is_simulation_env +logger = logging.getLogger("iris.host.memory.symmetric_heap") + +# Layout of LocalHipDriver handle bytes: native int FD, then uint64 offset, +# then uint64 base_size. See LocalHipDriver.export_handle in +# iris/drivers/local/amd.py. +_LOCAL_HIP_HANDLE_FMT = "=iQQ" +_LOCAL_HIP_HANDLE_BYTES = struct.calcsize(_LOCAL_HIP_HANDLE_FMT) +_LOCAL_CUDA_HANDLE_FMT = "=i" +_LOCAL_CUDA_HANDLE_BYTES = struct.calcsize(_LOCAL_CUDA_HANDLE_FMT) + + +def _extract_fd_from_local_handle(handle_bytes: bytes) -> int: + """Extract the FD from a local-driver handle byte string.""" + if len(handle_bytes) == _LOCAL_HIP_HANDLE_BYTES: + fd, _offset, _base_size = struct.unpack(_LOCAL_HIP_HANDLE_FMT, handle_bytes) + return fd + if len(handle_bytes) == _LOCAL_CUDA_HANDLE_BYTES: + (fd,) = struct.unpack(_LOCAL_CUDA_HANDLE_FMT, handle_bytes) + return fd + raise RuntimeError( + "Unsupported local driver handle size: " + f"expected {_LOCAL_HIP_HANDLE_BYTES} (HIP) or {_LOCAL_CUDA_HANDLE_BYTES} " + f"(CUDA), got {len(handle_bytes)}" + ) + + +def _replace_fd_in_local_handle(handle_bytes: bytes, new_fd: int) -> bytes: + """Return a copy of handle_bytes with the FD field replaced.""" + if len(handle_bytes) == _LOCAL_HIP_HANDLE_BYTES: + _old_fd, offset, base_size = struct.unpack(_LOCAL_HIP_HANDLE_FMT, handle_bytes) + return struct.pack(_LOCAL_HIP_HANDLE_FMT, new_fd, offset, base_size) + if len(handle_bytes) == _LOCAL_CUDA_HANDLE_BYTES: + return struct.pack(_LOCAL_CUDA_HANDLE_FMT, new_fd) + raise RuntimeError( + "Unsupported local driver handle size: " + f"expected {_LOCAL_HIP_HANDLE_BYTES} (HIP) or {_LOCAL_CUDA_HANDLE_BYTES} " + f"(CUDA), got {len(handle_bytes)}" + ) + + +def _validate_local_handle(handle_bytes: bytes) -> None: + """Validate that handle_bytes matches a supported local-driver layout.""" + if len(handle_bytes) not in ( + _LOCAL_HIP_HANDLE_BYTES, + _LOCAL_CUDA_HANDLE_BYTES, + ): + raise RuntimeError( + "Unsupported local driver handle size: " + f"expected {_LOCAL_HIP_HANDLE_BYTES} (HIP) or {_LOCAL_CUDA_HANDLE_BYTES} " + f"(CUDA), got {len(handle_bytes)}" + ) + class SymmetricHeap: """ @@ -28,9 +82,17 @@ class SymmetricHeap: Manages distributed memory with symmetric addressing across ranks, handling all allocator coordination and memory sharing internally. - Supports multiple allocator backends: 'torch' (default) and 'vmem'. + Supports multiple allocator backends: 'torch' (default), 'vmem', and + 'vmem_chunked'. """ + _PEER_REFRESH_FAILURE_MESSAGE = ( + "SymmetricHeap peer refresh failed unrecoverably. " + "Peer refresh is intentionally fatal because partial peer mappings " + "or VA reservations may remain; destroy the Iris context or restart " + "the process instead of retrying." + ) + def __init__( self, heap_size: int, @@ -47,7 +109,8 @@ def __init__( device_id: GPU device ID cur_rank: Current process rank num_ranks: Total number of ranks - allocator_type: Type of allocator ("torch" or "vmem"); default "torch" + allocator_type: Type of allocator ("torch", "vmem", or + "vmem_chunked"); default "torch" Raises: ValueError: If allocator_type is not supported @@ -56,6 +119,7 @@ def __init__( self.device_id = device_id self.cur_rank = cur_rank self.num_ranks = num_ranks + self._peer_va_ranges = {} allocator_type = os.environ.get("IRIS_ALLOCATOR", allocator_type).lower() _log_rank( logging.INFO, @@ -74,10 +138,39 @@ def __init__( self.allocator = TorchAllocator(heap_size, device_id, cur_rank, num_ranks) elif allocator_type == "vmem": self.allocator = VMemAllocator(heap_size, device_id, cur_rank, num_ranks) + elif allocator_type == "vmem_chunked": + from iris.host.distributed.topology import TopologyDiscovery + + try: + topology = TopologyDiscovery().discover() + except Exception as exc: + logger.warning( + "TopologyDiscovery.discover() failed (%s); VMemChunkedAllocator will default to INTRA_NODE driver.", + exc, + ) + topology = None + self.allocator = VMemChunkedAllocator( + heap_size, + device_id, + cur_rank, + num_ranks, + topology=topology, + ) else: - raise ValueError(f"Unknown allocator type: {allocator_type}. Supported: 'torch', 'vmem'") + raise ValueError( + f"Unknown allocator type: {allocator_type}. Supported: 'torch', 'vmem', 'vmem_chunked'" + ) - self.fd_conns = setup_fd_infrastructure(cur_rank, num_ranks) + needs_fd_infra = True + if isinstance(self.allocator, VMemChunkedAllocator): + from iris.host.distributed.topology import InterconnectLevel + + if self.allocator.transport_tier == InterconnectLevel.INTRA_RACK_FABRIC: + needs_fd_infra = False + + self.fd_conns = ( + setup_fd_infrastructure(cur_rank, num_ranks) if needs_fd_infra else None + ) device = self.allocator.get_device() # Use int64 instead of uint64 for gloo backend compatibility @@ -85,11 +178,16 @@ def __init__( heap_bases_array = np.zeros(self.num_ranks, dtype=np.int64) # Create on CPU first, then move to device to avoid FFM ioctl issue if is_simulation_env(): - self.heap_bases = torch.tensor(heap_bases_array, device="cpu", dtype=torch.int64) + self.heap_bases = torch.tensor( + heap_bases_array, device="cpu", dtype=torch.int64 + ) self.heap_bases = self.heap_bases.to(device) else: - self.heap_bases = torch.tensor(heap_bases_array, device=device, dtype=torch.int64) + self.heap_bases = torch.tensor( + heap_bases_array, device=device, dtype=torch.int64 + ) + self._peer_refresh_failed = False self.refresh_peer_access() def close_fd_conns(self): @@ -109,7 +207,18 @@ def close_fd_conns(self): pass self.fd_conns = None - def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) -> torch.Tensor: + def _ensure_peer_refresh_healthy(self) -> None: + if self._peer_refresh_failed: + raise RuntimeError(self._PEER_REFRESH_FAILURE_MESSAGE) + + def _peer_va_size(self) -> int: + if isinstance(self.allocator, (VMemAllocator, VMemChunkedAllocator)): + return self.allocator.va_size + return self.heap_size + + def allocate( + self, num_elements: int, dtype: torch.dtype, alignment: int = 1024 + ) -> torch.Tensor: """ Allocate a tensor on the symmetric heap. @@ -137,13 +246,27 @@ def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) rank=self.cur_rank, num_ranks=self.num_ranks, ) + + self._ensure_peer_refresh_healthy() + min_bytes = self.allocator.get_minimum_allocation_size() element_size = torch.tensor([], dtype=dtype).element_size() min_elements = max(1, (min_bytes + element_size - 1) // element_size) actual_elements = max(num_elements, min_elements) + + is_chunked = isinstance(self.allocator, VMemChunkedAllocator) + chunks_before = self.allocator.get_num_chunks() if is_chunked else 0 + tensor = self.allocator.allocate(actual_elements, dtype, alignment) tensor = tensor[:num_elements] - self.refresh_peer_access() + + if is_chunked: + chunks_after = self.allocator.get_num_chunks() + if chunks_after > chunks_before: + self.refresh_peer_access() + else: + self.refresh_peer_access() + return tensor def get_device(self) -> torch.device: @@ -154,13 +277,38 @@ def on_symmetric_heap(self, tensor: torch.Tensor) -> bool: """ Check if a tensor is allocated on the symmetric heap. + Returns True for any tensor whose memory falls inside either: + - the local allocator's own heap (delegates to allocator.owns_tensor), or + - any peer-VA region the symmetric heap has imported via + _refresh_peer_access_chunked / _refresh_peer_access_fabric. + + The peer-region check is needed because the symmetric heap calls + driver.import_and_map directly (rather than through the allocator's + import_peer_chunk method), so peer mappings are tracked on the + symmetric heap, not the allocator. Without this check, peer-imported + tensors would incorrectly report as NOT on the symmetric heap. + Args: tensor: PyTorch tensor to check Returns: True if tensor is on the symmetric heap, False otherwise """ - return self.allocator.owns_tensor(tensor) + if self.allocator.owns_tensor(tensor): + return True + + if not tensor.is_cuda or tensor.numel() == 0: + return False + + if not self._peer_va_ranges: + return False + + ptr = tensor.data_ptr() + va_size = self._peer_va_size() + for peer_va_base in self._peer_va_ranges.values(): + if peer_va_base <= ptr < peer_va_base + va_size: + return True + return False def is_symmetric(self, tensor: torch.Tensor) -> bool: """ @@ -186,25 +334,31 @@ def is_symmetric(self, tensor: torch.Tensor) -> bool: def get_heap_bases(self) -> torch.Tensor: """Get heap base addresses for all ranks as a tensor.""" + self._ensure_peer_refresh_healthy() return self.heap_bases def refresh_peer_access(self): """ - Refresh peer DMA-BUF imports using segmented export/import. + Refresh peer imports for the active allocator backend. Collective: all ranks must call together. Do not cache heap_bases. + + Failure policy: refresh is not recoverable. The low-level refresh paths + are intentionally not transactional; if any error is raised, partial + imports or VA reservations may remain. This method marks the heap as + failed and rejects future heap operations instead of allowing retries + against possibly-stale mapping state. """ + self._ensure_peer_refresh_healthy() + try: + self._refresh_peer_access_impl() + except Exception as exc: + self._peer_refresh_failed = True + logger.critical(self._PEER_REFRESH_FAILURE_MESSAGE, exc_info=True) + raise RuntimeError(self._PEER_REFRESH_FAILURE_MESSAGE) from exc + + def _refresh_peer_access_impl(self): + """Implementation for refresh_peer_access; caller owns fatal handling.""" import torch.distributed as dist - from iris.host.distributed.fd_passing import send_fd, recv_fd - from iris.host.platform.hip import ( - export_dmabuf_handle, - mem_import_from_shareable_handle, - mem_map, - mem_set_access, - mem_address_reserve, - hipMemAccessDesc, - hipMemLocationTypeDevice, - hipMemAccessFlagsProtReadWrite, - ) _log_rank( logging.DEBUG, @@ -221,28 +375,333 @@ def refresh_peer_access(self): my_base = self.allocator.get_base_address() # Use int64 instead of uint64 to avoid gloo issues with all_gather_object local_base_arr = np.array([my_base], dtype=np.int64) - all_bases_arr = distributed_allgather(local_base_arr).reshape(self.num_ranks).astype(np.int64) + all_bases_arr = ( + distributed_allgather(local_base_arr) + .reshape(self.num_ranks) + .astype(np.int64) + ) self.heap_bases[self.cur_rank] = int(all_bases_arr[self.cur_rank]) - if self.num_ranks == 1 or self.fd_conns is None: + if self.num_ranks == 1: return - if not hasattr(self.allocator, "get_allocation_segments"): - if hasattr(self.allocator, "establish_peer_access"): - # In simulation, all ranks share the same device, so skip peer access setup - from iris.host.platform.utils import is_simulation_env + # Dispatch to allocator-specific peer access path + if isinstance(self.allocator, VMemChunkedAllocator): + from iris.host.distributed.topology import InterconnectLevel - if is_simulation_env(): - # Just set heap_bases directly from all_bases_arr - for r in range(self.num_ranks): - self.heap_bases[r] = int(all_bases_arr[r]) - else: - all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)} - self.allocator.establish_peer_access(all_bases, self.fd_conns) - for r in range(self.num_ranks): - self.heap_bases[r] = int(self.allocator.heap_bases_array[r]) + if self.allocator.transport_tier == InterconnectLevel.INTRA_RACK_FABRIC: + self._refresh_peer_access_fabric(dist) + else: + if self.fd_conns is None: + return + self._refresh_peer_access_chunked(dist) + elif hasattr(self.allocator, "get_allocation_segments"): + if self.fd_conns is None: + return + self._refresh_peer_access_segmented(dist) + elif hasattr(self.allocator, "establish_peer_access"): + if self.fd_conns is None: + return + self._refresh_peer_access_torch(dist, all_bases_arr) + else: return + if dist.is_initialized(): + dist.barrier() + + def _refresh_peer_access_torch(self, dist, all_bases_arr): + """Peer access for TorchAllocator (IPC-based).""" + from iris.host.platform.utils import is_simulation_env + + if is_simulation_env(): + for r in range(self.num_ranks): + self.heap_bases[r] = int(all_bases_arr[r]) + else: + all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)} + self.allocator.establish_peer_access(all_bases, self.fd_conns) + for r in range(self.num_ranks): + self.heap_bases[r] = int(self.allocator.heap_bases_array[r]) + + def _get_collective_device(self, dist) -> torch.device: + """Choose a safe device for small distributed tensor collectives.""" + if not dist.is_initialized(): + return self.allocator.get_device() + try: + backend = str(dist.get_backend()).lower() + except Exception: + backend = "gloo" + if backend == "nccl": + return self.allocator.get_device() + return torch.device("cpu") + + def _ensure_chunk_tracking_state(self) -> None: + if not hasattr(self, "_shared_chunk_counts"): + self._shared_chunk_counts = [0] * self.num_ranks + if not hasattr(self, "_peer_va_ranges"): + self._peer_va_ranges = {} + if not hasattr(self, "_peer_imported_mappings"): + self._peer_imported_mappings = {} + + def _gather_total_chunk_counts(self, dist, local_total: int) -> list[int]: + count_device = self._get_collective_device(dist) + local_total_tensor = torch.tensor( + [local_total], dtype=torch.int64, device=count_device + ) + gathered_counts = [ + torch.zeros(1, dtype=torch.int64, device=count_device) + for _ in range(self.num_ranks) + ] + dist.all_gather(gathered_counts, local_total_tensor) + return [int(count.item()) for count in gathered_counts] + + def _refresh_peer_access_chunked(self, dist): + """ + Peer access for VMemChunkedAllocator on local-host (FD-based) transports. + + Uses Unix-domain sockets for FD passing, but routes export/import + through the driver abstraction. The handle-byte layout is driver- + specific; this method only extracts and replaces the embedded FD + field for socket transport. + + Supports both LocalHipDriver's 20-byte DMA-BUF layout and + LocalCudaDriver's 4-byte POSIX-FD layout. + """ + from iris.host.distributed.fd_passing import recv_fd, send_fd + + self._ensure_chunk_tracking_state() + + my_total = self.allocator.get_num_chunks() + total_chunks_per_rank = self._gather_total_chunk_counts(dist, my_total) + my_total = total_chunks_per_rank[self.cur_rank] + my_already_shared = self._shared_chunk_counts[self.cur_rank] + new_chunk_records = self.allocator.get_allocation_chunks_since( + my_already_shared + ) + + my_metadata = [ + (offset, size, handle_bytes) + for (_idx, offset, size, handle_bytes) in new_chunk_records + ] + gathered_metadata = [None] * self.num_ranks + dist.all_gather_object(gathered_metadata, my_metadata) + + new_fds = [ + _extract_fd_from_local_handle(handle_bytes) + for (_offset, _size, handle_bytes) in my_metadata + ] + + pending_cloned_fds = [] + try: + for peer, sock in self.fd_conns.items(): + if peer == self.cur_rank: + continue + + if peer not in self._peer_va_ranges: + peer_va_base = self.allocator.driver.reserve_va( + self.allocator.va_size, + self.allocator.granularity, + ) + self._peer_va_ranges[peer] = peer_va_base + else: + peer_va_base = self._peer_va_ranges[peer] + self._peer_imported_mappings.setdefault(peer, []) + + peer_metadata = gathered_metadata[peer] + if peer_metadata is None: + raise RuntimeError(f"Missing chunk metadata for peer {peer}") + + peer_new_count = ( + total_chunks_per_rank[peer] - self._shared_chunk_counts[peer] + ) + if peer_new_count != len(peer_metadata): + raise RuntimeError( + f"Chunk metadata count mismatch for peer {peer}: " + f"expected {peer_new_count}, got {len(peer_metadata)}" + ) + + for peer_offset, peer_size, peer_handle_bytes in peer_metadata: + if peer_offset + peer_size > self.allocator.va_size: + raise RuntimeError( + f"Peer {peer} chunk extends beyond va_size: " + f"offset={peer_offset}, size={peer_size}, " + f"va_size={self.allocator.va_size}" + ) + _validate_local_handle(peer_handle_bytes) + + cloned_fds = [] + if self.cur_rank > peer: + for my_fd in new_fds: + send_fd(sock, my_fd) + for _ in range(peer_new_count): + cloned_fd, _ = recv_fd(sock) + cloned_fds.append(cloned_fd) + pending_cloned_fds.append(cloned_fd) + else: + for _ in range(peer_new_count): + cloned_fd, _ = recv_fd(sock) + cloned_fds.append(cloned_fd) + pending_cloned_fds.append(cloned_fd) + for my_fd in new_fds: + send_fd(sock, my_fd) + + for cloned_fd, (peer_offset, peer_size, peer_handle_bytes) in zip( + cloned_fds, peer_metadata + ): + # After the upfront metadata validation above, handing the + # FD to import_and_map transfers ownership to the driver. + # Remove it from the local pending list first so the + # cleanup path below only closes never-consumed FDs. + pending_cloned_fds.pop(0) + reconstructed_handle = _replace_fd_in_local_handle( + peer_handle_bytes, cloned_fd + ) + import_kwargs = {} + if len(peer_handle_bytes) == _LOCAL_HIP_HANDLE_BYTES: + import_kwargs = { + "access_va": peer_va_base, + "access_size": peer_offset + peer_size, + } + mapping = self.allocator.driver.import_and_map( + peer, + reconstructed_handle, + peer_size, + va=peer_va_base + peer_offset, + **import_kwargs, + ) + self._peer_imported_mappings[peer].append(mapping) + + self._shared_chunk_counts[peer] = total_chunks_per_rank[peer] + self.heap_bases[peer] = peer_va_base + finally: + for fd in new_fds: + try: + os.close(fd) + except OSError: + pass + for fd in pending_cloned_fds: + try: + os.close(fd) + except OSError: + pass + + self._shared_chunk_counts[self.cur_rank] = my_total + + def _refresh_peer_access_fabric(self, dist): + """ + Peer access for VMemChunkedAllocator using fabric handles. + + Works across hosts in the same fabric domain. Handle bytes travel + through torch.distributed collectives instead of AF_UNIX sockets. + + Note: there is no FD cleanup in this method (cf. + _refresh_peer_access_chunked's try/finally). Fabric handles are + plain bytes, not kernel resources, so there is nothing to close. + """ + handle_bytes_size = 64 + record_bytes = 8 + 8 + handle_bytes_size + + self._ensure_chunk_tracking_state() + + my_total = self.allocator.get_num_chunks() + total_chunks_per_rank = self._gather_total_chunk_counts(dist, my_total) + my_total = total_chunks_per_rank[self.cur_rank] + my_already_shared = self._shared_chunk_counts[self.cur_rank] + new_chunks = self.allocator.get_allocation_chunks_since(my_already_shared) + + payload = bytearray(len(new_chunks) * record_bytes) + for index, (_chunk_idx, offset, size, handle_bytes) in enumerate(new_chunks): + if len(handle_bytes) != handle_bytes_size: + raise RuntimeError( + f"Expected {handle_bytes_size}-byte fabric handle, got {len(handle_bytes)}" + ) + record_offset = index * record_bytes + payload[record_offset : record_offset + 8] = int(offset).to_bytes( + 8, "little", signed=False + ) + payload[record_offset + 8 : record_offset + 16] = int(size).to_bytes( + 8, "little", signed=False + ) + payload[record_offset + 16 : record_offset + record_bytes] = handle_bytes + + gathered_payloads = [None] * self.num_ranks + dist.all_gather_object(gathered_payloads, bytes(payload)) + + for peer in range(self.num_ranks): + if peer == self.cur_rank: + self._shared_chunk_counts[peer] = total_chunks_per_rank[peer] + continue + + peer_new_count = ( + total_chunks_per_rank[peer] - self._shared_chunk_counts[peer] + ) + if peer_new_count == 0: + if peer in self._peer_va_ranges: + self.heap_bases[peer] = self._peer_va_ranges[peer] + continue + + peer_payload = gathered_payloads[peer] or b"" + expected_size = peer_new_count * record_bytes + if len(peer_payload) != expected_size: + raise RuntimeError( + f"Fabric payload size mismatch for peer {peer}: expected {expected_size}, got {len(peer_payload)}" + ) + + if peer not in self._peer_va_ranges: + peer_va_base = self.allocator.driver.reserve_va( + self.allocator.va_size, + self.allocator.granularity, + ) + self._peer_va_ranges[peer] = peer_va_base + else: + peer_va_base = self._peer_va_ranges[peer] + self._peer_imported_mappings.setdefault(peer, []) + + for index in range(peer_new_count): + record_offset = index * record_bytes + chunk_offset = int.from_bytes( + peer_payload[record_offset : record_offset + 8], + "little", + signed=False, + ) + chunk_size = int.from_bytes( + peer_payload[record_offset + 8 : record_offset + 16], + "little", + signed=False, + ) + handle_bytes = peer_payload[ + record_offset + 16 : record_offset + record_bytes + ] + if chunk_offset + chunk_size > self.allocator.va_size: + raise RuntimeError( + f"Peer {peer} chunk extends beyond va_size: " + f"offset={chunk_offset}, size={chunk_size}, " + f"va_size={self.allocator.va_size}" + ) + mapping = self.allocator.driver.import_and_map( + peer, + handle_bytes, + chunk_size, + va=peer_va_base + chunk_offset, + ) + self._peer_imported_mappings[peer].append(mapping) + + self._shared_chunk_counts[peer] = total_chunks_per_rank[peer] + self.heap_bases[peer] = peer_va_base + + def _refresh_peer_access_segmented(self, dist): + """Peer access for VMemAllocator (segment-based, legacy).""" + from iris.host.distributed.fd_passing import recv_fd, send_fd + from iris.host.platform.hip import ( + export_dmabuf_handle, + mem_import_from_shareable_handle, + mem_map, + mem_set_access, + mem_address_reserve, + hipMemAccessDesc, + hipMemLocationTypeDevice, + hipMemAccessFlagsProtReadWrite, + ) + my_segments = self.allocator.get_allocation_segments() my_exported_fds = [] for offset, size, va in my_segments: @@ -270,14 +729,15 @@ def refresh_peer_access(self): self._peer_va_ranges = {} if peer not in self._peer_va_ranges: - peer_va_base = mem_address_reserve(self.heap_size, self.allocator.granularity, 0) + peer_va_base = mem_address_reserve( + self.heap_size, self.allocator.granularity, 0 + ) self._peer_va_ranges[peer] = peer_va_base else: peer_va_base = self._peer_va_ranges[peer] peer_fds = [] for seg_idx, (my_fd, my_size, my_offset) in enumerate(my_exported_fds): - # Exchange FDs (higher rank sends first to avoid deadlock) if self.cur_rank > peer: send_fd(sock, my_fd) peer_fd, _ = recv_fd(sock) @@ -295,24 +755,29 @@ def refresh_peer_access(self): self._peer_imported_segments = {} if peer not in self._peer_imported_segments: self._peer_imported_segments[peer] = set() + if not hasattr(self, "_peer_imported_mappings"): + self._peer_imported_mappings = {} + if peer not in self._peer_imported_mappings: + self._peer_imported_mappings[peer] = [] for peer_fd, segment_size, offset in peer_fds: segment_key = (offset, segment_size) if segment_key in self._peer_imported_segments[peer]: - import os - os.close(peer_fd) continue imported_handle = mem_import_from_shareable_handle(peer_fd) - import os - os.close(peer_fd) peer_va = peer_va_base + offset mem_map(peer_va, segment_size, 0, imported_handle) self._peer_imported_segments[peer].add(segment_key) + # Track for cleanup + self._peer_imported_mappings[peer].append( + (imported_handle, peer_va, segment_size) + ) + new_cumulative = offset + segment_size if new_cumulative > cumulative_size: cumulative_size = new_cumulative @@ -322,8 +787,6 @@ def refresh_peer_access(self): self.heap_bases[peer] = peer_va_base for fd, _, _ in my_exported_fds: - import os - os.close(fd) if logger.isEnabledFor(logging.DEBUG): @@ -338,6 +801,7 @@ def refresh_peer_access(self): if dist.is_initialized(): dist.barrier() + def as_symmetric(self, external_tensor: torch.Tensor) -> torch.Tensor: """ Place an external PyTorch tensor on the symmetric heap. @@ -355,9 +819,79 @@ def as_symmetric(self, external_tensor: torch.Tensor) -> torch.Tensor: Raises: RuntimeError: If allocator doesn't support imports or import fails """ + self._ensure_peer_refresh_healthy() + if not hasattr(self.allocator, "import_external_tensor"): - raise RuntimeError(f"{type(self.allocator).__name__} does not support as_symmetric().") + raise RuntimeError( + f"{type(self.allocator).__name__} does not support as_symmetric()." + ) imported = self.allocator.import_external_tensor(external_tensor) self.refresh_peer_access() return imported + + def close(self): + """Release resources held by the symmetric heap.""" + try: + import torch + + if hasattr(self, "device_id"): + torch.cuda.synchronize(self.device_id) + except Exception: + pass + + if hasattr(self, "_peer_imported_mappings"): + has_driver = hasattr(self.allocator, "driver") + for peer, mappings in self._peer_imported_mappings.items(): + for entry in mappings: + try: + if has_driver and isinstance(entry, PeerMapping): + self.allocator.driver.cleanup_import(entry) + else: + from iris.host.platform.hip import mem_release, mem_unmap + + handle, va, size = entry + mem_unmap(va, size) + mem_release(handle) + except Exception as exc: + logger.warning("cleanup failed for peer %d: %s", peer, exc) + self._peer_imported_mappings.clear() + + # Free peer VA ranges (now safe -- all mappings have been removed) + if self._peer_va_ranges: + va_size = self._peer_va_size() + has_driver = hasattr(self.allocator, "driver") + if has_driver: + for peer, va_base in self._peer_va_ranges.items(): + try: + self.allocator.driver.free_va(va_base, va_size) + except Exception as exc: + logger.warning("free_va failed for peer %d: %s", peer, exc) + else: + from iris.host.platform.hip import mem_address_free + + for peer, va_base in self._peer_va_ranges.items(): + try: + mem_address_free(va_base, va_size) + except Exception as exc: + logger.warning("VA free failed for peer %d: %s", peer, exc) + self._peer_va_ranges.clear() + + # Close the allocator + if hasattr(self, "allocator") and hasattr(self.allocator, "close"): + self.allocator.close() + + # Close FD sockets + if hasattr(self, "fd_conns") and self.fd_conns: + for peer, sock in self.fd_conns.items(): + try: + sock.close() + except Exception: + pass + self.fd_conns = None + + def __del__(self): + try: + self.close() + except Exception: + pass diff --git a/tests/unittests/test_vmem_chunked_allocator.py b/tests/unittests/test_vmem_chunked_allocator.py new file mode 100644 index 000000000..3d4fca42d --- /dev/null +++ b/tests/unittests/test_vmem_chunked_allocator.py @@ -0,0 +1,417 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for VMemChunkedAllocator. + +Tests cover: +- Basic allocation and data integrity +- Multiple allocations and non-overlap +- Power-of-two free list reuse via GC +- Chunk growth on overflow +- as_symmetric (import external tensor) +- owns_tensor detection +- heap_bases stability +- Cross-rank RMA (peer memory access) +- OOM handling +- Allocator stats +- Thread safety +""" + +import gc +import threading + +import pytest +import torch + +import iris + + +ALLOC_TYPE = "vmem_chunked" + + +def test_chunked_creation(): + """Test that chunked VMem allocator can be created.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + assert ctx.cur_rank >= 0 + assert ctx.num_ranks >= 1 + assert ctx.heap_size == 1 << 20 + + +def test_chunked_basic_allocation(): + """Test basic allocation and data integrity.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + tensor = ctx.zeros(1024, dtype=torch.float32) + + assert tensor.shape == (1024,) + assert tensor.device.type == "cuda" + assert torch.all(tensor == 0) + + tensor.fill_(42.0) + assert torch.all(tensor == 42.0) + + +def test_chunked_multiple_allocations(): + """Test multiple allocations don't overlap.""" + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + + tensors = [] + for i in range(20): + t = ctx.zeros(256, dtype=torch.float32) + t.fill_(float(i)) + tensors.append(t) + + # Verify each tensor retains its value (no overlap) + for i, t in enumerate(tensors): + assert torch.all(t == float(i)), f"tensor {i} corrupted: expected {float(i)}, got {t[0].item()}" + + +def test_chunked_dtypes(): + """Test allocation with various dtypes.""" + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + + for dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int64]: + t = ctx.zeros(100, dtype=dtype) + assert t.dtype == dtype + assert t.shape == (100,) + + +def test_chunked_zero_elements(): + """Test zero-element allocation.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + t = ctx.zeros(0, dtype=torch.float32) + assert t.numel() == 0 + assert t.shape == (0,) + + +def test_chunked_gc_free_reuse(): + """Test that freed memory is reused via GC.""" + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + alloc = ctx.heap.allocator + + # Allocate a tensor + t = ctx.zeros(1024, dtype=torch.float32) + ptr1 = t.data_ptr() + + # Drop the tensor and trigger GC + del t + gc.collect() + torch.cuda.synchronize() + + # Allocate again -- should reuse the freed block + t2 = ctx.zeros(1024, dtype=torch.float32) + ptr2 = t2.data_ptr() + + # The reused block should be at the same offset (same free list bucket) + assert ptr2 == ptr1, f"Expected reuse at 0x{ptr1:x}, got 0x{ptr2:x}" + + +def test_chunked_gc_multiple_reuse(): + """Test multiple rounds of alloc-free-reuse.""" + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + + for _ in range(10): + t = ctx.zeros(512, dtype=torch.float32) + t.fill_(99.0) + torch.cuda.synchronize() + assert torch.all(t == 99.0) + del t + gc.collect() + + +def test_chunked_free_list_size_classes(): + """Test that different sizes use different free list buckets. + + SymmetricHeap.allocate() bumps element counts to at least + granularity / element_size, so we must pick sizes that remain in + distinct power-of-two buckets after that rounding. + """ + ctx = iris.iris(256 << 20, allocator_type=ALLOC_TYPE) + alloc = ctx.heap.allocator + elem_size = 4 # float32 + + # Compute the minimum element count (the floor imposed by SymmetricHeap) + min_elems = max(1, (alloc.granularity + elem_size - 1) // elem_size) + + # Pick three sizes that land in clearly different power-of-two buckets: + # small = 1x granularity (min_elems elements) + # medium = 4x granularity (min_elems * 4 elements) + # large = 16x granularity (min_elems * 16 elements) + size_small = min_elems + size_medium = min_elems * 4 + size_large = min_elems * 16 + + small = ctx.zeros(size_small, dtype=torch.float32) + medium = ctx.zeros(size_medium, dtype=torch.float32) + large = ctx.zeros(size_large, dtype=torch.float32) + + small_ptr = small.data_ptr() + medium_ptr = medium.data_ptr() + large_ptr = large.data_ptr() + + # Free all + del small, medium, large + gc.collect() + torch.cuda.synchronize() + + # Re-allocate -- each should reuse from its size class + small2 = ctx.zeros(size_small, dtype=torch.float32) + medium2 = ctx.zeros(size_medium, dtype=torch.float32) + large2 = ctx.zeros(size_large, dtype=torch.float32) + + assert small2.data_ptr() == small_ptr + assert medium2.data_ptr() == medium_ptr + assert large2.data_ptr() == large_ptr + + +def test_chunked_chunk_growth(): + """Test that allocator grows chunks when needed.""" + # Small chunk size to force growth + chunk_size = 1 << 20 # 1 MiB chunks + ctx = iris.iris( + 1 << 20, + allocator_type=ALLOC_TYPE, + ) + alloc = ctx.heap.allocator + + initial_chunks = alloc.get_num_chunks() + + # Allocate more than one chunk's worth + tensors = [] + total = 0 + while alloc.get_num_chunks() <= initial_chunks: + t = ctx.zeros(32768, dtype=torch.float32) # 128 KiB each + tensors.append(t) + total += 32768 * 4 + + assert alloc.get_num_chunks() > initial_chunks + + +def test_chunked_owns_tensor(): + """Test owns_tensor detection.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + + heap_tensor = ctx.zeros(100, dtype=torch.float32) + assert ctx.heap.allocator.owns_tensor(heap_tensor) + + external_tensor = torch.zeros(100, dtype=torch.float32, device=ctx.device) + assert not ctx.heap.allocator.owns_tensor(external_tensor) + + del heap_tensor, external_tensor + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def test_chunked_heap_bases(): + """Test that heap bases are stable and properly set.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + + assert ctx.heap_bases.shape == (ctx.num_ranks,) + base = int(ctx.heap_bases[ctx.cur_rank].item()) + assert base > 0 + + # Allocate several tensors -- base should not change + for _ in range(10): + ctx.zeros(100, dtype=torch.float32) + + assert int(ctx.heap_bases[ctx.cur_rank].item()) == base + + +def test_chunked_heap_bases_multirank(): + """Test heap bases across ranks.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + + if ctx.num_ranks > 1: + for peer in range(ctx.num_ranks): + if peer != ctx.cur_rank: + assert int(ctx.heap_bases[peer].item()) > 0 + assert int(ctx.heap_bases[peer].item()) != int(ctx.heap_bases[ctx.cur_rank].item()) + + +def test_chunked_import_external_tensor(): + """Test as_symmetric (import external tensor).""" + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + + original = torch.randn(100, dtype=torch.float32, device=ctx.device) + original_data = original.clone() + + imported = ctx.as_symmetric(original) + + # Should have same data + assert torch.allclose(imported, original_data) + + # Shared memory -- writes visible both ways + imported.fill_(42.0) + assert torch.all(original == 42.0) + + original.fill_(99.0) + assert torch.all(imported == 99.0) + + +def test_chunked_import_tensor_survives_ctx(): + """Test that original tensor survives ctx destruction.""" + original = torch.randn(100, dtype=torch.float32, device="cuda") + original_data = original.clone() + + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + imported = ctx.as_symmetric(original) + assert torch.allclose(imported, original_data) + + del ctx, imported + gc.collect() + torch.cuda.synchronize() + + # Original should still be valid + assert torch.all(original == original_data) + original.fill_(123.0) + assert torch.all(original == 123.0) + + +def test_chunked_stats(): + """Test allocator statistics.""" + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + alloc = ctx.heap.allocator + + stats = alloc.get_stats() + assert stats["num_chunks"] >= 1 + assert stats["mapped_bytes"] > 0 + assert stats["va_size"] > 0 + assert stats["granularity"] > 0 + + # Allocate some tensors + t1 = ctx.zeros(100, dtype=torch.float32) + t2 = ctx.zeros(200, dtype=torch.float32) + + stats = alloc.get_stats() + assert stats["num_active_allocs"] >= 2 + assert stats["bump"] > 0 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_chunked_multirank_exchange(): + """Test FD exchange and peer access across ranks.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + + if ctx.num_ranks < 2: + pytest.skip("Requires at least 2 ranks") + + tensor = ctx.zeros(1024, dtype=torch.float32) + tensor.fill_(float(ctx.cur_rank * 100)) + + ctx.barrier() + + # Verify peer heap bases are set + for peer in range(ctx.num_ranks): + if peer != ctx.cur_rank: + assert int(ctx.heap_bases[peer].item()) > 0 + + # Verify local data still intact after exchange + assert torch.all(tensor == float(ctx.cur_rank * 100)) + + +def test_chunked_thread_safety(): + """Test concurrent allocations from multiple threads.""" + ctx = iris.iris(16 << 20, allocator_type=ALLOC_TYPE) + alloc = ctx.heap.allocator + results = [] + errors = [] + + def alloc_free_loop(thread_id, n): + try: + for i in range(n): + t = alloc.allocate(100, torch.float32) + t.fill_(float(thread_id * 1000 + i)) + torch.cuda.synchronize() + val = t[0].item() + assert val == float(thread_id * 1000 + i) + results.append(val) + del t + gc.collect() + except Exception as e: + errors.append((thread_id, e)) + + threads = [] + for tid in range(4): + t = threading.Thread(target=alloc_free_loop, args=(tid, 10)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + assert len(errors) == 0, f"Thread errors: {errors}" + assert len(results) == 40 + + +def test_chunked_close(): + """Test explicit close releases resources.""" + ctx = iris.iris(1 << 20, allocator_type=ALLOC_TYPE) + alloc = ctx.heap.allocator + + t = ctx.zeros(100, dtype=torch.float32) + del t + gc.collect() + + alloc.close() + assert alloc._closed + assert len(alloc.chunks) == 0 + + +def test_chunked_no_refresh_on_reuse(): + """Test that reusing freed memory does NOT trigger refresh_peer_access.""" + ctx = iris.iris(4 << 20, allocator_type=ALLOC_TYPE) + alloc = ctx.heap.allocator + + # Track initial chunk count + initial_chunks = alloc.get_num_chunks() + + # Alloc-free-reuse cycle should not grow chunks + for _ in range(20): + t = ctx.zeros(256, dtype=torch.float32) + t.fill_(1.0) + torch.cuda.synchronize() + del t + gc.collect() + + assert alloc.get_num_chunks() == initial_chunks + + +def test_chunked_large_allocation(): + """Test allocation larger than default alignment.""" + ctx = iris.iris(64 << 20, allocator_type=ALLOC_TYPE) + + # 4 MiB tensor + t = ctx.zeros(1024 * 1024, dtype=torch.float32) + assert t.shape == (1024 * 1024,) + t.fill_(7.0) + torch.cuda.synchronize() + assert torch.all(t == 7.0) + + +def test_chunked_mixed_alloc_free_pattern(): + """Test interleaved alloc and free with varying sizes.""" + ctx = iris.iris(32 << 20, allocator_type=ALLOC_TYPE) + + active = [] + for i in range(50): + size = (i % 5 + 1) * 100 + t = ctx.zeros(size, dtype=torch.float32) + t.fill_(float(i)) + active.append(t) + + # Free every 3rd tensor + if i % 3 == 0 and active: + del active[0] + gc.collect() + + # Verify remaining tensors + for t in active: + torch.cuda.synchronize() + assert t.numel() > 0 + + +if __name__ == "__main__": + test_chunked_creation() + test_chunked_basic_allocation() + test_chunked_multiple_allocations() + test_chunked_gc_free_reuse()