diff --git a/python/flydsl/runtime/device.py b/python/flydsl/runtime/device.py index ed5fd47b2..ac95fc91c 100644 --- a/python/flydsl/runtime/device.py +++ b/python/flydsl/runtime/device.py @@ -91,3 +91,28 @@ def is_rdna_arch(arch: Optional[str] = None) -> bool: if arch.startswith("gfx120"): return True return False + + +# Architectures FlyDSL ships kernels for. Single source of truth for +# "does FlyDSL run here" — keep in sync with the kernels under kernels/. +# Downstream consumers (e.g. aiter) gate availability on this so that +# importing FlyDSL kernels on an unsupported arch is skipped cleanly +# instead of crashing during config registration. +SUPPORTED_ARCHS = frozenset( + { + "gfx942", # MI300A / MI300X (CDNA3) + "gfx950", # MI350 (CDNA4) + "gfx1151", # RDNA3.5 (Strix Halo) + "gfx1201", # RDNA4 + "gfx1250", # MI450 + } +) + + +def is_arch_supported(arch: Optional[str] = None) -> bool: + """Whether FlyDSL ships kernels for *arch* (current GPU if None).""" + if arch is None: + arch = get_rocm_arch() + if not arch: + return False + return arch.lower() in SUPPORTED_ARCHS