diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index cf5f9a0db..48399d6da 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -12,61 +12,48 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations on AM **FOLLOW THESE INSTRUCTIONS EXACTLY. Reference these instructions first before using search or bash commands.** -## Prerequisites +## GPU Environment -- **GPU**: AMD GPUs with ROCm compatibility (tested on MI300X, MI350X & MI355X) - > **Note**: See below for instructions on development without AMD GPU access -- **ROCm/HIP Toolkit**: Required for building C++/HIP components -- **Docker/Apptainer**: Recommended for containerized development +**You have AMD GPU access.** Your runner has 8x MI325X GPUs (gfx942) with ROCm installed. Do NOT skip GPU operations or say you lack GPU access. You can and should: -## Build +- Run benchmarks and post results (charts, tables) in PR comments +- Run tests with `torchrun --nproc_per_node=N` for multi-GPU tests +- Use `rocm-smi` to verify GPU status +- Run `python -c "import torch; print(torch.cuda.device_count())"` to confirm GPU count -### Docker Development Environment (Recommended) +When asked to run a benchmark, **run it and post the output**. Do not say you cannot. + +### Running multi-GPU tests and benchmarks + +Multi-GPU tests require `torch.distributed` initialization before pytest: ```bash -# Build and start development container (takes 45-60 minutes - NEVER CANCEL) -docker compose up --build -d +# Single GPU +pytest tests/unittests/ -v --tb=short -# Attach to running container -docker attach iris-dev +# Multi-GPU (N = number of GPUs) +torchrun --nproc_per_node=N -m pytest tests/ -v --tb=short -# Install Iris in development mode -cd iris && pip install -e ".[dev]" +# Benchmarks use iris.bench framework +torchrun --nproc_per_node=8 benchmark/ops/bench_.py ``` -### Alternative Docker Setup -```bash -# Build Docker image manually -./docker/build.sh # Takes 45-60 minutes +### iris.bench framework -# Run container -./docker/run.sh +Benchmarks use the declarative `iris.bench` framework. See existing `benchmark/ops/bench_*.py` files for examples. Output includes latency, throughput, and bandwidth tables. When posting benchmark results in PR comments, format as markdown tables. -# Install Iris -cd iris && pip install -e ".[dev]" -``` +## Prerequisites -### Apptainer Setup -```bash -# Build and run Apptainer image -./apptainer/build.sh -./apptainer/run.sh +- **GPU**: AMD GPUs with ROCm compatibility (tested on MI300X, MI325X, MI350X & MI355X) +- **ROCm/HIP Toolkit**: Required for building C++/HIP components +- **Docker/Apptainer**: Recommended for containerized development -# Install Iris -pip install -e ".[dev]" -``` +## Build -### Local Development (Not Recommended) +iris is already installed in your environment via `pip install -e .` in the setup steps. You do not need to build or install anything. If you need to reinstall after modifying `setup.py` or C extensions: ```bash -# Requires ROCm/HIP toolkit installation pip install -e ".[dev]" ``` -### Development Without AMD GPU -If you don't have access to AMD GPUs, you can still contribute to the project: -- **Code Editing**: Start editing code directly in your local environment -- **CI Testing**: The project has comprehensive CI pipelines that will test your changes automatically. You can check the CI logs if your changes fail to understand what went wrong. -- **Local Validation**: Run linting and formatting locally: `ruff check . --fix && ruff format .` - ## Run ### Testing diff --git a/.gitignore b/.gitignore index d8f9754f7..0bc6bbc55 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,8 @@ omni*.pdf slurm*.out *.egg-info +*.backup +*.with_chunked examples/gemm/results/* asm/ @@ -57,4 +59,8 @@ gpucore.* logs/ *.cap hsakmt_counters.csv -core \ No newline at end of file +core +.intellikit/ +.github/agents/docs/benchmark-results/ +.github/agents/ +docs/benchmark-results/*.png diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..41cf4672e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "ext/shader_sdma"] + path = ext/shader_sdma + url = https://github.com/AARInternal/shader_sdma.git diff --git a/benchmark/ops/all_gather_matmul/__init__.py b/benchmark/ops/all_gather_matmul/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmark/ops/all_gather_matmul/auto_config.py b/benchmark/ops/all_gather_matmul/auto_config.py new file mode 100644 index 000000000..0e8990886 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/auto_config.py @@ -0,0 +1,582 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Auto-selection mechanism for fused AG+MM kernel configurations. + +Given problem dimensions (M, N, K), transpose mode, world_size, and GPU +architecture, this module selects the best known configuration or returns +a sensible default. For world sizes where iris AG+MM is known to lose +against PyTorch (ws<8), the default disables iris and signals fallback. + +Config files live under: + configs/{arch}/{transpose}/ws{N}.json + +Each config file contains: + - Per-shape champion configs with all kernel parameters in a flat "params" dict + - A "default_params" dict with architecture-appropriate defaults + - Params include FusedConfig fields (block_size_m, etc.) and HBM buffer + kernel params (k_per_flag, num_fetch_sms, num_warps, num_stages, etc.) + +Transpose coverage: + The iris AG+MM kernel (`_fused_all_gather_matmul_kernel`) uses stride-based + addressing (`stride_am, stride_ak, stride_bk, stride_bn`), so transpose + layouts are handled implicitly by tensor strides. Config files exist for + all four layouts (NN, TN, NT, TT) under each architecture directory. + Only NN has per-shape champion configs from benchmarking (3,489 trials). + TN/NT/TT files contain heuristic defaults only (empty shapes dict) and are + marked enabled at ws>=8 to allow heuristic fallback. All transposes at ws<8 + are disabled (NO-GO based on NN benchmarks). + +Usage: + >>> from auto_config import select_ag_mm_config + >>> result = select_ag_mm_config(M=131072, N=16384, K=16384, world_size=8) + >>> if result.enabled: + ... config = result.to_fused_config() + ... hbm_params = result.hbm_buffer_params # k_per_flag, num_fetch_sms, etc. + ... shmem.ops.all_gather_matmul(output, A, B, config=config) + ... else: + ... # Fallback to PyTorch all_gather + matmul + ... ... + + >>> # List all regression test sizes + >>> from auto_config import load_regression_sizes + >>> sizes = load_regression_sizes() +""" + +import json +import os +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from iris.ops.config import FusedConfig + +# Config files live alongside this module +_CONFIGS_DIR = Path(__file__).parent / "configs" + +# FusedConfig field names — everything else in "params" is an HBM buffer param +_FUSED_CONFIG_FIELDS = {f.name for f in FusedConfig.__dataclass_fields__.values()} + +# HBM buffer param names (kernel launch params, not FusedConfig fields) +_HBM_BUFFER_FIELDS = { + "k_per_flag", + "num_fetch_sms", + "num_fetch_stages", + "first_stage_fetch_sms", + "fetch_block_m", + "fetch_block_k", + "num_warps", + "num_stages", +} + +# In-memory cache: (arch, transpose, world_size) -> loaded JSON data +_config_cache: Dict[Tuple[str, str, int], dict] = {} + +# Cached GPU architecture detection result +_detected_arch: Optional[str] = None + +# Supported transpose modes. The AG+MM kernel only supports NN layout. +# TN/NT/TT would require kernel-level changes to permute strides. +SUPPORTED_TRANSPOSES = ("NN",) + +# Supported GPU architectures with tuned configs +SUPPORTED_ARCHITECTURES = ("mi300x", "mi355x") + +# Map gfx target IDs to architecture names used in config paths +_GFX_TO_ARCH = { + "gfx942": "mi300x", # MI300X, MI300A + "gfx950": "mi355x", # MI355X +} + + +def detect_gpu_arch() -> str: + """Auto-detect GPU architecture from the current system. + + Detection order: + 1. IRIS_GPU_ARCH environment variable (override) + 2. rocm-smi --showproductname parsing + 3. rocminfo gfx target parsing + 4. Falls back to "mi300x" (most common deployment target) + + Returns: + Architecture string (e.g., "mi300x") suitable for config lookup. + """ + global _detected_arch + if _detected_arch is not None: + return _detected_arch + + # 1. Environment variable override + env_arch = os.environ.get("IRIS_GPU_ARCH", "").strip().lower() + if env_arch: + _detected_arch = env_arch + return _detected_arch + + # 2. Try rocminfo for gfx target + try: + result = subprocess.run( + ["rocminfo"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + for line in result.stdout.splitlines(): + line_stripped = line.strip().lower() + if "name:" in line_stripped and "gfx" in line_stripped: + for gfx_id, arch_name in _GFX_TO_ARCH.items(): + if gfx_id in line_stripped: + _detected_arch = arch_name + return _detected_arch + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + pass + + # 3. Fallback to MI300X (most common deployment target) + _detected_arch = "mi300x" + return _detected_arch + + +@dataclass +class AutoConfigResult: + """Result of auto-config lookup. + + Attributes: + enabled: If False, iris AG+MM should NOT be used; fallback to PyTorch. + config_params: Dict of FusedConfig parameters (only valid if enabled=True). + hbm_buffer_params: Dict of HBM buffer-specific kernel params + (k_per_flag, num_fetch_sms, num_fetch_stages, first_stage_fetch_sms). + source: Human-readable description of where this config came from. + shape_key: The MxNxK key that matched (None if heuristic/default). + expected_iris_ms: Expected kernel time in ms on target GPU (None if unknown). + """ + + enabled: bool = False + config_params: Dict = field(default_factory=dict) + hbm_buffer_params: Dict = field(default_factory=dict) + source: str = "default" + shape_key: Optional[str] = None + expected_iris_ms: Optional[float] = None + + def to_fused_config(self) -> FusedConfig: + """Convert to FusedConfig for use with iris.ops functions. + + Raises: + RuntimeError: If this config is disabled (enabled=False). + """ + if not self.enabled: + raise RuntimeError( + f"Cannot create FusedConfig: iris AG+MM is disabled for this " + f"configuration. Reason: {self.source}. " + f"Use PyTorch all_gather + matmul instead." + ) + # Filter to only fields FusedConfig accepts + valid_fields = {f.name for f in FusedConfig.__dataclass_fields__.values()} + filtered = {k: v for k, v in self.config_params.items() if k in valid_fields} + return FusedConfig(**filtered) + + +def _split_params(params: Dict) -> Tuple[Dict, Dict]: + """Split a flat params dict into (config_params, hbm_buffer_params). + + FusedConfig fields go into config_params. + Everything else (num_warps, num_stages, k_per_flag, etc.) goes into hbm_buffer_params. + """ + config_params = {} + hbm_params = {} + for k, v in params.items(): + if k in _FUSED_CONFIG_FIELDS: + config_params[k] = v + else: + hbm_params[k] = v + return config_params, hbm_params + + +def _extract_shape_params(shape_data: Dict) -> Tuple[Dict, Dict]: + """Extract config_params and hbm_buffer_params from shape data. + + Supports both the new flat "params" format and the legacy split + "config" + "hbm_buffer_params" format for backward compatibility. + """ + if "params" in shape_data: + return _split_params(shape_data["params"]) + return shape_data.get("config", {}), shape_data.get("hbm_buffer_params", {}) + + +def _extract_default_params(data: Dict) -> Optional[Tuple[Dict, Dict]]: + """Extract default config_params and hbm_buffer_params from file-level defaults. + + Supports both "default_params" (flat) and legacy "default_config" + "default_hbm_buffer_params". + Returns None if no defaults are available. + """ + if "default_params" in data and data["default_params"] is not None: + return _split_params(data["default_params"]) + default_config = data.get("default_config") + if default_config: + return default_config, data.get("default_hbm_buffer_params", {}) + return None + + +def _load_config_file(arch: str, transpose: str, world_size: int) -> Optional[dict]: + """Load and cache a config JSON file. + + Args: + arch: GPU architecture identifier (e.g., "mi300x"). + transpose: Transpose mode (e.g., "NN", "NT", "TN", "TT"). + world_size: Number of ranks. + + Returns: + Parsed JSON dict, or None if file doesn't exist. + """ + cache_key = (arch, transpose, world_size) + if cache_key in _config_cache: + return _config_cache[cache_key] + + config_path = _CONFIGS_DIR / arch / transpose / f"ws{world_size}.json" + if not config_path.exists(): + _config_cache[cache_key] = None + return None + + with open(config_path, "r") as f: + data = json.load(f) + + _config_cache[cache_key] = data + return data + + +def _load_default_config() -> dict: + """Load the global default config.""" + default_path = _CONFIGS_DIR / "default_config.json" + if default_path.exists(): + with open(default_path, "r") as f: + return json.load(f) + return {} + + +def _find_nearest_shape(M: int, N: int, K: int, shapes: dict, tolerance: float = 0.15) -> Optional[str]: + """Find the nearest matching shape in the config database. + + Uses log-space geometric distance to find shapes that are structurally + similar (within `tolerance` ratio per dimension). This avoids falling + back to heuristic when the user's problem is close to a champion shape. + + Args: + M, N, K: Target dimensions. + shapes: Dict of shape_key -> shape_data from the config file. + tolerance: Max fractional distance per dimension (default 15%). + + Returns: + The shape_key of the nearest match, or None if no shape is close enough. + """ + import math + + best_key = None + best_dist = float("inf") + + for shape_key, shape_data in shapes.items(): + sm, sn, sk = shape_data["M"], shape_data["N"], shape_data["K"] + + # Check per-dimension ratio tolerance + if sm == 0 or sn == 0 or sk == 0: + continue + rm = abs(M - sm) / sm + rn = abs(N - sn) / sn + rk = abs(K - sk) / sk + + if rm > tolerance or rn > tolerance or rk > tolerance: + continue + + # Geometric distance in log space + dist = math.sqrt( + math.log(max(M, 1) / max(sm, 1)) ** 2 + + math.log(max(N, 1) / max(sn, 1)) ** 2 + + math.log(max(K, 1) / max(sk, 1)) ** 2 + ) + if dist < best_dist: + best_dist = dist + best_key = shape_key + + return best_key + + +def _apply_heuristic(M: int, N: int, K: int, arch: str = "mi300x") -> Tuple[Dict, Dict]: + """Apply heuristic rules to generate config + HBM buffer params. + + Based on optimization data: + - MI300X: 3,489 measured trials + - MI355X: Optuna TPE + broad sweep + + Args: + M: Rows dimension. + N: Columns dimension. + K: Reduction dimension. + arch: GPU architecture for arch-specific heuristics. + + Returns: + Tuple of (config_params dict, hbm_buffer_params dict). + """ + bk = 64 + num_k_blocks = K // bk + + if arch == "mi355x": + bm = 256 + num_m_tiles = M // bm + gm = 4 if M <= 32768 else 8 + config_params = { + "block_size_m": bm, + "block_size_n": 256, + "block_size_k": bk, + "group_size_m": gm, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": True, + } + kpf = 8 if num_k_blocks <= 512 else 16 + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + hbm_params = { + "k_per_flag": kpf, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52, + } + return config_params, hbm_params + + # MI300X heuristics + if M <= 16384: + bm = 128 + else: + bm = 256 + + num_m_tiles = M // bm + + if M <= 8192: + gm = 8 + elif M <= 16384: + gm = 16 + else: + gm = 24 + + config_params = { + "block_size_m": bm, + "block_size_n": 256, + "block_size_k": bk, + "group_size_m": gm, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": True, + } + + if num_k_blocks >= 512: + kpf = 64 + elif num_k_blocks >= 128: + kpf = 16 + elif num_k_blocks >= 64: + kpf = 8 + else: + kpf = 4 + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + + if num_m_tiles <= 8: + fs = 4 + elif num_m_tiles <= 32: + fs = 16 + elif num_m_tiles <= 128: + fs = 32 + else: + fs = 52 + + if num_m_tiles >= 512: + nfs = 4 + elif num_m_tiles >= 64: + nfs = 2 + else: + nfs = 1 + + hbm_params = { + "k_per_flag": kpf, + "num_fetch_sms": fs, + "num_fetch_stages": nfs, + "first_stage_fetch_sms": 64, + } + + return config_params, hbm_params + + +def select_ag_mm_config( + M: int, + N: int, + K: int, + world_size: int, + transpose: str = "NN", + arch: str = "auto", +) -> AutoConfigResult: + """Select the best AG+MM config for the given problem. + + Lookup order: + 1. Exact shape match in benchmark/ops/all_gather_matmul/configs/{arch}/{transpose}/ws{world_size}.json + 2. Heuristic-based config from the same file's defaults + 3. Global default from benchmark/ops/all_gather_matmul/configs/default_config.json + + For world sizes where iris is known to lose (ws<8 on MI300X), returns + a disabled result signaling fallback to PyTorch. + + Args: + M: Number of rows (or M_local * world_size for AG+MM). + N: Number of columns. + K: Reduction dimension. + world_size: Number of ranks in the communicator. + transpose: Transpose mode ("NN", "NT", "TN", "TT"). Default "NN". + arch: GPU architecture ("mi300x", etc.) or "auto" to auto-detect. + Default "auto". Set IRIS_GPU_ARCH env var to override. + + Returns: + AutoConfigResult with .enabled indicating whether to use iris, + .to_fused_config() to get the FusedConfig if enabled, and + .hbm_buffer_params with kernel-specific parameters. + + Example: + >>> result = select_ag_mm_config(131072, 16384, 16384, world_size=8) + >>> result.enabled + True + >>> config = result.to_fused_config() + >>> result.hbm_buffer_params + {'k_per_flag': 32, 'num_fetch_sms': 4, 'num_fetch_stages': 64, 'first_stage_fetch_sms': 52} + + >>> result = select_ag_mm_config(4096, 4096, 4096, world_size=2) + >>> result.enabled + False + """ + transpose = transpose.upper() + if arch == "auto": + arch = detect_gpu_arch() + else: + arch = arch.lower() + + # Step 1: Try to load the specific config file + data = _load_config_file(arch, transpose, world_size) + + if data is not None: + # Check if this world_size is enabled + if not data.get("enabled", True): + return AutoConfigResult( + enabled=False, + source=f"Disabled by config: {arch}/{transpose}/ws{world_size}.json — {data.get('reason', 'no reason given')}", + ) + + # Look for exact shape match + shape_key = f"{M}x{N}x{K}" + shapes = data.get("shapes", {}) + if shape_key in shapes: + shape_data = shapes[shape_key] + cfg, hbm = _extract_shape_params(shape_data) + return AutoConfigResult( + enabled=True, + config_params=cfg, + hbm_buffer_params=hbm, + source=f"Exact match: {arch}/{transpose}/ws{world_size}.json [{shape_data.get('label', shape_key)}]", + shape_key=shape_key, + expected_iris_ms=shape_data.get("expected_iris_ms"), + ) + + # No exact match — try nearest champion shape (within 15% per dim) + nearest_key = _find_nearest_shape(M, N, K, shapes) + if nearest_key is not None: + nearest_data = shapes[nearest_key] + cfg, hbm = _extract_shape_params(nearest_data) + return AutoConfigResult( + enabled=True, + config_params=cfg, + hbm_buffer_params=hbm, + source=f"Nearest match: {arch}/{transpose}/ws{world_size}.json [{nearest_data.get('label', nearest_key)}] (target {M}x{N}x{K} ≈ {nearest_key})", + shape_key=nearest_key, + expected_iris_ms=nearest_data.get("expected_iris_ms"), + ) + + # No nearby match — use heuristic + file defaults + defaults = _extract_default_params(data) + if defaults is not None: + file_default_config, file_default_hbm = defaults + heuristic_config, heuristic_hbm = _apply_heuristic(M, N, K, arch=arch) + merged_config = {**file_default_config, **heuristic_config} + merged_hbm = {**file_default_hbm, **heuristic_hbm} + return AutoConfigResult( + enabled=True, + config_params=merged_config, + hbm_buffer_params=merged_hbm, + source=f"Heuristic (no exact shape match in {arch}/{transpose}/ws{world_size}.json)", + ) + + # Step 2: No config file found — check global default + default_data = _load_default_config() + ws_gate = default_data.get("world_size_gate", {}) + min_ws = ws_gate.get("min_world_size", 8) + + if world_size < min_ws: + return AutoConfigResult( + enabled=False, + source=f"world_size={world_size} < min_world_size={min_ws} (global default). {ws_gate.get('reason', '')}", + ) + + # World size OK but no specific config — apply heuristic + heuristic_config, heuristic_hbm = _apply_heuristic(M, N, K, arch=arch) + return AutoConfigResult( + enabled=True, + config_params=heuristic_config, + hbm_buffer_params=heuristic_hbm, + source=f"Heuristic fallback (no config file for {arch}/{transpose}/ws{world_size})", + ) + + +def list_known_shapes( + world_size: int, + transpose: str = "NN", + arch: str = "mi300x", +) -> list: + """List all known shape configurations for a given world_size/transpose/arch. + + Returns: + List of dicts with keys: shape_key, label, M, N, K, expected_iris_ms. + """ + data = _load_config_file(arch, transpose.upper(), world_size) + if data is None or not data.get("enabled", True): + return [] + + result = [] + for shape_key, shape_data in data.get("shapes", {}).items(): + result.append( + { + "shape_key": shape_key, + "label": shape_data.get("label", ""), + "M": shape_data["M"], + "N": shape_data["N"], + "K": shape_data["K"], + "expected_iris_ms": shape_data.get("expected_iris_ms"), + } + ) + + result.sort(key=lambda x: x.get("expected_iris_ms") or float("inf")) + return result + + +def load_regression_sizes() -> List[Dict]: + """Load regression test sizes from the JSON config file. + + Returns: + List of regression size dicts, each with: name, M, N, K, tier, + description, world_sizes, expected, regression_threshold_pct. + """ + reg_path = _CONFIGS_DIR / "regression_sizes.json" + if not reg_path.exists(): + return [] + with open(reg_path, "r") as f: + data = json.load(f) + return data.get("sizes", []) + + +def clear_config_cache(): + """Clear the in-memory config cache. Useful after modifying config files.""" + _config_cache.clear() diff --git a/benchmark/ops/all_gather_matmul/configs/default_config.json b/benchmark/ops/all_gather_matmul/configs/default_config.json new file mode 100644 index 000000000..ff96ac1f0 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/default_config.json @@ -0,0 +1,27 @@ +{ + "_meta": { + "description": "Global default fallback config for AG+MM operations. Disables iris AG+MM for ws<8 (fallback to PyTorch).", + "source": "benchmarking on MI300X (gfx942), 3489 measured trials", + "date": "2026-04-13" + }, + "world_size_gate": { + "min_world_size": 8, + "reason": "ws=2 best 0.89x, ws=4 best 0.86x vs PyTorch. Only ws>=8 is production-ready." + }, + "config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws2.json new file mode 100644 index 000000000..f5f0fd87e --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws2.json @@ -0,0 +1,159 @@ +{ + "_meta": { + "description": "AG+MM ws=2 on MI300X \u2014 DISABLED (loses vs PyTorch on all shapes)", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 AG transfers from 1 peer only. GEMM dominates latency. Fetch SM overhead exceeds overlap benefit. LDS overflow forces ns=1, imposing 15-35% penalty.", + "shapes": { + "8192x8192x262144": { + "label": "g5", + "M": 8192, + "N": 8192, + "K": 262144, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 186.062 + }, + "16384x16384x131072": { + "label": "g1", + "M": 16384, + "N": 16384, + "K": 131072, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 8, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 153.042 + }, + "4096x14336x4096": { + "label": "mixtral_gate", + "M": 4096, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 2.334 + }, + "4096x11008x4096": { + "label": "llama7b_gate", + "M": 4096, + "N": 11008, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 1.784 + }, + "4096x4096x4096": { + "label": "pow2_4k", + "M": 4096, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 1.109 + }, + "5120x13824x5120": { + "label": "llama13b_gate", + "M": 5120, + "N": 13824, + "K": 5120, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 4.144 + }, + "4096x4096x11008": { + "label": "llama7b_down", + "M": 4096, + "N": 4096, + "K": 11008, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 2.477 + } + }, + "default_params": null +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws4.json new file mode 100644 index 000000000..30b7a6bef --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws4.json @@ -0,0 +1,201 @@ +{ + "_meta": { + "description": "AG+MM ws=4 on MI300X \u2014 DISABLED (loses vs PyTorch on all shapes)", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses on all tested shapes. K=4096 shapes crash at ns=2 due to LDS overflow (65540>65536). ns=1 workaround constrains pipelining depth below break-even.", + "shapes": { + "262144x8192x8192": { + "label": "g6", + "M": 262144, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 64, + "num_fetch_sms": 52, + "num_fetch_stages": 4, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 161.027 + }, + "8192x8192x262144": { + "label": "g5", + "M": 8192, + "N": 8192, + "K": 262144, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 167.944 + }, + "131072x16384x16384": { + "label": "g2", + "M": 131072, + "N": 16384, + "K": 16384, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 64, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 209.556 + }, + "16384x16384x131072": { + "label": "g1", + "M": 16384, + "N": 16384, + "K": 131072, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 16, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 239.757 + }, + "4096x14336x4096": { + "label": "mixtral_gate", + "M": 4096, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 2.192 + }, + "4096x11008x4096": { + "label": "llama7b_gate", + "M": 4096, + "N": 11008, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 2.163 + }, + "4096x4096x4096": { + "label": "pow2_4k", + "M": 4096, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 1.494 + }, + "5120x13824x5120": { + "label": "llama13b_gate", + "M": 5120, + "N": 13824, + "K": 5120, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 3.257 + }, + "4096x4096x11008": { + "label": "llama7b_down", + "M": 4096, + "N": 4096, + "K": 11008, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 2.578 + } + }, + "default_params": null +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws8.json new file mode 100644 index 000000000..2c518f1df --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws8.json @@ -0,0 +1,290 @@ +{ + "_meta": { + "description": "Champion configs for HBM buffer AG+MM ws=8 on MI300X (gfx942)", + "source": "sweep (3489 trials), optimize-loop iter3", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM" + }, + "enabled": true, + "shapes": { + "262144x8192x8192": { + "label": "g6", + "description": "Llama-70B MLP hidden x hidden \u2014 M-dominant", + "M": 262144, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 64, + "num_fetch_sms": 52, + "num_fetch_stages": 4, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 139.069 + }, + "131072x16384x16384": { + "label": "g2", + "description": "Llama MLP variant \u2014 balanced large", + "M": 131072, + "N": 16384, + "K": 16384, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 64, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 167.345 + }, + "147456x28672x4096": { + "label": "g14", + "description": "Llama-70B up-projection medium batch", + "M": 147456, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 59, + "num_fetch_stages": 36, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 74.244 + }, + "229376x28672x4096": { + "label": "g16", + "description": "Llama-70B up-projection mid batch", + "M": 229376, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 56, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 114.265 + }, + "327680x28672x4096": { + "label": "g15", + "description": "Llama-70B up-projection large batch", + "M": 327680, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 32, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 162.136 + }, + "8192x8192x262144": { + "label": "g5", + "description": "K-dominant square", + "M": 8192, + "N": 8192, + "K": 262144, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 217.725 + }, + "16384x16384x131072": { + "label": "g1", + "description": "K-dominant large", + "M": 16384, + "N": 16384, + "K": 131072, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 16, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 223.748 + }, + "196608x18432x16384": { + "label": "g9", + "description": "Large balanced shape", + "M": 196608, + "N": 18432, + "K": 16384, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 32 + }, + "expected_iris_ms": 266.608 + }, + "262144x28672x8192": { + "label": "g8", + "description": "Large wide shape", + "M": 262144, + "N": 28672, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 128, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 32 + }, + "expected_iris_ms": 278.546 + }, + "4096x14336x4096": { + "label": "mixtral_gate", + "description": "Mixtral gate projection", + "M": 4096, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "expected_iris_ms": 1.933 + }, + "4096x11008x4096": { + "label": "llama7b_gate", + "description": "Llama-7B gate projection", + "M": 4096, + "N": 11008, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 128, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "expected_iris_ms": 1.946 + }, + "4096x4096x4096": { + "label": "pow2_4k", + "description": "Small power-of-2 square shape", + "M": 4096, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "expected_iris_ms": 1.512 + } + }, + "default_params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws2.json new file mode 100644 index 000000000..897be3f2c --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws2.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=2 NT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 loses vs PyTorch on all tested shapes. LDS overflow forces ns=1, imposing 15-35% perf penalty." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws4.json new file mode 100644 index 000000000..cc1f9d297 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws4.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=4 NT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses vs PyTorch on all tested shapes. Best measured: 0.856x. LDS overflow at K=4096." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws8.json new file mode 100644 index 000000000..873cb76e1 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws8.json @@ -0,0 +1,31 @@ +{ + "_meta": { + "description": "AG+MM ws=8 NT transpose on MI300X — heuristic defaults (no per-shape benchmarks yet)", + "source": "heuristic extrapolation from NN transpose champion data", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "data_tag": "heuristic", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM, B transposed (K×N → N×K)" + }, + "enabled": true, + "shapes": {}, + "default_config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "default_hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "heuristic_rules": { + "note": "Uses same heuristic as NN transpose. Shape-specific tuning pending." + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws2.json new file mode 100644 index 000000000..2fe67e154 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws2.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=2 TN transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 loses vs PyTorch on all tested shapes. LDS overflow forces ns=1, imposing 15-35% perf penalty." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws4.json new file mode 100644 index 000000000..c8977d5f0 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws4.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=4 TN transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses vs PyTorch on all tested shapes. Best measured: 0.856x. LDS overflow at K=4096." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws8.json new file mode 100644 index 000000000..df9a5b3f9 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws8.json @@ -0,0 +1,31 @@ +{ + "_meta": { + "description": "AG+MM ws=8 TN transpose on MI300X — heuristic defaults (no per-shape benchmarks yet)", + "source": "heuristic extrapolation from NN transpose champion data", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "data_tag": "heuristic", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM, A transposed (M×K → K×M)" + }, + "enabled": true, + "shapes": {}, + "default_config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "default_hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "heuristic_rules": { + "note": "Uses same heuristic as NN transpose. Shape-specific tuning pending." + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws2.json new file mode 100644 index 000000000..cc2c2497c --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws2.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=2 TT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 loses vs PyTorch on all tested shapes. LDS overflow forces ns=1, imposing 15-35% perf penalty." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws4.json new file mode 100644 index 000000000..55ee5f423 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws4.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=4 TT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses vs PyTorch on all tested shapes. Best measured: 0.856x. LDS overflow at K=4096." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws8.json new file mode 100644 index 000000000..a184b41a4 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws8.json @@ -0,0 +1,31 @@ +{ + "_meta": { + "description": "AG+MM ws=8 TT transpose on MI300X — heuristic defaults (no per-shape benchmarks yet)", + "source": "heuristic extrapolation from NN transpose champion data", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "data_tag": "heuristic", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM, both A and B transposed" + }, + "enabled": true, + "shapes": {}, + "default_config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "default_hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "heuristic_rules": { + "note": "Uses same heuristic as NN transpose. Shape-specific tuning pending." + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws2.json new file mode 100644 index 000000000..9c07592f1 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws2.json @@ -0,0 +1,17 @@ +{ + "_meta": { + "description": "AG+MM ws=2 on MI355X (gfx950) — defaults only, needs tuning", + "gpu": "AMD Instinct MI355X (gfx950)", + "date": "2026-04-15", + "validated": "unvalidated — no shape-specific tuning yet" + }, + "enabled": true, + "shapes": {}, + "default_params": { + "block_size_m": 256, "block_size_n": 256, "block_size_k": 64, + "group_size_m": 4, "num_xcds": 8, "allow_tf32": true, + "num_warps": 8, "num_stages": 2, + "k_per_flag": 16, "num_fetch_sms": 4, + "num_fetch_stages": 1, "first_stage_fetch_sms": 52 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws4.json new file mode 100644 index 000000000..3d64610a2 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws4.json @@ -0,0 +1,17 @@ +{ + "_meta": { + "description": "AG+MM ws=4 on MI355X (gfx950) — defaults only, needs tuning", + "gpu": "AMD Instinct MI355X (gfx950)", + "date": "2026-04-15", + "validated": "unvalidated — no shape-specific tuning yet" + }, + "enabled": true, + "shapes": {}, + "default_params": { + "block_size_m": 256, "block_size_n": 256, "block_size_k": 64, + "group_size_m": 4, "num_xcds": 8, "allow_tf32": true, + "num_warps": 8, "num_stages": 2, + "k_per_flag": 16, "num_fetch_sms": 4, + "num_fetch_stages": 1, "first_stage_fetch_sms": 52 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws8.json new file mode 100644 index 000000000..17fa7051a --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws8.json @@ -0,0 +1,246 @@ +{ + "_meta": { + "description": "Champion configs for HBM buffer AG+MM ws=8 on MI355X (gfx950)", + "source": "Optuna TPE + broad sweep", + "gpu": "AMD Instinct MI355X (gfx950)", + "date": "2026-04-15", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM" + }, + "enabled": true, + "shapes": { + "262144x8192x8192": { + "label": "g6", + "description": "Llama-70B MLP hidden x hidden \u2014 M-dominant", + "M": 262144, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 37.558 + }, + "131072x4096x4096": { + "label": "a3", + "description": "Output proj, 128K seq, 4K hidden", + "M": 131072, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 6.304 + }, + "65536x8192x28672": { + "label": "f4", + "description": "Llama 70B down, 64K seq", + "M": 65536, + "N": 8192, + "K": 28672, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 35.959 + }, + "32768x8192x8192": { + "label": "l1", + "description": "Training batch 32K", + "M": 32768, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 5.445 + }, + "65536x8192x4096": { + "label": "a2", + "description": "QKV proj, 64K seq, 8K hidden", + "M": 65536, + "N": 8192, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 2, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 5.775 + }, + "32768x4096x14336": { + "label": "f2", + "description": "Llama 7B down, 32K seq", + "M": 32768, + "N": 4096, + "K": 14336, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 5.961 + }, + "32768x4096x4096": { + "label": "a1", + "description": "QKV proj, 32K seq, 4K hidden", + "M": 32768, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 4, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 256 + }, + "expected_iris_ms": 1.864 + }, + "65536x28672x8192": { + "label": "f3", + "description": "Llama 70B gate/up, 64K seq", + "M": 65536, + "N": 28672, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 32.303 + }, + "16384x28672x4096": { + "label": "m2", + "description": "Large FFN up (Llama 70B-like)", + "M": 16384, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 256 + }, + "expected_iris_ms": 4.658 + }, + "32768x14336x4096": { + "label": "f1", + "description": "Llama 7B gate/up, 32K seq", + "M": 32768, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 4.86 + } + }, + "default_params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 4, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/regression_sizes.json b/benchmark/ops/all_gather_matmul/configs/regression_sizes.json new file mode 100644 index 000000000..40a497b0a --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/regression_sizes.json @@ -0,0 +1,101 @@ +{ + "_meta": { + "description": "Regression test sizes for HBM buffer AG+MM kernel across ws=2/4/8", + "source": "28 measured shapes (12 ws=8, 9 ws=4, 7 ws=2) from 3489 trials", + "gpu_target": "MI300X (gfx942)", + "date": "2026-04-13", + "usage": "from iris.ops import load_regression_sizes" + }, + "sizes": [ + { + "name": "g2_ws8", + "label": "g2", + "description": "Llama MLP variant — balanced large (highest speedup)", + "M": 131072, "N": 16384, "K": 16384, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.343, "ws8_tflops": 420.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g15_ws8", + "label": "g15", + "description": "Llama-70B up-projection large batch — highest TFLOPS", + "M": 327680, "N": 28672, "K": 4096, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.284, "ws8_tflops": 474.7}, + "regression_threshold_pct": 10 + }, + { + "name": "g14_ws8", + "label": "g14", + "description": "Llama-70B up-projection medium batch", + "M": 147456, "N": 28672, "K": 4096, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.288, "ws8_tflops": 466.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g16_ws8", + "label": "g16", + "description": "Llama-70B up-projection mid batch", + "M": 229376, "N": 28672, "K": 4096, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.277, "ws8_tflops": 471.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g5_ws8", + "label": "g5", + "description": "K-dominant square — M-small, needs bm=128", + "M": 8192, "N": 8192, "K": 262144, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.224, "ws8_tflops": 161.6}, + "regression_threshold_pct": 10 + }, + { + "name": "g6_ws8", + "label": "g6", + "description": "Llama-70B MLP hidden x hidden — M-dominant", + "M": 262144, "N": 8192, "K": 8192, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.200, "ws8_tflops": 253.0}, + "regression_threshold_pct": 10 + }, + { + "name": "g1_ws8", + "label": "g1", + "description": "K-dominant large — parity shape", + "M": 16384, "N": 16384, "K": 131072, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.136, "ws8_tflops": 314.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g5_ws2_disabled", + "label": "g5", + "description": "Best ws=2 shape — still loses vs PyTorch (0.887x). Verifies fallback.", + "M": 8192, "N": 8192, "K": 262144, + "tier": "disabled", + "world_sizes": [2], + "expected": {"ws2_speedup": 0.887, "ws2_disabled": true}, + "regression_threshold_pct": null + }, + { + "name": "g6_ws4_disabled", + "label": "g6", + "description": "Best ws=4 shape — still loses vs PyTorch (0.856x). Verifies fallback.", + "M": 262144, "N": 8192, "K": 8192, + "tier": "disabled", + "world_sizes": [4], + "expected": {"ws4_speedup": 0.856, "ws4_disabled": true}, + "regression_threshold_pct": null + } + ] +} diff --git a/benchmark/ops/all_gather_matmul/derive_params_copy_engine.py b/benchmark/ops/all_gather_matmul/derive_params_copy_engine.py new file mode 100755 index 000000000..7f513bacb --- /dev/null +++ b/benchmark/ops/all_gather_matmul/derive_params_copy_engine.py @@ -0,0 +1,596 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Parameter derivation for all_gather_matmul_copy_engine. + +This variant uses SDMA (copy engine) for data movement instead of fetch workgroups: + - Host (or device SDMA WGs) orchestrate remote tile transfers + - GEMM workgroups poll flags and compute, no fetch WGs + - Batch-based orchestration with coarse per-(batch, K-flag-group) synchronization + +Key differences from HBM buffer variant (derive_params.py): + - No fetcher workgroups (GEMM only) + - Two modes: host-initiated (default) vs device-initiated SDMA + - Batch-based transfer scheduling instead of wave-based pipeline + - Coarser synchronization granularity + +Usage: + python derive_params_copy_engine.py -m 16384 -n 2048 -k 16384 --world_size 8 + python derive_params_copy_engine.py -m 16384 -n 2048 -k 16384 -v --mode device +""" + +import argparse +import math + +# ── MI300X hardware defaults ────────────────────────────────────────────── +DEFAULT_NUM_CUS = 304 +DEFAULT_PEAK_TFLOPS_FP16 = 1300.0 +DEFAULT_HBM_BW_GBPS = 5300.0 +DEFAULT_L2_SIZE_BYTES = 256 * 1024 * 1024 +DEFAULT_NUM_XCDS = 8 +DEFAULT_WORLD_SIZE = 8 + +# Calibrated from MI300X trace data: the ratio of measured wall time to +# the CU-work-queue lower bound. Captures WG dispatch overhead, +# cross-XCD coherence latency, and pipeline bubble effects. +DEFAULT_SCHEDULING_FACTOR = 4.5 + + +# SDMA/copy engine specific latencies (calibrated from MI300X traces) +DEFAULT_SDMA_LATENCY_US = 2.0 # SDMA packet submission latency +DEFAULT_HOST_POST_OVERHEAD_US = 0.5 # Host API overhead per transfer +DEFAULT_DEVICE_POST_OVERHEAD_US = 1.0 # Device SDMA WG posting overhead +DEFAULT_FLAG_POLL_LATENCY_US = 0.1 # Flag detection latency + +# Performance parameters (to be calibrated) +DEFAULT_TFLOPS_ACHIEVED_RATIO = 0.85 # Achieved vs peak TFLOPS +DEFAULT_XGMI_BW_GBPS = 896.0 # Total XGMI bandwidth + + +# ── Block size heuristics (matches HBM buffer logic) ───────────────────────── + + +def _choose_block_sizes(M, N, K, K_local): + """Heuristic tile-size selection for MI300X MFMA. + + Matches the logic from derive_params.py to ensure shared memory limits + are respected and block sizes are consistent across variants. + """ + bk = 64 + + bm = 256 if M >= 8192 else 128 + while M % bm != 0 and bm > 64: + bm //= 2 + + if N >= 512: + bn = 256 + elif N >= 256: + bn = 256 if N % 256 == 0 else 128 + else: + bn = 128 + while N % bn != 0 and bn > 32: + bn //= 2 + + while K % bk != 0 and bk > 16: + bk //= 2 + while K_local % bk != 0 and bk > 16: + bk //= 2 + + nw = 8 if bm * bn >= 256 * 256 else 4 + return bm, bn, bk, nw + + +def _choose_k_per_flag(num_k_blocks, num_k_blocks_local, target_groups=8): + """Pick k_per_flag so that flag groups align to rank boundaries when possible. + + Identical logic to HBM buffer variant - rank alignment is important for + efficient all-gather patterns regardless of the transfer mechanism. + """ + if num_k_blocks % num_k_blocks_local == 0: + candidate = num_k_blocks_local + groups = num_k_blocks // candidate + if groups >= 4: + return candidate + + kpf = max(1, num_k_blocks // target_groups) + while num_k_blocks % kpf != 0 and kpf > 1: + kpf -= 1 + return kpf + + +def _choose_m_tiles_per_batch(num_m_tiles, num_n_tiles, tile_gemm_us, tile_transfer_us): + """Choose m_tiles_per_batch to minimize exposed communication time. + + Wave-based model: + - Wave 0: Transfer for first batch is fully exposed + - Wave 1+: If gemm_time >= transfer_time per batch, transfer is hidden + + Args: + num_m_tiles: Total M tiles + num_n_tiles: Total N tiles + tile_gemm_us: Per-tile GEMM time (microseconds) + tile_transfer_us: Per-M-tile transfer time (microseconds) + + Returns: + Optimal m_tiles_per_batch + """ + best_batch_size = num_m_tiles # Default: all tiles in one batch + best_time_us = float("inf") + + # Try divisors of num_m_tiles + candidates = [] + for d in range(1, num_m_tiles + 1): + if num_m_tiles % d == 0: + candidates.append(num_m_tiles // d) + + for m_batch in candidates: + num_batches = num_m_tiles // m_batch + + # Per-batch transfer time (for m_batch tiles across all K) + transfer_per_batch_us = m_batch * tile_transfer_us + + # Per-batch GEMM time (m_batch × num_n_tiles tiles) + gemm_per_batch_us = m_batch * num_n_tiles * tile_gemm_us + + # Wave 0: First batch - transfer exposed + wave0_us = transfer_per_batch_us + gemm_per_batch_us + + # Wave 1+: Remaining batches + if num_batches > 1: + if gemm_per_batch_us >= transfer_per_batch_us: + # Transfer hidden by GEMM + wave_rest_us = (num_batches - 1) * gemm_per_batch_us + else: + # Transfer exposed + wave_rest_us = (num_batches - 1) * (transfer_per_batch_us + gemm_per_batch_us) + else: + wave_rest_us = 0 + + total_time_us = wave0_us + wave_rest_us + + if total_time_us < best_time_us: + best_time_us = total_time_us + best_batch_size = m_batch + + return best_batch_size + + +# ── Roofline model (reused from HBM buffer) ────────────────────────────────── + + +def _tile_roofline(bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size): + """Compute achievable per-CU TFLOPS from tile arithmetic intensity. + + Identical to HBM buffer variant - roofline analysis is architecture-independent. + """ + tile_flops = 2 * bm * bn * bk + a_bytes = bm * bk * dtype_bytes + b_bytes = bk * bn * dtype_bytes + + b_total = K * N * dtype_bytes + staged_a_total = M * K * dtype_bytes + b_in_l2 = (staged_a_total <= l2_size) and (b_total <= l2_size) + + hbm_bytes = a_bytes + (0 if b_in_l2 else b_bytes) + intensity = tile_flops / max(hbm_bytes, 1) + + ridge = peak_tflops * 1e3 / hbm_bw_gbps + if intensity >= ridge: + roofline = peak_tflops + else: + roofline = hbm_bw_gbps * intensity / 1e3 + + return roofline, intensity, ridge, b_in_l2 + + +# ── Per-WG execution time models ──────────────────────────────────────── + + +def _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups, roofline_tflops, num_cus): + """Estimate per-WG GEMM execution time in microseconds. + + Identical to HBM buffer variant (lines 192-213 of derive_params.py). + """ + total_flops = 2 * bm * bn * K + per_cu_tflops = roofline_tflops / num_cus + + # Roofline-ideal per-WG time + ideal_us = total_flops / (per_cu_tflops * 1e6) + + # Single-occupancy overhead: imperfect latency hiding, instruction + # scheduling gaps, cross-XCD coherence on staged_a reads. + # Calibrated from MI300X traces: actual/ideal ≈ 1.2-1.3. + occupancy_factor = 1.25 if bm * bn >= 256 * 256 else 1.10 + + # Flag polling: acquire-semantics atomic per flag group + flag_us = num_flag_groups * 2.5 + + return ideal_us * occupancy_factor + flag_us + + +# ── Per-SDMA-WG timing (for device-initiated mode) ─────────────────────────── + + +def _sdma_wg_time_us(num_transfers_per_wg, bytes_per_transfer, sdma_latency_us, device_overhead_us, xgmi_bw_gbps): + """Estimate per-SDMA-WG execution time in microseconds. + + Analogous to _fetch_wg_time_us in HBM buffer variant. Each SDMA WG posts + transfers for one remote rank. + + Args: + num_transfers_per_wg: Number of transfers this WG must post + bytes_per_transfer: Size of each transfer in bytes + sdma_latency_us: SDMA packet submission latency + device_overhead_us: Device-side posting overhead per transfer + xgmi_bw_gbps: Available XGMI bandwidth (shared across all SDMA WGs) + + Returns: + Estimated execution time in microseconds + """ + # Posting overhead: WG serially posts all its transfers + post_time_us = num_transfers_per_wg * (device_overhead_us + sdma_latency_us) + + # Bandwidth time: SDMA engine executes transfers + # (Note: bandwidth is shared across all SDMA WGs; this is per-WG share) + total_bytes = num_transfers_per_wg * bytes_per_transfer + bandwidth_time_us = total_bytes / (xgmi_bw_gbps * 1e3) + + # Conservative: assume posting and transfer happen sequentially + # (In reality, some overlap possible depending on SDMA engine scheduling) + return post_time_us + bandwidth_time_us + + +# ── Kernel time estimation (reused from HBM buffer) ────────────────────────── + + +def _estimate_kernel_time(total_gemm_wgs, gemm_wg_us, total_sdma_wgs, sdma_wg_us, num_cus, scheduling_factor): + """Estimate kernel wall-clock time from the CU work queue model. + + Identical to HBM buffer variant. total_CU_work / num_CUs gives the ideal + (work-conserving) lower bound. The scheduling_factor captures GPU dispatch + overhead, cross-XCD coherence, and pipeline bubble effects. + """ + total_cu_work_us = total_gemm_wgs * gemm_wg_us + total_sdma_wgs * sdma_wg_us + + ideal_ms = total_cu_work_us / num_cus / 1e3 + estimated_ms = ideal_ms * scheduling_factor + return estimated_ms, ideal_ms + + +def derive( + M, + N, + K, + world_size, + link_bw, + num_cus, + peak_tflops, + hbm_bw_gbps, + l2_size, + scheduling_factor, + dtype_bytes, + device_initiated=None, +): + """Derive optimal parameters for all_gather_matmul_copy_engine. + + Matches the interface of derive_params.py (HBM buffer variant) but returns + copy-engine-specific parameters. + + Args: + M, N, K: GEMM dimensions (K is total across all ranks) + world_size: Number of GPUs + link_bw: XGMI bandwidth (not used - using hardcoded XGMI_BW_GBPS) + num_cus: Number of compute units + peak_tflops: Peak TFLOPS (not used - using hardcoded value) + hbm_bw_gbps: HBM bandwidth (not used - using hardcoded value) + l2_size: L2 cache size + scheduling_factor: Not used in copy engine (no work queue model) + dtype_bytes: 2 for fp16/bf16, 4 for fp32 + device_initiated: None (auto-select), True (force device mode), False (force host mode) + + Returns: + dict with kernel parameters and performance estimates + """ + K_local = K // world_size + + # 1. Tile sizes (matches HBM buffer logic) + bm, bn, bk, nw = _choose_block_sizes(M, N, K, K_local) + gm = 4 + num_m_tiles = M // bm + num_tiles_n = math.ceil(N / bn) + num_k_blocks = K // bk + num_k_blocks_local = K_local // bk + + # 2. Per-tile roofline (reused from HBM buffer) + roofline_tflops, intensity, ridge, b_in_l2 = _tile_roofline( + bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size + ) + + # 3. Communication model (link-limited) + total_remote_bytes = M * K_local * (world_size - 1) * dtype_bytes + total_link_bw = link_bw * (world_size - 1) + comm_time_ms = total_remote_bytes / (total_link_bw * 1e9) * 1e3 + + # 4. Compute model (roofline-limited) + total_flops = 2 * M * N * K + compute_time_ms = total_flops / (roofline_tflops * 1e12) * 1e3 + + ratio = comm_time_ms / compute_time_ms if compute_time_ms > 0 else 999 + + # 5. k_per_flag selection (matches HBM buffer) + kpf = _choose_k_per_flag(num_k_blocks, num_k_blocks_local) + num_flag_groups_k = num_k_blocks // kpf + + # 6. Per-tile GEMM time + gemm_wg_us_val = _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups_k, roofline_tflops, num_cus) + + # 7. Per-M-tile transfer time + # Each M-tile needs K_local data from (world_size-1) ranks + bytes_per_m_tile = bm * K_local * dtype_bytes * (world_size - 1) + transfer_bw_gbps = DEFAULT_XGMI_BW_GBPS / math.sqrt(world_size) # Congestion model + tile_transfer_us = bytes_per_m_tile / (transfer_bw_gbps * 1e3) + + # 8. m_tiles_per_batch selection (minimize exposed communication) + m_tiles_per_batch = _choose_m_tiles_per_batch(num_m_tiles, num_tiles_n, gemm_wg_us_val, tile_transfer_us) + + # Sanity check + assert num_m_tiles % m_tiles_per_batch == 0, ( + f"m_tiles_per_batch={m_tiles_per_batch} must divide num_m_tiles={num_m_tiles}" + ) + num_batches = num_m_tiles // m_tiles_per_batch + + # 9. device_initiated default + if device_initiated is None: + device_initiated = True + + # 10. Grid geometry + total_gemm_wgs = num_m_tiles * num_tiles_n + num_sdma_wgs = world_size - 1 if device_initiated else 0 + + # 11. Per-SDMA-WG time (for device-initiated mode) + if device_initiated: + num_transfers = num_batches * num_flag_groups_k * (world_size - 1) + transfers_per_sdma_wg = num_transfers // num_sdma_wgs + batch_m = m_tiles_per_batch * bm + flag_group_k = kpf * bk + bytes_per_transfer = batch_m * flag_group_k * dtype_bytes + sdma_wg_us = _sdma_wg_time_us( + transfers_per_sdma_wg, + bytes_per_transfer, + DEFAULT_SDMA_LATENCY_US, + DEFAULT_DEVICE_POST_OVERHEAD_US, + DEFAULT_XGMI_BW_GBPS, + ) + else: + sdma_wg_us = 0 + + # 12. Kernel time estimate (CU-work model) + if device_initiated: + est_kernel_ms, est_ideal_ms = _estimate_kernel_time( + total_gemm_wgs, gemm_wg_us_val, num_sdma_wgs, sdma_wg_us, num_cus, scheduling_factor + ) + else: + # Host-initiated: no device WG orchestration + est_kernel_ms = compute_time_ms + est_ideal_ms = compute_time_ms + + # 13. Pipeline time (from wave model) + transfer_per_batch_us = m_tiles_per_batch * tile_transfer_us + gemm_per_batch_us = m_tiles_per_batch * num_tiles_n * gemm_wg_us_val + wave0_us = transfer_per_batch_us + gemm_per_batch_us + if num_batches > 1: + if gemm_per_batch_us >= transfer_per_batch_us: + wave_rest_us = (num_batches - 1) * gemm_per_batch_us + else: + wave_rest_us = (num_batches - 1) * (transfer_per_batch_us + gemm_per_batch_us) + else: + wave_rest_us = 0 + pipeline_time_ms = (wave0_us + wave_rest_us) / 1000 + overlap_efficiency = compute_time_ms / pipeline_time_ms if pipeline_time_ms > 0 else 0 + + # 13. Staged A size + + # 14. Standalone GEMM estimate (rocBLAS-class efficiency for comparison) + standalone_gemm_eff = 0.30 + standalone_tflops = roofline_tflops * standalone_gemm_eff + standalone_gemm_ms = total_flops / (standalone_tflops * 1e12) * 1e3 + pytorch_est_ms = comm_time_ms + standalone_gemm_ms + + staged_a_gb = M * K * dtype_bytes / (1024**3) + + return dict( + block_size_m=bm, + block_size_n=bn, + block_size_k=bk, + group_size_m=gm, + num_warps=nw, + k_per_flag=kpf, + m_tiles_per_batch=m_tiles_per_batch, + device_initiated=device_initiated, + # derived + K_local=K_local, + num_m_tiles=num_m_tiles, + num_tiles_n=num_tiles_n, + num_k_blocks=num_k_blocks, + num_flag_groups_k=num_flag_groups_k, + num_batches=num_batches, + num_k_flag_groups=num_flag_groups_k, + # roofline + roofline_tflops=roofline_tflops, + tile_intensity=intensity, + ridge_point=ridge, + b_in_l2=b_in_l2, + # per-WG timing + gemm_wg_us=gemm_wg_us_val, + sdma_wg_us=sdma_wg_us, + # grid + total_gemm_wgs=total_gemm_wgs, + num_sdma_wgs=num_sdma_wgs, + # estimates + total_remote_bytes=total_remote_bytes, + total_link_bw=total_link_bw, + comm_time_ms=comm_time_ms, + total_flops=total_flops, + compute_time_ms=compute_time_ms, + ratio=ratio, + est_kernel_ms=est_kernel_ms, + est_ideal_ms=est_ideal_ms, + standalone_gemm_ms=standalone_gemm_ms, + pytorch_est_ms=pytorch_est_ms, + staged_a_gb=staged_a_gb, + scheduling_factor=scheduling_factor, + # copy engine performance estimates + pipeline_time_ms=pipeline_time_ms, + overlap_efficiency=overlap_efficiency, + transfer_time_per_batch_us=transfer_per_batch_us, + gemm_time_per_batch_us=gemm_per_batch_us, + ) + + +# ── Formatting helpers ─────────────────────────────────────────────────── + + +def _fmt_bytes(n): + if n >= 1024**3: + return f"{n / 1024**3:.2f} GB" + if n >= 1024**2: + return f"{n / 1024**2:.1f} MB" + return f"{n / 1024:.1f} KB" + + +def _fmt_flops(n): + if n >= 1e15: + return f"{n / 1e15:.2f} PFLOPs" + return f"{n / 1e12:.2f} TFLOPs" + + +def _fmt_tflops(t): + return f"{t:.0f} TFLOPS" + + +# ── Analysis output ────────────────────────────────────────────────────── + + +def print_analysis(p, M, N, K_local, world_size): + """Print detailed performance analysis (matches HBM buffer format).""" + K = K_local * world_size + + print("\n" + "=" * 80) + print("ALL_GATHER_MATMUL_COPY_ENGINE: PERFORMANCE MODEL & DERIVED PARAMETERS") + print("=" * 80) + + print("\nProblem Size:") + print(f" M = {M}, N = {N}, K_local = {K_local}, K_total = {K}, world_size = {world_size}") + print(f" Total GEMM: ({M}, {K}) @ ({K}, {N}) = ({M}, {N})") + print(f" Staged A buffer: {p['staged_a_gb']:.2f} GB") + + print("\nBlock Sizes:") + print(f" BLOCK_M = {p['block_size_m']}, BLOCK_N = {p['block_size_n']}, BLOCK_K = {p['block_size_k']}") + print(f" num_warps = {p['num_warps']}") + + print("\nGEMM Analysis (Roofline):") + print(f" Arithmetic Intensity: {p['tile_intensity']:.2f} FLOPs/byte") + print(f" Ridge Point: {p['ridge_point']:.2f} FLOPs/byte") + print(f" B in L2: {p['b_in_l2']}") + print(f" Achieved TFLOPS: {p['roofline_tflops']:.0f}") + print(f" Compute Time: {p['compute_time_ms']:.2f} ms") + + print("\nCommunication Analysis:") + print(f" Total Recv: {p['total_remote_bytes'] / 1e9:.2f} GB per rank") + print(f" Comm Time: {p['comm_time_ms']:.2f} ms") + + print("\nCopy Engine Parameters:") + print(f" k_per_flag: {p['k_per_flag']}") + print(f" m_tiles_per_batch: {p['m_tiles_per_batch']}") + print(f" num_batches: {p['num_batches']}") + print(f" device_initiated: {p['device_initiated']}") + + print("\nPipeline Performance:") + print(f" Pipeline Time: {p['pipeline_time_ms']:.2f} ms") + print(f" Overlap Efficiency: {p['overlap_efficiency'] * 100:.1f}%") + + # Comparison + baseline = p["comm_time_ms"] + p["compute_time_ms"] + speedup = baseline / p["pipeline_time_ms"] if p["pipeline_time_ms"] > 0 else 0 + print(f"\nSpeedup vs Sequential (AllGather→GEMM): {speedup:.2f}x") + print(f" Sequential: {baseline:.2f} ms") + print(f" Pipelined: {p['pipeline_time_ms']:.2f} ms") + print("=" * 80 + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Derive optimal parameters for all_gather_matmul_copy_engine", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, required=True, help="M dimension (rows of output)") + parser.add_argument("-n", type=int, required=True, help="N dimension (cols of output)") + parser.add_argument("-k", type=int, required=True, help="K dimension (total reduction dim)") + parser.add_argument("-w", "--world_size", type=int, default=DEFAULT_WORLD_SIZE, help="Number of GPUs") + parser.add_argument("--mode", type=str, default="auto", choices=["auto", "host", "device"]) + parser.add_argument( + "--link_bw", + type=float, + default=None, + help="Per-link XGMI bandwidth in GB/s (one direction). Omit to auto-profile via GPU-to-GPU copies.", + ) + parser.add_argument("--num_cus", type=int, default=DEFAULT_NUM_CUS, help="Number of compute units") + parser.add_argument("--peak_tflops", type=float, default=DEFAULT_PEAK_TFLOPS_FP16, help="Peak fp16 TFLOPS") + parser.add_argument("--hbm_bw", type=float, default=DEFAULT_HBM_BW_GBPS, help="HBM bandwidth in GB/s") + parser.add_argument( + "--scheduling_factor", + type=float, + default=DEFAULT_SCHEDULING_FACTOR, + help="CU scheduling overhead factor (calibrated from traces)", + ) + + args, passthrough = parser.parse_known_args() + + if args.k % args.world_size != 0: + parser.error(f"K ({args.k}) must be divisible by world_size ({args.world_size})") + + # Convert mode string to device_initiated parameter + if args.mode == "auto": + device_initiated = None + elif args.mode == "device": + device_initiated = True + else: # host + device_initiated = False + + p = derive( + args.m, + args.n, + args.k, + args.world_size, + args.link_bw, + args.num_cus, + args.peak_tflops, + args.hbm_bw, + DEFAULT_L2_SIZE_BYTES, + args.scheduling_factor, + dtype_bytes=2, + device_initiated=device_initiated, + ) + + # Print analysis + print_analysis(p, args.m, args.n, args.k_local, args.world_size) + + # Print benchmark command + print("Benchmark command:") + print(f" torchrun --nproc_per_node={args.world_size} \\") + print(" benchmark/ops/all_gather_matmul/benchmark_copy_engine.py \\") + print(f" -m {args.m} -n {args.n} -k {args.k_local} \\") + print( + f" --block_size_m {p['block_size_m']} --block_size_n {p['block_size_n']} --block_size_k {p['block_size_k']} \\" + ) + print(f" --k_per_flag {p['k_per_flag']} \\") + print(f" --m_tiles_per_batch {p['m_tiles_per_batch']} \\") + if p["device_initiated"]: + print(" --device_initiated \\") + print(" --benchmark --validate") + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/all_gather_matmul/tune_copy_engine.py b/benchmark/ops/all_gather_matmul/tune_copy_engine.py new file mode 100644 index 000000000..c6db908b3 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/tune_copy_engine.py @@ -0,0 +1,608 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tune ``m_tiles_per_batch`` for ``all_gather_matmul_copy_engine``. + +Unlike the more general matmul tuners, this script keeps the GEMM tile geometry +under the tritonBLAS selector and only sweeps the batch size used by the +host/device copy-engine path. +""" + +import argparse +import importlib.util +import json +import math +import os +import subprocess +import sys +import time +from datetime import datetime +from pathlib import Path + +import torch + +from tritonblas.matmul import _make_matmul_selector + + +BENCHMARK_TARGETS = { + "device": { + "script": "benchmark/ops/all_gather_matmul/benchmark_copy_engine.py", + "operation_name": "all_gather_matmul_copy_engine", + "output_stem": "tune_copy_engine", + "flags": ["--force-device-initiated"], + }, + "host": { + "script": "benchmark/ops/all_gather_matmul/benchmark_copy_engine.py", + "operation_name": "all_gather_matmul_host_copy_engine", + "output_stem": "tune_host_copy_engine", + "flags": ["--force-host-initiated", "--no-trace"], + }, +} + + +def _load_sweep_dimension_configs(): + """Load the shared sweep dimension list from benchmark/ops/sweep_benchmarks.py.""" + sweep_path = Path(__file__).resolve().parents[1] / "sweep_benchmarks.py" + module_name = "_shared_sweep_benchmarks" + spec = importlib.util.spec_from_file_location(module_name, sweep_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load sweep benchmark config from {sweep_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + dimension_configs = [] + for config in module.DIMENSION_CONFIGS: + if isinstance(config, dict): + dimension_configs.append((config["m_local"], config["n"], config["k"])) + else: + dimension_configs.append(tuple(config)) + return dimension_configs + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Tune m_tiles_per_batch for all_gather_matmul_copy_engine.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="M_local dimension") + parser.add_argument("-n", type=int, default=2048, help="N dimension") + parser.add_argument("-k", type=int, default=131072, help="K dimension") + parser.add_argument( + "--benchmark", + type=str, + default="device", + choices=sorted(BENCHMARK_TARGETS.keys()), + help="Which copy-engine benchmark to tune", + ) + parser.add_argument( + "--use_sweep_dimensions", + action="store_true", + default=True, + help="Use the shared dimension list from benchmark/ops/sweep_benchmarks.py", + ) + parser.add_argument( + "--single_shape", + dest="use_sweep_dimensions", + action="store_false", + help="Only tune the single shape given by -m/-n/-k", + ) + parser.add_argument("--nproc", type=int, default=8, help="Number of ranks / GPUs") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype passed through to the benchmark", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size passed to the benchmark") + parser.add_argument("--num_sms", type=int, default=None, help="Optional NUM_SMS override for the benchmark") + parser.add_argument("--num_xcds", type=int, default=None, help="Optional NUM_XCDS override for the benchmark") + parser.add_argument( + "--m_tiles_per_batch", + type=int, + nargs="+", + default=None, + help="Explicit sweep values. If omitted, derive a candidate set from selector geometry.", + ) + parser.add_argument( + "--all_values", + action="store_true", + help="Sweep every value from 1..num_m_tiles instead of the heuristic candidate set", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory (auto-generated if unset)") + parser.add_argument("--dry_run", action="store_true", help="Print the candidate list and exit") + parser.add_argument("--skip_validation", action="store_true", help="Skip validation for faster sweeps") + parser.add_argument("--timeout", type=int, default=600, help="Per-run timeout in seconds") + return parser.parse_args() + + +def _dtype_from_name(name: str) -> torch.dtype: + return { + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + }[name] + + +def _selector_metadata(m_local: int, n: int, k: int, dtype: torch.dtype): + device = torch.device("cuda:0") + selector = _make_matmul_selector( + m_local, + n, + k, + dtype, + dtype, + dtype, + device, + streamk=False, + ) + + block_m = selector.block_m + block_n = selector.block_n + block_k = selector.block_k + group_size_m = selector.group_m + num_stages = getattr(selector, "num_stages", 2) + waves_per_eu = getattr(selector, "waves_per_eu", 0) + active_cus = getattr(selector, "_ACTIVE_CU", None) + if active_cus is None: + active_cus = getattr(selector._hardware, "N_CU", getattr(selector._hardware, "NUM_XCD", 1)) + + num_tiles_m = math.ceil(m_local / block_m) + num_tiles_n = math.ceil(n / block_n) + tiles_per_group = max(1, group_size_m * num_tiles_n) + groups_per_wave = max(1, active_cus // tiles_per_group) + m_tiles_per_wave = min(num_tiles_m, groups_per_wave * group_size_m) + + return { + "selector": selector, + "block_size_m": block_m, + "block_size_n": block_n, + "block_size_k": block_k, + "group_size_m": group_size_m, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "active_cus": active_cus, + "num_tiles_m": num_tiles_m, + "num_tiles_n": num_tiles_n, + "tiles_per_group": tiles_per_group, + "groups_per_wave": groups_per_wave, + "m_tiles_per_wave": m_tiles_per_wave, + } + + +def _candidate_values( + num_tiles_m: int, group_size_m: int, groups_per_wave: int, m_tiles_per_wave: int, sweep_all: bool +): + if sweep_all: + return list(range(1, num_tiles_m + 1)) + + values = {1, num_tiles_m} + + power = 1 + while power <= num_tiles_m: + values.add(power) + power *= 2 + + # Keep a small number of shape-aware anchors even in sparse mode. + for candidate in (group_size_m, groups_per_wave, m_tiles_per_wave): + if 1 <= candidate <= num_tiles_m: + values.add(candidate) + + return sorted(values) + + +def _build_command(args, output_path: str, m_tiles_per_batch: int): + target = BENCHMARK_TARGETS[args.benchmark] + cmd = [ + "torchrun", + "--nproc_per_node", + str(args.nproc), + target["script"], + "-m", + str(args.m), + "-n", + str(args.n), + "-k", + str(args.k), + "--datatype", + args.datatype, + "--heap_size", + str(args.heap_size), + "--m_tiles_per_batch", + str(m_tiles_per_batch), + "--output_file", + output_path, + "-b", + ] + cmd.extend(target["flags"]) + + if not args.skip_validation: + cmd.append("-v") + if args.num_sms is not None: + cmd.extend(["--num_sms", str(args.num_sms)]) + if args.num_xcds is not None: + cmd.extend(["--num_xcds", str(args.num_xcds)]) + + return cmd + + +def _parse_json_output(json_path: Path): + result = { + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": None, + "group_size_m": None, + "block_size_m": None, + "block_size_n": None, + "block_size_k": None, + "m_tiles_per_batch": None, + "output_tile_size_m": None, + "output_tile_size_n": None, + "output_tile_size_k": None, + "num_stages": None, + "waves_per_eu": None, + "active_cus": None, + "num_tiles_m": None, + "num_tiles_n": None, + "tiles_per_group": None, + "groups_per_wave": None, + "m_tiles_per_wave": None, + "m_tiles_first_wave": None, + "schedule_iterations": None, + "num_batches": None, + "last_batch_m_tiles": None, + "m_tiles_per_batch_over_wave": None, + "gemm_wg_us": None, + "scatter_wg_us": None, + "bottleneck": None, + "ratio": None, + "roofline_tflops": None, + "intensity": None, + } + + try: + with open(json_path, "r") as f: + data = json.load(f) + + result["iris_ms"] = data.get("avg_ms") + result["iris_tflops"] = data.get("tflops") + result["iris_bw_gbps"] = data.get("bandwidth_gbps") + result["validation"] = "PASSED" if data.get("success") is True else ("FAILED" if "success" in data else None) + result["group_size_m"] = data.get("group_size_m") + result["block_size_m"] = data.get("block_size_m") + result["block_size_n"] = data.get("block_size_n") + result["block_size_k"] = data.get("block_size_k") + result["m_tiles_per_batch"] = data.get("m_tiles_per_batch") + result["output_tile_size_m"] = data.get("output_tile_size_m", data.get("block_size_m")) + result["output_tile_size_n"] = data.get("output_tile_size_n", data.get("block_size_n")) + result["output_tile_size_k"] = data.get("output_tile_size_k", data.get("block_size_k")) + result["num_stages"] = data.get("num_stages") + result["waves_per_eu"] = data.get("waves_per_eu") + result["active_cus"] = data.get("active_cus") + result["num_tiles_m"] = data.get("num_tiles_m") + result["num_tiles_n"] = data.get("num_tiles_n") + result["tiles_per_group"] = data.get("tiles_per_group") + result["groups_per_wave"] = data.get("groups_per_wave") + result["m_tiles_per_wave"] = data.get("m_tiles_per_wave") + result["m_tiles_first_wave"] = data.get("m_tiles_first_wave") + result["schedule_iterations"] = data.get("schedule_iterations") + result["num_batches"] = data.get("num_batches") + result["last_batch_m_tiles"] = data.get("last_batch_m_tiles") + result["m_tiles_per_batch_over_wave"] = data.get("m_tiles_per_batch_over_wave") + result["gemm_wg_us"] = data.get("gemm_wg_us") + result["scatter_wg_us"] = data.get("scatter_wg_us") + result["bottleneck"] = data.get("bottleneck") + result["ratio"] = data.get("ratio") + result["roofline_tflops"] = data.get("roofline_tflops") + result["intensity"] = data.get("intensity") + except Exception: + pass + + return result + + +def _print_selector_summary(meta, candidates): + tile_shape = f"{meta['block_size_m']}x{meta['block_size_n']}x{meta['block_size_k']}" + print("\nSelector-derived geometry") + print(f" output tile size : {tile_shape}") + print(f" group_size_m : {meta['group_size_m']}") + print(f" num_stages : {meta['num_stages']}") + print(f" waves_per_eu : {meta['waves_per_eu']}") + print(f" active CUs : {meta['active_cus']}") + print(f" tile grid : {meta['num_tiles_m']} M-tiles x {meta['num_tiles_n']} N-tiles") + print(f" tiles per group : {meta['tiles_per_group']}") + print(f" groups per wave/stage : {meta['groups_per_wave']}") + print(f" M-tiles per wave/stage : {meta['m_tiles_per_wave']}") + print(f" sweep m_tiles_per_batch : {candidates}") + + +def _shape_tag(m_local: int, n: int, k: int): + return f"M{m_local}_N{n}_K{k}" + + +def main(): + args = parse_args() + dtype = _dtype_from_name(args.datatype) + if args.use_sweep_dimensions: + dimension_configs = _load_sweep_dimension_configs() + else: + dimension_configs = [(args.m, args.n, args.k)] + + if args.output_dir: + output_dir = Path(args.output_dir) + else: + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + output_stem = BENCHMARK_TARGETS[args.benchmark]["output_stem"] + output_dir = Path(f"benchmark/ops/all_gather_matmul/{output_stem}_{ts}") + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n{'=' * 100}") + print(f" {BENCHMARK_TARGETS[args.benchmark]['operation_name']} — m_tiles_per_batch tuning") + if args.use_sweep_dimensions: + print(f" Shapes: sweep_benchmarks.py DIMENSION_CONFIGS ({len(dimension_configs)} shapes)") + else: + print(f" M_local={args.m} N={args.n} K={args.k} nproc={args.nproc} dtype={args.datatype}") + print(f" benchmark={args.benchmark} nproc={args.nproc} dtype={args.datatype}") + print(f" Output dir: {output_dir}") + print(f" Validation: {'OFF' if args.skip_validation else 'ON'}") + print(f"{'=' * 100}") + + if args.dry_run: + print("") + for m_local, n, k in dimension_configs: + meta = _selector_metadata(m_local, n, k, dtype) + if args.m_tiles_per_batch is not None: + candidates = sorted({value for value in args.m_tiles_per_batch if 1 <= value <= meta["num_tiles_m"]}) + else: + candidates = _candidate_values( + meta["num_tiles_m"], + meta["group_size_m"], + meta["groups_per_wave"], + meta["m_tiles_per_wave"], + args.all_values, + ) + print(f"Shape {_shape_tag(m_local, n, k)}") + _print_selector_summary(meta, candidates) + print("") + print("Dry run only; no benchmarks launched.") + return + + env = os.environ.copy() + env["HSA_NO_SCRATCH_RECLAIM"] = "1" + + results = [] + total_start = time.time() + + for shape_idx, (m_local, n, k) in enumerate(dimension_configs, start=1): + meta = _selector_metadata(m_local, n, k, dtype) + if args.m_tiles_per_batch is not None: + candidates = sorted({value for value in args.m_tiles_per_batch if 1 <= value <= meta["num_tiles_m"]}) + else: + candidates = _candidate_values( + meta["num_tiles_m"], + meta["group_size_m"], + meta["groups_per_wave"], + meta["m_tiles_per_wave"], + args.all_values, + ) + + if not candidates: + raise ValueError(f"No valid m_tiles_per_batch values to test for shape {_shape_tag(m_local, n, k)}") + + shape_tag = _shape_tag(m_local, n, k) + shape_output_dir = output_dir / shape_tag + shape_output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n{'=' * 100}") + print(f"[{shape_idx}/{len(dimension_configs)}] Shape {shape_tag}") + _print_selector_summary(meta, candidates) + + for idx, m_tiles_per_batch in enumerate(candidates, start=1): + label = f"{shape_tag}_mtpb{m_tiles_per_batch}" + json_path = shape_output_dir / f"results_mtpb{m_tiles_per_batch}.json" + log_path = shape_output_dir / f"log_mtpb{m_tiles_per_batch}.txt" + cmd_args = argparse.Namespace(**vars(args)) + cmd_args.m = m_local + cmd_args.n = n + cmd_args.k = k + cmd = _build_command(cmd_args, str(json_path), m_tiles_per_batch) + cmd_str = " ".join(cmd) + + print(f"\n{'-' * 80}") + print(f"[{idx}/{len(candidates)}] m_tiles_per_batch={m_tiles_per_batch}") + print(f" $ HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}") + + started = time.time() + try: + proc = subprocess.run( + cmd, + env=env, + capture_output=True, + text=True, + timeout=args.timeout, + ) + elapsed = time.time() - started + parsed = _parse_json_output(json_path) + json_ok = json_path.exists() + + results.append( + { + "shape": {"m_local": m_local, "n": n, "k": k}, + "shape_tag": shape_tag, + "label": label, + "m_tiles_per_batch": m_tiles_per_batch, + "iris_ms": parsed["iris_ms"], + "iris_tflops": parsed["iris_tflops"], + "iris_bw_gbps": parsed["iris_bw_gbps"], + "validation": parsed["validation"], + "benchmark_json": parsed, + "returncode": proc.returncode, + "elapsed_s": round(elapsed, 1), + "json_path": str(json_path) if json_ok else None, + } + ) + + summary = [] + if parsed["iris_tflops"] is not None: + summary.append(f"{parsed['iris_tflops']:.2f} TFLOPS") + if parsed["iris_ms"] is not None: + summary.append(f"{parsed['iris_ms']:.3f} ms") + if parsed["iris_bw_gbps"] is not None: + summary.append(f"{parsed['iris_bw_gbps']:.1f} GB/s") + if parsed["validation"] is not None: + summary.append(f"valid={parsed['validation']}") + summary.append("json=OK" if json_ok else "json=MISSING") + if proc.returncode != 0: + summary.append(f"EXIT={proc.returncode}") + print(f" => {' | '.join(summary)} ({elapsed:.0f}s)") + + with open(log_path, "w") as f: + f.write(f"COMMAND: HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}\n") + f.write(f"EXIT CODE: {proc.returncode}\n") + f.write(f"ELAPSED: {elapsed:.1f}s\n\n") + f.write("=== STDOUT ===\n") + f.write(proc.stdout) + f.write("\n=== STDERR ===\n") + f.write(proc.stderr) + + except subprocess.TimeoutExpired as exc: + elapsed = time.time() - started + results.append( + { + "shape": {"m_local": m_local, "n": n, "k": k}, + "shape_tag": shape_tag, + "label": label, + "m_tiles_per_batch": m_tiles_per_batch, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": "TIMEOUT", + "benchmark_json": {}, + "returncode": -1, + "elapsed_s": round(elapsed, 1), + "json_path": None, + } + ) + print(f" => TIMEOUT after {args.timeout}s") + with open(log_path, "w") as f: + f.write(f"COMMAND: HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}\n") + f.write(f"TIMEOUT: {args.timeout}s\n\n") + f.write(getattr(exc, "stdout", "") or "") + f.write("\n") + f.write(getattr(exc, "stderr", "") or "") + + total_elapsed = time.time() - total_start + + print(f"\n{'=' * 112}") + print(f" TUNING RESULTS | {len(dimension_configs)} shapes | {len(results)} runs in {total_elapsed:.0f}s") + print(f"{'=' * 112}") + print( + f" {'#':>3} {'Shape':<24} {'m_tiles_per_batch':>17} {'ms':>8} {'TFLOPS':>8} " + f"{'GB/s':>8} {'Valid':>8} {'JSON':>4}" + ) + print(f" {'-' * 108}") + + for idx, result in enumerate(results, start=1): + ms_s = f"{result['iris_ms']:.3f}" if result["iris_ms"] is not None else "--" + tf_s = f"{result['iris_tflops']:.2f}" if result["iris_tflops"] is not None else "--" + bw_s = f"{result['iris_bw_gbps']:.1f}" if result["iris_bw_gbps"] is not None else "--" + valid_s = (result["validation"] or "--")[:8] + json_s = "Y" if result["json_path"] else "N" + best_tag = "" + if result["iris_tflops"] is not None: + best_value = max((x["iris_tflops"] for x in results if x["iris_tflops"] is not None), default=None) + if best_value is not None and result["iris_tflops"] == best_value: + best_tag = " *" + + print( + f" {idx:>3} {result['shape_tag']:<24} {result['m_tiles_per_batch']:>17} {ms_s:>8} {tf_s:>8} " + f"{bw_s:>8} {valid_s:>8} {json_s:>4}{best_tag}" + ) + + valid_results = [result for result in results if result["iris_tflops"] is not None] + if valid_results: + best = max(valid_results, key=lambda result: result["iris_tflops"]) + worst = min(valid_results, key=lambda result: result["iris_tflops"]) + best_json = best["benchmark_json"] + tile_m = best_json.get("output_tile_size_m") or meta["block_size_m"] + tile_n = best_json.get("output_tile_size_n") or meta["block_size_n"] + tile_k = best_json.get("output_tile_size_k") or meta["block_size_k"] + best_group_size_m = best_json.get("group_size_m") or meta["group_size_m"] + best_m_tiles_per_wave = best_json.get("m_tiles_per_wave") or meta["m_tiles_per_wave"] + tile_shape = f"{tile_m}x{tile_n}x{tile_k}" + + print("\nBest configuration") + print(f" m_tiles_per_batch : {best['m_tiles_per_batch']}") + print(f" avg_ms : {best['iris_ms']:.3f}") + print(f" tflops : {best['iris_tflops']:.2f}") + print(f" bandwidth_gbps : {best['iris_bw_gbps']:.1f}") + print(f" output tile size : {tile_shape}") + print(f" group_size_m : {best_group_size_m}") + print(f" M-tiles per wave/stage : {best_m_tiles_per_wave}") + if best_json.get("groups_per_wave") is not None: + print(f" groups per wave/stage : {best_json['groups_per_wave']}") + if best_json.get("num_batches") is not None: + print(f" num_batches : {best_json['num_batches']}") + if best_json.get("last_batch_m_tiles") is not None: + print(f" last_batch_m_tiles : {best_json['last_batch_m_tiles']}") + if best_json.get("ratio") is not None: + print(f" scatter/gemm ratio : {best_json['ratio']:.2f}x") + if best_json.get("bottleneck") is not None: + print(f" bottleneck : {best_json['bottleneck']}") + + print("\nSpread") + print( + f" best : {best['iris_tflops']:.2f} TFLOPS @ m_tiles_per_batch={best['m_tiles_per_batch']}" + ) + print( + f" worst : {worst['iris_tflops']:.2f} TFLOPS @ m_tiles_per_batch={worst['m_tiles_per_batch']}" + ) + if worst["iris_tflops"] and best["iris_tflops"]: + print(f" best / worst : {best['iris_tflops'] / worst['iris_tflops']:.2f}x") + + results_path = output_dir / "results.json" + with open(results_path, "w") as f: + json.dump( + { + "meta": { + "dimension_configs": [{"m_local": m_local, "n": n, "k": k} for m_local, n, k in dimension_configs], + "use_sweep_dimensions": args.use_sweep_dimensions, + "benchmark": args.benchmark, + "benchmark_script": BENCHMARK_TARGETS[args.benchmark]["script"], + "nproc": args.nproc, + "datatype": args.datatype, + "timestamp": datetime.now().isoformat(), + "total_elapsed_s": round(total_elapsed, 1), + "candidate_generation": { + "block_size_m": meta["block_size_m"], + "block_size_n": meta["block_size_n"], + "block_size_k": meta["block_size_k"], + "group_size_m": meta["group_size_m"], + "num_stages": meta["num_stages"], + "waves_per_eu": meta["waves_per_eu"], + "active_cus": meta["active_cus"], + "num_tiles_m": meta["num_tiles_m"], + "num_tiles_n": meta["num_tiles_n"], + "tiles_per_group": meta["tiles_per_group"], + "groups_per_wave": meta["groups_per_wave"], + "m_tiles_per_wave": meta["m_tiles_per_wave"], + }, + "candidates": candidates, + }, + "results": results, + }, + f, + indent=2, + ) + + print(f"\nSummary JSON : {results_path}") + print(f"Per-run JSON : {output_dir}/results_*.json") + print(f"Per-run logs : {output_dir}/log_*.txt") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/bench_all_gather_matmul.py b/benchmark/ops/bench_all_gather_matmul.py index 9a50d3180..6a5879361 100644 --- a/benchmark/ops/bench_all_gather_matmul.py +++ b/benchmark/ops/bench_all_gather_matmul.py @@ -2,11 +2,28 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. -"""Benchmark for fused all-gather + GEMM (iris.ops).""" +"""Benchmark for all-gather + GEMM: RCCL baseline vs iris HBM-buffer prefetch. + +The HBM-buffer benchmark automatically loads tuned kernel parameters from +configs/{arch}/{transpose}/ws{N}.json when available. Run with --list-configs +to see which shapes have tuned configs for the current GPU. +""" + +import sys +import os import torch +import torch.distributed as dist +import tritonblas import iris.bench as bench -from iris.ops import FusedConfig, all_gather_matmul_preamble +from iris.ops import FusedConfig +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer as _hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "all_gather_matmul")) +from auto_config import select_ag_mm_config @bench.register @@ -16,24 +33,160 @@ @bench.axis("K", [8192]) @bench.axis("dtype", [torch.float16]) def all_gather_matmul(state, ctx): + """Iris fused all-gather + GEMM baseline.""" M, N, K = state["M"], state["N"], state["K"] dtype = state["dtype"] world_size = ctx.get_num_ranks() + rank = ctx.get_rank() K_local = K // world_size - A_sharded = ctx.zeros((M, K_local), dtype=dtype) - A_sharded.fill_(1.0) - B = torch.randn((K, N), device="cuda", dtype=dtype) - C = torch.zeros((M, N), device="cuda", dtype=dtype) - + torch.manual_seed(123 + rank) + A_sharded = ctx.randn((M, K_local), dtype=dtype) + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) config = FusedConfig() - workspace = all_gather_matmul_preamble(ctx, A_sharded, B, config) state.set_flops(2 * M * N * K) state.set_bytes((world_size - 1) * M * K_local * A_sharded.element_size()) state.exec( - lambda: ctx.ops.all_gather_matmul(C, A_sharded, B, config=config, workspace=workspace), + lambda: ctx.ops.all_gather_matmul(C, A_sharded, B, config=config), + ) + + +@bench.register +@bench.axis("num_ranks", [2, 4, 8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def rccl_all_gather_matmul(state, ctx): + """PyTorch/RCCL baseline: all_gather + torch.cat + torch.mm.""" + M, N, K = state["M"], state["N"], state["K"] + dtype = state["dtype"] + world_size = dist.get_world_size() + rank = ctx.get_rank() + K_local = K // world_size + + torch.manual_seed(123 + rank) + A_sharded = ctx.randn((M, K_local), dtype=dtype) + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) + A_gathered_parts = [ctx.zeros((M, K_local), dtype=dtype) for _ in range(world_size)] + A_gathered = ctx.zeros((M, K), dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) + + state.set_flops(2 * M * N * K) + state.set_bytes((world_size - 1) * M * K_local * A_sharded.element_size()) + + state.exec( + lambda: ( + dist.all_gather(A_gathered_parts, A_sharded), + A_gathered.copy_(torch.cat(A_gathered_parts, dim=1)), + torch.mm(A_gathered, B, out=C), + ), + ) + + +@bench.register +@bench.axis("num_ranks", [2, 4, 8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def tritonblas_rccl_all_gather_matmul(state, ctx): + """RCCL all_gather + tritonBLAS matmul baseline.""" + M, N, K = state["M"], state["N"], state["K"] + dtype = state["dtype"] + world_size = dist.get_world_size() + rank = ctx.get_rank() + K_local = K // world_size + + torch.manual_seed(123 + rank) + A_sharded = ctx.randn((M, K_local), dtype=dtype) + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) + A_gathered_parts = [ctx.zeros((M, K_local), dtype=dtype) for _ in range(world_size)] + A_gathered = ctx.zeros((M, K), dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) + selector = tritonblas.OrigamiMatmulSelector( + M, + N, + K, + A_gathered.dtype, + B.dtype, + C.dtype, + A_gathered.device, + ) + config = tritonblas.matmul_preamble(selector) + + state.set_flops(2 * M * N * K) + state.set_bytes((world_size - 1) * M * K_local * A_sharded.element_size()) + + state.exec( + lambda: ( + dist.all_gather(A_gathered_parts, A_sharded), + A_gathered.copy_(torch.cat(A_gathered_parts, dim=1)), + tritonblas.matmul_lt(A_gathered, B, C, selector, config), + ), + ) + + +@bench.register +@bench.axis("num_ranks", [2, 4, 8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def all_gather_matmul_hbm_buffer(state, ctx): + """Iris HBM-buffer AG+MM with auto-tuned config from configs/ JSON files.""" + M, N, K = state["M"], state["N"], state["K"] + dtype = state["dtype"] + world_size = ctx.get_num_ranks() + rank = ctx.get_rank() + K_local = K // world_size + + result = select_ag_mm_config(M, N, K, world_size=world_size) + config = result.to_fused_config() + hbm = result.hbm_buffer_params + + torch.manual_seed(123 + rank) + A_sharded = ctx.randn((M, K_local), dtype=dtype) + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) + + workspace = all_gather_matmul_hbm_buffer_preamble( + ctx, + A_sharded, + B, + config, + k_per_flag=hbm.get("k_per_flag", 8), + ) + + state.set_flops(2 * M * N * K) + state.set_bytes((world_size - 1) * M * K_local * A_sharded.element_size()) + + state.exec( + lambda: _hbm_buffer( + ctx, + C, + A_sharded, + B, + config=config, + workspace=workspace, + num_fetch_sms=hbm.get("num_fetch_sms", 16), + k_per_flag=hbm.get("k_per_flag", 8), + fetch_block_m=hbm.get("fetch_block_m"), + fetch_block_k=hbm.get("fetch_block_k"), + num_warps=hbm.get("num_warps", 8), + num_stages=hbm.get("num_stages", 2), + num_fetch_stages=hbm.get("num_fetch_stages"), + first_stage_fetch_sms=hbm.get("first_stage_fetch_sms"), + ), + # TODO get ride of preamble + preamble_fn=lambda: (C.zero_(), workspace.locks.zero_()), ) diff --git a/benchmark/ops/bench_all_gather_matmul_copy_engine.py b/benchmark/ops/bench_all_gather_matmul_copy_engine.py new file mode 100644 index 000000000..24bf70b88 --- /dev/null +++ b/benchmark/ops/bench_all_gather_matmul_copy_engine.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +"""Benchmarks for all-gather + GEMM copy-engine variants.""" + +import os + +import torch +import iris.bench as bench + +from iris.ops import FusedConfig +from iris.ops.all_gather_matmul_copy_engine import ( + all_gather_matmul_copy_engine as _copy_engine, + all_gather_matmul_copy_engine_preamble, +) +from tritonblas.matmul import _make_matmul_selector + + +def _selector_and_config(M: int, N: int, K: int, dtype: torch.dtype, device: torch.device) -> tuple: + selector = _make_matmul_selector( + M, + N, + K, + dtype, + dtype, + dtype, + device, + streamk=False, + ) + config = FusedConfig( + block_size_m=selector.block_m, + block_size_n=selector.block_n, + block_size_k=selector.block_k, + group_size_m=selector.group_m, + num_xcds=max(1, int(getattr(selector, "num_sms", 1))), + ) + return selector, config + + +def _register_copy_engine(state, ctx, *, device_initiated: bool, host_transfer_backend: str = "anvil") -> None: + M, N, K = state["M"], state["N"], state["K"] + dtype = state["dtype"] + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + if K % world_size != 0: + state.skip(f"K={K} must be divisible by world_size={world_size}") + + K_local = K // world_size + device = torch.device(f"cuda:{torch.cuda.current_device()}") + selector, config = _selector_and_config(M, N, K, dtype, device) + + if M % config.block_size_m != 0: + state.skip(f"M={M} must be divisible by block_size_m={config.block_size_m}") + if K % config.block_size_k != 0: + state.skip(f"K={K} must be divisible by block_size_k={config.block_size_k}") + if K_local % config.block_size_k != 0: + state.skip(f"K_local={K_local} must be divisible by block_size_k={config.block_size_k}") + + m_tiles_per_batch = config.group_size_m + k_per_flag = 4 + host_transfer_backend = os.environ.get("IRIS_BENCH_HOST_TRANSFER_BACKEND", host_transfer_backend) + + A_sharded = ctx.zeros((M, K_local), dtype=dtype) + torch.manual_seed(123 + rank) + A_sharded_data = torch.randn((M, K_local), device="cuda", dtype=dtype) + A_sharded.copy_(A_sharded_data) + torch.manual_seed(456) + B = torch.randn((K, N), device="cuda", dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) + + workspace = all_gather_matmul_copy_engine_preamble( + ctx, + A_sharded, + B, + config, + k_per_flag=k_per_flag, + m_tiles_per_batch=m_tiles_per_batch, + ) + workspace.selector = selector + + flag_iteration = [0] + + def _run(): + _copy_engine( + ctx, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + flag_iteration=flag_iteration[0], + k_per_flag=k_per_flag, + m_tiles_per_batch=m_tiles_per_batch, + device_initiated=device_initiated, + host_transfer_backend=host_transfer_backend, + ) + flag_iteration[0] += 1 + + state.set_flops(2 * M * N * K) + state.set_bytes((world_size - 1) * M * K_local * A_sharded.element_size()) + state.add_counter("group_size_m", float(config.group_size_m)) + state.add_counter("m_tiles_per_batch", float(m_tiles_per_batch)) + state.add_counter("device_initiated", 1.0 if device_initiated else 0.0) + state.add_counter("host_transfer_backend_hip_memcpy", 1.0 if host_transfer_backend == "hip_memcpy" else 0.0) + + state.exec(_run, preamble_fn=lambda: C.zero_()) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def all_gather_matmul_copy_engine_host(state, ctx): + _register_copy_engine(state, ctx, device_initiated=False) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def all_gather_matmul_copy_engine_host_hip_memcpy(state, ctx): + _register_copy_engine(state, ctx, device_initiated=False, host_transfer_backend="hip_memcpy") + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def all_gather_matmul_copy_engine_device(state, ctx): + _register_copy_engine(state, ctx, device_initiated=True) + + +if __name__ == "__main__": + bench.main() diff --git a/benchmark/ops/bench_matmul.py b/benchmark/ops/bench_matmul.py new file mode 100644 index 000000000..72a183ec9 --- /dev/null +++ b/benchmark/ops/bench_matmul.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +"""Benchmarks for standalone GEMM.""" + +import torch +import iris.bench as bench + +from iris.ops import FusedConfig +from iris.ops.matmul import matmul as _matmul +from iris.ops.matmul import matmul_preamble as _matmul_preamble + + +def _register_local_matmul(state, ctx, *, m_key: str, pytorch: bool) -> None: + M, N, K = state[m_key], state["N"], state["K"] + dtype = state["dtype"] + rank = ctx.get_rank() + + state.set_flops(2 * M * N * K) + state.set_bytes(((M * K) + (K * N) + (M * N)) * torch.tensor([], dtype=dtype).element_size()) + + torch.manual_seed(123 + rank) + A_data = torch.randn((M, K), device="cuda", dtype=dtype) + torch.manual_seed(456) + B_data = torch.randn((K, N), device="cuda", dtype=dtype) + + if pytorch: + C_torch = torch.empty((M, N), device="cuda", dtype=dtype) + state.exec(lambda: torch.mm(A_data, B_data, out=C_torch)) + else: + A = ctx.zeros((M, K), dtype=dtype) + A.copy_(A_data) + C = ctx.zeros((M, N), dtype=dtype) + + workspace = _matmul_preamble(ctx, A, B_data, FusedConfig()) + state.exec( + lambda: _matmul(ctx, C, A, B_data, workspace=workspace), + preamble_fn=lambda: C.zero_(), + ) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M_local", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def matmul_only_local(state, ctx): + _register_local_matmul(state, ctx, m_key="M_local", pytorch=False) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M_local", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def pytorch_matmul_only_local(state, ctx): + _register_local_matmul(state, ctx, m_key="M_local", pytorch=True) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def matmul_only(state, ctx): + _register_local_matmul(state, ctx, m_key="M", pytorch=False) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def pytorch_matmul_only(state, ctx): + _register_local_matmul(state, ctx, m_key="M", pytorch=True) + + +if __name__ == "__main__": + bench.main() diff --git a/benchmark/ops/bench_matmul_all_gather.py b/benchmark/ops/bench_matmul_all_gather.py index fd7dff480..cd9ca5bbd 100644 --- a/benchmark/ops/bench_matmul_all_gather.py +++ b/benchmark/ops/bench_matmul_all_gather.py @@ -2,30 +2,28 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. -"""Benchmark for fused GEMM + all-gather (iris.ops).""" +"""Benchmarks for GEMM + all-gather and related baselines.""" import torch +import torch.distributed as dist +import tritonblas import iris.bench as bench + from iris.ops import FusedConfig -@bench.register -@bench.axis("num_ranks", [2, 4, 8]) -@bench.axis("M_local", [1024, 4096, 16384]) -@bench.axis("N", [3584]) -@bench.axis("K", [8192]) -@bench.axis("dtype", [torch.float16]) -def matmul_all_gather(state, ctx): +def _register_fused_matmul_all_gather(state, ctx) -> None: M_local, N, K = state["M_local"], state["N"], state["K"] dtype = state["dtype"] world_size = ctx.get_num_ranks() + rank = ctx.get_rank() M = M_local * world_size - A = ctx.zeros((M_local, K), dtype=dtype) - A.fill_(1.0) - B = torch.randn((K, N), device="cuda", dtype=dtype) + torch.manual_seed(123 + rank) + A = ctx.randn((M_local, K), dtype=dtype) + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) C = ctx.zeros((M, N), dtype=dtype) - config = FusedConfig() state.set_flops(2 * M_local * N * K) @@ -33,9 +31,100 @@ def matmul_all_gather(state, ctx): state.exec( lambda: ctx.ops.matmul_all_gather(C, A, B, config=config), - preamble_fn=lambda: C.zero_(), ) +def _register_pytorch_matmul_all_gather(state, ctx) -> None: + M_local, N, K = state["M_local"], state["N"], state["K"] + dtype = state["dtype"] + world_size = ctx.get_num_ranks() + rank = ctx.get_rank() + M = M_local * world_size + + torch.manual_seed(123 + rank) + A = ctx.randn((M_local, K), dtype=dtype) + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) + C_local = ctx.zeros((M_local, N), dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) + + state.set_flops(2 * M_local * N * K) + state.set_bytes((world_size - 1) * M_local * N * A.element_size()) + + state.exec( + lambda: ( + torch.mm(A, B, out=C_local), + dist.all_gather_into_tensor(C, C_local), + ), + ) + + +def _register_tritonblas_matmul_all_gather(state, ctx) -> None: + M_local, N, K = state["M_local"], state["N"], state["K"] + dtype = state["dtype"] + world_size = ctx.get_num_ranks() + rank = ctx.get_rank() + M = M_local * world_size + + torch.manual_seed(123 + rank) + A = ctx.randn((M_local, K), dtype=dtype) + + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) + + C_local = ctx.zeros((M_local, N), dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) + selector = tritonblas.OrigamiMatmulSelector( + M_local, + N, + K, + A.dtype, + B.dtype, + C_local.dtype, + A.device, + ) + config = tritonblas.matmul_preamble(selector) + + state.set_flops(2 * M_local * N * K) + state.set_bytes((world_size - 1) * M_local * N * A.element_size()) + + state.exec( + lambda: ( + tritonblas.matmul_lt(A, B, C_local, selector, config), + dist.all_gather_into_tensor(C, C_local), + ), + ) + + +@bench.register +@bench.axis("num_ranks", [2, 4, 8]) +@bench.axis("M_local", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def pytorch_matmul_all_gather(state, ctx): + _register_pytorch_matmul_all_gather(state, ctx) + + +@bench.register +@bench.axis("num_ranks", [2, 4, 8]) +@bench.axis("M_local", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def tritonblas_matmul_all_gather(state, ctx): + _register_tritonblas_matmul_all_gather(state, ctx) + + +@bench.register +@bench.axis("num_ranks", [2, 4, 8]) +@bench.axis("M_local", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def matmul_all_gather(state, ctx): + _register_fused_matmul_all_gather(state, ctx) + + if __name__ == "__main__": bench.main() diff --git a/benchmark/ops/bench_matmul_all_gather_copy_engine.py b/benchmark/ops/bench_matmul_all_gather_copy_engine.py new file mode 100644 index 000000000..b2d6a1fed --- /dev/null +++ b/benchmark/ops/bench_matmul_all_gather_copy_engine.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +"""Benchmarks for fused GEMM + all-gather copy-engine variants.""" + +import torch +import iris.bench as bench + +from iris.ops.matmul_all_gather_copy_engine import ( + matmul_all_gather_copy_engine as _device_copy_engine, + matmul_all_gather_copy_engine_preamble as _device_preamble, +) +from iris.ops.matmul_all_gather_host_copy_engine import ( + matmul_all_gather_host_copy_engine as _host_copy_engine, + matmul_all_gather_host_copy_engine_preamble as _host_preamble, +) +from tritonblas.matmul import _make_matmul_selector + + +def _make_selector(M_local: int, N: int, K: int, dtype: torch.dtype, device: torch.device): + return _make_matmul_selector( + M_local, + N, + K, + dtype, + dtype, + dtype, + device, + streamk=False, + ) + + +def _register_copy_engine(state, ctx, *, device_initiated: bool) -> None: + M_local, N, K = state["M_local"], state["N"], state["K"] + dtype = state["dtype"] + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + M = M_local * world_size + + device = torch.device(f"cuda:{torch.cuda.current_device()}") + selector = _make_selector(M_local, N, K, dtype, device) + + if M_local % selector.block_m != 0: + state.skip(f"M_local={M_local} must be divisible by block_size_m={selector.block_m}") + if K % selector.block_k != 0: + state.skip(f"K={K} must be divisible by block_size_k={selector.block_k}") + + torch.manual_seed(123 + rank) + A = ctx.randn((M_local, K), dtype=dtype) + torch.manual_seed(456) + B = ctx.randn((K, N), dtype=dtype) + C = ctx.zeros((M, N), dtype=dtype) + + flag_iteration = [0] + + if device_initiated: + workspace = _device_preamble( + ctx, + A, + B, + selector=selector, + ) + + def _run(): + _device_copy_engine( + ctx, + C, + A, + B, + async_op=False, + workspace=workspace, + flag_iteration=flag_iteration[0], + ) + flag_iteration[0] += 1 + + else: + workspace = _host_preamble( + ctx, + A, + B, + trace=False, + selector=selector, + ) + + def _run(): + _host_copy_engine( + ctx, + C, + A, + B, + async_op=False, + workspace=workspace, + flag_iteration=flag_iteration[0], + trace=False, + ) + flag_iteration[0] += 1 + + state.set_flops(2 * M_local * N * K) + state.set_bytes((world_size - 1) * M_local * N * A.element_size()) + state.add_counter("group_size_m", float(selector.group_m)) + state.add_counter("m_tiles_per_batch", float(workspace.m_tiles_per_batch)) + state.add_counter("device_initiated", 1.0 if device_initiated else 0.0) + + state.exec(_run) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M_local", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def matmul_all_gather_copy_engine_host(state, ctx): + _register_copy_engine(state, ctx, device_initiated=False) + + +@bench.register +@bench.axis("num_ranks", [8]) +@bench.axis("M_local", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def matmul_all_gather_copy_engine_device(state, ctx): + _register_copy_engine(state, ctx, device_initiated=True) + + +if __name__ == "__main__": + bench.main() diff --git a/benchmark/ops/compare_benchmarks.py b/benchmark/ops/compare_benchmarks.py new file mode 100755 index 000000000..d3f7ac9b7 --- /dev/null +++ b/benchmark/ops/compare_benchmarks.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +""" +Compare two benchmark JSON files and summarize tflops improvements/regressions. +Only compares entries where both have "success": true. +""" + +import json +import sys +from typing import Dict, List, Tuple + + +def load_json(filepath: str) -> List[Dict]: + """Load and parse a JSON benchmark file.""" + with open(filepath, "r") as f: + return json.load(f) + + +def get_benchmark_key(config: Dict) -> Tuple: + """Create a unique key for a benchmark configuration.""" + return (config["M"], config["N"], config["K"], config["operation"]) + + +def compare_benchmarks(baseline_file: str, new_file: str) -> None: + """Compare two benchmark files and print tflops improvements/regressions.""" + + baseline_data = load_json(baseline_file) + new_data = load_json(new_file) + + # Create lookup dictionaries + baseline_dict = {get_benchmark_key(config): config for config in baseline_data} + new_dict = {get_benchmark_key(config): config for config in new_data} + + # Find common benchmark configurations + common_keys = set(baseline_dict.keys()) & set(new_dict.keys()) + + if not common_keys: + print("No common benchmark configurations found between the two files.") + return + + print(f"\n{'=' * 100}") + print(f"BENCHMARK COMPARISON: {baseline_file} vs {new_file}") + print(f"{'=' * 100}\n") + + # Track statistics + improvements = [] + regressions = [] + no_change = [] + + # Sort keys for consistent output + sorted_keys = sorted(common_keys) + + for key in sorted_keys: + baseline_config = baseline_dict[key] + new_config = new_dict[key] + + m, n, k, operation = key + + print(f"\n{'─' * 100}") + print(f"Configuration: M={m}, N={n}, K={k}, operation={operation}") + print(f"{'─' * 100}") + + # Compare each benchmark variant + baseline_benchmarks = baseline_config.get("benchmarks", {}) + new_benchmarks = new_config.get("benchmarks", {}) + + common_variants = set(baseline_benchmarks.keys()) & set(new_benchmarks.keys()) + + if not common_variants: + print(" No common benchmark variants found.") + continue + + variant_results = [] + + for variant in sorted(common_variants): + baseline_variant = baseline_benchmarks[variant] + new_variant = new_benchmarks[variant] + + # Only compare if both have success=true + baseline_success = baseline_variant.get("success", False) + new_success = new_variant.get("success", False) + + if not (baseline_success and new_success): + status_msg = [] + if not baseline_success: + status_msg.append("baseline failed") + if not new_success: + status_msg.append("new failed") + print(f" {variant:25s} - SKIPPED ({', '.join(status_msg)})") + continue + + baseline_tflops = baseline_variant.get("tflops", 0) + new_tflops = new_variant.get("tflops", 0) + + if baseline_tflops == 0: + print(f" {variant:25s} - SKIPPED (baseline tflops is 0)") + continue + + # Calculate change + delta_tflops = new_tflops - baseline_tflops + percent_change = (delta_tflops / baseline_tflops) * 100 + + # Determine status + if abs(percent_change) < 0.01: # Less than 0.01% change + status = "→" + color_code = "" + no_change.append((key, variant, baseline_tflops, new_tflops, percent_change)) + elif percent_change > 0: + status = "↑" + color_code = "+" + improvements.append((key, variant, baseline_tflops, new_tflops, percent_change)) + else: + status = "↓" + color_code = "-" + regressions.append((key, variant, baseline_tflops, new_tflops, percent_change)) + + variant_results.append( + { + "variant": variant, + "status": status, + "baseline_tflops": baseline_tflops, + "new_tflops": new_tflops, + "delta_tflops": delta_tflops, + "percent_change": percent_change, + } + ) + + print( + f" {variant:25s} {status} {baseline_tflops:10.2f} → {new_tflops:10.2f} TFLOPs " + f"({color_code}{percent_change:+7.2f}%)" + ) + + # Print summary + print(f"\n{'=' * 100}") + print("SUMMARY") + print(f"{'=' * 100}\n") + + total_comparisons = len(improvements) + len(regressions) + len(no_change) + + print(f"Total comparisons: {total_comparisons}") + print( + f"Improvements: {len(improvements)} ({len(improvements) / total_comparisons * 100:.1f}%)" + if total_comparisons > 0 + else "Improvements: 0" + ) + print( + f"Regressions: {len(regressions)} ({len(regressions) / total_comparisons * 100:.1f}%)" + if total_comparisons > 0 + else "Regressions: 0" + ) + print( + f"No change: {len(no_change)} ({len(no_change) / total_comparisons * 100:.1f}%)" + if total_comparisons > 0 + else "No change: 0" + ) + + if improvements: + print(f"\n{'─' * 100}") + print("TOP IMPROVEMENTS:") + print(f"{'─' * 100}") + # Sort by absolute improvement + improvements.sort(key=lambda x: x[4], reverse=True) + for i, (key, variant, baseline, new, pct) in enumerate(improvements[:10], 1): + m, n, k, op = key + print( + f"{i:2d}. {variant:25s} M={m:6d} N={n:6d} K={k:6d}: {baseline:8.2f} → {new:8.2f} TFLOPs (+{pct:.2f}%)" + ) + + if regressions: + print(f"\n{'─' * 100}") + print("TOP REGRESSIONS:") + print(f"{'─' * 100}") + # Sort by absolute regression + regressions.sort(key=lambda x: x[4]) + for i, (key, variant, baseline, new, pct) in enumerate(regressions[:10], 1): + m, n, k, op = key + print(f"{i:2d}. {variant:25s} M={m:6d} N={n:6d} K={k:6d}: {baseline:8.2f} → {new:8.2f} TFLOPs ({pct:.2f}%)") + + print() + + +def main(): + if len(sys.argv) != 3: + print("Usage: python compare_benchmarks.py ") + print() + print("Example:") + print(" python compare_benchmarks.py baseline_results.json new_results.json") + sys.exit(1) + + baseline_file = sys.argv[1] + new_file = sys.argv[2] + + try: + compare_benchmarks(baseline_file, new_file) + except FileNotFoundError as e: + print(f"Error: {e}") + sys.exit(1) + except json.JSONDecodeError as e: + print(f"Error parsing JSON: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/create_sweep_markdown_table.py b/benchmark/ops/create_sweep_markdown_table.py new file mode 100644 index 000000000..f5fa78cf5 --- /dev/null +++ b/benchmark/ops/create_sweep_markdown_table.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Create a markdown table from benchmark_sweep_results.json. + +Example: + python benchmark/ops/create_sweep_markdown_table.py \ + benchmark/ops/matmul_all_gather/benchmark_sweep_results.json \ + --benchmark host_copy_engine +""" + +import argparse +import json +from pathlib import Path + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Create a markdown table from benchmark_sweep_results.json", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "results_json", + type=Path, + help="Path to benchmark_sweep_results.json", + ) + parser.add_argument( + "--benchmark", + required=True, + help="Benchmark key to extract, for example host_copy_engine or copy_engine_host", + ) + parser.add_argument( + "--pytorch-benchmark", + default="pytorchbaseline", + help="PyTorch reference benchmark key used for the speedup calculation", + ) + parser.add_argument( + "--sort-by", + choices=("shape", "label", "speedup", "tflops"), + default="shape", + help="How to sort rows in the output table", + ) + parser.add_argument( + "--descending", + action="store_true", + help="Sort in descending order", + ) + return parser.parse_args() + + +def _as_float(value): + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _shape_string(row: dict) -> str: + return f"{row.get('M')}x{row.get('N')}x{row.get('K')}" + + +def _label_string(row: dict) -> str: + label = row.get("label") + if label: + return str(label) + return _shape_string(row) + + +def _extract_table_rows(data: list[dict], benchmark_key: str, pytorch_key: str) -> list[dict]: + rows = [] + for row in data: + benchmarks = row.get("benchmarks") or {} + bench = benchmarks.get(benchmark_key) or {} + pytorch = benchmarks.get(pytorch_key) or {} + + bench_tflops = _as_float(bench.get("tflops")) + pytorch_tflops = _as_float(pytorch.get("tflops")) + + if bench_tflops is None: + continue + + speedup = None + if pytorch_tflops not in (None, 0.0): + speedup = bench_tflops / pytorch_tflops + + rows.append( + { + "shape": _shape_string(row), + "label": _label_string(row), + "speedup": speedup, + "tflops": bench_tflops, + } + ) + return rows + + +def _sort_rows(rows: list[dict], sort_by: str, descending: bool) -> list[dict]: + if sort_by == "shape": + key_fn = lambda row: tuple(int(dim) for dim in row["shape"].split("x")) + elif sort_by == "label": + key_fn = lambda row: row["label"] + elif sort_by == "speedup": + key_fn = lambda row: ( + row["speedup"] is not None, + row["speedup"] if row["speedup"] is not None else float("-inf"), + ) + else: + key_fn = lambda row: row["tflops"] + + return sorted(rows, key=key_fn, reverse=descending) + + +def _format_speedup(value) -> str: + if value is None: + return "--" + return f"{value:.2f}x" + + +def _format_tflops(value) -> str: + if value is None: + return "--" + return f"{value:.1f}" + + +def main() -> None: + args = parse_args() + + with open(args.results_json, "r") as f: + data = json.load(f) + + if not isinstance(data, list): + raise SystemExit("Expected a top-level JSON list from benchmark_sweep_results.json") + + rows = _extract_table_rows(data, args.benchmark, args.pytorch_benchmark) + rows = _sort_rows(rows, args.sort_by, args.descending) + + print("| Shape | Label | Speedup vs PyTorch | TFLOPS |") + print("|---|---|---:|---:|") + for row in rows: + print( + f"| {row['shape']} | {row['label']} | {_format_speedup(row['speedup'])} | {_format_tflops(row['tflops'])} |" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_gather/benchmark_matmul.py b/benchmark/ops/matmul_all_gather/benchmark_matmul.py new file mode 100644 index 000000000..57c930200 --- /dev/null +++ b/benchmark/ops/matmul_all_gather/benchmark_matmul.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul operation. + +This benchmark showcases the GEMM operation where each rank +computes a local matmul. +""" + +import os +import torch +import torch.distributed as dist +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops.matmul import ( + matmul, + matmul_preamble, +) +from iris.ops import FusedConfig + +# NOTE: derive_params is no longer needed since iris now uses tritonBLAS, +# which automatically selects optimal parameters via Origami heuristics. +# The block size arguments are kept for API compatibility but are ignored. + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows per rank in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Tensor datatype", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--num_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument("--block_size_m", type=int, default=None, help="Block size M (model-derived if omitted)") + parser.add_argument("--block_size_n", type=int, default=None, help="Block size N (model-derived if omitted)") + parser.add_argument("--block_size_k", type=int, default=None, help="Block size K (model-derived if omitted)") + parser.add_argument("--group_size_m", type=int, default=None, help="Group size M (model-derived if omitted)") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto if None)") + parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") + parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") + parser.add_argument( + "--output_file", + type=str, + default="matmul.json", + help="Output file", + ) + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul) for comparison", + ) + + return vars(parser.parse_args()) + + +def _worker(args: dict): + """Worker function for PyTorch distributed execution.""" + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend) + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + datatype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + datatype = datatype_map.get(args["datatype"], torch.float16) + # Note: tritonBLAS automatically selects optimal parameters via Origami + if rank == 0: + shmem.info("Using tritonBLAS backend with automatic parameter selection (Origami)") + + M = args["m"] + N = args["n"] + K = args["k"] + + # Create config + # Note: block_size_* and group_size_m are ignored by tritonBLAS backend + # tritonBLAS uses Origami to automatically select optimal parameters + config_kwargs = {} + if args["num_sms"] is not None: + config_kwargs["num_sms"] = args["num_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul") + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON + # Note: block sizes are now chosen by tritonBLAS Origami heuristics + json_writer.add_field("backend", "tritonblas") + json_writer.add_field("num_sms", config.num_sms if hasattr(config, "num_sms") else None) + json_writer.add_field("num_xcds", config.num_xcds if hasattr(config, "num_xcds") else None) + + # Create input and output tensors + # A_local is M x K, output is M x N (local matmul, no gather) + A_local = shmem.zeros((M, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + + # Fill inputs with deterministic values + # Each rank has different A_local, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M, K), dtype=datatype, device=f"cuda:{rank}") + A_local.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # Expected + expected_tensor = None + if args["validate"]: + # Plain matmul: just A_local @ B (local computation, no gather) + expected_tensor = torch.matmul(A_local_data, B_data) + + # Pre-allocate workspace + workspace = matmul_preamble(shmem, A_local, B, config) + + # ── Timing ─────────────────────────────────────────────────────────── + comm_stream = torch.cuda.Stream() + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + total_ms = 0.0 + num_experiments = 0 + + num_warps = args["num_warps"] + num_stages = args["num_stages"] + + def run_experiment(): + nonlocal total_ms, num_experiments + shmem.barrier() + + with torch.cuda.stream(comm_stream): + start_ev.record() + matmul( + shmem, + C, + A_local, + B, + config=config, + async_op=False, + workspace=workspace, + num_warps=num_warps, + num_stages=num_stages, + ) + end_ev.record() + num_experiments += 1 + shmem.barrier() + total_ms += start_ev.elapsed_time(end_ev) + + shmem.barrier() + + # ── Validate ───────────────────────────────────────────────────────── + if args["validate"]: + shmem.info("Validating...") + C.zero_() + shmem.barrier() + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + rtol = 1e-2 if datatype == torch.float16 else 1e-5 + success = torch.allclose(C, expected_tensor, atol=atol, rtol=rtol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation FAILED, max diff: {max_diff}") + else: + shmem.info("Validation PASSED!") + shmem.barrier() + json_writer.add_field("success", success) + + # ── Benchmark ──────────────────────────────────────────────────────── + if args["benchmark"]: + if args.get("single_run"): + n_warmup, n_repeat = 0, 1 + else: + n_warmup, n_repeat = 25, 100 + + # Warmup + total_ms = 0.0 + num_experiments = 0 + if n_warmup > 0: + iris.do_bench(run_experiment, shmem.barrier, n_warmup=n_warmup, n_repeat=1) + + total_ms = 0.0 + num_experiments = 0 + C.zero_() + shmem.barrier() + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=0, n_repeat=n_repeat) + avg_ms = total_ms / num_experiments if num_experiments > 0 else 0 + + total_flops = 2 * M * N * K + tflops = (total_flops * 1e-12) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + element_size = torch.tensor([], dtype=datatype).element_size() + # Plain matmul has no communication, just local compute + input_bytes = (M * K + K * N) * element_size + output_bytes = M * N * element_size + total_bytes = input_bytes + output_bytes + total_bytes_gb = total_bytes / (1024**3) + bw_gbps = (total_bytes / (1024**3)) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + + shmem.info( + f"Matmul (M={M}, N={N}, K={K}, dtype={args['datatype']}): " + f"{avg_ms:.3f} ms, {tflops:.3f} TFLOPS, {bw_gbps:.3f} GB/s (HBM)" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bw_gbps) + json_writer.add_field("avg_ms", avg_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (all_gather_into_tensor + matmul) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A = torch.randn(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + # pytorch_A_gathered = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + # dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + torch.matmul(pytorch_A, pytorch_B, out=pytorch_C) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + # Calculate bandwidth for all-gather part + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + def run_pytorch_experiment(): + # dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + torch.matmul(pytorch_A, pytorch_B, out=pytorch_C) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch all_gather_into_tensor+matmul (M={M}, K={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + print("Starting matmul benchmark...") + args = parse_args() + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + _worker(args) + else: + print( + "Please run with torchrun:\n" + " torchrun --nproc_per_node=N " + "benchmark/ops/matmul_all_gather/benchmark_matmul.py [OPTIONS]" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_gather/derive_params.py b/benchmark/ops/matmul_all_gather/derive_params.py new file mode 100644 index 000000000..f8404631b --- /dev/null +++ b/benchmark/ops/matmul_all_gather/derive_params.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Parameter derivation for matmul_all_gather_copy_engine. + +Given a problem size (M, N, K), world size, derives kernel parameters +for the fused GEMM → scatter pattern where each rank: +1. Computes local GEMM: (M_local, K) @ (K, N) → (M_local, N) +2. Scatters result tiles to all other ranks via SDMA + +Key differences from all_gather_matmul: +- SMALLER local GEMM (M_local instead of M_total) +- SCATTER communication (send results vs gather inputs) +- PERSISTENT kernel with per-tile GEMM+scatter fusion +- SDMA overhead dominates (not GEMM) + +Usage: + python derive_params.py -m 131072 -n 2048 -k 16384 + python derive_params.py -m 16384 -n 2048 -k 16384 --link_bw 50 +""" + +import argparse +import math +import time + +# ── MI300X hardware defaults (COPIED from all_gather_matmul/derive_params.py) ── +DEFAULT_NUM_CUS = 304 +DEFAULT_PEAK_TFLOPS_FP16 = 1300.0 +DEFAULT_HBM_BW_GBPS = 5300.0 +DEFAULT_L2_SIZE_BYTES = 256 * 1024 * 1024 +DEFAULT_NUM_XCDS = 8 +DEFAULT_WORLD_SIZE = 8 + +# Calibrated from MI300X trace data +DEFAULT_SCHEDULING_FACTOR = 4.5 + +# SDMA latency parameters (specific to copy engine) +SDMA_LATENCY_US = 2.0 +DEVICE_POST_OVERHEAD_US = 1.0 +HOST_POST_OVERHEAD_US = 0.5 +FLAG_POLL_LATENCY_US = 0.1 +REMOTE_WRITE_SLOWDOWN_PER_RANK = 0.05 +BIDIRECTIONAL_TRAFFIC_FACTOR = 1.5 + + +def profile_link_bandwidth(world_size=DEFAULT_WORLD_SIZE): + """Measure per-link unidirectional XGMI bandwidth. + + COPIED from all_gather_matmul/derive_params.py + """ + import torch + + n_gpus = torch.cuda.device_count() + if n_gpus < 2: + raise RuntimeError( + f"Need >= 2 visible GPUs for bandwidth profiling, found {n_gpus}. Pass --link_bw explicitly instead." + ) + + n_peers = min(world_size, n_gpus) - 1 + size_bytes = 256 * 1024 * 1024 + numel = size_bytes // 2 + warmup_iters = 10 + timed_iters = 40 + + print(f"\n── Link Bandwidth Profiling {'─' * 43}") + print(f" GPUs visible: {n_gpus}") + print(f" Testing: GPU 0 → GPUs 1..{n_peers}") + print(f" Transfer size: {size_bytes // (1024**2)} MB × {timed_iters} iterations\n") + + src = torch.empty(numel, dtype=torch.float16, device="cuda:0").normal_() + bandwidths = [] + + for peer in range(1, n_peers + 1): + dst = torch.empty(numel, dtype=torch.float16, device=f"cuda:{peer}") + + for _ in range(warmup_iters): + dst.copy_(src) + torch.cuda.synchronize(0) + torch.cuda.synchronize(peer) + + t_start = time.perf_counter() + for _ in range(timed_iters): + dst.copy_(src) + torch.cuda.synchronize(peer) + elapsed_s = time.perf_counter() - t_start + + bw = size_bytes * timed_iters / elapsed_s / 1e9 + bandwidths.append(bw) + print(f" GPU 0 → GPU {peer}: {bw:6.1f} GB/s") + + del dst + + del src + torch.cuda.empty_cache() + + bw_min = min(bandwidths) + bw_max = max(bandwidths) + bw_avg = sum(bandwidths) / len(bandwidths) + print(f"\n min = {bw_min:.1f} avg = {bw_avg:.1f} max = {bw_max:.1f} GB/s") + print(f" Using conservative (min): {bw_min:.1f} GB/s per link") + + return bw_min + + +# ── Tile / block size heuristics (COPIED from all_gather_matmul/derive_params.py) ── + + +def _choose_block_sizes(M_local, N, K): + """Heuristic tile-size selection for MI300X MFMA.""" + bk = 64 + + bm = 256 if M_local >= 8192 else 128 + while M_local % bm != 0 and bm > 64: + bm //= 2 + + if N >= 512: + bn = 256 + elif N >= 256: + bn = 256 if N % 256 == 0 else 128 + else: + bn = 128 + while N % bn != 0 and bn > 32: + bn //= 2 + + while K % bk != 0 and bk > 16: + bk //= 2 + + nw = 8 if bm * bn >= 256 * 256 else 4 + return bm, bn, bk, nw + + +# ── Per-tile roofline model (COPIED from all_gather_matmul/derive_params.py) ── + + +def _tile_roofline(bm, bn, bk, M_local, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size): + """Compute achievable per-CU TFLOPS from tile arithmetic intensity. + + For matmul_all_gather: A and B are both local (no remote reads), + so we only check if they fit in L2. + """ + tile_flops = 2 * bm * bn * bk + a_bytes = bm * bk * dtype_bytes + b_bytes = bk * bn * dtype_bytes + + a_total = M_local * K * dtype_bytes + b_total = K * N * dtype_bytes + + # Both A and B are local - check if they fit in L2 + b_in_l2 = (a_total + b_total) <= l2_size + + hbm_bytes = a_bytes + (0 if b_in_l2 else b_bytes) + intensity = tile_flops / max(hbm_bytes, 1) + + ridge = peak_tflops * 1e3 / hbm_bw_gbps + if intensity >= ridge: + roofline = peak_tflops + else: + roofline = hbm_bw_gbps * intensity / 1e3 + + return roofline, intensity, ridge, b_in_l2 + + +# ── matmul_all_gather specific models ── + + +def _gemm_wg_time_us(bm, bn, bk, K, roofline_tflops, num_cus): + """Estimate per-WG local GEMM execution time. + + Each rank does (M_local, K) @ (K, N) where M_local is local to that rank. + """ + num_k_blocks = K // bk + total_flops = 2 * bm * bn * K + per_cu_tflops = roofline_tflops / num_cus + + # Roofline-ideal per-WG time + ideal_us = total_flops / (per_cu_tflops * 1e6) + + # Single-occupancy overhead + occupancy_factor = 1.25 if bm * bn >= 256 * 256 else 1.10 + + # Signaling overhead per output tile + signal_us = 2.5 # TODO use parameters + return ideal_us * occupancy_factor + signal_us + + +def _scatter_sdma_time_us(bm, bn, world_size, link_bw, dtype_bytes): + """Estimate per-WG scatter time for one output tile. + + Scatters (bm × bn) tile to (world_size - 1) remote ranks via SDMA. + """ + tile_bytes = bm * bn * dtype_bytes + scatters_per_tile = world_size - 1 + + # Effective XGMI bandwidth with bidirectional traffic + # All ranks scatter simultaneously + effective_bw = link_bw / (math.sqrt(world_size) * BIDIRECTIONAL_TRAFFIC_FACTOR) + + # XGMI transfer time + xgmi_us = (tile_bytes * scatters_per_tile) / (effective_bw * 1e3) + + # SDMA posting overhead (iris.put per remote rank) + sdma_overhead_us = scatters_per_tile * (DEVICE_POST_OVERHEAD_US + SDMA_LATENCY_US) + + # Remote write contention + remote_write_slowdown = 1 + (world_size - 1) * REMOTE_WRITE_SLOWDOWN_PER_RANK + remote_write_us = (tile_bytes / (DEFAULT_HBM_BW_GBPS * 1e3)) * remote_write_slowdown + + # Total scatter cost (serialized) + total_us = xgmi_us + sdma_overhead_us + remote_write_us + + return total_us + + +def _estimate_kernel_time(num_tiles, gemm_wg_us, scatter_wg_us, num_cus, scheduling_factor): + """Estimate kernel time for persistent fused GEMM+scatter. + + Persistent kernel: NUM_CUS tiles in flight, each doing GEMM then scatter. + """ + # Per-tile time (serialized GEMM + scatter) + tile_time_us = gemm_wg_us + scatter_wg_us + + # Persistent kernel work queue model + total_work_us = num_tiles * tile_time_us + ideal_time_us = total_work_us / num_cus + + # Apply scheduling overhead + kernel_time_us = ideal_time_us * scheduling_factor + + return kernel_time_us + + +# ── Main derivation function ── + + +def derive( + M, + N, + K, + world_size=DEFAULT_WORLD_SIZE, + link_bw=50.0, + num_cus=DEFAULT_NUM_CUS, + peak_tflops=DEFAULT_PEAK_TFLOPS_FP16, + hbm_bw_gbps=DEFAULT_HBM_BW_GBPS, + l2_size=DEFAULT_L2_SIZE_BYTES, + scheduling_factor=DEFAULT_SCHEDULING_FACTOR, + dtype_bytes=2, +): + """Derive optimal parameters for matmul_all_gather_copy_engine. + + Args: + M, N, K: Problem dimensions for a SINGLE rank + M is M_local (sharded), total M across all ranks = M * world_size + K is full K dimension (NOT sharded) + world_size: Number of ranks + link_bw: XGMI link bandwidth (GB/s per link) + ...hardware params... + + Returns: + dict with kernel parameters and performance estimates + """ + M_local = M # Input M is already the local dimension + M_total = M_local * world_size + + # 1. Tile sizes + bm, bn, bk, nw = _choose_block_sizes(M_local, N, K) + # TODO overwrite for testing + # bm = bm // 2 + # bn = bn // 2 + # if xgmi bound set to 1 + gm = 4 # M-dimension grouping for L2 cache reuse (matches all_gather_matmul) + + # 2. Per-tile roofline + roofline_tflops, intensity, ridge, b_in_l2 = _tile_roofline( + bm, bn, bk, M_local, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size + ) + + # Number of output tiles (per rank) + num_m_tiles = M_local // bm + num_n_tiles = (N + bn - 1) // bn + total_tiles = num_m_tiles * num_n_tiles + + # M-tiles fully computer per wave of num_cus + tiles_per_group = gm * num_n_tiles + groups_completing_per_wave = num_cus // tiles_per_group + m_tiles_per_batch = groups_completing_per_wave + print(f"block: {bm}x{bn}x{bk}") + print(f"m_tiles_per_batch {m_tiles_per_batch:.2f}") + + # Calculate transfer time per group + # TODO add latency and signal overhead + bytes_per_group = tiles_per_group * bm * bn * dtype_bytes # 256 * 128 * 256 * 2 = 16777216 + transfer_time_per_group_us = bytes_per_group / (link_bw * 1e3) + print(f"transfer_time_per_group_us {transfer_time_per_group_us:.2f}") + transfer_time_per_wave_us = transfer_time_per_group_us * groups_completing_per_wave + print(f"transfer_time_per_wave_us {transfer_time_per_wave_us:.2f}") + + # Per-WG times + gemm_wg_us_val = _gemm_wg_time_us(bm, bn, bk, K, roofline_tflops, num_cus) + scatter_sdma_us_val = _scatter_sdma_time_us(bm, bn, world_size, link_bw, dtype_bytes) + + # Kernel time estimate + kernel_time_us = _estimate_kernel_time(total_tiles, gemm_wg_us_val, scatter_sdma_us_val, num_cus, scheduling_factor) + + # Sequential baseline (GEMM then separate scatter) + total_flops = 2 * M_local * N * K + gemm_only_us = (total_flops / 1e12) / (peak_tflops * 0.85) * 1e6 + + total_scatter_bytes = M_local * N * dtype_bytes * (world_size - 1) + effective_scatter_bw = link_bw / (math.sqrt(world_size) * BIDIRECTIONAL_TRAFFIC_FACTOR) + scatter_only_us = (total_scatter_bytes / 1e9) / effective_scatter_bw * 1e6 + sequential_us = gemm_only_us + scatter_only_us + + # Speedup + speedup = sequential_us / kernel_time_us + + return dict( + block_size_m=bm, + block_size_n=bn, + block_size_k=bk, + group_size_m=gm, + num_warps=nw, + device_initiated=True, # Always use device-initiated for persistent kernel + # derived + M_local=M_local, + M_total=M_total, + num_m_tiles=num_m_tiles, + num_tiles_n=num_n_tiles, + total_tiles=total_tiles, + # roofline + roofline_tflops=roofline_tflops, + tile_intensity=intensity, + ridge_point=ridge, + b_in_l2=b_in_l2, + # per-WG timing + gemm_wg_us=gemm_wg_us_val, + scatter_wg_us=scatter_sdma_us_val, + tile_time_us=gemm_wg_us_val + scatter_sdma_us_val, + # estimates + kernel_time_us=kernel_time_us, + kernel_time_ms=kernel_time_us / 1000, + sequential_us=sequential_us, + sequential_ms=sequential_us / 1000, + speedup=speedup, + ) + + +# ── CLI ── + + +def main(): + parser = argparse.ArgumentParser( + description="Derive parameters for matmul_all_gather_copy_engine", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, required=True, help="M dimension (M_local, rows per rank)") + parser.add_argument("-n", type=int, required=True, help="N dimension (columns)") + parser.add_argument("-k", type=int, required=True, help="K dimension (full K, NOT sharded)") + parser.add_argument("-w", "--world_size", type=int, default=DEFAULT_WORLD_SIZE, help="Number of ranks") + parser.add_argument("--link_bw", type=float, default=None, help="XGMI link BW (GB/s). Auto-profile if None.") + parser.add_argument("--dtype", choices=["fp16", "fp32", "bf16"], default="fp16", help="Data type") + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + + args = parser.parse_args() + + dtype_bytes = 2 if args.dtype in ["fp16", "bf16"] else 4 + + # Profile link bandwidth if not provided + if args.link_bw is None: + try: + args.link_bw = profile_link_bandwidth(args.world_size) + except Exception as e: + print(f"Auto-profiling failed ({e}), using default 50 GB/s") + args.link_bw = 50.0 + + # Derive parameters + params = derive( + args.m, + args.n, + args.k, + args.world_size, + link_bw=args.link_bw, + dtype_bytes=dtype_bytes, + ) + + # Print results + print("\n" + "=" * 80) + print("MATMUL_ALL_GATHER_COPY_ENGINE: DERIVED PARAMETERS") + print("=" * 80) + + print("\nProblem:") + print(f" M_local = {params['M_local']}, M_total = {params['M_total']}, N = {args.n}, K = {args.k}") + print(f" world_size = {args.world_size}") + print(f" Local GEMM: ({params['M_local']}, {args.k}) @ ({args.k}, {args.n})") + + print("\nDerived Kernel Parameters:") + print(f" block_size_m: {params['block_size_m']}") + print(f" block_size_n: {params['block_size_n']}") + print(f" block_size_k: {params['block_size_k']}") + print(f" num_warps: {params['num_warps']}") + print(f" device_initiated: {params['device_initiated']}") + + print("\nPerformance Model:") + print(f" Roofline TFLOPS: {params['roofline_tflops']:.1f}") + print(f" Arith. Intensity: {params['arithmetic_intensity']:.1f} FLOPs/byte") + print(f" B in L2: {params['b_in_l2']}") + print(f" GEMM per tile: {params['gemm_wg_us']:.2f} μs") + print(f" Scatter per tile: {params['scatter_wg_us']:.2f} μs") + print(f" Total per tile: {params['tile_time_us']:.2f} μs") + print(f" Total tiles: {params['total_tiles']}") + + print("\nEstimated Times:") + print(f" Kernel (fused): {params['kernel_time_ms']:.2f} ms") + print(f" Sequential: {params['sequential_ms']:.2f} ms (GEMM then scatter)") + print(f" Speedup: {params['speedup']:.2f}x") + + if params["scatter_wg_us"] > params["gemm_wg_us"]: + ratio = params["scatter_wg_us"] / params["gemm_wg_us"] + print(f"\n ⚠ Scatter dominates ({ratio:.1f}x slower than GEMM per tile)") + print(" → Communication-bound workload") + else: + print("\n ✓ GEMM dominates") + print(" → Compute-bound workload") + + print("=" * 80) + + print("\nBenchmark command:") + print(f" torchrun --nproc_per_node={args.world_size} \\") + print(" benchmark/ops/matmul_all_gather/benchmark_copy_engine.py \\") + print(f" -m {params['M_local']} -n {args.n} -k {args.k} \\") + print(f" --block_size_m {params['block_size_m']} \\") + print(f" --block_size_n {params['block_size_n']} \\") + print(f" --block_size_k {params['block_size_k']} \\") + print(" --device_initiated \\") + print(" --benchmark --validate") + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_gather/plot_sweep_results.py b/benchmark/ops/matmul_all_gather/plot_sweep_results.py new file mode 100755 index 000000000..8656ccd07 --- /dev/null +++ b/benchmark/ops/matmul_all_gather/plot_sweep_results.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Plot benchmark sweep results from benchmark_sweep_results.json. + +Creates a grouped bar chart comparing TFLOPS across different benchmark variants +and dimension configurations. + +Usage: + python plot_sweep_results.py [--input benchmark_sweep_results.json] [--output plot.png] +""" + +import json +import argparse +from pathlib import Path +from enum import Enum +import matplotlib.pyplot as plt +import numpy as np + + +class BenchmarkType(Enum): + """Canonical benchmark types.""" + + IRIS_FUSED = "iris_fused" + IRIS_OPTIMIZED = "iris_optimized" + TRITONBLAS_RCCL = "tritonblas_rccl" + TRITON_DEVICE_SDMA = "triton_device_sdma" + TRITON_HOST_SDMA = "triton_host_sdma" + TRITON_GEMM_ONLY = "triton_gemm_only" + PYTORCH_RCCL = "pytorch_rccl" + PYTORCH_GEMM_ONLY = "pytorch_gemm_only" + + +# Normalize benchmark names to canonical types +BENCHMARK_NAME_MAP = { + "baseline": BenchmarkType.IRIS_FUSED, + "hbm_buffer": BenchmarkType.IRIS_OPTIMIZED, + "tritonblas_rccl": BenchmarkType.TRITONBLAS_RCCL, + "tritonblas_rcclbaseline": BenchmarkType.TRITONBLAS_RCCL, + "device_copy_engine": BenchmarkType.TRITON_DEVICE_SDMA, + "copy_engine_device": BenchmarkType.TRITON_DEVICE_SDMA, + "copy_engine": BenchmarkType.TRITON_DEVICE_SDMA, + "host_copy_engine": BenchmarkType.TRITON_HOST_SDMA, + "copy_engine_host": BenchmarkType.TRITON_HOST_SDMA, + "matmul_only": BenchmarkType.TRITON_GEMM_ONLY, + "matmulonly": BenchmarkType.TRITON_GEMM_ONLY, + "pytorchbaseline": BenchmarkType.PYTORCH_RCCL, + "pytorch_baseline": BenchmarkType.PYTORCH_RCCL, + "pytorchmatmul_only": BenchmarkType.PYTORCH_GEMM_ONLY, + "pytorch_matmul_only": BenchmarkType.PYTORCH_GEMM_ONLY, + "pytorchmatmulonly": BenchmarkType.PYTORCH_GEMM_ONLY, +} + +# Display labels for each benchmark type +BENCHMARK_LABELS = { + BenchmarkType.IRIS_FUSED: "Iris baseline", + BenchmarkType.IRIS_OPTIMIZED: "Iris optimized fused kernel", + BenchmarkType.TRITONBLAS_RCCL: "TritonBlas + RCCL", + BenchmarkType.TRITON_DEVICE_SDMA: "TritonBlas + device-initiated SDMA", + BenchmarkType.TRITON_HOST_SDMA: "TritonBlas + host-initiated SDMA", + BenchmarkType.TRITON_GEMM_ONLY: "TritonBlas (GEMM only)", + BenchmarkType.PYTORCH_RCCL: "Pytorch + RCCL", + BenchmarkType.PYTORCH_GEMM_ONLY: "Pytorch (GEMM only)", +} + +# Colors for each benchmark type +BENCHMARK_COLORS = { + BenchmarkType.IRIS_FUSED: "#2E7D32", # Dark Green + BenchmarkType.IRIS_OPTIMIZED: "#66BB6A", # Light Green + BenchmarkType.TRITONBLAS_RCCL: "#26A69A", # Teal + BenchmarkType.TRITON_DEVICE_SDMA: "#82E8FF", # Light Blue + BenchmarkType.TRITON_HOST_SDMA: "#1976D2", # Blue + BenchmarkType.TRITON_GEMM_ONLY: "#7B1FA2", # Purple + BenchmarkType.PYTORCH_RCCL: "#F57C00", # Orange + BenchmarkType.PYTORCH_GEMM_ONLY: "#FFB74D", # Light Orange +} + +# Preferred display order +BENCHMARK_ORDER = [ + BenchmarkType.IRIS_FUSED, + BenchmarkType.IRIS_OPTIMIZED, + BenchmarkType.TRITONBLAS_RCCL, + BenchmarkType.TRITON_DEVICE_SDMA, + BenchmarkType.TRITON_HOST_SDMA, + BenchmarkType.TRITON_GEMM_ONLY, + BenchmarkType.PYTORCH_RCCL, + BenchmarkType.PYTORCH_GEMM_ONLY, +] + + +def extract_tflops(benchmark_data, benchmark_name): + """Extract TFLOPS value from benchmark result based on benchmark type.""" + if benchmark_data.get("status") == "FAILED": + return None + + # Try multiple possible TFLOPS keys in order of preference + possible_keys = [ + "tflops", + "host_copy_engine_tflops", + "pytorch_tflops", + "copy_engine_tflops", + ] + + for key in possible_keys: + value = benchmark_data.get(key) + if value is not None: + return value + + return None + + +def canonical_operation_name(operation_name): + """Map variant-specific operation names to the canonical sweep operation.""" + if not operation_name: + return None + + if "all_gather_matmul" in operation_name: + return "all_gather_matmul" + + if "matmul_all_gather" in operation_name: + return "matmul_all_gather" + + return operation_name + + +def plot_sweep_results(input_file, output_file, device="MI300X"): + """Create grouped bar chart from sweep results.""" + + # Load results + with open(input_file, "r") as f: + results = json.load(f) + + if not results: + print("No results found in input file") + return + + # Detect operation type from the top-level row first, then fall back to + # benchmark-specific operation names for older result files. + operation = "matmul_all_gather" # default + for result in results: + top_level_operation = canonical_operation_name(result.get("operation")) + if top_level_operation: + operation = top_level_operation + break + + if "benchmarks" in result: + for bench_data in result["benchmarks"].values(): + if "operation" in bench_data: + operation = canonical_operation_name(bench_data["operation"]) + break + if operation != "matmul_all_gather": + break + + # Extract dimension configurations and normalize benchmark names + dim_configs = [] + benchmark_types = set() + + for result in results: + m, n, k = result["M"], result["N"], result["K"] + dim_label = f"{m}×{n}×{k}" + dim_configs.append((m, n, k, dim_label)) + + # Normalize benchmark names to canonical types + for bench_name in result["benchmarks"].keys(): + if bench_name in BENCHMARK_NAME_MAP: + benchmark_types.add(BENCHMARK_NAME_MAP[bench_name]) + + # Sort by preferred order + def sort_key(bench_type): + try: + return BENCHMARK_ORDER.index(bench_type) + except ValueError: + return len(BENCHMARK_ORDER) # Put unknowns at the end + + benchmark_types = sorted(benchmark_types, key=sort_key) + + # Prepare data for plotting - organize by benchmark type + data = {bench_type: [] for bench_type in benchmark_types} + + for result in results: + for bench_type in benchmark_types: + # Find all raw benchmark names that map to this type + matching_names = [ + name + for name, btype in BENCHMARK_NAME_MAP.items() + if btype == bench_type and name in result["benchmarks"] + ] + + if matching_names: + # Use the first matching benchmark name + bench_name = matching_names[0] + tflops = extract_tflops(result["benchmarks"][bench_name], bench_name) + data[bench_type].append(tflops if tflops is not None else 0) + else: + data[bench_type].append(0) + + # Create plot + fig, ax = plt.subplots(figsize=(14, 8)) + + x = np.arange(len(dim_configs)) + width = 0.8 / len(benchmark_types) # Width of bars + + # Plot bars for each benchmark type + for i, bench_type in enumerate(benchmark_types): + offset = width * (i - len(benchmark_types) / 2 + 0.5) + values = data[bench_type] + color = BENCHMARK_COLORS.get(bench_type, f"C{i}") + display_label = BENCHMARK_LABELS.get(bench_type, str(bench_type)) + + bars = ax.bar( + x + offset, + values, + width, + label=display_label, + color=color, + alpha=0.8, + edgecolor="black", + linewidth=0.5, + ) + + # Add value labels on top of bars (only for non-zero values) + for j, (bar, val) in enumerate(zip(bars, values)): + if val > 0: + height = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2, + height, + f"{val:.1f}", + ha="center", + va="bottom", + fontsize=7, + rotation=0, + ) + + # Customize plot + ax.set_xlabel("Dimension Configuration (M×N×K)", fontsize=12, fontweight="bold") + ax.set_ylabel("TFLOPS", fontsize=12, fontweight="bold") + # Set title based on operation type + if operation == "all_gather_matmul": + title = f"All-Gather-Matmul Benchmark Sweep: TFLOPS Comparison ({device})" + else: + title = f"Matmul-All-Gather Benchmark Sweep: TFLOPS Comparison ({device})" + ax.set_title(title, fontsize=14, fontweight="bold", pad=20) + ax.set_xticks(x) + ax.set_xticklabels([label for _, _, _, label in dim_configs], rotation=45, ha="right") + ax.legend(loc="upper left", fontsize=10) + ax.grid(axis="y", alpha=0.3, linestyle="--") + + # Set y-axis to start from 0 + ax.set_ylim(bottom=0) + + plt.tight_layout() + + # Save figure + plt.savefig(output_file, dpi=300, bbox_inches="tight") + print(f"Plot saved to: {output_file}") + + # Print summary statistics in formatted table + print("\n" + "=" * 120) + print("SUMMARY STATISTICS") + print("=" * 120) + print() + + # Build table data + table_data = [] + for dim_config in dim_configs: + m, n, k, label = dim_config + idx = dim_configs.index(dim_config) + + row = { + "M": m, + "N": n, + "K": k, + "Config": label, + } + + # Add TFLOPS values and validation status for each benchmark type + for bench_type in benchmark_types: + val = data[bench_type][idx] + label = BENCHMARK_LABELS.get(bench_type, str(bench_type)) + + if val > 0: + row[label] = f"{val:.2f}" + else: + row[label] = "FAILED" + + # Add validation status column - find the raw benchmark name + matching_names = [ + name + for name, btype in BENCHMARK_NAME_MAP.items() + if btype == bench_type and name in results[idx]["benchmarks"] + ] + + validation_key = f"{label}_validation" + if matching_names: + bench_data = results[idx]["benchmarks"].get(matching_names[0], {}) + if bench_data.get("success") is True: + row[validation_key] = "✓" + elif bench_data.get("success") is False: + row[validation_key] = "✗" + else: + row[validation_key] = "-" + else: + row[validation_key] = "-" + + # Find best performer + valid_values = [ + (bench_type, data[bench_type][idx]) for bench_type in benchmark_types if data[bench_type][idx] > 0 + ] + if valid_values: + best_type, best_val = max(valid_values, key=lambda x: x[1]) + best_label = BENCHMARK_LABELS.get(best_type, str(best_type)) + row["Best"] = f"{best_label} ({best_val:.2f})" + else: + row["Best"] = "NONE" + + table_data.append(row) + + # Calculate column widths + benchmark_labels = [BENCHMARK_LABELS.get(bt, str(bt)) for bt in benchmark_types] + validation_headers = [f"{label}_validation" for label in benchmark_labels] + headers = ["M", "N", "K", "Config"] + benchmark_labels + validation_headers + ["Best"] + col_widths = {} + for header in headers: + col_widths[header] = len(header) + for row in table_data: + col_widths[header] = max(col_widths[header], len(str(row[header]))) + + # Print header with cleaner labels + header_parts = [] + for h in headers: + # Format validation headers as "bench (✓/✗)" + if h.endswith("_validation"): + bench_name = h.replace("_validation", "") + display_name = f"{bench_name} (val)" + else: + display_name = h + header_parts.append(display_name.ljust(col_widths[h])) + print(" ".join(header_parts)) + + # Print separator + sep_parts = [] + for h in headers: + sep_parts.append("-" * col_widths[h]) + print(" ".join(sep_parts)) + + # Print rows + for row in table_data: + row_parts = [] + for h in headers: + row_parts.append(str(row[h]).ljust(col_widths[h])) + print(" ".join(row_parts)) + + print("\n" + "=" * 120) + + +def main(): + parser = argparse.ArgumentParser( + description="Plot benchmark sweep results", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--input", + type=str, + default="benchmark/ops/matmul_all_gather/benchmark_sweep_results.json", + help="Input JSON file with sweep results", + ) + parser.add_argument( + "--output", + type=str, + default="benchmark/ops/matmul_all_gather/sweep_results_plot.png", + help="Output PNG file for plot", + ) + parser.add_argument( + "--device", + type=str, + default="MI300X", + help="Device name to display in plot title", + ) + + args = parser.parse_args() + + input_path = Path(args.input) + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + return 1 + + plot_sweep_results(input_path, args.output, args.device) + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/benchmark/ops/matmul_all_gather/print_tune_copy_engine_summary.py b/benchmark/ops/matmul_all_gather/print_tune_copy_engine_summary.py new file mode 100644 index 000000000..1dc6d89eb --- /dev/null +++ b/benchmark/ops/matmul_all_gather/print_tune_copy_engine_summary.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Print a per-shape summary from tune_copy_engine aggregated results. + +The input is the top-level ``results.json`` produced by +``benchmark/ops/matmul_all_gather/tune_copy_engine.py``. +""" + +import argparse +import json +from collections import defaultdict +from pathlib import Path + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Print per-shape summaries from tune_copy_engine results.json", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "results_json", + type=Path, + help="Path to the aggregated results.json emitted by tune_copy_engine.py", + ) + return parser.parse_args() + + +def _load_results(path: Path): + with open(path, "r") as f: + return json.load(f) + + +def _as_float(value): + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _format_float(value, precision): + numeric = _as_float(value) + if numeric is None: + return "--" + return f"{numeric:.{precision}f}" + + +def _sort_key(result): + tflops = _as_float(result.get("iris_tflops")) + return (tflops is not None, tflops if tflops is not None else float("-inf")) + + +def _record_field(result, key): + benchmark_json = result.get("benchmark_json") or {} + if key in benchmark_json: + return benchmark_json.get(key) + return result.get(key) + + +def _print_shape_summary(shape_tag, records): + records = sorted(records, key=_sort_key, reverse=True) + shape = records[0].get("shape", {}) + print("=" * 170) + print(f"{shape_tag} | M_local={shape.get('m_local')} N={shape.get('n')} K={shape.get('k')}") + print("=" * 170) + print( + f"{'#':>3} {'m_tiles_per_batch':>17} {'ms':>9} {'TFLOPS':>9} {'Valid':>8} " + f"{'batches':>7} {'group_m':>7} {'tiles_m':>7} {'tiles_n':>7} {'tiles/wave':>10} " + f"{'tiles/1st':>10} {'iters':>7} " + f"{'block_m':>7} {'block_n':>7} {'block_k':>7} " + f"{'tiles/group':>11} {'gemm_us':>9} {'scatter_us':>11} {'ratio':>8}" + ) + print("-" * 170) + + for idx, record in enumerate(records, start=1): + print( + f"{idx:>3} " + f"{str(_record_field(record, 'm_tiles_per_batch')):>17} " + f"{_format_float(record.get('iris_ms'), 3):>9} " + f"{_format_float(record.get('iris_tflops'), 2):>9} " + f"{str(record.get('validation') or '--'):>8} " + f"{str(_record_field(record, 'num_batches') or '--'):>7} " + f"{str(_record_field(record, 'group_size_m') or '--'):>7} " + f"{str(_record_field(record, 'num_tiles_m') or '--'):>7} " + f"{str(_record_field(record, 'num_tiles_n') or '--'):>7} " + f"{str(_record_field(record, 'm_tiles_per_wave') or '--'):>10} " + f"{str(_record_field(record, 'm_tiles_first_wave') or '--'):>10} " + f"{str(_record_field(record, 'schedule_iterations') or '--'):>7} " + f"{str(_record_field(record, 'block_size_m') or _record_field(record, 'output_tile_size_m') or '--'):>7} " + f"{str(_record_field(record, 'block_size_n') or _record_field(record, 'output_tile_size_n') or '--'):>7} " + f"{str(_record_field(record, 'block_size_k') or _record_field(record, 'output_tile_size_k') or '--'):>7} " + f"{str(_record_field(record, 'tiles_per_group') or '--'):>11} " + f"{_format_float(_record_field(record, 'gemm_wg_us'), 2):>9} " + f"{_format_float(_record_field(record, 'scatter_wg_us'), 2):>11} " + f"{_format_float(_record_field(record, 'ratio'), 3):>8}" + ) + print("") + + +def main(): + args = parse_args() + data = _load_results(args.results_json) + results = data.get("results", []) + + if not results: + raise SystemExit("No results found in input JSON") + + by_shape = defaultdict(list) + for result in results: + by_shape[result.get("shape_tag", "UNKNOWN_SHAPE")].append(result) + + print(f"Input: {args.results_json}") + print(f"Shapes: {len(by_shape)}") + print(f"Runs: {len(results)}") + print("") + + for shape_tag in sorted(by_shape): + _print_shape_summary(shape_tag, by_shape[shape_tag]) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_gather/tune_copy_engine.py b/benchmark/ops/matmul_all_gather/tune_copy_engine.py new file mode 100644 index 000000000..426b80b33 --- /dev/null +++ b/benchmark/ops/matmul_all_gather/tune_copy_engine.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tune ``m_tiles_per_batch`` for ``matmul_all_gather_copy_engine``. + +Unlike the more general matmul tuners, this script keeps the GEMM tile geometry +under the tritonBLAS selector and only sweeps the batch size used by the +host/device copy-engine path. +""" + +import argparse +import importlib.util +import json +import math +import os +import subprocess +import sys +import time +from datetime import datetime +from pathlib import Path + +import torch + +from tritonblas.matmul import _make_matmul_selector + + +BENCHMARK_TARGETS = { + "device": { + "script": "benchmark/ops/matmul_all_gather/benchmark_copy_engine.py", + "operation_name": "matmul_all_gather_copy_engine", + "output_stem": "tune_copy_engine", + }, + "host": { + "script": "benchmark/ops/matmul_all_gather/benchmark_host_copy_engine.py", + "operation_name": "matmul_all_gather_host_copy_engine", + "output_stem": "tune_host_copy_engine", + }, +} + + +def _load_sweep_dimension_configs(): + """Load the shared sweep dimension list from benchmark/ops/sweep_benchmarks.py.""" + sweep_path = Path(__file__).resolve().parents[1] / "sweep_benchmarks.py" + module_name = "_shared_sweep_benchmarks" + spec = importlib.util.spec_from_file_location(module_name, sweep_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load sweep benchmark config from {sweep_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + dimension_configs = [] + for config in module.DIMENSION_CONFIGS: + if isinstance(config, dict): + dimension_configs.append((config["m_local"], config["n"], config["k"])) + else: + dimension_configs.append(tuple(config)) + return dimension_configs + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Tune m_tiles_per_batch for matmul_all_gather_copy_engine.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="M_local dimension") + parser.add_argument("-n", type=int, default=2048, help="N dimension") + parser.add_argument("-k", type=int, default=131072, help="K dimension") + parser.add_argument( + "--benchmark", + type=str, + default="device", + choices=sorted(BENCHMARK_TARGETS.keys()), + help="Which copy-engine benchmark to tune", + ) + parser.add_argument( + "--use_sweep_dimensions", + action="store_true", + default=True, + help="Use the shared dimension list from benchmark/ops/sweep_benchmarks.py", + ) + parser.add_argument( + "--single_shape", + dest="use_sweep_dimensions", + action="store_false", + help="Only tune the single shape given by -m/-n/-k", + ) + parser.add_argument("--nproc", type=int, default=8, help="Number of ranks / GPUs") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype passed through to the benchmark", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size passed to the benchmark") + parser.add_argument("--num_sms", type=int, default=None, help="Optional NUM_SMS override for the benchmark") + parser.add_argument("--num_xcds", type=int, default=None, help="Optional NUM_XCDS override for the benchmark") + parser.add_argument( + "--m_tiles_per_batch", + type=int, + nargs="+", + default=None, + help="Explicit sweep values. If omitted, derive a candidate set from selector geometry.", + ) + parser.add_argument( + "--all_values", + action="store_true", + help="Sweep every value from 1..num_m_tiles instead of the heuristic candidate set", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory (auto-generated if unset)") + parser.add_argument("--dry_run", action="store_true", help="Print the candidate list and exit") + parser.add_argument("--skip_validation", action="store_true", help="Skip validation for faster sweeps") + parser.add_argument("--timeout", type=int, default=600, help="Per-run timeout in seconds") + return parser.parse_args() + + +def _dtype_from_name(name: str) -> torch.dtype: + return { + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + }[name] + + +def _selector_metadata(m_local: int, n: int, k: int, dtype: torch.dtype): + device = torch.device("cuda:0") + selector = _make_matmul_selector( + m_local, + n, + k, + dtype, + dtype, + dtype, + device, + streamk=False, + ) + + block_m = selector.block_m + block_n = selector.block_n + block_k = selector.block_k + group_size_m = selector.group_m + num_stages = getattr(selector, "num_stages", 2) + waves_per_eu = getattr(selector, "waves_per_eu", 0) + active_cus = getattr(selector, "_ACTIVE_CU", None) + if active_cus is None: + active_cus = getattr(selector._hardware, "N_CU", getattr(selector._hardware, "NUM_XCD", 1)) + + num_tiles_m = math.ceil(m_local / block_m) + num_tiles_n = math.ceil(n / block_n) + tiles_per_group = max(1, group_size_m * num_tiles_n) + groups_per_wave = max(1, active_cus // tiles_per_group) + m_tiles_per_wave = min(num_tiles_m, groups_per_wave * group_size_m) + + return { + "selector": selector, + "block_size_m": block_m, + "block_size_n": block_n, + "block_size_k": block_k, + "group_size_m": group_size_m, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "active_cus": active_cus, + "num_tiles_m": num_tiles_m, + "num_tiles_n": num_tiles_n, + "tiles_per_group": tiles_per_group, + "groups_per_wave": groups_per_wave, + "m_tiles_per_wave": m_tiles_per_wave, + } + + +def _candidate_values( + num_tiles_m: int, group_size_m: int, groups_per_wave: int, m_tiles_per_wave: int, sweep_all: bool +): + if sweep_all: + return list(range(1, num_tiles_m + 1)) + + values = {1, num_tiles_m} + + power = 1 + while power <= num_tiles_m: + values.add(power) + power *= 2 + + # Keep a small number of shape-aware anchors even in sparse mode. + for candidate in (group_size_m, groups_per_wave, m_tiles_per_wave): + if 1 <= candidate <= num_tiles_m: + values.add(candidate) + + return sorted(values) + + +def _build_command(args, output_path: str, m_tiles_per_batch: int): + target = BENCHMARK_TARGETS[args.benchmark] + cmd = [ + "torchrun", + "--nproc_per_node", + str(args.nproc), + target["script"], + "-m", + str(args.m), + "-n", + str(args.n), + "-k", + str(args.k), + "--datatype", + args.datatype, + "--heap_size", + str(args.heap_size), + "--m_tiles_per_batch", + str(m_tiles_per_batch), + "--output_file", + output_path, + "-b", + ] + + if not args.skip_validation: + cmd.append("-v") + if args.num_sms is not None: + cmd.extend(["--num_sms", str(args.num_sms)]) + if args.num_xcds is not None: + cmd.extend(["--num_xcds", str(args.num_xcds)]) + + return cmd + + +def _parse_json_output(json_path: Path): + result = { + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": None, + "group_size_m": None, + "block_size_m": None, + "block_size_n": None, + "block_size_k": None, + "m_tiles_per_batch": None, + "output_tile_size_m": None, + "output_tile_size_n": None, + "output_tile_size_k": None, + "num_stages": None, + "waves_per_eu": None, + "active_cus": None, + "num_tiles_m": None, + "num_tiles_n": None, + "tiles_per_group": None, + "groups_per_wave": None, + "m_tiles_per_wave": None, + "num_batches": None, + "last_batch_m_tiles": None, + "m_tiles_per_batch_over_wave": None, + "gemm_wg_us": None, + "scatter_wg_us": None, + "bottleneck": None, + "ratio": None, + "roofline_tflops": None, + "intensity": None, + } + + try: + with open(json_path, "r") as f: + data = json.load(f) + + result["iris_ms"] = data.get("avg_ms") + result["iris_tflops"] = data.get("tflops") + result["iris_bw_gbps"] = data.get("bandwidth_gbps") + result["validation"] = "PASSED" if data.get("success") is True else ("FAILED" if "success" in data else None) + result["group_size_m"] = data.get("group_size_m") + result["block_size_m"] = data.get("block_size_m") + result["block_size_n"] = data.get("block_size_n") + result["block_size_k"] = data.get("block_size_k") + result["m_tiles_per_batch"] = data.get("m_tiles_per_batch") + result["output_tile_size_m"] = data.get("output_tile_size_m", data.get("block_size_m")) + result["output_tile_size_n"] = data.get("output_tile_size_n", data.get("block_size_n")) + result["output_tile_size_k"] = data.get("output_tile_size_k", data.get("block_size_k")) + result["num_stages"] = data.get("num_stages") + result["waves_per_eu"] = data.get("waves_per_eu") + result["active_cus"] = data.get("active_cus") + result["num_tiles_m"] = data.get("num_tiles_m") + result["num_tiles_n"] = data.get("num_tiles_n") + result["tiles_per_group"] = data.get("tiles_per_group") + result["groups_per_wave"] = data.get("groups_per_wave") + result["m_tiles_per_wave"] = data.get("m_tiles_per_wave") + result["num_batches"] = data.get("num_batches") + result["last_batch_m_tiles"] = data.get("last_batch_m_tiles") + result["m_tiles_per_batch_over_wave"] = data.get("m_tiles_per_batch_over_wave") + result["gemm_wg_us"] = data.get("gemm_wg_us") + result["scatter_wg_us"] = data.get("scatter_wg_us") + result["bottleneck"] = data.get("bottleneck") + result["ratio"] = data.get("ratio") + result["roofline_tflops"] = data.get("roofline_tflops") + result["intensity"] = data.get("intensity") + except Exception: + pass + + return result + + +def _print_selector_summary(meta, candidates): + tile_shape = f"{meta['block_size_m']}x{meta['block_size_n']}x{meta['block_size_k']}" + print("\nSelector-derived geometry") + print(f" output tile size : {tile_shape}") + print(f" group_size_m : {meta['group_size_m']}") + print(f" num_stages : {meta['num_stages']}") + print(f" waves_per_eu : {meta['waves_per_eu']}") + print(f" active CUs : {meta['active_cus']}") + print(f" tile grid : {meta['num_tiles_m']} M-tiles x {meta['num_tiles_n']} N-tiles") + print(f" tiles per group : {meta['tiles_per_group']}") + print(f" groups per wave/stage : {meta['groups_per_wave']}") + print(f" M-tiles per wave/stage : {meta['m_tiles_per_wave']}") + print(f" sweep m_tiles_per_batch : {candidates}") + + +def _shape_tag(m_local: int, n: int, k: int): + return f"M{m_local}_N{n}_K{k}" + + +def main(): + args = parse_args() + dtype = _dtype_from_name(args.datatype) + if args.use_sweep_dimensions: + dimension_configs = _load_sweep_dimension_configs() + else: + dimension_configs = [(args.m, args.n, args.k)] + + if args.output_dir: + output_dir = Path(args.output_dir) + else: + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + output_stem = BENCHMARK_TARGETS[args.benchmark]["output_stem"] + output_dir = Path(f"benchmark/ops/matmul_all_gather/{output_stem}_{ts}") + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n{'=' * 100}") + print(f" {BENCHMARK_TARGETS[args.benchmark]['operation_name']} — m_tiles_per_batch tuning") + if args.use_sweep_dimensions: + print(f" Shapes: sweep_benchmarks.py DIMENSION_CONFIGS ({len(dimension_configs)} shapes)") + else: + print(f" M_local={args.m} N={args.n} K={args.k} nproc={args.nproc} dtype={args.datatype}") + print(f" benchmark={args.benchmark} nproc={args.nproc} dtype={args.datatype}") + print(f" Output dir: {output_dir}") + print(f" Validation: {'OFF' if args.skip_validation else 'ON'}") + print(f"{'=' * 100}") + + if args.dry_run: + print("") + for m_local, n, k in dimension_configs: + meta = _selector_metadata(m_local, n, k, dtype) + if args.m_tiles_per_batch is not None: + candidates = sorted({value for value in args.m_tiles_per_batch if 1 <= value <= meta["num_tiles_m"]}) + else: + candidates = _candidate_values( + meta["num_tiles_m"], + meta["group_size_m"], + meta["groups_per_wave"], + meta["m_tiles_per_wave"], + args.all_values, + ) + print(f"Shape {_shape_tag(m_local, n, k)}") + _print_selector_summary(meta, candidates) + print("") + print("Dry run only; no benchmarks launched.") + return + + env = os.environ.copy() + env["HSA_NO_SCRATCH_RECLAIM"] = "1" + + results = [] + total_start = time.time() + + for shape_idx, (m_local, n, k) in enumerate(dimension_configs, start=1): + meta = _selector_metadata(m_local, n, k, dtype) + if args.m_tiles_per_batch is not None: + candidates = sorted({value for value in args.m_tiles_per_batch if 1 <= value <= meta["num_tiles_m"]}) + else: + candidates = _candidate_values( + meta["num_tiles_m"], + meta["group_size_m"], + meta["groups_per_wave"], + meta["m_tiles_per_wave"], + args.all_values, + ) + + if not candidates: + raise ValueError(f"No valid m_tiles_per_batch values to test for shape {_shape_tag(m_local, n, k)}") + + shape_tag = _shape_tag(m_local, n, k) + shape_output_dir = output_dir / shape_tag + shape_output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n{'=' * 100}") + print(f"[{shape_idx}/{len(dimension_configs)}] Shape {shape_tag}") + _print_selector_summary(meta, candidates) + + for idx, m_tiles_per_batch in enumerate(candidates, start=1): + label = f"{shape_tag}_mtpb{m_tiles_per_batch}" + json_path = shape_output_dir / f"results_mtpb{m_tiles_per_batch}.json" + log_path = shape_output_dir / f"log_mtpb{m_tiles_per_batch}.txt" + cmd_args = argparse.Namespace(**vars(args)) + cmd_args.m = m_local + cmd_args.n = n + cmd_args.k = k + cmd = _build_command(cmd_args, str(json_path), m_tiles_per_batch) + cmd_str = " ".join(cmd) + + print(f"\n{'-' * 80}") + print(f"[{idx}/{len(candidates)}] m_tiles_per_batch={m_tiles_per_batch}") + print(f" $ HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}") + + started = time.time() + try: + proc = subprocess.run( + cmd, + env=env, + capture_output=True, + text=True, + timeout=args.timeout, + ) + elapsed = time.time() - started + parsed = _parse_json_output(json_path) + json_ok = json_path.exists() + + results.append( + { + "shape": {"m_local": m_local, "n": n, "k": k}, + "shape_tag": shape_tag, + "label": label, + "m_tiles_per_batch": m_tiles_per_batch, + "iris_ms": parsed["iris_ms"], + "iris_tflops": parsed["iris_tflops"], + "iris_bw_gbps": parsed["iris_bw_gbps"], + "validation": parsed["validation"], + "benchmark_json": parsed, + "returncode": proc.returncode, + "elapsed_s": round(elapsed, 1), + "json_path": str(json_path) if json_ok else None, + } + ) + + summary = [] + if parsed["iris_tflops"] is not None: + summary.append(f"{parsed['iris_tflops']:.2f} TFLOPS") + if parsed["iris_ms"] is not None: + summary.append(f"{parsed['iris_ms']:.3f} ms") + if parsed["iris_bw_gbps"] is not None: + summary.append(f"{parsed['iris_bw_gbps']:.1f} GB/s") + if parsed["validation"] is not None: + summary.append(f"valid={parsed['validation']}") + summary.append("json=OK" if json_ok else "json=MISSING") + if proc.returncode != 0: + summary.append(f"EXIT={proc.returncode}") + print(f" => {' | '.join(summary)} ({elapsed:.0f}s)") + + with open(log_path, "w") as f: + f.write(f"COMMAND: HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}\n") + f.write(f"EXIT CODE: {proc.returncode}\n") + f.write(f"ELAPSED: {elapsed:.1f}s\n\n") + f.write("=== STDOUT ===\n") + f.write(proc.stdout) + f.write("\n=== STDERR ===\n") + f.write(proc.stderr) + + except subprocess.TimeoutExpired as exc: + elapsed = time.time() - started + results.append( + { + "shape": {"m_local": m_local, "n": n, "k": k}, + "shape_tag": shape_tag, + "label": label, + "m_tiles_per_batch": m_tiles_per_batch, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": "TIMEOUT", + "benchmark_json": {}, + "returncode": -1, + "elapsed_s": round(elapsed, 1), + "json_path": None, + } + ) + print(f" => TIMEOUT after {args.timeout}s") + with open(log_path, "w") as f: + f.write(f"COMMAND: HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}\n") + f.write(f"TIMEOUT: {args.timeout}s\n\n") + f.write(getattr(exc, "stdout", "") or "") + f.write("\n") + f.write(getattr(exc, "stderr", "") or "") + + total_elapsed = time.time() - total_start + + print(f"\n{'=' * 112}") + print(f" TUNING RESULTS | {len(dimension_configs)} shapes | {len(results)} runs in {total_elapsed:.0f}s") + print(f"{'=' * 112}") + print( + f" {'#':>3} {'Shape':<24} {'m_tiles_per_batch':>17} {'ms':>8} {'TFLOPS':>8} " + f"{'GB/s':>8} {'Valid':>8} {'JSON':>4}" + ) + print(f" {'-' * 108}") + + for idx, result in enumerate(results, start=1): + ms_s = f"{result['iris_ms']:.3f}" if result["iris_ms"] is not None else "--" + tf_s = f"{result['iris_tflops']:.2f}" if result["iris_tflops"] is not None else "--" + bw_s = f"{result['iris_bw_gbps']:.1f}" if result["iris_bw_gbps"] is not None else "--" + valid_s = (result["validation"] or "--")[:8] + json_s = "Y" if result["json_path"] else "N" + best_tag = "" + if result["iris_tflops"] is not None: + best_value = max((x["iris_tflops"] for x in results if x["iris_tflops"] is not None), default=None) + if best_value is not None and result["iris_tflops"] == best_value: + best_tag = " *" + + print( + f" {idx:>3} {result['shape_tag']:<24} {result['m_tiles_per_batch']:>17} {ms_s:>8} {tf_s:>8} " + f"{bw_s:>8} {valid_s:>8} {json_s:>4}{best_tag}" + ) + + valid_results = [result for result in results if result["iris_tflops"] is not None] + if valid_results: + best = max(valid_results, key=lambda result: result["iris_tflops"]) + worst = min(valid_results, key=lambda result: result["iris_tflops"]) + best_json = best["benchmark_json"] + tile_m = best_json.get("output_tile_size_m") or meta["block_size_m"] + tile_n = best_json.get("output_tile_size_n") or meta["block_size_n"] + tile_k = best_json.get("output_tile_size_k") or meta["block_size_k"] + best_group_size_m = best_json.get("group_size_m") or meta["group_size_m"] + best_m_tiles_per_wave = best_json.get("m_tiles_per_wave") or meta["m_tiles_per_wave"] + tile_shape = f"{tile_m}x{tile_n}x{tile_k}" + + print("\nBest configuration") + print(f" m_tiles_per_batch : {best['m_tiles_per_batch']}") + print(f" avg_ms : {best['iris_ms']:.3f}") + print(f" tflops : {best['iris_tflops']:.2f}") + print(f" bandwidth_gbps : {best['iris_bw_gbps']:.1f}") + print(f" output tile size : {tile_shape}") + print(f" group_size_m : {best_group_size_m}") + print(f" M-tiles per wave/stage : {best_m_tiles_per_wave}") + if best_json.get("groups_per_wave") is not None: + print(f" groups per wave/stage : {best_json['groups_per_wave']}") + if best_json.get("num_batches") is not None: + print(f" num_batches : {best_json['num_batches']}") + if best_json.get("last_batch_m_tiles") is not None: + print(f" last_batch_m_tiles : {best_json['last_batch_m_tiles']}") + if best_json.get("ratio") is not None: + print(f" scatter/gemm ratio : {best_json['ratio']:.2f}x") + if best_json.get("bottleneck") is not None: + print(f" bottleneck : {best_json['bottleneck']}") + + print("\nSpread") + print( + f" best : {best['iris_tflops']:.2f} TFLOPS @ m_tiles_per_batch={best['m_tiles_per_batch']}" + ) + print( + f" worst : {worst['iris_tflops']:.2f} TFLOPS @ m_tiles_per_batch={worst['m_tiles_per_batch']}" + ) + if worst["iris_tflops"] and best["iris_tflops"]: + print(f" best / worst : {best['iris_tflops'] / worst['iris_tflops']:.2f}x") + + results_path = output_dir / "results.json" + with open(results_path, "w") as f: + json.dump( + { + "meta": { + "dimension_configs": [{"m_local": m_local, "n": n, "k": k} for m_local, n, k in dimension_configs], + "use_sweep_dimensions": args.use_sweep_dimensions, + "benchmark": args.benchmark, + "benchmark_script": BENCHMARK_TARGETS[args.benchmark]["script"], + "nproc": args.nproc, + "datatype": args.datatype, + "timestamp": datetime.now().isoformat(), + "total_elapsed_s": round(total_elapsed, 1), + "candidate_generation": { + "block_size_m": meta["block_size_m"], + "block_size_n": meta["block_size_n"], + "block_size_k": meta["block_size_k"], + "group_size_m": meta["group_size_m"], + "num_stages": meta["num_stages"], + "waves_per_eu": meta["waves_per_eu"], + "active_cus": meta["active_cus"], + "num_tiles_m": meta["num_tiles_m"], + "num_tiles_n": meta["num_tiles_n"], + "tiles_per_group": meta["tiles_per_group"], + "groups_per_wave": meta["groups_per_wave"], + "m_tiles_per_wave": meta["m_tiles_per_wave"], + }, + "candidates": candidates, + }, + "results": results, + }, + f, + indent=2, + ) + + print(f"\nSummary JSON : {results_path}") + print(f"Per-run JSON : {output_dir}/results_*.json") + print(f"Per-run logs : {output_dir}/log_*.txt") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/print_tritonblas_tile_schedule.py b/benchmark/ops/print_tritonblas_tile_schedule.py new file mode 100644 index 000000000..1b5187fb3 --- /dev/null +++ b/benchmark/ops/print_tritonblas_tile_schedule.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Print the tritonBLAS launch-wave GEMM tile schedule. + +This mirrors the current tritonBLAS matmul path used by the copy-engine code: +1. launch one workgroup per output tile +2. remap program IDs with chiplet_transform_chunked() +3. let hardware schedule those workgroups in waves of active CUs +4. map remapped tile IDs to (pid_m, pid_n) using GROUP_SIZE_M swizzling + +The output is grouped into hardware waves of `wave_size` workgroups so it is +easy to see which tile coordinates are active in the first 304 WGs, second +304 WGs, and so on. +""" + +import argparse +import importlib.util +from pathlib import Path +import sys + +_SCRIPT_PATH = Path(__file__).resolve() +_HELPER_PATH = None +for _parent in (_SCRIPT_PATH.parent, *_SCRIPT_PATH.parents): + _candidate = _parent / "iris" / "ops" / "tritonblas_launch_wave_schedule.py" + if _candidate.is_file(): + _HELPER_PATH = _candidate + break +if _HELPER_PATH is None: + raise FileNotFoundError( + f"Unable to locate iris/ops/tritonblas_launch_wave_schedule.py starting from {_SCRIPT_PATH}" + ) + +_SPEC = importlib.util.spec_from_file_location("tritonblas_launch_wave_schedule", str(_HELPER_PATH)) +if _SPEC is None or _SPEC.loader is None: + raise ImportError(f"Unable to load helper module from {_HELPER_PATH}") +_MODULE = importlib.util.module_from_spec(_SPEC) +sys.modules[_SPEC.name] = _MODULE +_SPEC.loader.exec_module(_MODULE) + +build_launch_wave_plan = _MODULE.build_launch_wave_plan +ceil_div = _MODULE.ceil_div +chiplet_transform_chunked = _MODULE.chiplet_transform_chunked +default_chunk_size = _MODULE.default_chunk_size +grouped_tile_coords = _MODULE.grouped_tile_coords + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Print tritonBLAS XCD-aware launch-wave GEMM tile schedule by iteration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--m", type=int, required=True, help="Problem M dimension") + parser.add_argument("--n", type=int, required=True, help="Problem N dimension") + parser.add_argument("--block-m", type=int, required=True, help="BLOCK_SIZE_M") + parser.add_argument("--block-n", type=int, required=True, help="BLOCK_SIZE_N") + parser.add_argument("--group-size-m", type=int, required=True, help="GROUP_SIZE_M") + parser.add_argument("--wave-size", type=int, default=304, help="Active workgroups / CUs per hardware wave") + parser.add_argument("--num-xcds", type=int, default=8, help="Number of XCDs used by the transform") + parser.add_argument( + "--merge-order", + type=str, + choices=("column", "row"), + default="column", + help="How to coalesce neighboring tiles into transfer rectangles.", + ) + parser.add_argument( + "--launch-grid", + type=int, + default=None, + help="Optional explicit kernel launch grid. Defaults to total output tiles.", + ) + parser.add_argument( + "--chunk-size", + type=int, + default=None, + help="Optional explicit XCD chunk size. Defaults to min(group_size_m^2, total_tiles // num_xcds).", + ) + parser.add_argument( + "--iterations", + type=int, + default=None, + help="How many waves of workgroups to print. Defaults to all waves.", + ) + parser.add_argument( + "--summary-only", + action="store_true", + help="Print only the compact per-iteration summaries, not every workgroup entry.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + num_tiles_m = ceil_div(args.m, args.block_m) + num_tiles_n = ceil_div(args.n, args.block_n) + total_tiles = num_tiles_m * num_tiles_n + launch_grid = total_tiles if args.launch_grid is None else args.launch_grid + + if args.chunk_size is None: + chunk_size = default_chunk_size(total_tiles, args.group_size_m, args.num_xcds) + else: + chunk_size = args.chunk_size + + plan = build_launch_wave_plan( + num_tiles_m=num_tiles_m, + num_tiles_n=num_tiles_n, + group_size_m=args.group_size_m, + launch_grid=launch_grid, + wave_size=args.wave_size, + num_xcds=args.num_xcds, + chunk_size=chunk_size, + merge_order=args.merge_order, + ) + + max_iterations = plan.num_waves + iterations = max_iterations if args.iterations is None else min(args.iterations, max_iterations) + + print(f"Shape : M={args.m} N={args.n}") + print(f"Tile shape : BLOCK_M={args.block_m} BLOCK_N={args.block_n}") + print(f"Tile grid : num_tiles_m={num_tiles_m} num_tiles_n={num_tiles_n} total_tiles={total_tiles}") + print(f"Group size : group_size_m={args.group_size_m} tiles_per_group={args.group_size_m * num_tiles_n}") + print(f"Hardware : wave_size={args.wave_size} num_xcds={args.num_xcds} chunk_size={chunk_size}") + print(f"Merge order : {args.merge_order}") + print(f"Launch grid : {launch_grid}") + print(f"Iterations : printing {iterations} / {max_iterations}") + print(f"Transfers : {len(plan.transfers)}") + print() + + for iteration in range(iterations): + entries: list[tuple[int, int, int, int, int, int]] = [] + iteration_transfers = [transfer for transfer in plan.transfers if transfer.wave_id == iteration] + pid_start = iteration * args.wave_size + pid_end = min(pid_start + args.wave_size, launch_grid) + for pid in range(pid_start, pid_end): + transformed_pid = chiplet_transform_chunked(pid, launch_grid, args.num_xcds, chunk_size) + tile_id = transformed_pid + if tile_id >= total_tiles: + continue + pid_m, pid_n, group_id = grouped_tile_coords(tile_id, num_tiles_m, num_tiles_n, args.group_size_m) + xcd = pid % args.num_xcds if args.num_xcds > 0 else 0 + entries.append((pid, xcd, transformed_pid, tile_id, pid_m, pid_n, group_id)) + + if not entries: + break + + unique_groups = sorted({entry[6] for entry in entries}) + m_min = min(entry[4] for entry in entries) + m_max = max(entry[4] for entry in entries) + n_min = min(entry[5] for entry in entries) + n_max = max(entry[5] for entry in entries) + print( + f"Iteration {iteration:2d}: {len(entries):3d} tiles " + f"groups={unique_groups} m=[{m_min},{m_max}] n=[{n_min},{n_max}] " + f"transfers={len(iteration_transfers)}" + ) + for transfer in iteration_transfers: + print( + f" transfer group={transfer.group_id:2d} " + f"m=[{transfer.m_tile_start},{transfer.m_tile_start + transfer.m_tile_count - 1}] " + f"n=[{transfer.n_tile_start},{transfer.n_tile_start + transfer.n_tile_count - 1}] " + f"shape={transfer.m_tile_count}x{transfer.n_tile_count}" + ) + + if args.summary_only: + continue + + print(" pid xcd remap tile_id m n group") + for pid, xcd, transformed_pid, tile_id, pid_m, pid_n, group_id in entries: + print(f" {pid:3d} {xcd:3d} {transformed_pid:5d} {tile_id:7d} {pid_m:2d} {pid_n:2d} {group_id:5d}") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/sweep_bench.py b/benchmark/ops/sweep_bench.py new file mode 100644 index 000000000..fe9c2198a --- /dev/null +++ b/benchmark/ops/sweep_bench.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Unified sweep benchmark script for matmul and all-gather operations. + +Runs benchmarks across all permutations of M, N, K dimensions. +Supports both operation types via --operation argument. + +Usage: + python sweep_bench.py --operation matmul_all_gather + python sweep_bench.py --operation all_gather_matmul +""" + +import argparse +import json +import os +import signal +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional + + +# Project root (2 levels up from this script) +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent + +# Dimension configurations to test. +# Each entry contains M_local (per-rank M), N, K, and an optional label. +DIMENSION_CONFIGS = [ + {"m_local": 2048, "n": 2048, "k": 16384, "label": "M2048_N2048_K16384"}, + {"m_local": 2048, "n": 16384, "k": 2048, "label": "M2048_N16384_K2048"}, + {"m_local": 2048, "n": 16384, "k": 16384, "label": "M2048_N16384_K16384"}, + {"m_local": 2048, "n": 16384, "k": 65536, "label": "M2048_N16384_K65536"}, + {"m_local": 2048, "n": 131072, "k": 16384, "label": "M2048_N131072_K16384"}, + {"m_local": 16384, "n": 2048, "k": 2048, "label": "M16384_N2048_K2048"}, + {"m_local": 16384, "n": 2048, "k": 16384, "label": "M16384_N2048_K16384"}, + {"m_local": 16384, "n": 2048, "k": 131072, "label": "M16384_N2048_K131072"}, + {"m_local": 16384, "n": 16384, "k": 2048, "label": "M16384_N16384_K2048"}, + {"m_local": 131072, "n": 2048, "k": 16384, "label": "M131072_N2048_K16384"}, + {"m_local": 131072, "n": 16384, "k": 16384, "label": "g2"}, + {"m_local": 147456, "n": 28672, "k": 4096, "label": "g14"}, + {"m_local": 327680, "n": 28672, "k": 4096, "label": "g15"}, # run out of heap memory + {"m_local": 229376, "n": 28672, "k": 4096, "label": "g16"}, + {"m_local": 8192, "n": 8192, "k": 262144, "label": "g5"}, + {"m_local": 262144, "n": 8192, "k": 8192, "label": "g6"}, + {"m_local": 16384, "n": 16384, "k": 131072, "label": "g1"}, + {"m_local": 262144, "n": 28672, "k": 8192, "label": "g8"}, # run out of heap memory + {"m_local": 196608, "n": 18432, "k": 16384, "label": "g9"}, + {"m_local": 4096, "n": 14336, "k": 4096, "label": "mixtral_gate"}, + {"m_local": 4096, "n": 11008, "k": 4096, "label": "llama7b_gate"}, + {"m_local": 4096, "n": 4096, "k": 4096, "label": "pow2_4k"}, + {"m_local": 1024, "n": 3584, "k": 8192, "label": "M1024_N3584_K8192"}, + {"m_local": 4096, "n": 3584, "k": 8192, "label": "M4096_N3584_K8192"}, + {"m_local": 16384, "n": 3584, "k": 8192, "label": "M16384_N3584_K8192"}, +] + +# Benchmark configurations per operation type +BENCHMARK_CONFIGS = { + "matmul_all_gather": { + "pytorchbaseline": { + "script": "benchmark/ops/bench_matmul_all_gather.py", + "benchmark_filter": "^pytorch_matmul_all_gather$", + "axes": {"m": "M_local", "n": "N", "k": "K"}, + }, + "tritonblas_rcclbaseline": { + "script": "benchmark/ops/bench_matmul_all_gather.py", + "benchmark_filter": "^tritonblas_matmul_all_gather$", + "axes": {"m": "M_local", "n": "N", "k": "K"}, + }, + "baseline": { + "script": "benchmark/ops/bench_matmul_all_gather.py", + "benchmark_filter": "^matmul_all_gather$", + "axes": {"m": "M_local", "n": "N", "k": "K"}, + }, + "host_copy_engine": { + "script": "benchmark/ops/bench_matmul_all_gather_copy_engine.py", + "benchmark_filter": "^matmul_all_gather_copy_engine_host$", + "axes": {"m": "M_local", "n": "N", "k": "K"}, + }, + "device_copy_engine": { + "script": "benchmark/ops/bench_matmul_all_gather_copy_engine.py", + "benchmark_filter": "^matmul_all_gather_copy_engine_device$", + "axes": {"m": "M_local", "n": "N", "k": "K"}, + }, + "matmul_only": { + "script": "benchmark/ops/bench_matmul.py", + "benchmark_filter": "^matmul_only_local$", + "axes": {"m": "M_local", "n": "N", "k": "K"}, + }, + "pytorchmatmul_only": { + "script": "benchmark/ops/bench_matmul.py", + "benchmark_filter": "^pytorch_matmul_only_local$", + "axes": {"m": "M_local", "n": "N", "k": "K"}, + }, + }, + "all_gather_matmul": { + "baseline": { + "script": "benchmark/ops/bench_all_gather_matmul.py", + "benchmark_filter": "^all_gather_matmul$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "pytorchbaseline": { + "script": "benchmark/ops/bench_all_gather_matmul.py", + "benchmark_filter": "^rccl_all_gather_matmul$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "tritonblas_rcclbaseline": { + "script": "benchmark/ops/bench_all_gather_matmul.py", + "benchmark_filter": "^tritonblas_rccl_all_gather_matmul$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "hbm_buffer": { + "script": "benchmark/ops/bench_all_gather_matmul.py", + "benchmark_filter": "^all_gather_matmul_hbm_buffer$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "copy_engine_host": { + "script": "benchmark/ops/bench_all_gather_matmul_copy_engine.py", + "benchmark_filter": "^all_gather_matmul_copy_engine_host$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "copy_engine_host_hip_memcpy": { + "script": "benchmark/ops/bench_all_gather_matmul_copy_engine.py", + "benchmark_filter": "^all_gather_matmul_copy_engine_host_hip_memcpy$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "copy_engine_device": { + "script": "benchmark/ops/bench_all_gather_matmul_copy_engine.py", + "benchmark_filter": "^all_gather_matmul_copy_engine_device$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "matmul_only": { + "script": "benchmark/ops/bench_matmul.py", + "benchmark_filter": "^matmul_only$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + "pytorchmatmul_only": { + "script": "benchmark/ops/bench_matmul.py", + "benchmark_filter": "^pytorch_matmul_only$", + "axes": {"m": "M", "n": "N", "k": "K"}, + }, + }, +} + +TIMEOUT_SECONDS = 150 +NUM_GPUS = 8 + + +def log(msg: str): + """Log to stderr to keep stdout clean for JSON.""" + print(msg, file=sys.stderr, flush=True) + + +def _dimension_values(config: Dict[str, Any]) -> tuple[int, int, int]: + return int(config["m_local"]), int(config["n"]), int(config["k"]) + + +def _dimension_label(config: Dict[str, Any]) -> str: + return str(config.get("label") or f"M{config['m_local']}_N{config['n']}_K{config['k']}") + + +def _run_bench_benchmark( + benchmark_name: str, + script: str, + m: int, + n: int, + k: int, + benchmark_filter: str, + axes: Dict[str, str], +) -> Optional[Dict[str, Any]]: + with tempfile.NamedTemporaryFile( + mode="w", + suffix=f"_{benchmark_name}_M{m}_N{n}_K{k}.json", + dir=str(PROJECT_ROOT), + delete=False, + ) as tmp_file: + benchmark_out = tmp_file.name + + cmd = [ + sys.executable, + script, + "--benchmark_format=json", + f"--benchmark_out={benchmark_out}", + f"--benchmark_filter={benchmark_filter}", + f"--axis_num_ranks={NUM_GPUS}", + f"--axis_{axes['m']}={m}", + f"--axis_{axes['n']}={n}", + f"--axis_{axes['k']}={k}", + "--axis_dtype=fp16", + ] + + log(f" Running {benchmark_name}: M={m}, N={n}, K={k}") + log(f" Command: {' '.join(cmd)}") + + process = None + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=str(PROJECT_ROOT), + preexec_fn=os.setsid, + ) + stdout, stderr = process.communicate(timeout=TIMEOUT_SECONDS) + result = subprocess.CompletedProcess(cmd, process.returncode, stdout, stderr) + + if result.returncode != 0: + log(" ✗ Failed: Non-zero return code") + log(f" Return code: {result.returncode}") + error_log_file = PROJECT_ROOT / f"benchmark_error_{benchmark_name}_M{m}_N{n}_K{k}.log" + with open(error_log_file, "w") as f: + f.write(f"Benchmark: {benchmark_name}\n") + f.write(f"Dimensions: M={m}, N={n}, K={k}\n") + f.write(f"Command: {' '.join(cmd)}\n") + f.write(f"Return code: {result.returncode}\n\n") + f.write("=" * 80 + "\n") + f.write("STDOUT:\n") + f.write("=" * 80 + "\n") + f.write(result.stdout) + f.write("\n" + "=" * 80 + "\n") + f.write("STDERR:\n") + f.write("=" * 80 + "\n") + f.write(result.stderr) + log(f" Full output saved to: {error_log_file}") + lines = (result.stdout + result.stderr).strip().split("\n") + log(" Last output lines:") + for line in lines[-5:]: + log(f" {line}") + return None + + with open(benchmark_out, "r") as f: + records = json.load(f) + if not isinstance(records, list) or not records: + log(" ✗ Failed: bench JSON output was empty") + return None + + record = next((r for r in records if not r.get("skipped")), None) + if record is None: + skip_reason = records[0].get("skip_reason", "") + log(f" ✗ Failed: benchmark was skipped ({skip_reason})") + return {"status": "SKIPPED", "skip_reason": skip_reason} + + params = record.get("params", {}) + counters = record.get("counters", {}) + data = { + "world_size": record.get("world_size"), + "operation": record.get("benchmark"), + "m": int(params.get(axes["m"], m)), + "n": int(params.get(axes["n"], n)), + "k": int(params.get(axes["k"], k)), + "datatype": params.get("dtype", "float16"), + "total_ms": record.get("gpu_time_ms"), + "gpu_time_ms": record.get("gpu_time_ms"), + "all_times_ms": record.get("all_times_ms", []), + "bandwidth_gbps": record.get("bandwidth_gbps"), + "tflops": record.get("tflops"), + } + data.update(counters) + log(" ✓ Success: Loaded bench JSON results") + return data + + except subprocess.TimeoutExpired as timeout_err: + log(f" ✗ Timeout after {TIMEOUT_SECONDS}s - killing process group") + partial_stdout = timeout_err.stdout.decode("utf-8", errors="replace") if timeout_err.stdout else "" + partial_stderr = timeout_err.stderr.decode("utf-8", errors="replace") if timeout_err.stderr else "" + if process: + try: + os.killpg(os.getpgid(process.pid), signal.SIGTERM) + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + log(" Process didn't terminate, force killing...") + os.killpg(os.getpgid(process.pid), signal.SIGKILL) + process.wait() + except ProcessLookupError: + pass + error_log_file = PROJECT_ROOT / f"benchmark_timeout_{benchmark_name}_M{m}_N{n}_K{k}.log" + with open(error_log_file, "w") as f: + f.write(f"Benchmark: {benchmark_name}\n") + f.write(f"Dimensions: M={m}, N={n}, K={k}\n") + f.write(f"Command: {' '.join(cmd)}\n") + f.write(f"Status: TIMEOUT after {TIMEOUT_SECONDS}s\n\n") + f.write("=" * 80 + "\n") + f.write("PARTIAL STDOUT (before timeout):\n") + f.write("=" * 80 + "\n") + f.write(partial_stdout) + f.write("\n" + "=" * 80 + "\n") + f.write("PARTIAL STDERR (before timeout):\n") + f.write("=" * 80 + "\n") + f.write(partial_stderr) + log(f" Timeout logged to: {error_log_file}") + return None + except json.JSONDecodeError as e: + log(f" ✗ Error: Failed to parse bench JSON: {e}") + return None + except Exception as e: + log(f" ✗ Error: {e}") + return None + finally: + try: + os.remove(benchmark_out) + except OSError: + pass + + +def run_benchmark( + benchmark_name: str, + bench_config: Dict[str, Any], + m: int, + n: int, + k: int, +) -> Optional[Dict[str, Any]]: + return _run_bench_benchmark( + benchmark_name, + bench_config["script"], + m, + n, + k, + bench_config["benchmark_filter"], + bench_config.get("axes", {"m": "M", "n": "N", "k": "K"}), + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Run sweep benchmarks for matmul and all-gather operations", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--operation", + type=str, + required=True, + choices=["matmul_all_gather", "all_gather_matmul"], + help="Operation type to benchmark", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON file path (default: benchmark/ops/{operation}/benchmark_sweep_results.json)", + ) + args = parser.parse_args() + + operation = args.operation + benchmarks = BENCHMARK_CONFIGS[operation] + + # Determine output file + if args.output: + output_file = Path(args.output) + else: + output_file = PROJECT_ROOT / f"benchmark/ops/{operation}/benchmark_sweep_results.json" + + log("=" * 80) + log(f"{operation.upper().replace('_', '-')} Benchmark Sweep") + log("=" * 80) + log(f"Dimension configurations: {len(DIMENSION_CONFIGS)}") + for config in DIMENSION_CONFIGS: + m, n, k = _dimension_values(config) + log(f" - {_dimension_label(config)}: M={m}, N={n}, K={k}") + log(f"Benchmarks per configuration: {len(benchmarks)}") + log(f"Total benchmarks: {len(DIMENSION_CONFIGS) * len(benchmarks)}") + log(f"Timeout per benchmark: {TIMEOUT_SECONDS}s") + log(f"GPUs: {NUM_GPUS}") + log(f"Output file: {output_file}") + log("=" * 80) + log("") + + results = [] + + log(f"Running {len(DIMENSION_CONFIGS)} dimension configurations...\n") + + for idx, config in enumerate(DIMENSION_CONFIGS, 1): + m, n, k = _dimension_values(config) + label = _dimension_label(config) + log(f"[{idx}/{len(DIMENSION_CONFIGS)}] Testing {label}: M={m}, N={n}, K={k}") + + row = {"label": label, "M": m, "N": n, "K": k, "operation": operation, "benchmarks": {}} + + # Run each benchmark variant + for bench_key, bench_config in benchmarks.items(): + result = run_benchmark( + benchmark_name=bench_key, + bench_config=bench_config, + m=m, + n=n, + k=k, + ) + + if result is not None: + row["benchmarks"][bench_key] = result + else: + row["benchmarks"][bench_key] = {"status": "FAILED"} + + results.append(row) + log("") + + # Write JSON file + log(f"Writing results to {output_file}...") + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + + log(f"✓ Results saved to {output_file}\n") + + log("\n" + "=" * 80) + log("Benchmark sweep complete!") + log("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/sweep_benchmarks.py b/benchmark/ops/sweep_benchmarks.py new file mode 100644 index 000000000..1629fbf69 --- /dev/null +++ b/benchmark/ops/sweep_benchmarks.py @@ -0,0 +1,402 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Unified sweep benchmark script for matmul and all-gather operations. + +Runs benchmarks across all permutations of M, N, K dimensions. +Supports both operation types via --operation argument. + +Usage: + python sweep_benchmarks.py --operation matmul_all_gather + python sweep_benchmarks.py --operation all_gather_matmul +""" + +import subprocess +import sys +import json +import os +import signal +import argparse +from pathlib import Path +from typing import Optional, Dict, Any + + +# Project root (2 levels up from this script) +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent + +# Dimension configurations to test. +# Each entry contains M_local (per-rank M), N, K, and an optional label. +DIMENSION_CONFIGS = [ + {"m_local": 2048, "n": 2048, "k": 16384, "label": "M2048_N2048_K16384"}, + {"m_local": 2048, "n": 16384, "k": 2048, "label": "M2048_N16384_K2048"}, + {"m_local": 2048, "n": 16384, "k": 16384, "label": "M2048_N16384_K16384"}, + {"m_local": 2048, "n": 16384, "k": 65536, "label": "M2048_N16384_K65536"}, + {"m_local": 2048, "n": 131072, "k": 16384, "label": "M2048_N131072_K16384"}, + {"m_local": 16384, "n": 2048, "k": 2048, "label": "M16384_N2048_K2048"}, + {"m_local": 16384, "n": 2048, "k": 16384, "label": "M16384_N2048_K16384"}, + {"m_local": 16384, "n": 2048, "k": 131072, "label": "M16384_N2048_K131072"}, + {"m_local": 16384, "n": 16384, "k": 2048, "label": "M16384_N16384_K2048"}, + {"m_local": 131072, "n": 2048, "k": 16384, "label": "M131072_N2048_K16384"}, + {"m_local": 131072, "n": 16384, "k": 16384, "label": "g2"}, + {"m_local": 147456, "n": 28672, "k": 4096, "label": "g14"}, + # {"m_local": 327680, "n": 28672, "k": 4096, "label": "g15"}, # run out of heap memory + {"m_local": 229376, "n": 28672, "k": 4096, "label": "g16"}, + {"m_local": 8192, "n": 8192, "k": 262144, "label": "g5"}, + {"m_local": 262144, "n": 8192, "k": 8192, "label": "g6"}, + {"m_local": 16384, "n": 16384, "k": 131072, "label": "g1"}, + {"m_local": 262144, "n": 28672, "k": 8192, "label": "g8"}, # run out of heap memory + {"m_local": 196608, "n": 18432, "k": 16384, "label": "g9"}, + {"m_local": 4096, "n": 14336, "k": 4096, "label": "mixtral_gate"}, + {"m_local": 4096, "n": 11008, "k": 4096, "label": "llama7b_gate"}, + {"m_local": 4096, "n": 4096, "k": 4096, "label": "pow2_4k"}, + {"m_local": 1024, "n": 3584, "k": 8192, "label": "M1024_N3584_K8192"}, + {"m_local": 4096, "n": 3584, "k": 8192, "label": "M4096_N3584_K8192"}, + {"m_local": 16384, "n": 3584, "k": 8192, "label": "M16384_N3584_K8192"}, +] + +# Benchmark configurations per operation type +BENCHMARK_CONFIGS = { + "matmul_all_gather": { + "baseline": { + "script": "benchmark/ops/matmul_all_gather/benchmark.py", + "extra_args": ["--benchmark_pytorch"], + "output_file": "matmul_all_gather_baseline.json", + "extract_multiple": True, + }, + "host_copy_engine": { + "script": "benchmark/ops/matmul_all_gather/benchmark_host_copy_engine.py", + "extra_args": ["--no-trace"], + "output_file": "matmul_all_gather_host_copy_engine.json", + "extract_multiple": False, + }, + "device_copy_engine": { + "script": "benchmark/ops/matmul_all_gather/benchmark_copy_engine.py", + "extra_args": [], + "output_file": "matmul_all_gather_device_copy_engine.json", + "extract_multiple": False, + }, + "matmul_only": { + "script": "benchmark/ops/matmul_all_gather/benchmark_matmul.py", + "extra_args": ["--benchmark_pytorch"], + "output_file": "matmul_only.json", + "extract_multiple": True, + }, + }, + "all_gather_matmul": { + "baseline": { + "script": "benchmark/ops/all_gather_matmul/benchmark_torchrun.py", + "extra_args": ["--benchmark_pytorch"], + "output_file": "all_gather_matmul_baseline.json", + "extract_multiple": True, + }, + "hbm_buffer": { + "script": "benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py", + "extra_args": [], + "output_file": "all_gather_matmul_hbm_buffer.json", + "extract_multiple": False, + }, + "copy_engine_host": { + "script": "benchmark/ops/all_gather_matmul/benchmark_copy_engine.py", + "extra_args": ["--force-host-initiated", "--no-trace"], + "output_file": "all_gather_matmul_host_copy_engine.json", + "extract_multiple": False, + }, + "copy_engine_device": { + "script": "benchmark/ops/all_gather_matmul/benchmark_copy_engine.py", + "extra_args": ["--force-device-initiated", "--no-trace"], + "output_file": "all_gather_matmul_device_copy_engine.json", + "extract_multiple": False, + }, + "matmul_only": { + "script": "benchmark/ops/matmul_all_gather/benchmark_matmul.py", + "extra_args": ["--benchmark_pytorch"], + "output_file": "matmul_only.json", + "extract_multiple": True, + }, + }, +} + +TIMEOUT_SECONDS = 150 +NUM_GPUS = 8 + + +def log(msg: str): + """Log to stderr to keep stdout clean for JSON.""" + print(msg, file=sys.stderr, flush=True) + + +def _dimension_values(config: Dict[str, Any]) -> tuple[int, int, int]: + return int(config["m_local"]), int(config["n"]), int(config["k"]) + + +def _dimension_label(config: Dict[str, Any]) -> str: + return str(config.get("label") or f"M{config['m_local']}_N{config['n']}_K{config['k']}") + + +def run_benchmark( + benchmark_name: str, + script: str, + m: int, + n: int, + k: int, + extra_args: list, + output_file: str, +) -> Optional[Dict[str, Any]]: + """ + Run a single benchmark and extract results from JSON output file. + + Returns: + Dictionary containing benchmark results, or None if failed/timeout. + """ + cmd = [ + "torchrun", + f"--nproc_per_node={NUM_GPUS}", + script, + "-m", + str(m), + "-n", + str(n), + "-k", + str(k), + "--validate", + "--benchmark", + "--output_file", + output_file, + ] + extra_args + + log(f" Running {benchmark_name}: M={m}, N={n}, K={k}") + log(f" Command: {' '.join(cmd)}") + + process = None + try: + # Start process in new process group so we can kill all children + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=str(PROJECT_ROOT), + preexec_fn=os.setsid, # Create new process group + ) + stdout, stderr = process.communicate(timeout=TIMEOUT_SECONDS) + result = subprocess.CompletedProcess(cmd, process.returncode, stdout, stderr) + + if result.returncode != 0: + log(" ✗ Failed: Non-zero return code") + log(f" Return code: {result.returncode}") + + # Save full output to error log + error_log_file = PROJECT_ROOT / f"benchmark_error_{benchmark_name}_M{m}_N{n}_K{k}.log" + with open(error_log_file, "w") as f: + f.write(f"Benchmark: {benchmark_name}\n") + f.write(f"Dimensions: M={m}, N={n}, K={k}\n") + f.write(f"Command: {' '.join(cmd)}\n") + f.write(f"Return code: {result.returncode}\n\n") + f.write("=" * 80 + "\n") + f.write("STDOUT:\n") + f.write("=" * 80 + "\n") + f.write(result.stdout) + f.write("\n" + "=" * 80 + "\n") + f.write("STDERR:\n") + f.write("=" * 80 + "\n") + f.write(result.stderr) + + log(f" Full output saved to: {error_log_file}") + + # Show last few lines for quick diagnosis + output = result.stdout + result.stderr + lines = output.strip().split("\n") + log(" Last output lines:") + for line in lines[-5:]: + log(f" {line}") + return None + + # Read the JSON output file + json_path = PROJECT_ROOT / output_file + if not json_path.exists(): + log(f" ✗ Failed: JSON output file not found: {json_path}") + return None + + with open(json_path, "r") as f: + data = json.load(f) + + # Check validation status + validation_status = "" + if "success" in data: + if data["success"]: + validation_status = " (validation: PASSED)" + else: + validation_status = " (validation: FAILED)" + + log(f" ✓ Success: Loaded JSON results{validation_status}") + + return data + + except subprocess.TimeoutExpired as timeout_err: + log(f" ✗ Timeout after {TIMEOUT_SECONDS}s - killing process group") + + # Capture any partial output from the timeout exception (decode bytes to str) + partial_stdout = timeout_err.stdout.decode("utf-8", errors="replace") if timeout_err.stdout else "" + partial_stderr = timeout_err.stderr.decode("utf-8", errors="replace") if timeout_err.stderr else "" + + if process: + try: + # Kill entire process group (torchrun + all child processes) + os.killpg(os.getpgid(process.pid), signal.SIGTERM) + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + # Force kill if graceful termination fails + log(" Process didn't terminate, force killing...") + os.killpg(os.getpgid(process.pid), signal.SIGKILL) + process.wait() + except ProcessLookupError: + # Process already died + pass + + # Log timeout error with partial output + error_log_file = PROJECT_ROOT / f"benchmark_timeout_{benchmark_name}_M{m}_N{n}_K{k}.log" + with open(error_log_file, "w") as f: + f.write(f"Benchmark: {benchmark_name}\n") + f.write(f"Dimensions: M={m}, N={n}, K={k}\n") + f.write(f"Command: {' '.join(cmd)}\n") + f.write(f"Status: TIMEOUT after {TIMEOUT_SECONDS}s\n\n") + f.write("=" * 80 + "\n") + f.write("PARTIAL STDOUT (before timeout):\n") + f.write("=" * 80 + "\n") + f.write(partial_stdout) + f.write("\n" + "=" * 80 + "\n") + f.write("PARTIAL STDERR (before timeout):\n") + f.write("=" * 80 + "\n") + f.write(partial_stderr) + log(f" Timeout logged to: {error_log_file}") + + return None + except json.JSONDecodeError as e: + log(f" ✗ Error: Failed to parse JSON: {e}") + return None + except Exception as e: + log(f" ✗ Error: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Run sweep benchmarks for matmul and all-gather operations", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--operation", + type=str, + required=True, + choices=["matmul_all_gather", "all_gather_matmul"], + help="Operation type to benchmark", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON file path (default: benchmark/ops/{operation}/benchmark_sweep_results.json)", + ) + args = parser.parse_args() + + operation = args.operation + benchmarks = BENCHMARK_CONFIGS[operation] + + # Determine output file + if args.output: + output_file = Path(args.output) + else: + output_file = PROJECT_ROOT / f"benchmark/ops/{operation}/benchmark_sweep_results.json" + + log("=" * 80) + log(f"{operation.upper().replace('_', '-')} Benchmark Sweep") + log("=" * 80) + log(f"Dimension configurations: {len(DIMENSION_CONFIGS)}") + for config in DIMENSION_CONFIGS: + m, n, k = _dimension_values(config) + log(f" - {_dimension_label(config)}: M={m}, N={n}, K={k}") + log(f"Benchmarks per configuration: {len(benchmarks)}") + log(f"Total benchmarks: {len(DIMENSION_CONFIGS) * len(benchmarks)}") + log(f"Timeout per benchmark: {TIMEOUT_SECONDS}s") + log(f"GPUs: {NUM_GPUS}") + log(f"Output file: {output_file}") + log("=" * 80) + log("") + + results = [] + + log(f"Running {len(DIMENSION_CONFIGS)} dimension configurations...\n") + + for idx, config in enumerate(DIMENSION_CONFIGS, 1): + m, n, k = _dimension_values(config) + label = _dimension_label(config) + log(f"[{idx}/{len(DIMENSION_CONFIGS)}] Testing {label}: M={m}, N={n}, K={k}") + + row = {"label": label, "M": m, "N": n, "K": k, "operation": operation, "benchmarks": {}} + + # Run each benchmark variant + for bench_key, bench_config in benchmarks.items(): + pytorch_bench_key = "pytorch" + bench_key + result = run_benchmark( + benchmark_name=bench_key, + script=bench_config["script"], + m=m, + n=n, + k=k, + extra_args=bench_config["extra_args"], + output_file=bench_config["output_file"], + ) + + if result is not None: + # Check if this benchmark produces multiple results + if bench_config.get("extract_multiple", False): + # Extract baseline results (tflops, etc.) + baseline_result = {k: v for k, v in result.items() if not k.startswith("pytorch_")} + row["benchmarks"][bench_key] = baseline_result + + # Extract pytorch results (pytorch_tflops, etc.) + if "pytorch_tflops" in result: + pytorch_result = { + "tflops": result.get("pytorch_tflops"), + "bandwidth_gbps": result.get("pytorch_bandwidth_gbps"), + "total_ms": result.get("pytorch_ms"), + } + # Copy common fields + for field in ["world_size", "operation", "m", "n", "k", "datatype"]: + if field in result: + pytorch_result[field] = result[field] + row["benchmarks"][pytorch_bench_key] = pytorch_result + else: + row["benchmarks"][pytorch_bench_key] = {"status": "FAILED"} + else: + # Single result benchmark + row["benchmarks"][bench_key] = result + else: + # Failed benchmark + if bench_config.get("extract_multiple", False): + row["benchmarks"][bench_key] = {"status": "FAILED"} + row["benchmarks"][pytorch_bench_key] = {"status": "FAILED"} + else: + row["benchmarks"][bench_key] = {"status": "FAILED"} + + results.append(row) + log("") + + # Write JSON file + log(f"Writing results to {output_file}...") + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + + log(f"✓ Results saved to {output_file}\n") + + log("\n" + "=" * 80) + log("Benchmark sweep complete!") + log("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/validate_sweep.py b/benchmark/ops/validate_sweep.py new file mode 100644 index 000000000..be3faabb7 --- /dev/null +++ b/benchmark/ops/validate_sweep.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Distributed validation sweep for the same shapes used by benchmark sweeps. + +This runs the pytest-based distributed correctness tests under torchrun and +records pass/fail/skip results in a JSON file, separate from performance data. +""" + +import argparse +import json +import os +import shlex +import signal +import subprocess +import sys +from pathlib import Path +from typing import Any + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +TIMEOUT_SECONDS = 180 +NUM_GPUS = 8 +DEFAULT_HEAP_SIZE = 1 << 34 +DEFAULT_ELEMENT_SIZE_BYTES = 2 +HEAP_HEADROOM_FACTOR = 1.25 +HEAP_ALIGNMENT_BYTES = 1 << 30 + + +DIMENSION_CONFIGS = [ + # {"m_local": 2048, "n": 2048, "k": 16384, "label": "M2048_N2048_K16384"}, + # {"m_local": 2048, "n": 16384, "k": 2048, "label": "M2048_N16384_K2048"}, + # {"m_local": 2048, "n": 16384, "k": 16384, "label": "M2048_N16384_K16384"}, + # {"m_local": 2048, "n": 16384, "k": 65536, "label": "M2048_N16384_K65536"}, + # {"m_local": 2048, "n": 131072, "k": 16384, "label": "M2048_N131072_K16384"}, + # {"m_local": 16384, "n": 2048, "k": 2048, "label": "M16384_N2048_K2048"}, + # {"m_local": 16384, "n": 2048, "k": 16384, "label": "M16384_N2048_K16384"}, + # {"m_local": 16384, "n": 2048, "k": 131072, "label": "M16384_N2048_K131072"}, + # {"m_local": 16384, "n": 16384, "k": 2048, "label": "M16384_N16384_K2048"}, + {"m_local": 131072, "n": 2048, "k": 16384, "label": "M131072_N2048_K16384"}, + {"m_local": 131072, "n": 16384, "k": 16384, "label": "g2"}, + {"m_local": 147456, "n": 28672, "k": 4096, "label": "g14"}, + {"m_local": 327680, "n": 28672, "k": 4096, "label": "g15"}, + {"m_local": 229376, "n": 28672, "k": 4096, "label": "g16"}, + {"m_local": 8192, "n": 8192, "k": 262144, "label": "g5"}, + {"m_local": 262144, "n": 8192, "k": 8192, "label": "g6"}, + {"m_local": 16384, "n": 16384, "k": 131072, "label": "g1"}, + {"m_local": 262144, "n": 28672, "k": 8192, "label": "g8"}, + {"m_local": 196608, "n": 18432, "k": 16384, "label": "g9"}, + {"m_local": 4096, "n": 14336, "k": 4096, "label": "mixtral_gate"}, + {"m_local": 4096, "n": 11008, "k": 4096, "label": "llama7b_gate"}, + {"m_local": 4096, "n": 4096, "k": 4096, "label": "pow2_4k"}, + {"m_local": 1024, "n": 3584, "k": 8192, "label": "M1024_N3584_K8192"}, + {"m_local": 4096, "n": 3584, "k": 8192, "label": "M4096_N3584_K8192"}, + {"m_local": 16384, "n": 3584, "k": 8192, "label": "M16384_N3584_K8192"}, +] + + +VALIDATION_TESTS = { + "matmul_all_gather": [ + { + "name": "baseline", + "path": "tests/ops/test_matmul_all_gather.py", + "pytest_k": "test_matmul_all_gather and not test_tritonblas", + }, + { + "name": "tritonblas_rcclbaseline", + "path": "tests/ops/test_matmul_all_gather.py", + "pytest_k": "test_tritonblas_rccl_matmul_all_gather", + }, + { + "name": "host_copy_engine", + "path": "tests/ops/test_matmul_all_gather_copy_engine.py", + "pytest_k": "test_matmul_all_gather_copy_engine", + "env": {"IRIS_TEST_COPY_ENGINE_MODE": "host"}, + }, + { + "name": "device_copy_engine", + "path": "tests/ops/test_matmul_all_gather_copy_engine.py", + "pytest_k": "test_matmul_all_gather_copy_engine", + "env": {"IRIS_TEST_COPY_ENGINE_MODE": "device"}, + }, + ], + "all_gather_matmul": [ + { + "name": "baseline", + "path": "tests/ops/test_all_gather_matmul.py", + "pytest_k": "test_all_gather_matmul_baseline", + }, + { + "name": "tritonblas_rcclbaseline", + "path": "tests/ops/test_all_gather_matmul.py", + "pytest_k": "test_tritonblas_rccl_all_gather_matmul", + }, + { + "name": "hbm_buffer", + "path": "tests/ops/test_all_gather_matmul.py", + "pytest_k": "test_all_gather_matmul_hbm_buffer and not test_all_gather_matmul_hbm_buffer_with_bias", + }, + { + "name": "host_copy_engine", + "path": "tests/ops/test_all_gather_matmul_copy_engine.py", + "pytest_k": "test_all_gather_matmul_copy_engine", + "env": {"IRIS_TEST_COPY_ENGINE_MODE": "host"}, + }, + { + "name": "copy_engine_host_hip_memcpy", + "path": "tests/ops/test_all_gather_matmul_copy_engine.py", + "pytest_k": "test_all_gather_matmul_copy_engine", + "env": { + "IRIS_TEST_COPY_ENGINE_MODE": "host", + "IRIS_TEST_HOST_TRANSFER_BACKEND": "hip_memcpy", + }, + }, + { + "name": "device_copy_engine", + "path": "tests/ops/test_all_gather_matmul_copy_engine.py", + "pytest_k": "test_all_gather_matmul_copy_engine", + "env": {"IRIS_TEST_COPY_ENGINE_MODE": "device"}, + }, + ], +} + + +def log(msg: str) -> None: + print(msg, file=sys.stderr, flush=True) + + +def _coerce_subprocess_output(value: str | bytes | None) -> str: + if value is None: + return "" + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return value + + +def _format_repro_command(cmd: list[str], env: dict[str, str]) -> str: + env_keys = [ + "PYTORCH_ALLOC_CONF", + "IRIS_TEST_M", + "IRIS_TEST_N", + "IRIS_TEST_K", + "IRIS_TEST_K_LOCAL", + "IRIS_TEST_HEAP_SIZE", + "IRIS_TEST_COPY_ENGINE_MODE", + "IRIS_TEST_HOST_TRANSFER_BACKEND", + ] + env_prefix = " ".join(shlex.quote(f"{key}={env[key]}") for key in env_keys if key in env) + cmd_str = shlex.join(cmd) + if env_prefix: + return f"cd {shlex.quote(str(PROJECT_ROOT))} && {env_prefix} {cmd_str}" + return f"cd {shlex.quote(str(PROJECT_ROOT))} && {cmd_str}" + + +def _shape_env(config: dict[str, Any], operation: str) -> dict[str, str]: + m_local = int(config["m_local"]) + n = int(config["n"]) + k = int(config["k"]) + + env = { + "IRIS_TEST_N": str(n), + } + + if operation == "matmul_all_gather": + env["IRIS_TEST_M"] = str(m_local * NUM_GPUS) + env["IRIS_TEST_K"] = str(k) + else: + if k % NUM_GPUS != 0: + env["IRIS_TEST_INVALID"] = "1" + env["IRIS_TEST_M"] = str(m_local) + env["IRIS_TEST_K_LOCAL"] = str(k // NUM_GPUS) + return env + + +def _estimate_heap_bytes(operation: str, test_name: str, shape_cfg: dict[str, Any]) -> int | None: + m_local = int(shape_cfg["m_local"]) + n = int(shape_cfg["n"]) + k = int(shape_cfg["k"]) + elem = DEFAULT_ELEMENT_SIZE_BYTES + + if operation == "matmul_all_gather": + m_total = m_local * NUM_GPUS + # Matches the test allocations: + # A_local (M_local, K), B (K, N), output (M_total, N) + # Output is ALWAYS allocated from heap via shmem.zeros() + total = (m_local * k + k * n + m_total * n) * elem + if test_name == "tritonblas_rcclbaseline": + # The direct tritonBLAS+RCCL path also materializes local C_local (M_local, N). + total += (m_local * n) * elem + return total + + if operation == "all_gather_matmul": + if k % NUM_GPUS != 0: + return None + + k_local = k // NUM_GPUS + # Common allocations across the validation tests: + # A_sharded (M, K_local), B (K, N), output (M, N) + total = (m_local * k_local + k * n + m_local * n) * elem + + if test_name in {"hbm_buffer", "host_copy_engine", "copy_engine_host_hip_memcpy", "device_copy_engine"}: + # Both HBM-buffer and copy-engine variants allocate staged_a as (M, K). + total += (m_local * k) * elem + + if test_name == "hbm_buffer": + block_m = 64 + block_k = 32 + k_per_flag = 8 + num_m_tiles = (m_local + block_m - 1) // block_m + num_k_blocks_local = k_local // block_k + num_flag_groups_k = (num_k_blocks_local + k_per_flag - 1) // k_per_flag + total += num_m_tiles * num_flag_groups_k * 4 + elif test_name in {"host_copy_engine", "copy_engine_host_hip_memcpy", "device_copy_engine"}: + block_m = 64 + block_n = 64 + num_m_tiles = (m_local + block_m - 1) // block_m + num_tiles_n = (n + block_n - 1) // block_n + total_tiles = num_m_tiles * num_tiles_n + num_batches = num_m_tiles + total += total_tiles * 4 + total += num_batches * 4 + + return total + + return None + + +def _round_up(value: int, alignment: int) -> int: + return ((value + alignment - 1) // alignment) * alignment + + +def _run_validation_test(operation: str, test_cfg: dict[str, str], shape_cfg: dict[str, Any]) -> dict[str, Any]: + label = shape_cfg["label"] + env = os.environ.copy() + env.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + env.update(_shape_env(shape_cfg, operation)) + env.update(test_cfg.get("env", {})) + estimated_heap_bytes = _estimate_heap_bytes(operation, test_cfg["name"], shape_cfg) + requested_heap_size = int(env.get("IRIS_TEST_HEAP_SIZE", DEFAULT_HEAP_SIZE)) + if estimated_heap_bytes is not None: + requested_heap_size = max( + requested_heap_size, + _round_up(int(estimated_heap_bytes * HEAP_HEADROOM_FACTOR), HEAP_ALIGNMENT_BYTES), + ) + env["IRIS_TEST_HEAP_SIZE"] = str(requested_heap_size) + + if env.get("IRIS_TEST_INVALID") == "1": + return { + "status": "SKIPPED", + "reason": f"K={shape_cfg['k']} not divisible by world_size={NUM_GPUS}", + } + + cmd = [ + "torchrun", + f"--nproc_per_node={NUM_GPUS}", + str(PROJECT_ROOT / "tests/run_tests_distributed.py"), + "-q", + str(PROJECT_ROOT / test_cfg["path"]), + ] + if test_cfg.get("pytest_k"): + cmd.extend(["-k", test_cfg["pytest_k"]]) + repro_command = _format_repro_command(cmd, env) + + log(f" Validating {test_cfg['name']}: {label}") + log(f" Heap size: {requested_heap_size / (1024**3):.1f} GiB") + log(f" Command: {' '.join(cmd)}") + + process = None + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=str(PROJECT_ROOT), + env=env, + preexec_fn=os.setsid, + ) + stdout, stderr = process.communicate(timeout=TIMEOUT_SECONDS) + + if process.returncode == 0: + return { + "status": "PASSED", + "heap_size_bytes": requested_heap_size, + } + + error_log_file = PROJECT_ROOT / f"validation_error_{operation}_{test_cfg['name']}_{label}.log" + with open(error_log_file, "w") as f: + f.write(stdout) + f.write("\n") + f.write(stderr) + lines = (stdout + stderr).strip().split("\n") + for line in lines[-5:]: + log(f" {line}") + return { + "status": "FAILED", + "log": str(error_log_file), + "heap_size_bytes": requested_heap_size, + "command": repro_command, + } + + except subprocess.TimeoutExpired as timeout_err: + if process is not None: + try: + os.killpg(os.getpgid(process.pid), signal.SIGTERM) + except ProcessLookupError: + pass + partial_stdout = _coerce_subprocess_output(timeout_err.stdout) + partial_stderr = _coerce_subprocess_output(timeout_err.stderr) + error_log_file = PROJECT_ROOT / f"validation_timeout_{operation}_{test_cfg['name']}_{label}.log" + with open(error_log_file, "w") as f: + f.write(partial_stdout) + f.write("\n") + f.write(partial_stderr) + lines = (partial_stdout + partial_stderr).strip().split("\n") if (partial_stdout or partial_stderr) else [] + for line in lines[-5:]: + log(f" {line}") + return { + "status": "TIMEOUT", + "log": str(error_log_file), + "heap_size_bytes": requested_heap_size, + "command": repro_command, + } + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run distributed validation sweep for benchmark shapes", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--operation", + type=str, + required=True, + choices=["matmul_all_gather", "all_gather_matmul"], + help="Operation family to validate", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON file path", + ) + args = parser.parse_args() + + output_file = ( + Path(args.output) + if args.output + else PROJECT_ROOT / f"benchmark/ops/validation_sweep_results_{args.operation}.json" + ) + + tests = VALIDATION_TESTS[args.operation] + results = [] + + for cfg in DIMENSION_CONFIGS: + row = { + "label": cfg["label"], + "M": int(cfg["m_local"]) * NUM_GPUS if args.operation == "matmul_all_gather" else int(cfg["m_local"]), + "N": int(cfg["n"]), + "K": int(cfg["k"]), + "operation": args.operation, + "validations": {}, + } + log(f"Testing {cfg['label']}") + for test_cfg in tests: + row["validations"][test_cfg["name"]] = _run_validation_test(args.operation, test_cfg, cfg) + results.append(row) + + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + log(f"Saved validation results to {output_file}") + + +if __name__ == "__main__": + main() diff --git a/docs/benchmark-results/latency_comparison.png b/docs/benchmark-results/latency_comparison.png new file mode 100644 index 000000000..288fad091 Binary files /dev/null and b/docs/benchmark-results/latency_comparison.png differ diff --git a/docs/benchmark-results/tflops_comparison.png b/docs/benchmark-results/tflops_comparison.png new file mode 100644 index 000000000..c33582ec7 Binary files /dev/null and b/docs/benchmark-results/tflops_comparison.png differ diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py new file mode 100644 index 000000000..677b68465 --- /dev/null +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import argparse + +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import random + +from mpi4py import MPI + +import iris + + +@triton.jit +def producer_kernel( + source_buffer, # tl.tensor: pointer to source data + target_buffer, # tl.tensor: pointer to target data + flag, # tl.tensor: pointer to flags + buffer_size, # int32: total number of elements + producer_rank: tl.constexpr, + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers + copy_engine_handle_ptr, +): + pid = tl.program_id(0) + + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Guard for out-of-bounds accesses + mask = offsets < buffer_size + + # Put chunk into remote buffer + iris.put( + source_buffer + offsets, + target_buffer + offsets, + producer_rank, + consumer_rank, + heap_bases_ptr, + copy_engine_handle_ptr, + mask=mask, + USE_COPY_ENGINE=True, + ) + + # Set flag to signal completion + iris.signal_ce(flag + pid, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr) + + +@triton.jit +def consumer_kernel( + buffer, # tl.tensor: pointer to shared buffer (read from target_rank) + flag, # tl.tensor: sync flag per block + buffer_size, # int32: total number of elements + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + # Spin-wait until writer sets flag[pid] = 1 + # zero_u64 = tl.zeros((1,), tl.uint64) + # one_u64 = tl.full((1,), 1, tl.uint64) + done = 0 # zero_u64 + while done == 0: + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) + + # Read from the target buffer (written by producer) + values = tl.load(buffer + offsets, mask=mask) + + # Do something with values... + # (Here you might write to output, do computation, etc.) + values = values * 2 + + # Store chunk to target buffer + tl.store( + buffer + offsets, + values, + mask=mask, + ) + + # Optionally reset the flag for next iteration + tl.store(flag + pid, 0) + + +torch.manual_seed(123) +random.seed(123) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "int8": torch.int8, + "bf16": torch.bfloat16, + } + try: + return dtype_map[datatype] + except KeyError: + print(f"Unknown datatype: {datatype}") + exit(1) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse Message Passing configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="fp32", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") + parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") + + parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + shmem = iris.iris(args["heap_size"]) + dtype = torch_dtype_from_str(args["datatype"]) + cur_rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Allocate source and destination buffers on the symmetric heap + destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + if world_size != 2: + raise ValueError("This example requires exactly two processes.") + + producer_rank = 0 + consumer_rank = 1 + + n_elements = source_buffer.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + num_blocks = triton.cdiv(n_elements, args["block_size"]) + + # Allocate flags on the symmetric heap + flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + if cur_rank == producer_rank: + shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") + kk = producer_kernel[grid]( + source_buffer, + destination_buffer, + flags, + n_elements, + producer_rank, + consumer_rank, + args["block_size"], + shmem.get_heap_bases(), + shmem.get_copy_engine_handle(consumer_rank), + ) + else: + shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") + kk = consumer_kernel[grid]( + destination_buffer, flags, n_elements, consumer_rank, args["block_size"], shmem.get_heap_bases() + ) + shmem.barrier() + shmem.info(f"Rank {cur_rank} has finished sending/receiving data.") + shmem.info("Validating output...") + + success = True + if cur_rank == consumer_rank: + expected = source_buffer * 2 + diff_mask = ~torch.isclose(destination_buffer, expected, atol=1) + breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + + if not torch.allclose(destination_buffer, expected, atol=1): + max_diff = (destination_buffer - expected).abs().max().item() + shmem.info(f"Max absolute difference: {max_diff}") + for idx in breaking_indices: + idx = tuple(idx.tolist()) + computed_val = destination_buffer[idx] + expected_val = expected[idx] + shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}") + success = False + break + + if success: + shmem.info("Validation successful.") + else: + shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}") + + shmem.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + comm = MPI.COMM_WORLD # Communicator for all processes + rank = comm.Get_rank() # Get the rank of the current process + num_ranks = comm.Get_size() # Total number of processes + # TODO local_rank + torch.cuda.set_device(rank) + + # Synchronize all processes + comm.barrier() + + init_url = "tcp://127.0.0.1:29500" + + _worker(rank, num_ranks, init_url, args) + + +if __name__ == "__main__": + main() diff --git a/examples/06_message_passing/message_passing_host_initiated.py b/examples/06_message_passing/message_passing_host_initiated.py new file mode 100644 index 000000000..b1b589225 --- /dev/null +++ b/examples/06_message_passing/message_passing_host_initiated.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +""" +Host-Initiated Message Passing Example + +This example demonstrates message passing where the producer (GPU 0) is +controlled by the HOST (Python/CPU) instead of a device kernel, while +the consumer (GPU 1) remains a device kernel. + +Key difference from message_passing_put.py: +- Producer: Host uses anvil to initiate SDMA transfers from Python +- Consumer: Same device kernel waiting for data + +This shows how to orchestrate GPU-to-GPU transfers from Python without +requiring kernel launches on the source GPU. +""" + +import argparse + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import triton.language as tl +import random + +import iris + + +@triton.jit +def consumer_kernel( + buffer, # tl.tensor: pointer to shared buffer (read from target_rank) + flag, # tl.tensor: sync flag per block + buffer_size, # int32: total number of elements + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + # Spin-wait until writer sets flag[pid] = 1 + done = 0 + while done == 0: + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) + + # Read from the target buffer (written by producer) + values = tl.load(buffer + offsets, mask=mask) + + # Do something with values... + # (Here you might write to output, do computation, etc.) + values = values * 2 + + # Store chunk to target buffer + tl.store( + buffer + offsets, + values, + mask=mask, + ) + + # Optionally reset the flag for next iteration + tl.store(flag + pid, 0) + + +torch.manual_seed(123) +random.seed(123) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "int8": torch.int8, + "bf16": torch.bfloat16, + } + try: + return dtype_map[datatype] + except KeyError: + print(f"Unknown datatype: {datatype}") + exit(1) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Host-Initiated SDMA Message Passing Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="fp32", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") + parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") + parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + shmem = iris.iris(args["heap_size"]) + dtype = torch_dtype_from_str(args["datatype"]) + cur_rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Allocate source and destination buffers on the symmetric heap + destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + if world_size != 2: + raise ValueError("This example requires exactly two processes.") + + producer_rank = 0 + consumer_rank = 1 + + n_elements = source_buffer.numel() + # Use fixed block size for both producer and consumer + BLOCK_SIZE = args["block_size"] + num_blocks = triton.cdiv(n_elements, BLOCK_SIZE) + grid = (num_blocks,) + + # Allocate flags on the symmetric heap + flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + if cur_rank == producer_rank: + shmem.info(f"Rank {cur_rank} (HOST) is sending data to rank {consumer_rank}.") + # Initialize CUDA context even though we're doing host-side operations + # This is needed for the barrier to work + torch.cuda.current_device() + + # Create host-initiated SDMA connection separate from iris's device connection + # This allows the host to orchestrate transfers without kernel launches + anvil_lib = shmem.copy_engines # Reuse iris's anvil instance + anvil_lib.connect(producer_rank, consumer_rank, num_channels=1, allocate_on_host=True) + + # Host-initiated transfer: send data block by block + elem_size = source_buffer.element_size() + + import time + + start_time = time.time() + + for block_id in range(num_blocks): + block_start = block_id * BLOCK_SIZE + block_end = min(block_start + BLOCK_SIZE, n_elements) + block_len = block_end - block_start + + # Calculate byte offsets + src_offset = block_start * elem_size + dst_offset = block_start * elem_size + size_bytes = block_len * elem_size + + # Translate destination buffer address from producer to consumer address space + dst_local_addr = destination_buffer.data_ptr() + dst_offset + dst_remote_addr = shmem.translate(dst_local_addr, producer_rank, consumer_rank) + + # Transfer data block using SDMA + anvil_lib.host_put( + producer_rank, consumer_rank, 0, source_buffer.data_ptr() + src_offset, dst_remote_addr, size_bytes + ) + + # Signal completion with atomic add - translate flag address + flag_local_addr = flags.data_ptr() + block_id * 4 # 4 bytes for int32 + flag_remote_addr = shmem.translate(flag_local_addr, producer_rank, consumer_rank) + + anvil_lib.host_atomic_add_32(producer_rank, consumer_rank, 0, flag_remote_addr, 1) + + end_time = time.time() + elapsed_ms = (end_time - start_time) * 1000 + shmem.info( + f"Host SDMA loop took {elapsed_ms:.2f} ms for {num_blocks} blocks ({elapsed_ms / num_blocks:.2f} ms/block)" + ) + + # Synchronize to ensure all transfers complete + # TODO use quiet() + + else: + shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") + kk = consumer_kernel[grid]( + destination_buffer, flags, n_elements, consumer_rank, BLOCK_SIZE, shmem.get_heap_bases() + ) + shmem.barrier() + shmem.info(f"Rank {cur_rank} has finished sending/receiving data.") + shmem.info("Validating output...") + + success = True + if cur_rank == consumer_rank: + expected = source_buffer * 2 + diff_mask = ~torch.isclose(destination_buffer, expected, atol=1) + breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + + if not torch.allclose(destination_buffer, expected, atol=1): + max_diff = (destination_buffer - expected).abs().max().item() + shmem.info(f"Max absolute difference: {max_diff}") + for idx in breaking_indices: + idx = tuple(idx.tolist()) + computed_val = destination_buffer[idx] + expected_val = expected[idx] + shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}") + success = False + break + + if success: + shmem.info("Validation successful.") + else: + shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}") + + shmem.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 4f7269695..c0c4d7b51 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -23,6 +23,8 @@ def producer_kernel( consumer_rank: tl.constexpr, BLOCK_SIZE: tl.constexpr, heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers + copy_engine_handle_ptr, + USE_COPY_ENGINE: tl.constexpr, ): pid = tl.program_id(0) @@ -34,10 +36,30 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, mask=mask) + iris.put( + source_buffer + offsets, + target_buffer + offsets, + producer_rank, + consumer_rank, + heap_bases_ptr, + copy_engine_handle_ptr, + mask=mask, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) # Set flag to signal completion - iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") + # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, sem="release", scope="sys") + iris.atomic_add( + flag + pid, + 1, + producer_rank, + consumer_rank, + heap_bases_ptr, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_handle_ptr, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) @triton.jit @@ -113,9 +135,11 @@ def parse_args(): ) parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") - parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -138,12 +162,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() # Allocate source and destination buffers on the symmetric heap - source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) if dtype.is_floating_point: - destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) else: ii = torch.iinfo(dtype) - destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) if world_size != 2: raise ValueError("This example requires exactly two processes.") @@ -158,6 +182,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate flags on the symmetric heap flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + # Get copy engine context + # copy_engine_ctx = shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None + copy_engine_ctx = shmem.get_copy_engine_ctx() + if cur_rank == producer_rank: shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") kk = producer_kernel[grid]( @@ -169,6 +197,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): consumer_rank, args["block_size"], shmem.get_heap_bases(), + copy_engine_ctx, + USE_COPY_ENGINE=args["use_copy_engine"], ) else: shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") @@ -199,7 +229,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if success: shmem.info("Validation successful.") else: - shmem.info("Validation failed.") + shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}") shmem.barrier() diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 994c10cad..c515df52c 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -58,6 +58,9 @@ def parse_args(): help="Number of SMs for persistent GEMM algorithm (default: auto-detected)", ) parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -124,11 +127,15 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N + # Get copy engine context + copy_engine_ctx = shmem.get_copy_engine_ctx() + bias = None gemm_stream = torch.cuda.Stream() json_writer.add_field("gemm_sms", args["gemm_sms"]) + json_writer.add_field("total_tiles", total_tiles) kernel_timing = { "gemm": { @@ -142,11 +149,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) + # Allocate flags for synchronization (one flag per SM per rank) + flags = shmem.zeros((args["gemm_sms"] * world_size,), device="cuda", dtype=torch.int32) + def run_experiment(): nonlocal local_C nonlocal global_C nonlocal kernel_timing + # Reset flags to zero before each experiment + flags.zero_() + shmem.barrier() if args["trace_tiles"]: @@ -163,6 +176,7 @@ def run_experiment(): local_C, global_C, bias, + flags, rank, world_size, args["gemm_sms"], @@ -174,6 +188,8 @@ def run_experiment(): shmem.get_heap_bases(), "gfx942", args["trace_tiles"], + args["use_copy_engine"], + copy_engine_ctx, timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, ) diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index 937835d6f..78d4fba6a 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -9,6 +9,11 @@ import iris +@triton.jit +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + @triton.jit() def persistent_gemm_all_scatter( A, @@ -16,6 +21,7 @@ def persistent_gemm_all_scatter( C, c_global, bias_ptr, + flags, M, N, K, @@ -40,13 +46,15 @@ def persistent_gemm_all_scatter( cur_rank: tl.constexpr, world_size: tl.constexpr, COLLECT_TIMESTAMPS: tl.constexpr = False, + USE_COPY_ENGINE: tl.constexpr = False, + copy_engine_ctx: tl.tensor = None, mm_begin_timestamp_ptr: tl.tensor = None, mm_end_timestamp_ptr: tl.tensor = None, ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + # if NUM_XCDS != 1: + # pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -60,6 +68,7 @@ def persistent_gemm_all_scatter( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + # Process all tiles for this SM for tile_id in range(pid, total_tiles, NUM_SMS): if COLLECT_TIMESTAMPS: timestamp = read_realtime() @@ -132,17 +141,66 @@ def persistent_gemm_all_scatter( timestamp = read_realtime() tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) - # Store data to the global result using puts - for remote_rank in range(world_size): - if remote_rank == cur_rank: - # For the current rank, we can use store - tl.store(c_global + global_offset, c, mask=sub_mask) - else: - iris.store( - c_global + global_offset, - c, - cur_rank, - remote_rank, - heap_bases, - mask=sub_mask, - ) + if USE_COPY_ENGINE: + # Store locally first + tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") + wait_cnt() + tl.debug_barrier() + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.put( + c_global + global_offset, + c_global + global_offset, + cur_rank, + remote_rank, + heap_bases, + copy_engine_ctx, + stride_tm=stride_cm_global, + stride_tn=stride_cn_global, + stride_fm=stride_cm_global, + stride_fn=stride_cn_global, + mask=sub_mask, + USE_COPY_ENGINE=True, + IS_2D_COPY=True, + from_base_ptr=c_global, + to_base_ptr=c_global, + ) + + else: + # Store data to the global result using puts + for remote_rank in range(world_size): + if remote_rank == cur_rank: + # For the current rank, we can use store + tl.store(c_global + global_offset, c, mask=sub_mask) + else: + iris.store( + c_global + global_offset, + c, + cur_rank, + remote_rank, + heap_bases, + mask=sub_mask, + ) + + # After all tiles are processed, signal and wait once per SM + tl.debug_barrier() + # Signal other ranks that all our puts/stores are complete + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.atomic_add( + flags + (pid * world_size) + cur_rank, + 1, + cur_rank, + remote_rank, + heap_bases, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_ctx, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) + + # Wait for other ranks to signal us + for remote_rank in range(world_size): + if remote_rank != cur_rank: + while tl.load(flags + (pid * world_size) + remote_rank, cache_modifier=".cv", volatile=True) != 1: + pass diff --git a/examples/07_gemm_all_scatter/matmul_wrapper.py b/examples/07_gemm_all_scatter/matmul_wrapper.py index 5d8adb589..3f6d3e0d6 100644 --- a/examples/07_gemm_all_scatter/matmul_wrapper.py +++ b/examples/07_gemm_all_scatter/matmul_wrapper.py @@ -44,6 +44,7 @@ def _call( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, num_sms: int, @@ -55,6 +56,8 @@ def _call( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -86,6 +89,7 @@ def _call( c, c_global, bias, + flags, M, N, K, @@ -115,6 +119,8 @@ def _call( cur_rank=rank, world_size=world_size, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp_ptr=mm_begin_timestamp, mm_end_timestamp_ptr=mm_end_timestamp, ) @@ -133,6 +139,7 @@ def forward( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, num_sms: int, @@ -144,6 +151,8 @@ def forward( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -153,6 +162,7 @@ def forward( c=c, c_global=c_global, bias=bias, + flags=flags, rank=rank, world_size=world_size, num_sms=num_sms, @@ -164,6 +174,8 @@ def forward( heap_bases_ptr=heap_bases_ptr, arch=arch, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, ) diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 910ebdd6f..cbd5433b9 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -67,6 +67,9 @@ def parse_args(): ) parser.add_argument("--num_stages", type=int, default=2, help="Number of stages") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -134,7 +137,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + comm_sms = args["num_sms"] - args["gemm_sms"] + flags = shmem.zeros((comm_sms, world_size), device="cuda", dtype=torch.uint32) + + # Get copy engine context + copy_engine_ctx = shmem.get_copy_engine_ctx() bias = None @@ -182,6 +190,7 @@ def run_experiment(): global_C, bias, locks, + flags, rank, world_size, args["gemm_sms"], @@ -194,6 +203,8 @@ def run_experiment(): shmem.get_heap_bases(), "gfx942", args["trace_tiles"], + args["use_copy_engine"], + copy_engine_ctx, timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, ) @@ -241,7 +252,6 @@ def run_experiment(): # Wait for all to finish validation shmem.barrier() - shmem.info("Validating local C...") json_writer.add_field("success", success) diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 643e84f90..c17eb2371 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -9,6 +9,11 @@ import iris +@triton.jit +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + @triton.jit() def persistent_gemm_all_scatter_wg_specialization( A, @@ -17,6 +22,7 @@ def persistent_gemm_all_scatter_wg_specialization( c_global, bias_ptr, locks, + flags, M, N, K, @@ -24,8 +30,8 @@ def persistent_gemm_all_scatter_wg_specialization( stride_ak, stride_bk, stride_bn, - stride_cm, - stride_cn, + stride_cm, # unused + stride_cn, # unused stride_cm_global, stride_cn_global, stride_bias, @@ -42,6 +48,8 @@ def persistent_gemm_all_scatter_wg_specialization( cur_rank: tl.constexpr, world_size: tl.constexpr, COLLECT_TIMESTAMPS: tl.constexpr = False, + USE_COPY_ENGINE: tl.constexpr = False, + copy_engine_ctx: tl.tensor = None, mm_begin_timestamp_ptr: tl.tensor = None, mm_end_timestamp_ptr: tl.tensor = None, ): @@ -67,6 +75,9 @@ def persistent_gemm_all_scatter_wg_specialization( # and another that performs the communication. Uses persistent- # kernel. if pid < GEMM_SMS: + # tl.device_print("GEMM_SMS: ", GEMM_SMS) + # tl.device_print("GEMM pid: ", pid) + for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: timestamp = read_realtime() @@ -140,11 +151,15 @@ def persistent_gemm_all_scatter_wg_specialization( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") + wait_cnt() + tl.debug_barrier() tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu") else: # pid >= GEMM_SMS COMM_SMS = NUM_SMS - GEMM_SMS pid = pid - GEMM_SMS + # tl.device_print("COMM_SMS: ", COMM_SMS) + # tl.device_print("COMM pid: ", pid) for tile_id in range(pid, total_tiles, COMM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group @@ -176,5 +191,37 @@ def persistent_gemm_all_scatter_wg_specialization( cur_rank, remote_rank, heap_bases, + copy_engine_ctx, + stride_tm=stride_cm_global, + stride_tn=stride_cn_global, + stride_fm=stride_cm_global, + stride_fn=stride_cn_global, mask=sub_mask, + USE_COPY_ENGINE=USE_COPY_ENGINE, + IS_2D_COPY=True, + from_base_ptr=c_global, + to_base_ptr=c_global, ) + tl.debug_barrier() + # Signal other ranks + for remote_rank in range(world_size): + if remote_rank != cur_rank: + # print("Issue atomic_add") + iris.atomic_add( + flags + (pid * world_size) + cur_rank, + 1, + cur_rank, + remote_rank, + heap_bases, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_ctx, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) + # print("Start waiting") + # Wait for other ranks to signal us + for remote_rank in range(world_size): + if remote_rank != cur_rank: + while tl.load(flags + (pid * world_size) + remote_rank, cache_modifier=".cv", volatile=True) != 1: + pass + # print("done waiting") diff --git a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py index 1d46297a4..135313fb4 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py @@ -47,6 +47,7 @@ def _call( c_global: torch.Tensor, bias: torch.Tensor, locks: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, gemm_sms: int, @@ -59,6 +60,8 @@ def _call( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -91,6 +94,7 @@ def _call( c_global, bias, locks, + flags, M, N, K, @@ -121,6 +125,8 @@ def _call( cur_rank=rank, world_size=world_size, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp_ptr=mm_begin_timestamp, mm_end_timestamp_ptr=mm_end_timestamp, ) @@ -140,6 +146,7 @@ def forward( c_global: torch.Tensor, bias: torch.Tensor, locks: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, gemm_sms: int, @@ -152,6 +159,8 @@ def forward( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -162,6 +171,7 @@ def forward( c_global=c_global, bias=bias, locks=locks, + flags=flags, rank=rank, world_size=world_size, gemm_sms=gemm_sms, @@ -174,6 +184,8 @@ def forward( heap_bases_ptr=heap_bases_ptr, arch=arch, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, ) diff --git a/ext/shader_sdma b/ext/shader_sdma new file mode 160000 index 000000000..2b354bf87 --- /dev/null +++ b/ext/shader_sdma @@ -0,0 +1 @@ +Subproject commit 2b354bf87250264acd179f7e69b995a5e53827a0 diff --git a/iris/__init__.py b/iris/__init__.py index 3ec70efa8..6a8943f04 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -51,6 +51,14 @@ copy, get, put, + translate_ptr, + wait_then_put_rect, + wait_then_put_rects, + wait_then_put_signal_rect, + wait_then_put_signal_rects, + quiet, + put_signal, + put_signal_rect, atomic_add, atomic_cas, atomic_xchg, @@ -97,6 +105,14 @@ "copy", "get", "put", + "translate_ptr", + "wait_then_put_rect", + "wait_then_put_rects", + "wait_then_put_signal_rect", + "wait_then_put_signal_rects", + "quiet", + "put_signal", + "put_signal_rect", "atomic_add", "atomic_cas", "atomic_xchg", diff --git a/iris/fd_passing.py b/iris/fd_passing.py index 4e8c13f44..1f71290fa 100644 --- a/iris/fd_passing.py +++ b/iris/fd_passing.py @@ -140,6 +140,55 @@ def setup_fd_mesh(rank: int, world_size: int, all_paths: Dict[int, str]) -> Dict return conns +def _allgather_paths_tensor(my_path: str, num_ranks: int): + """ + Exchange socket paths across ranks using a fixed-size tensor all_gather. + + Uses ``dist.all_gather`` with a fixed-size int8 tensor instead of + ``dist.all_gather_object`` to avoid injecting extra NCCL collective + calls (``all_gather_object`` internally issues two NCCL all_gathers for + size+data). At ws<8 the additional collectives can interleave with + data-plane ``all_gather_into_tensor`` calls on the same process group, + causing a rank-asymmetric collective ordering deadlock. + + AF_UNIX paths are at most 108 bytes; we use a 256-byte buffer for safety. + """ + import torch + import torch.distributed as dist + + _PATH_BUF_LEN = 256 + path_bytes = my_path.encode("utf-8") + if len(path_bytes) >= _PATH_BUF_LEN: + raise ValueError(f"Socket path too long ({len(path_bytes)} bytes, max {_PATH_BUF_LEN - 1}): {my_path}") + + # Encode into a fixed-size uint8 tensor (CPU for gloo, GPU for nccl). + # uint8 matches the [0,255] byte range; NCCL supports it natively. + buf = torch.zeros(_PATH_BUF_LEN, dtype=torch.uint8) + for i, b in enumerate(path_bytes): + buf[i] = b + + backend = str(dist.get_backend()).lower() + if backend == "nccl" and torch.cuda.is_available(): + device = torch.device("cuda", torch.cuda.current_device()) + buf = buf.to(device) + # else: keep on CPU (gloo) + + gathered = [torch.zeros_like(buf) for _ in range(num_ranks)] + dist.all_gather(gathered, buf) + + all_paths = {} + for r in range(num_ranks): + raw = gathered[r].cpu().tolist() + # Find null terminator (first 0) + try: + end = raw.index(0) + except ValueError: + end = _PATH_BUF_LEN + all_paths[r] = bytes(raw[:end]).decode("utf-8") + + return all_paths + + def setup_fd_infrastructure(cur_rank: int, num_ranks: int): """ Setup FD passing infrastructure for multi-rank communication. @@ -156,15 +205,17 @@ def setup_fd_infrastructure(cur_rank: int, num_ranks: int): if num_ranks <= 1: return None - import torch.distributed as dist from iris._distributed_helpers import distributed_barrier # Setup socket mesh for FD passing prefix = "iris-dmabuf" my_path = make_rank_sock_path(prefix, cur_rank) - obj_list = [None for _ in range(num_ranks)] - dist.all_gather_object(obj_list, my_path) - all_paths = {r: obj_list[r] for r in range(num_ranks)} + + # Use tensor-based all_gather instead of all_gather_object to avoid + # injecting extra NCCL collectives that can deadlock with data-plane + # all_gather_into_tensor at ws<8 (see _allgather_paths_tensor docstring). + all_paths = _allgather_paths_tensor(my_path, num_ranks) + distributed_barrier() fd_conns = setup_fd_mesh(cur_rank, num_ranks, all_paths) distributed_barrier() diff --git a/iris/hip.py b/iris/hip.py index e6dc598d8..5af6b5899 100644 --- a/iris/hip.py +++ b/iris/hip.py @@ -256,6 +256,119 @@ def hip_free(ptr): gpu_try(gpu_runtime.cudaFree(ptr)) +def create_stream(non_blocking: bool = True): + """Create a HIP/CUDA stream and return the opaque stream handle as an int.""" + stream = ctypes.c_void_p() + flags = 1 if non_blocking else 0 + + if _is_amd_backend: + gpu_runtime.hipStreamCreateWithFlags.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_uint, + ] + gpu_runtime.hipStreamCreateWithFlags.restype = ctypes.c_int + gpu_try(gpu_runtime.hipStreamCreateWithFlags(ctypes.byref(stream), flags)) + else: + gpu_runtime.cudaStreamCreateWithFlags.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_uint, + ] + gpu_runtime.cudaStreamCreateWithFlags.restype = ctypes.c_int + gpu_try(gpu_runtime.cudaStreamCreateWithFlags(ctypes.byref(stream), flags)) + + return stream.value + + +def destroy_stream(stream): + """Destroy a HIP/CUDA stream created by create_stream().""" + stream_arg = ctypes.c_void_p(stream) + if _is_amd_backend: + gpu_runtime.hipStreamDestroy.argtypes = [ctypes.c_void_p] + gpu_runtime.hipStreamDestroy.restype = ctypes.c_int + gpu_try(gpu_runtime.hipStreamDestroy(stream_arg)) + else: + gpu_runtime.cudaStreamDestroy.argtypes = [ctypes.c_void_p] + gpu_runtime.cudaStreamDestroy.restype = ctypes.c_int + gpu_try(gpu_runtime.cudaStreamDestroy(stream_arg)) + + +def stream_synchronize(stream): + """Synchronize a HIP/CUDA stream handle.""" + stream_arg = ctypes.c_void_p(stream) + if _is_amd_backend: + gpu_runtime.hipStreamSynchronize.argtypes = [ctypes.c_void_p] + gpu_runtime.hipStreamSynchronize.restype = ctypes.c_int + gpu_try(gpu_runtime.hipStreamSynchronize(stream_arg)) + else: + gpu_runtime.cudaStreamSynchronize.argtypes = [ctypes.c_void_p] + gpu_runtime.cudaStreamSynchronize.restype = ctypes.c_int + gpu_try(gpu_runtime.cudaStreamSynchronize(stream_arg)) + + +def memcpy_2d_async( + dst_ptr: int, + dst_pitch: int, + src_ptr: int, + src_pitch: int, + width_bytes: int, + height: int, + *, + stream=None, +): + """Launch an async device-to-device 2D memcpy on a HIP/CUDA stream.""" + memcpy_device_to_device = 3 + stream_arg = ctypes.c_void_p(0 if stream is None else stream) + + if _is_amd_backend: + gpu_runtime.hipMemcpy2DAsync.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_void_p, + ] + gpu_runtime.hipMemcpy2DAsync.restype = ctypes.c_int + gpu_try( + gpu_runtime.hipMemcpy2DAsync( + ctypes.c_void_p(dst_ptr), + dst_pitch, + ctypes.c_void_p(src_ptr), + src_pitch, + width_bytes, + height, + memcpy_device_to_device, + stream_arg, + ) + ) + else: + gpu_runtime.cudaMemcpy2DAsync.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_void_p, + ] + gpu_runtime.cudaMemcpy2DAsync.restype = ctypes.c_int + gpu_try( + gpu_runtime.cudaMemcpy2DAsync( + ctypes.c_void_p(dst_ptr), + dst_pitch, + ctypes.c_void_p(src_ptr), + src_pitch, + width_bytes, + height, + memcpy_device_to_device, + stream_arg, + ) + ) + + def export_dmabuf_handle(ptr, size): """ Export a DMA-BUF file descriptor for a memory range. diff --git a/iris/iris.py b/iris/iris.py index 8c750ba67..5fc4c03f3 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -55,6 +55,8 @@ get_cu_count, count_devices, ) + +import anvil from iris.symmetric_heap import SymmetricHeap import numpy as np from typing import Any @@ -111,6 +113,10 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): self.device = f"cuda:{gpu_id}" self.heap_bases = self.heap.get_heap_bases() + # Pre-fetch heap_bases to CPU for host-side address translation + # This avoids needing to copy from GPU during SDMA operations + self.heap_bases_cpu = self.heap_bases.cpu().numpy() + if is_simulation_env(): import json @@ -130,6 +136,38 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): distributed_barrier() + # initialize copy engines + self.copy_engines = anvil.AnvilLib.get_instance() + self.copy_engines.init() + + # connect to all peers (including local) + # TODO only connect local ranks + # TODO get size + context_size = 6 + self.copy_engines_device_ctx = torch.zeros((num_ranks, context_size), dtype=torch.uint64, device=self.device) + + for rank in range(num_ranks): + # Device-initiated queues + self.copy_engines.connect(cur_rank, rank, 1, allocate_on_host=False) + # Host-initiated queues + self.copy_engines.connect(cur_rank, rank, 1, allocate_on_host=True) + + queue = self.copy_engines.get_sdma_queue(cur_rank, rank, 0) + handle = queue.device_ctx() + self.info(f"---- Queue {rank} ------------") + self.info(f"queue_buf {handle.queue_buf:#x} at {id(handle.queue_buf):#x}") + self.info(f"rptr {handle.rptr:#x} at {id(handle.rptr):#x}") + self.info(f"wptr {handle.wptr:#x} at {id(handle.wptr):#x}") + self.info(f"doorbell {handle.doorbell:#x} at {id(handle.doorbell):#x}") + self.info(f"cached_write_ptr {handle.cached_wptr:#x} at {id(handle.cached_wptr):#x}") + self.info(f"committed_write_ptr {handle.committed_wptr:#x} at {id(handle.committed_wptr):#x}") + + self.copy_engines_device_ctx[rank][0] = handle.queue_buf + self.copy_engines_device_ctx[rank][1] = handle.rptr + self.copy_engines_device_ctx[rank][2] = handle.wptr + self.copy_engines_device_ctx[rank][3] = handle.doorbell + self.copy_engines_device_ctx[rank][4] = handle.cached_wptr + self.copy_engines_device_ctx[rank][5] = handle.committed_wptr # Initialize CCL interface self.ccl = self.CCL(self) @@ -906,6 +944,288 @@ def get_heap_bases(self): """ return self.heap_bases + def get_copy_engine_ctx(self): + return self.copy_engines_device_ctx + + def translate(self, ptr: int, from_rank: int, to_rank: int) -> int: + """ + Translate a pointer address from one rank's address space to another. + + This is useful for host-side SDMA operations where you need to convert + peer-mapped addresses to the target GPU's local address space. + + Args: + ptr (int): The pointer address in from_rank's address space + from_rank (int): Source rank (address space of ptr) + to_rank (int): Target rank (desired address space) + + Returns: + int: Translated pointer address in to_rank's address space + + Example: + >>> ctx = iris.iris() + >>> buffer = ctx.zeros(1024, dtype=torch.float32) + >>> # Translate buffer address from rank 0 to rank 1's address space + >>> remote_addr = ctx.translate(buffer.data_ptr(), 0, 1) + >>> ctx.copy_engines.host_put(0, 1, 0, src_ptr, remote_addr, size) + """ + # Use pre-cached CPU copy to avoid GPU->CPU transfer on every call + from_base = int(self.heap_bases_cpu[from_rank]) + to_base = int(self.heap_bases_cpu[to_rank]) + offset = ptr - from_base + return to_base + offset + + def put( + self, + src_tensor: torch.Tensor, + dst_rank: int, + dst_tensor: torch.Tensor = None, + wait_flag: torch.Tensor = None, + wait_value: int = None, + signal_flag: torch.Tensor = None, + signal_value: int = 1, + async_op: bool = False, + channel: int = 0, + ): + """ + One-sided put operation with optional wait (POLL) and signal (ATOMIC). + + Supports: + - Simple copy: put(src, dst_rank) + - Copy + signal: put(src, dst_rank, signal_flag=flag) + - Wait + copy: put(src, dst_rank, wait_flag=flag, wait_value=N) + - Wait + copy + signal: put(src, dst_rank, wait_flag=..., signal_flag=...) + + Args: + src_tensor: Source tensor (local, must be symmetric) + dst_rank: Destination rank + dst_tensor: Destination tensor (symmetric). If None, uses src_tensor. + wait_flag: Optional LOCAL flag tensor to poll before transfer (POLL packet) + wait_value: Expected value for wait_flag + signal_flag: Optional flag tensor to atomic-add on REMOTE rank after transfer (will be translated) + signal_value: Value to add to signal_flag (default 1) + async_op: If True, don't wait for completion + channel: SDMA channel to use + + Examples: + >>> # Simple copy + >>> shmem.put(data, dst_rank=1) + + >>> # Copy with completion signal + >>> shmem.put(data, dst_rank=1, signal_flag=completion_flag) + + >>> # Wait for ready signal, then copy + >>> shmem.put(data, dst_rank=1, wait_flag=ready_flag, wait_value=1) + + >>> # Full pipeline: wait, copy, signal + >>> shmem.put(data, dst_rank=1, + ... wait_flag=batch_ready, wait_value=256, + ... signal_flag=transfer_done, signal_value=1) + """ + if dst_tensor is None: + dst_tensor = src_tensor + + src_rank = self.get_rank() + src_ptr = src_tensor.data_ptr() + dst_ptr = self.translate(dst_tensor.data_ptr(), src_rank, dst_rank) + size = src_tensor.numel() * src_tensor.element_size() + + # Determine which SDMA packet combination to use + has_wait = wait_flag is not None + has_signal = signal_flag is not None + + if has_wait and has_signal: + # POLL + COPY + ATOMIC (two submissions) + wait_ptr = wait_flag.data_ptr() + signal_ptr = self.translate(signal_flag.data_ptr(), src_rank, dst_rank) + + # First: POLL + COPY + self.copy_engines.host_wait_flag_then_put( + src_rank, dst_rank, channel, wait_ptr, wait_value, src_ptr, dst_ptr, size + ) + # Then: ATOMIC + self.copy_engines.host_atomic_add(src_rank, dst_rank, channel, signal_ptr, signal_value) + + elif has_wait: + # POLL + COPY + wait_ptr = wait_flag.data_ptr() + self.copy_engines.host_wait_flag_then_put( + src_rank, dst_rank, channel, wait_ptr, wait_value, src_ptr, dst_ptr, size + ) + + elif has_signal: + # COPY + ATOMIC (combined in one submission) + signal_ptr = self.translate(signal_flag.data_ptr(), src_rank, dst_rank) + self.copy_engines.host_put_signal( + src_rank, dst_rank, channel, src_ptr, dst_ptr, size, signal_ptr, signal_value + ) + + else: + # Simple COPY + self.copy_engines.host_put(src_rank, dst_rank, channel, src_ptr, dst_ptr, size) + + if not async_op: + self.copy_engines.host_quiet(src_rank, dst_rank, channel) + + def put_tile( + self, + tile, + dst_rank: int, + dst_ptr: int, + dst_stride: int, + wait_flag: int = None, + wait_value: int = None, + signal_flag: int = None, + signal_value: int = 1, + async_op: bool = False, + channel: int = 0, + ): + """ + 2D tile transfer with optional wait/signal (sub-window copy). + + Low-level API - caller provides pre-translated pointers for performance. + + Args: + tile: Pre-configured anvil.Tile object with data pointer and dimensions set + dst_rank: Destination rank + dst_ptr: Destination pointer (already translated to remote address space) + dst_stride: Destination row stride in bytes + wait_flag: Optional LOCAL flag pointer to poll before transfer + wait_value: Expected value for wait_flag + signal_flag: Optional REMOTE flag pointer to atomic-add after transfer (already translated) + signal_value: Value to add to signal_flag + async_op: If True, don't wait for completion + channel: SDMA channel to use + + Examples: + >>> import anvil + >>> tile = anvil.Tile() + >>> tile.pid_m = 0 + >>> tile.pid_n = 0 + >>> tile.block_m = 256 + >>> tile.block_n = 256 + >>> tile.elem_size = A.element_size() + >>> tile.src_stride = A.stride(0) * tile.elem_size + >>> tile.data = A.data_ptr() + >>> dst_ptr = shmem.translate(A.data_ptr(), src_rank, dst_rank) + >>> dst_stride = A.stride(0) * tile.elem_size + >>> wait_ptr = flag.data_ptr() + >>> signal_ptr = shmem.translate(flag.data_ptr(), src_rank, dst_rank) + >>> shmem.put_tile(tile, dst_rank=1, dst_ptr=dst_ptr, dst_stride=dst_stride, + ... wait_flag=wait_ptr, wait_value=256, signal_flag=signal_ptr) + """ + src_rank = self.get_rank() + + has_wait = wait_flag is not None + has_signal = signal_flag is not None + + if has_wait and has_signal: + # POLL + SUB_WINDOW_COPY + ATOMIC (two submissions) + self.copy_engines.host_wait_flag_then_put_tile( + src_rank, dst_rank, channel, wait_flag, wait_value, tile, dst_ptr, dst_stride + ) + self.copy_engines.host_atomic_add_32(src_rank, dst_rank, channel, signal_flag, signal_value) + + elif has_wait: + # POLL + SUB_WINDOW_COPY + self.copy_engines.host_wait_flag_then_put_tile( + src_rank, dst_rank, channel, wait_flag, wait_value, tile, dst_ptr, dst_stride + ) + + elif has_signal: + # SUB_WINDOW_COPY + ATOMIC + self.copy_engines.host_put_tile_signal( + src_rank, dst_rank, channel, tile, dst_ptr, dst_stride, signal_flag, signal_value + ) + + else: + # Simple SUB_WINDOW_COPY + self.copy_engines.host_put_tile(src_rank, dst_rank, channel, tile, dst_ptr, dst_stride) + + if not async_op: + self.copy_engines.host_quiet(src_rank, dst_rank, channel) + + def put_tiles( + self, + tiles, + dst_rank: int, + dst_ptrs, + dst_strides, + wait_flag: int = None, + wait_value: int = None, + signal_flag: int = None, + signal_value: int = 1, + async_op: bool = False, + channel: int = 0, + ): + """ + Batched 2D tile transfer with optional shared wait/signal. + + Args: + tiles: Sequence of pre-configured anvil.Tile objects + dst_rank: Destination rank + dst_ptrs: Sequence of translated destination pointers + dst_strides: Sequence of destination row strides in bytes + wait_flag: Optional LOCAL flag pointer to poll before all transfers + wait_value: Expected value for wait_flag + signal_flag: Optional REMOTE flag pointer to atomic-add after all transfers + signal_value: Value to add to signal_flag + async_op: If True, don't wait for completion + channel: SDMA channel to use + """ + src_rank = self.get_rank() + + if len(tiles) != len(dst_ptrs) or len(tiles) != len(dst_strides): + raise ValueError("tiles, dst_ptrs, and dst_strides must have the same length") + + has_wait = wait_flag is not None + has_signal = signal_flag is not None + + if has_wait: + self.copy_engines.host_wait_flag_then_put_tiles( + src_rank, dst_rank, channel, wait_flag, wait_value, tiles, dst_ptrs, dst_strides + ) + if has_signal: + self.copy_engines.host_atomic_add_32(src_rank, dst_rank, channel, signal_flag, signal_value) + else: + for tile, dst_ptr, dst_stride in zip(tiles, dst_ptrs, dst_strides): + self.put_tile( + tile, + dst_rank=dst_rank, + dst_ptr=dst_ptr, + dst_stride=dst_stride, + signal_flag=None, + async_op=True, + channel=channel, + ) + if has_signal: + self.copy_engines.host_atomic_add_32(src_rank, dst_rank, channel, signal_flag, signal_value) + + if not async_op: + self.copy_engines.host_quiet(src_rank, dst_rank, channel) + + def quiet(self, dst_rank: int = None, channel: int = 0): + """ + Wait for all outstanding SDMA operations to complete. + + Args: + dst_rank: If specified, wait only for ops to this rank. + If None, wait for ops to all ranks. + channel: SDMA channel + + Example: + >>> shmem.put(tensor, dst_rank=1, async_op=True) + >>> shmem.quiet(dst_rank=1) # Wait for completion + >>> shmem.quiet() # Wait for all ranks + """ + src_rank = self.get_rank() + if dst_rank is not None: + self.copy_engines.host_quiet(src_rank, dst_rank, channel) + else: + # Quiet to all ranks + for rank in range(self.get_num_ranks()): + self.copy_engines.host_quiet(src_rank, rank, channel) + def _build_device_context(self): """ Build and cache the device context tensor. @@ -975,7 +1295,7 @@ def get_device_context(self): """ return self._device_context - def barrier(self, stream=None, group=None): + def barrier(self, stream=None, group=None, sync_copy_engine=False): """ Synchronize ranks within the specified group and their CUDA devices. @@ -987,11 +1307,14 @@ def barrier(self, stream=None, group=None): stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). group (ProcessGroup, optional): The process group to synchronize. If None, uses the default process group (all ranks). + sync_copy_engine (bool, optional): If True, also wait for all outstanding SDMA operations to complete. + Default is False. Example: >>> ctx = iris.iris(1 << 20) >>> ctx.barrier() # Synchronize all ranks >>> ctx.barrier(group=my_group) # Synchronize only ranks in my_group + >>> ctx.barrier(sync_copy_engine=True) # Synchronize GPU + SDMA """ # Wait for all GPUs to finish work if stream is None: @@ -999,6 +1322,10 @@ def barrier(self, stream=None, group=None): else: stream.synchronize() + # Wait for SDMA operations if requested + if sync_copy_engine: + self.quiet() + # Distributed barrier distributed_barrier(group=group) @@ -1331,6 +1658,18 @@ def __translate(ptr, from_rank, to_rank, heap_bases, hint: tl.constexpr = None): return translated_ptr +@triton.jit +def translate_ptr(ptr, from_rank, to_rank, heap_bases, hint: tl.constexpr = None): + """ + Public device-side pointer translation helper. + + This is a thin wrapper around the internal translation routine so Triton + kernels importing the top-level `iris` package can access address-space + translation without depending on a private symbol name. + """ + return __translate(ptr, from_rank, to_rank, heap_bases, hint) + + @aggregate class DeviceContext: """ @@ -1581,6 +1920,7 @@ def store(self, pointer, value, to_rank, mask=None, cache_modifier=None, hint: t value (Block): The tensor of elements to be stored. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. @@ -1623,6 +1963,7 @@ def get( to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank where the data will be written. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. @@ -1672,6 +2013,7 @@ def put( to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `to_rank`. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. @@ -2259,11 +2601,20 @@ def put( from_rank, to_rank, heap_bases, + copy_engine_ctx: tl.tensor, + stride_tm: tl.constexpr = 0, + stride_tn: tl.constexpr = 0, + stride_fm: tl.constexpr = 0, + stride_fn: tl.constexpr = 0, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None, hint: tl.constexpr = None, + USE_COPY_ENGINE: tl.constexpr = False, + IS_2D_COPY: tl.constexpr = False, + from_base_ptr=None, + to_base_ptr=None, ): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -2271,15 +2622,24 @@ def put( rank's `from_ptr`, translating the `to_ptr` from the current rank's address space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. + Supports both 1D (flat/linear) and 2D (tiled) copies: + - 1D copies: Used when stride_tm == 0 and stride_fm == 0 (default), uses linear SDMA packets + - 2D copies: Used when strides are non-zero, uses sub-window SDMA packets for better performance The load is **always local** (reading from the current rank's own ``from_ptr``), while the store is **remote** when ``from_rank != to_rank`` (writing to a peer GPU). Args: from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. - to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. from_rank (int): The current rank ID from which to read the data. - to_rank (int): The `to_rank` ID to which the data will be written. + to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + copy_engine_ctx (tl.tensor): Copy engine context for SDMA operations. + stride_tm (int, optional): Stride in M dimension for destination buffer (in elements). Default: 0 (flat copy). + stride_tn (int, optional): Stride in N dimension for destination buffer (in elements). Default: 0. + stride_fm (int, optional): Stride in M dimension for source buffer (in elements). Default: 0 (flat copy). + stride_fn (int, optional): Stride in N dimension for source buffer (in elements). Default: 0. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load/copy data at that index. Defaults to None. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. @@ -2296,27 +2656,270 @@ def put( - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). + USE_COPY_ENGINE (tl.constexpr, optional): Whether to use SDMA copy engine. Defaults to False (uses regular load/store). + from_base_ptr (triton.PointerType, optional): Base pointer of the source buffer. Required for 2D copies when USE_COPY_ENGINE is True. + to_base_ptr (triton.PointerType, optional): Base pointer of the destination buffer. Required for 2D copies when USE_COPY_ENGINE is True. Returns: None - Example: + Examples: + 1D (flat) copy: >>> @triton.jit - >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> def kernel(local_ptr, remote_ptr, heap_bases, copy_engine_ctx): >>> from_rank = 0 >>> to_rank = 1 - >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) + >>> offsets = tl.arange(0, 256) + >>> iris.put(local_ptr + offsets, remote_ptr + offsets, + >>> from_rank, to_rank, heap_bases, copy_engine_ctx, + >>> mask=offsets < 256, USE_COPY_ENGINE=True) + + 2D (tiled) copy: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases, copy_engine_ctx, base_ptr): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, + >>> stride_tm=1024, stride_fm=1024, + >>> mask=mask, USE_COPY_ENGINE=True, + >>> from_base_ptr=base_ptr, to_base_ptr=base_ptr) """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) - data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + if not USE_COPY_ENGINE: + data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) + else: + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + # dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64), axis=-1) + dst_ptr_val0 = tl.min(translated_to_ptr.to(tl.uint64)) + # Extract source address (min of pointer block where data is stored) + src_ptr_u64 = from_ptr.to(tl.uint64) + # src_ptr_val = tl.min(src_ptr_u64, axis=-1) + src_ptr_val0 = tl.min(src_ptr_u64) + # max_src_ptr = tl.max(src_ptr_u64, axis=0) + + # Infer element size from pointer type + # src_ptr is a block of pointers with a specific element type (e.g., pointer) + # The pointer dtype tells us the element type, which has a known size + # Map Triton dtypes to their byte sizes + ptr_dtype = from_ptr.dtype.element_ty # Get the element type that the pointer points to + + # Get element size in bytes from the dtype + # tl.float16 -> 2, tl.float32 -> 4, tl.float64 -> 8, etc. + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + element_size_bytes = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: + element_size_bytes = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: + element_size_bytes = 8 + elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: + element_size_bytes = 1 + elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: + element_size_bytes = 2 + else: + # Default to 4 bytes for unknown types + element_size_bytes = 4 + + # Determine packet size based on copy type + # Linear copy packet: 32 bytes for 1D, Sub-window copy packet: 80 bytes for 2D + # IS_2D_COPY is a compile-time constant for proper branch elimination + mask_int = mask.to(tl.int32) + command_in_bytes_u32 = 80 if IS_2D_COPY else 32 + command_in_bytes = command_in_bytes_u32.to(tl.uint64) + + # Acquire space in the queue + base, offset = anvil.acquire_fadd( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + + # Write padding NOPs if we wrapped around + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + # Place the appropriate packet type + packet_offset_bytes = base + offset + + if not IS_2D_COPY: + # For 1D copies, mask is 1D, so just sum all elements + num_elements = tl.sum(mask_int, axis=0) + size_bytes = (num_elements * element_size_bytes).to(tl.uint32) + + # Place linear copy packet for 1D/flat copies + anvil.place_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + size_bytes, + src_ptr_val0, + dst_ptr_val0, + ) + else: + # For 2D copies, mask is 2D [M, N], use axis operations + num_elements_per_stride = tl.max(tl.sum(mask_int, axis=-1)) + num_strides = tl.max(tl.sum(mask_int, axis=0)) + size_bytes = (num_elements_per_stride * element_size_bytes).to(tl.uint32) + src_stride = (stride_fm * element_size_bytes).to(tl.uint32) + dst_stride = (stride_tm * element_size_bytes).to(tl.uint32) + + # Place sub-window copy packet for 2D tiled copies + # Calculate base addresses and offsets for sub-window copy + src_base = from_base_ptr.to(tl.uint64) + dst_base = __translate(to_base_ptr, from_rank, to_rank, heap_bases).to(tl.uint64) + + # Calculate tile offset from base + tile_offset_bytes = src_ptr_val0 - src_base + src_y_val = (tile_offset_bytes // src_stride).to(tl.uint32) + src_x_val = (tile_offset_bytes % src_stride).to(tl.uint32) + + tile_offset_bytes_dst = dst_ptr_val0 - dst_base + dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) + dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) + + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + src_base, + dst_base, + tile_width=size_bytes, + tile_height=num_strides, + src_buffer_pitch=src_stride, + dst_buffer_pitch=dst_stride, + src_x=src_x_val, + src_y=src_y_val, + dst_x=dst_x_val, + dst_y=dst_y_val, + ) + + # Submit the command to the queue + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) - tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) + +@triton.jit +def nontemporal_store(addr, value): + tl.inline_asm_elementwise( + asm="""flat_store_dwordx2 $1 $2 sc0 nt; s_waitcnt vmcnt(0)""", + constraints=("=r,v,v"), # =r used for dummy return to satisfy compiler requirement + args=[addr, value], + dtype=tl.int32, # return not used + is_pure=False, + pack=1, + ) + + +# TODO rename or add nt +@triton.jit +def nontemporal_load(addr): + val = tl.inline_asm_elementwise( + asm="""flat_load_dwordx2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", + constraints=("=v,v"), + args=[addr], + dtype=tl.uint64, + is_pure=False, + pack=1, + ) + return val + + +@triton.jit +def nontemporal_atomic_add(addr, value): + old = tl.inline_asm_elementwise( + asm="""flat_atomic_add_x2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", + constraints=("=v,v,v"), + args=[addr, value], + dtype=tl.uint64, + is_pure=False, + pack=1, + ) + return old + + +# @triton.jit +# def nontemporal_compare_exchange(addr, cmp_low, cmp_high, val_low, val_high): +# # data_128bit = tl.cat([cmp_low, cmp_high, val_low, val_high]) +# data_128bit = tl.make_vector([cmp_low, cmp_high, val_low, val_high], type=tl.uint32) +# old = tl.inline_asm_elementwise( +# asm="""flat_atomic_cmpswap_x2 $0 $1 $2 sc0 nt; s_waitcnt vmcnt(0)""", +# constraints=("=v,v,v"), +# args=[addr, data_128bit], +# dtype=tl.uint64, +# is_pure=False, +# pack=1, +# ) +# return True # TODO if old == cmp else False + + +# @triton.jit +# def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): +# """ +# Copies data from the current rank's local memory to the specified rank's memory. +# This function performs a memory write operation by loading data from the current +# rank's `from_ptr`, translating the `to_ptr` from the current rank's address +# space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. +# If the `to_rank` is the same as the current rank, this function performs a local copy operation. + +# Args: +# from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. +# to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. +# from_rank (int): The current rank ID from which to read the data. +# to_rank (int): The `to_rank` ID to which the data will be written. +# heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. +# mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + +# Returns: +# None + +# Example: +# >>> @triton.jit +# >>> def kernel(local_ptr, remote_ptr, heap_bases): +# >>> from_rank = 0 +# >>> to_rank = 1 +# >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) +# """ + +# handle = ce_handle # iris.get_copy_engine_handle(to_rank) +# queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) +# read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) +# write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) +# doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) +# cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) +# committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + +# translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) +# dst_ptr_val = translated_to_ptr.to(tl.uint64) + +# command_in_bytes = 32 +# # Acquire space +# base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) + +# # Place command packet +# slot_ptr_u32 = queue_ptr_u32 + (base // 4) +# anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) + +# # Submit command +# anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) @triton.jit def atomic_add( - pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None + pointer, + val, + from_rank, + to_rank, + heap_bases, + mask=None, + sem=None, + scope=None, + hint: tl.constexpr = None, + copy_engine_ctx=None, + USE_COPY_ENGINE: tl.constexpr = False, ): """ Performs an atomic add at the specified rank's memory location. @@ -2350,7 +2953,646 @@ def atomic_add( >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) - return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + if not USE_COPY_ENGINE: + return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + else: + handle = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + + dst_ptr_val = translated_ptr.to(tl.uint64) + + command_in_bytes = 32 + # Acquire space (returns base index and wraparound offset) + base, offset = anvil.acquire_fadd( + # base = anvil.acquire( + queue_ptr_u32, + read_ptr, + write_ptr, + doorbell_ptr, + cached_write_ptr, + committed_write_ptr, + command_in_bytes, + ) + # tl.device_print("offset ", offset) + + # Write padding NOPs if we wrapped around + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + # Calculate packet position (base + offset for wraparound) + packet_offset_bytes = base + offset + + # Place command packet + anvil.place_atomic_packet(queue_ptr_u32, packet_offset_bytes, dst_ptr_val, val) + + # Submit command + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def put_signal( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + copy_engine_ctx: tl.tensor, + flag_ptr, + flag_value, + stride_tm: tl.constexpr, + stride_fm: tl.constexpr, + mask=None, + hint: tl.constexpr = None, + from_base_ptr=None, + to_base_ptr=None, +): + """ + Combines 2D copy (put) with atomic_add signal in one SDMA submission. + + This is equivalent to calling put() followed by atomic_add(), but batches both + operations into a single SDMA queue submission for better performance. + + Args: + from_ptr: Source pointer in current rank's local memory + to_ptr: Destination pointer (will be translated to to_rank's address space) + from_rank: Current rank ID + to_rank: Remote rank ID to write to + heap_bases: Array of heap base addresses for all ranks + copy_engine_ctx: Copy engine context for SDMA operations + flag_ptr: Pointer to flag location for signaling + flag_value: Value to atomically add to flag (typically 1) + stride_tm: Destination row stride in elements + stride_fm: Source row stride in elements + mask: 2D mask indicating which elements to copy + hint: Vectorization hint for translated pointers + from_base_ptr: Base pointer of source buffer (required for 2D) + to_base_ptr: Base pointer of destination buffer (required for 2D) + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(local_A, remote_staged_a, flags, heap_bases, ctx, base_ptr): + >>> # Copy tile and signal completion + >>> m_offs = tl.arange(0, 256)[:, None] + >>> k_offs = tl.arange(0, 64)[None, :] + >>> mask = (m_offs < 256) & (k_offs < 64) + >>> iris.put_signal(local_A + offsets, remote_staged_a + offsets, + >>> from_rank=0, to_rank=1, heap_bases=heap_bases, + >>> copy_engine_ctx=ctx, flag_ptr=flags, flag_value=1, + >>> stride_tm=1024, stride_fm=64, mask=mask, + >>> from_base_ptr=base_ptr, to_base_ptr=base_ptr) + """ + # Translate destination pointer + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) + translated_flag_ptr = __translate(flag_ptr, from_rank, to_rank, heap_bases, hint) + + # Get copy engine context for target rank + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + # Extract addresses + dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64)) + src_ptr_u64 = from_ptr.to(tl.uint64) + src_ptr_val = tl.min(src_ptr_u64) + + # Get element size from pointer type + ptr_dtype = from_ptr.dtype.element_ty + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + element_size_bytes = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: + element_size_bytes = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: + element_size_bytes = 8 + elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: + element_size_bytes = 1 + elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: + element_size_bytes = 2 + else: + element_size_bytes = 4 + + # Reserve space for BOTH packets: SUB_WINDOW_COPY (80 bytes) + ATOMIC (32 bytes) = 112 bytes + command_in_bytes = 112 + + # Acquire space in queue + base, offset = anvil.acquire_fadd( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + + # Write padding NOPs if we wrapped around + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + packet_offset_bytes = base + offset + + # Calculate 2D copy parameters from mask + mask_int = mask.to(tl.int32) + num_elements_per_stride = tl.max(tl.sum(mask_int, axis=-1)) + num_strides = tl.max(tl.sum(mask_int, axis=0)) + size_bytes = (num_elements_per_stride * element_size_bytes).to(tl.uint32) + src_stride = (stride_fm * element_size_bytes).to(tl.uint32) + dst_stride = (stride_tm * element_size_bytes).to(tl.uint32) + + # Calculate base addresses and offsets + src_base = from_base_ptr.to(tl.uint64) + dst_base = __translate(to_base_ptr, from_rank, to_rank, heap_bases).to(tl.uint64) + + tile_offset_bytes = src_ptr_val - src_base + src_y_val = (tile_offset_bytes // src_stride).to(tl.uint32) + src_x_val = (tile_offset_bytes % src_stride).to(tl.uint32) + + tile_offset_bytes_dst = dst_ptr_val - dst_base + dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) + dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) + + # Place SUB_WINDOW_COPY packet (80 bytes) + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + src_base, + dst_base, + tile_width=size_bytes, + tile_height=num_strides, + src_buffer_pitch=src_stride, + dst_buffer_pitch=dst_stride, + src_x=src_x_val, + src_y=src_y_val, + dst_x=dst_x_val, + dst_y=dst_y_val, + ) + + # Place ATOMIC packet immediately after (32 bytes) + atomic_offset_bytes = packet_offset_bytes + 80 + flag_dst_ptr_val = translated_flag_ptr.to(tl.uint64) + anvil.place_atomic_packet(queue_ptr_u32, atomic_offset_bytes, flag_dst_ptr_val, flag_value) + + # Submit both packets in one doorbell ring + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def put_signal_rect( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + copy_engine_ctx: tl.tensor, + flag_ptr, + flag_value, + width_bytes: tl.constexpr, + height: tl.constexpr, + src_pitch: tl.constexpr, + dst_pitch: tl.constexpr, + hint: tl.constexpr = None, +): + """ + Combines 2D rectangular copy (put) with atomic_add signal in one SDMA submission. + + Unlike put_signal(), this function takes explicit dimensions instead of a mask, + allowing arbitrarily large tiles without hitting Triton's tensor size limit. + + Args: + from_ptr: Source base pointer (scalar) in current rank's local memory + to_ptr: Destination base pointer (scalar) - will be translated to to_rank's address space + from_rank: Current rank ID + to_rank: Remote rank ID to write to + heap_bases: Array of heap base addresses for all ranks + copy_engine_ctx: Copy engine context for SDMA operations + flag_ptr: Pointer to flag location for signaling + flag_value: Value to atomically add to flag (typically 1) + width_bytes: Width of rectangle in bytes + height: Height of rectangle in rows + src_pitch: Source row stride in bytes + dst_pitch: Destination row stride in bytes + hint: Vectorization hint for translated pointers + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(A_sharded, staged_a, flags, heap_bases, ctx): + >>> # Transfer 256 rows × 1024 bytes (128 elements × 4 bytes × 2 K-blocks) + >>> src_ptr = A_sharded + m_offset * stride_am + k_offset * stride_ak + >>> dst_ptr = staged_a + m_offset * stride_sa_m + k_offset * stride_sa_k + >>> iris.put_signal_rect( + >>> src_ptr, dst_ptr, 0, 1, heap_bases, ctx, flags, 1, + >>> width_bytes=1024, height=256, + >>> src_pitch=stride_am * 2, dst_pitch=stride_sa_m * 2 + >>> ) + """ + # Translate destination pointers + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) + translated_flag_ptr = __translate(flag_ptr, from_rank, to_rank, heap_bases, hint) + + # Get copy engine context for target rank + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + # Extract addresses (scalar pointers) + src_ptr_val = from_ptr.to(tl.uint64) + dst_ptr_val = translated_to_ptr.to(tl.uint64) + flag_dst_ptr_val = translated_flag_ptr.to(tl.uint64) + + # Reserve space for BOTH packets: SUB_WINDOW_COPY (80 bytes) + ATOMIC (32 bytes) = 112 bytes + command_in_bytes = 112 + + # Acquire space in queue + base, offset = anvil.acquire_fadd( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + + # Write padding NOPs if we wrapped around + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + packet_offset_bytes = base + offset + + # Place SUB_WINDOW_COPY packet (80 bytes) + # Using base pointers directly (no offset calculation needed) + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + src_ptr_val, + dst_ptr_val, + tile_width=width_bytes, + tile_height=height, + src_buffer_pitch=src_pitch, + dst_buffer_pitch=dst_pitch, + src_x=0, # Offset already baked into pointers + src_y=0, + dst_x=0, + dst_y=0, + ) + + # Place ATOMIC packet immediately after (32 bytes) + atomic_offset_bytes = packet_offset_bytes + 80 + anvil.place_atomic_packet(queue_ptr_u32, atomic_offset_bytes, flag_dst_ptr_val, flag_value) + + # Submit both packets in one doorbell ring + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def wait_then_put_rect( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + copy_engine_ctx: tl.tensor, + wait_flag_ptr, + wait_value, + width_bytes: tl.constexpr, + height: tl.constexpr, + src_pitch: tl.constexpr, + dst_pitch: tl.constexpr, + hint: tl.constexpr = None, +): + """ + Enqueue a POLL_REGMEM followed by a 2D SUB_WINDOW_COPY in one SDMA submission. + + This is the device-side counterpart to the host wait-then-put-tile path. + The SDMA queue waits on a local flag and performs the copy autonomously + after the producer has completed the corresponding batch. + """ + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) + + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + poll_packet_bytes = 24 + copy_packet_bytes = 80 + command_in_bytes = poll_packet_bytes + copy_packet_bytes + + base, offset = anvil.acquire_fadd( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + packet_offset_bytes = base + offset + anvil.place_poll_regmem_packet( + queue_ptr_u32, + packet_offset_bytes, + wait_flag_ptr.to(tl.uint64), + wait_value, + ) + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes + poll_packet_bytes, + from_ptr.to(tl.uint64), + translated_to_ptr.to(tl.uint64), + tile_width=width_bytes, + tile_height=height, + src_buffer_pitch=src_pitch, + dst_buffer_pitch=dst_pitch, + src_x=0, + src_y=0, + dst_x=0, + dst_y=0, + ) + + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def wait_then_put_rects( + from_base_ptr, + to_base_ptr, + from_rank, + to_rank, + heap_bases, + copy_engine_ctx: tl.tensor, + wait_flag_ptr, + wait_value, + transfer_row_offsets, + transfer_col_offsets, + transfer_width_bytes, + transfer_heights, + transfer_start, + transfer_count, + stride_n_bytes, + src_pitch: tl.constexpr, + dst_pitch: tl.constexpr, + MAX_RECTS: tl.constexpr, + hint: tl.constexpr = None, +): + """ + Enqueue one POLL_REGMEM followed by many 2D SUB_WINDOW_COPY packets. + + The copy list is provided as flattened metadata arrays plus a per-wave + start/count pair so the poster can submit an entire wave with one queue + reservation and one doorbell ring. + """ + translated_to_base_ptr = __translate(to_base_ptr, from_rank, to_rank, heap_bases, hint) + + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + poll_packet_bytes = 24 + copy_packet_bytes = 80 + command_in_bytes = poll_packet_bytes + transfer_count * copy_packet_bytes + + base, offset = anvil.acquire_fadd( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + packet_offset_bytes = base + offset + anvil.place_poll_regmem_packet( + queue_ptr_u32, + packet_offset_bytes, + wait_flag_ptr.to(tl.uint64), + wait_value, + ) + + from_base_val = from_base_ptr.to(tl.uint64) + to_base_val = translated_to_base_ptr.to(tl.uint64) + + for i in range(MAX_RECTS): + if i < transfer_count: + transfer_idx = transfer_start + i + row_offset = tl.load(transfer_row_offsets + transfer_idx) + col_offset = tl.load(transfer_col_offsets + transfer_idx) + width_bytes = tl.load(transfer_width_bytes + transfer_idx) + height = tl.load(transfer_heights + transfer_idx) + byte_offset = (row_offset.to(tl.uint64) * src_pitch) + (col_offset.to(tl.uint64) * stride_n_bytes) + copy_offset_bytes = packet_offset_bytes + poll_packet_bytes + i * copy_packet_bytes + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + copy_offset_bytes, + from_base_val + byte_offset, + to_base_val + byte_offset, + tile_width=width_bytes, + tile_height=height, + src_buffer_pitch=src_pitch, + dst_buffer_pitch=dst_pitch, + src_x=0, + src_y=0, + dst_x=0, + dst_y=0, + ) + + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def wait_then_put_signal_rect( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + copy_engine_ctx: tl.tensor, + wait_flag_ptr, + wait_value, + signal_flag_ptr, + signal_value, + width_bytes: tl.constexpr, + height: tl.constexpr, + src_pitch: tl.constexpr, + dst_pitch: tl.constexpr, + hint: tl.constexpr = None, +): + """ + Enqueue POLL_REGMEM + 2D SUB_WINDOW_COPY + ATOMIC in one SDMA submission. + + This is the device-side counterpart to host-side wait/copy/signal flows and is + useful for marking receiver-visible completion after the final copy in a queue. + """ + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) + translated_signal_ptr = __translate(signal_flag_ptr, from_rank, to_rank, heap_bases, hint) + + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + poll_packet_bytes = 24 + copy_packet_bytes = 80 + atomic_packet_bytes = 32 + command_in_bytes = poll_packet_bytes + copy_packet_bytes + atomic_packet_bytes + + base, offset = anvil.acquire_fadd( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + packet_offset_bytes = base + offset + anvil.place_poll_regmem_packet( + queue_ptr_u32, + packet_offset_bytes, + wait_flag_ptr.to(tl.uint64), + wait_value, + ) + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes + poll_packet_bytes, + from_ptr.to(tl.uint64), + translated_to_ptr.to(tl.uint64), + tile_width=width_bytes, + tile_height=height, + src_buffer_pitch=src_pitch, + dst_buffer_pitch=dst_pitch, + src_x=0, + src_y=0, + dst_x=0, + dst_y=0, + ) + anvil.place_atomic_packet( + queue_ptr_u32, + packet_offset_bytes + poll_packet_bytes + copy_packet_bytes, + translated_signal_ptr.to(tl.uint64), + signal_value, + ) + + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def wait_then_put_signal_rects( + from_base_ptr, + to_base_ptr, + from_rank, + to_rank, + heap_bases, + copy_engine_ctx: tl.tensor, + wait_flag_ptr, + wait_value, + signal_flag_ptr, + signal_value, + transfer_row_offsets, + transfer_col_offsets, + transfer_width_bytes, + transfer_heights, + transfer_start, + transfer_count, + stride_n_bytes, + src_pitch: tl.constexpr, + dst_pitch: tl.constexpr, + MAX_RECTS: tl.constexpr, + hint: tl.constexpr = None, +): + """ + Enqueue one POLL_REGMEM, many 2D SUB_WINDOW_COPY packets, and one ATOMIC. + """ + translated_to_base_ptr = __translate(to_base_ptr, from_rank, to_rank, heap_bases, hint) + translated_signal_ptr = __translate(signal_flag_ptr, from_rank, to_rank, heap_bases, hint) + + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + poll_packet_bytes = 24 + copy_packet_bytes = 80 + atomic_packet_bytes = 32 + command_in_bytes = poll_packet_bytes + transfer_count * copy_packet_bytes + atomic_packet_bytes + + base, offset = anvil.acquire_fadd( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + packet_offset_bytes = base + offset + anvil.place_poll_regmem_packet( + queue_ptr_u32, + packet_offset_bytes, + wait_flag_ptr.to(tl.uint64), + wait_value, + ) + + from_base_val = from_base_ptr.to(tl.uint64) + to_base_val = translated_to_base_ptr.to(tl.uint64) + for i in range(MAX_RECTS): + if i < transfer_count: + transfer_idx = transfer_start + i + row_offset = tl.load(transfer_row_offsets + transfer_idx) + col_offset = tl.load(transfer_col_offsets + transfer_idx) + width_bytes = tl.load(transfer_width_bytes + transfer_idx) + height = tl.load(transfer_heights + transfer_idx) + byte_offset = (row_offset.to(tl.uint64) * src_pitch) + (col_offset.to(tl.uint64) * stride_n_bytes) + copy_offset_bytes = packet_offset_bytes + poll_packet_bytes + i * copy_packet_bytes + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + copy_offset_bytes, + from_base_val + byte_offset, + to_base_val + byte_offset, + tile_width=width_bytes, + tile_height=height, + src_buffer_pitch=src_pitch, + dst_buffer_pitch=dst_pitch, + src_x=0, + src_y=0, + dst_x=0, + dst_y=0, + ) + + anvil.place_atomic_packet( + queue_ptr_u32, + packet_offset_bytes + poll_packet_bytes + transfer_count * copy_packet_bytes, + translated_signal_ptr.to(tl.uint64), + signal_value, + ) + + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def quiet(copy_engine_ctx: tl.tensor, to_rank): + """ + Device-side equivalent of host_quiet for a single destination queue. + + Waits until the hardware read pointer catches up to the queue's committed + write pointer, meaning all packets submitted to that SDMA queue have + completed. + """ + ctx = copy_engine_ctx + (6 * to_rank) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + target_wptr = tl.load(committed_write_ptr, cache_modifier=".cv", volatile=True) + while tl.load(read_ptr, cache_modifier=".cv", volatile=True) != target_wptr: + pass + + # tl.debug_barrier() @triton.jit diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index e0d12ba51..647ff91de 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -36,6 +36,7 @@ # from .matmul import matmul # Simple single-GPU GEMM - TODO: implement from .matmul_all_reduce import matmul_all_reduce, matmul_all_reduce_preamble from .all_gather_matmul import all_gather_matmul, all_gather_matmul_preamble +from .all_gather_matmul_hbm_buffer import all_gather_matmul_hbm_buffer, all_gather_matmul_hbm_buffer_preamble from .matmul_all_gather import matmul_all_gather from .matmul_reduce_scatter import matmul_reduce_scatter, matmul_reduce_scatter_preamble @@ -180,6 +181,8 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, "matmul_all_reduce_preamble", "all_gather_matmul", "all_gather_matmul_preamble", + "all_gather_matmul_hbm_buffer", + "all_gather_matmul_hbm_buffer_preamble", "matmul_all_gather", "matmul_reduce_scatter", "matmul_reduce_scatter_preamble", diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 5d700206c..cfff038ab 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -153,7 +153,7 @@ def _fused_all_gather_matmul_kernel( # Store result using tritonblas Tile rm, rn = out_tile.indices() - C_ptr = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + C_ptr = C + rm.to(tl.int64)[:, None] * stride_cm + rn.to(tl.int64)[None, :] * stride_cn mask = (rm[:, None] < M) & (rn[None, :] < N) tl.store(C_ptr, c, mask=mask) @@ -164,7 +164,7 @@ def all_gather_matmul_preamble( B: torch.Tensor, config: Optional[FusedConfig] = None, ) -> FusedWorkspace: - """Allocate workspace for all_gather_matmul (none needed for pull pattern).""" + """Allocate workspace for all_gather_matmul.""" if config is None: config = FusedConfig() @@ -175,7 +175,7 @@ def all_gather_matmul_preamble( expected_K = world_size * K_local assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - return FusedWorkspace( + ws = FusedWorkspace( operation="all_gather_matmul", shape=(M, N, K), dtype=A_sharded.dtype, @@ -183,6 +183,8 @@ def all_gather_matmul_preamble( prepared=True, ) + return ws + def all_gather_matmul( shmem, @@ -208,17 +210,6 @@ def all_gather_matmul( assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" # Validate problem size against block sizes - assert M >= config.block_size_m, ( - f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." - ) - assert K_local >= config.block_size_k, ( - f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k}). " - f"Use smaller block sizes for small problems." - ) - assert N >= config.block_size_n, ( - f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." - ) - if workspace is None: workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) @@ -245,7 +236,8 @@ def all_gather_matmul( even_k = K_local % config.block_size_k == 0 num_k_blocks_local = (K_local + config.block_size_k - 1) // config.block_size_k - # Launch single fused kernel + num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n grid = (num_sms,) _fused_all_gather_matmul_kernel[grid]( A_sharded, diff --git a/iris/ops/all_gather_matmul_copy_engine.py b/iris/ops/all_gather_matmul_copy_engine.py new file mode 100644 index 000000000..5195fc3b6 --- /dev/null +++ b/iris/ops/all_gather_matmul_copy_engine.py @@ -0,0 +1,903 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused All-Gather + GEMM using copy engine (SDMA) for data movement. + +Key differences from HBM buffer variant: +- SMs only perform GEMM (no fetcher workgroups) +- Host orchestrates SDMA transfers of remote tiles to staged_a buffer +- GEMM processes local K-blocks first (from A_sharded), then remote K-blocks (from staged_a) +- Flags only track remote tiles, updated by copy engine via host_atomic_add_32 +""" + +from typing import Optional +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import iris +import iris.hip as hip +import iris.x +from tritonblas.matmul import persistent_matmul_lt, create_wait_config +from tritonblas.kernels.stages import ( + Tile as StageTile, + GemmContext, + chiplet_transform_chunked, + make_bias_view, + make_input_view, + make_output_view, + make_wait_view, +) + + +from iris.tracing.events import TraceEvent +from .workspace import FusedWorkspace + +# Import Tile class from anvil module +try: + import anvil + + Tile = anvil.Tile +except (ImportError, AttributeError): + Tile = None # Will raise error later if needed + + +@triton.jit +def _batch_poster_kernel( + A_sharded, + staged_a, + flags_ptr, + flag_iteration, + M, + K_local, + stride_am, + stride_sa_m, + stride_sa_k, + context_tensor: tl.tensor, + heap_bases_ptr: tl.tensor, + copy_engine_ctx: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_M_TILES: tl.constexpr, + M_TILES_PER_BATCH: tl.constexpr, + TRACE: tl.constexpr, +): + """Post one SDMA transfer per (batch, rank) including local copy.""" + zero = tl.program_id(0) * 0 + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size, tracing=TRACE) + + pid = tl.program_id(0) + dst_rank = pid + + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wg_sdma, + target_rank=dst_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=zero, + ) + + ptr_dtype = A_sharded.dtype.element_ty + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + elem_size = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32: + elem_size = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64: + elem_size = 8 + else: + elem_size = 4 + + num_batches = (NUM_M_TILES + M_TILES_PER_BATCH - 1) // M_TILES_PER_BATCH + rows_per_batch = M_TILES_PER_BATCH * BLOCK_SIZE_M + + for batch_id in range(num_batches): + src_m_offset = batch_id * rows_per_batch + remaining_rows = M - src_m_offset + tile_height = tl.minimum(remaining_rows, rows_per_batch) + + src_m_offset_i64 = (src_m_offset + 0 * stride_am).to(tl.int64) + stride_am_i64 = (stride_am + 0 * src_m_offset).to(tl.int64) + src_ptr = A_sharded + src_m_offset_i64 * stride_am_i64 + + dst_m_offset = src_m_offset + dst_k_offset = cur_rank * K_local + dst_m_offset_i64 = (dst_m_offset + 0 * stride_sa_m).to(tl.int64) + dst_k_offset_i64 = (dst_k_offset + 0 * stride_sa_k).to(tl.int64) + stride_sa_m_i64 = (stride_sa_m + 0 * dst_m_offset).to(tl.int64) + stride_sa_k_i64 = (stride_sa_k + 0 * dst_m_offset).to(tl.int64) + dst_ptr = staged_a + (dst_m_offset_i64 * stride_sa_m_i64 + dst_k_offset_i64 * stride_sa_k_i64) + + tile_width_bytes = K_local * elem_size + src_pitch_bytes = stride_am * elem_size + dst_pitch_bytes = stride_sa_m * elem_size + + iris.put_signal_rect( + src_ptr, + dst_ptr, + cur_rank, + dst_rank, + heap_bases_ptr, + copy_engine_ctx, + flags_ptr + batch_id, + 1, + width_bytes=tile_width_bytes, + height=tile_height, + src_pitch=src_pitch_bytes, + dst_pitch=dst_pitch_bytes, + ) + + if TRACE: + ctx.tracing.record_event_end(_trace_handle) + + +@triton.jit +def _nonpersistent_xcd_comm_gemm_kernel( + A_sharded, + staged_a, + B, + C, + bias_ptr, + wait_ptr, + wait_expected_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_sa_m, + stride_sa_k, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + context_tensor: tl.tensor, + heap_bases_ptr: tl.tensor, + copy_engine_ctx: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_XCDS: tl.constexpr, + COMPUTE_WGS: tl.constexpr, + M_TILES_PER_BATCH: tl.constexpr = 1, + COMM_WGS: tl.constexpr = 8, + WAIT_NUM: tl.constexpr = 0, + WAIT_MAP_TYPE: tl.constexpr = 0, + WAIT_BLOCK_GROUP_M: tl.constexpr = 1, + WAIT_BLOCK_GROUP_N: tl.constexpr = 1, + WAIT_EXPECTED_INC: tl.constexpr = 1, + BIAS: tl.constexpr = False, + EVEN_K: tl.constexpr = True, + ALLOW_TF32: tl.constexpr = True, + TRACE: tl.constexpr = False, +): + """Template kernel: reserve leading comm WGs, remap compute WGs across XCDs. + + This kernel is intentionally not wired into the production launch path yet. + It sketches the structure needed for: + + - ``COMM_WGS`` front-loaded workgroups, typically one per XCD + - XCD-aware remapping over the compute-only PID space + - non-persistent GEMM execution (one compute WG per tile) + + The comm branch mirrors the current host path's transfer pattern: + + - one poster WG per remote rank when available + - batched copies across M tiles + - full local K shard copied into the rank's global-K slot in ``staged_a`` + - one readiness flag signal per batch + + The compute branch is functional and can be used as a reference for + launch-time experimentation. + """ + raw_pid = tl.program_id(0) + zero = raw_pid * 0 + + if raw_pid < COMM_WGS: + # Map poster WG to one rank (including local). Extra COMM_WGS beyond + # world_size simply go idle. + dst_rank = raw_pid + if dst_rank >= world_size: + return + + ctx = None + if TRACE: + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size, tracing=True) + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wg_sdma, + target_rank=dst_rank, + address=wait_ptr + tl.arange(0, 1), + pid_m=raw_pid, + pid_n=zero, + ) + + # Element size for pointer arithmetic + ptr_dtype = A_sharded.dtype.element_ty + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + elem_size = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32: + elem_size = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64: + elem_size = 8 + else: + elem_size = 4 + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_batches = (num_pid_m + M_TILES_PER_BATCH - 1) // M_TILES_PER_BATCH + rows_per_batch = M_TILES_PER_BATCH * BLOCK_SIZE_M + + # Mirror the current host path: for each batch, transfer the entire + # local K shard into the current rank's global-K slot in staged_a and + # signal one per-batch readiness flag. + for batch_id in range(num_batches): + src_m_offset = batch_id * rows_per_batch + remaining_rows = M - src_m_offset + tile_height = tl.minimum(remaining_rows, rows_per_batch) + + src_m_offset_i64 = (src_m_offset + 0 * stride_am).to(tl.int64) + stride_am_i64 = (stride_am + 0 * src_m_offset).to(tl.int64) + src_ptr = A_sharded + src_m_offset_i64 * stride_am_i64 + dst_m_offset = src_m_offset + dst_k_offset = cur_rank * K_local + dst_m_offset_i64 = (dst_m_offset + 0 * stride_sa_m).to(tl.int64) + dst_k_offset_i64 = (dst_k_offset + 0 * stride_sa_k).to(tl.int64) + stride_sa_m_i64 = (stride_sa_m + 0 * dst_m_offset).to(tl.int64) + stride_sa_k_i64 = (stride_sa_k + 0 * dst_m_offset).to(tl.int64) + dst_ptr = staged_a + (dst_m_offset_i64 * stride_sa_m_i64 + dst_k_offset_i64 * stride_sa_k_i64) + + tile_width_bytes = K_local * elem_size + src_pitch_bytes = stride_am * elem_size + dst_pitch_bytes = stride_sa_m * elem_size + + iris.put_signal_rect( + src_ptr, + dst_ptr, + cur_rank, + dst_rank, + heap_bases_ptr, + copy_engine_ctx, + wait_ptr + batch_id, + 1, + width_bytes=tile_width_bytes, + height=tile_height, + src_pitch=src_pitch_bytes, + dst_pitch=dst_pitch_bytes, + ) + + if TRACE and ctx is not None: + ctx.tracing.record_event_end(_trace_handle) + return + + compute_pid = raw_pid - COMM_WGS + if compute_pid >= COMPUTE_WGS: + return + + if NUM_XCDS != 1: + compute_pid = chiplet_transform_chunked( + compute_pid, + COMPUTE_WGS, + NUM_XCDS, + GROUP_SIZE_M * GROUP_SIZE_M, + ) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + tile_id = compute_pid + if tile_id >= total_tiles: + return + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tensorA = make_input_view(staged_a, M, K, stride_sa_m, stride_sa_k) + tensorB = make_input_view(B, K, N, stride_bk, stride_bn) + tensorC = make_output_view(C, M, N, stride_cm, stride_cn) + bias_view = make_bias_view(bias_ptr, N, stride_bias) if BIAS else None + wait_view = make_wait_view(wait_ptr, wait_expected_ptr) if WAIT_NUM > 0 else None + out_tile = StageTile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N) + + if wait_view is not None: + wait_view.wait_for_tile( + out_tile, + M, + N, + num_flags=WAIT_NUM, + map_type=WAIT_MAP_TYPE, + block_group_m=WAIT_BLOCK_GROUP_M, + block_group_n=WAIT_BLOCK_GROUP_N, + expected_inc=WAIT_EXPECTED_INC, + ) + + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + ctx = GemmContext( + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + COMPUTE_WGS, + NUM_XCDS, + GROUP_SIZE_M, + GROUP_SIZE_M * GROUP_SIZE_M, + None, + None, + acc_dtype, + ALLOW_TF32, + EVEN_K, + False, + ) + acc = ctx.reduce_axis(tensorA, tensorB, out_tile) + tensorC.store(acc, out_tile, bias=bias_view) + + +# ========================================================================== +# Python API +# ========================================================================== + + +def all_gather_matmul_copy_engine_preamble( + shmem, + A_sharded: torch.Tensor, + B: torch.Tensor, + selector=None, + k_per_flag: int = 4, + m_tiles_per_batch: Optional[int] = None, + staged_a_layout: str = "k_contiguous", +) -> FusedWorkspace: + """ + Allocate workspace for copy engine variant. + + Args: + staged_a_layout: "k_contiguous" (default, row-major (M,K)) or + "m_contiguous" (col-major, stored as (K,M) transposed). + """ + from tritonblas.matmul import _make_matmul_selector + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + + assert world_size * K_local == K + + if selector is None: + selector = _make_matmul_selector( + M, N, K, A_sharded.dtype, B.dtype, A_sharded.dtype, A_sharded.device, streamk=False + ) + + assert K_local % selector.block_k == 0 + assert K % selector.block_k == 0 + assert M % selector.block_m == 0 + + num_m_tiles = M // selector.block_m + num_tiles_n = (N + selector.block_n - 1) // selector.block_n + total_tiles = num_m_tiles * num_tiles_n + + if m_tiles_per_batch is None: + # Auto-calculate optimal m_tiles_per_batch + active_cus = getattr(selector, "_ACTIVE_CU", None) + if active_cus is None or active_cus <= 0: + active_cus = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + tiles_per_group = max(1, selector.group_m * num_tiles_n) + groups_per_wave = max(1, int(active_cus) // tiles_per_group) + m_tiles_per_batch = max(1, min(num_m_tiles, groups_per_wave * selector.group_m)) + + num_batches = (num_m_tiles + m_tiles_per_batch - 1) // m_tiles_per_batch + num_flags = num_batches + ws = FusedWorkspace( + operation="all_gather_matmul_copy_engine", + shape=(M, N, K), + dtype=A_sharded.dtype, + world_size=world_size, + variant=f"copy_engine_{staged_a_layout}_wave_aware_noninterleaved", + prepared=True, + ) + + # Allocate staged_a - full K dimension (NON-INTERLEAVED like HBM buffer) + # Each rank's K-blocks are stored contiguously for efficient bulk SDMA + if staged_a_layout == "m_contiguous": + storage = shmem.zeros((K, M), dtype=A_sharded.dtype) + ws.aux_buffer = storage.T # (M, K) with M-contiguous + else: + ws.aux_buffer = shmem.zeros((M, K), dtype=A_sharded.dtype) + + # Allocate per-batch flags + ws.locks = shmem.zeros((num_flags,), dtype=torch.int32) + ws.wait_expected = shmem.zeros((total_tiles,), dtype=torch.int32) + + # Store metadata + ws.selector = selector + ws.m_tiles_per_batch = m_tiles_per_batch + ws.num_m_tiles = num_m_tiles + ws.num_batches = num_batches + + shmem.info( + f"Allocated {num_flags} per-batch flags " + f"(tiles={num_m_tiles}, " + f"{m_tiles_per_batch} M-tiles per batch) " + f"flags buffer at 0x{ws.locks.data_ptr():x}" + ) + + # Share pointers across ranks for SDMA addressing + # Need: A_sharded (source), staged_a (destination), flags (signaling) + A_sharded_ptr_tensor = torch.tensor([A_sharded.data_ptr()], dtype=torch.int64, device="cuda") + A_sharded_ptrs = [torch.zeros(1, dtype=torch.int64, device="cuda") for _ in range(world_size)] + dist.all_gather(A_sharded_ptrs, A_sharded_ptr_tensor) + + staged_a_ptr_tensor = torch.tensor([ws.aux_buffer.data_ptr()], dtype=torch.int64, device="cuda") + staged_a_ptrs = [torch.zeros(1, dtype=torch.int64, device="cuda") for _ in range(world_size)] + dist.all_gather(staged_a_ptrs, staged_a_ptr_tensor) + + flags_ptr_tensor = torch.tensor([ws.locks.data_ptr()], dtype=torch.int64, device="cuda") + flags_ptrs = [torch.zeros(1, dtype=torch.int64, device="cuda") for _ in range(world_size)] + dist.all_gather(flags_ptrs, flags_ptr_tensor) + + # Store all remote pointers in workspace + ws.remote_pointers = { + "A_sharded": [ptr.item() for ptr in A_sharded_ptrs], + "staged_a": [ptr.item() for ptr in staged_a_ptrs], + "flags": [ptr.item() for ptr in flags_ptrs], + } + + # Note: heap_bases are already cached in shmem.heap_bases_cpu (done in iris.py __init__) + + buffer_mb = M * K * A_sharded.element_size() / (1024**2) + sa_stride_m, sa_stride_k = ws.aux_buffer.stride() + shmem.info( + f"Copy Engine: staged_a=({M},{K}) [{buffer_mb:.1f} MB] " + f"layout={staged_a_layout} strides=({sa_stride_m},{sa_stride_k}), " + f"NON-INTERLEAVED: each rank's K-blocks contiguous" + ) + + shmem.barrier() + return ws + + +_WG_GEMM = 15 +_WG_GEMM_WAIT = 16 +_WG_SDMA = 17 + + +def _extract_wg_trace(shmem, grid_size, num_tiles, **metadata): + """Reconstruct per-tile and per-SDMA-WG trace arrays from DeviceTracing events. + + For copy_engine with device-initiated mode: + - grid_size includes both GEMM tiles + SDMA WGs + - num_tiles is the number of GEMM tiles only + - SDMA WGs are stored separately in sdma_* arrays + """ + import numpy as np + + bufs = shmem.tracing.trace_buffers + n = min(shmem.tracing.trace_counter.item(), shmem.tracing.max_events) + + event_ids = bufs["event_id"][:n].cpu().numpy() + pid_ms = bufs["pid_m"][:n].cpu().numpy() # tile_id or SDMA WG pid + timestamps = bufs["timestamp"][:n].cpu().numpy().astype(np.int64) + end_ts = bufs["duration_cycles"][:n].cpu().numpy().astype(np.int64) + xcc_ids = bufs["xcc_id"][:n].cpu().numpy().astype(np.int32) + pid_ns = bufs["pid_n"][:n].cpu().numpy() + + # GEMM tile traces + starts = torch.zeros(num_tiles, dtype=torch.int64) + ends = torch.zeros(num_tiles, dtype=torch.int64) + waits = torch.zeros(num_tiles, dtype=torch.int64) + xcds = torch.zeros(num_tiles, dtype=torch.int32) + + # SDMA WG traces (if device-initiated) + num_sdma = grid_size - num_tiles # Number of SDMA WGs + sdma_starts = torch.zeros(num_sdma, dtype=torch.int64) if num_sdma > 0 else None + sdma_ends = torch.zeros(num_sdma, dtype=torch.int64) if num_sdma > 0 else None + sdma_xcds = torch.zeros(num_sdma, dtype=torch.int32) if num_sdma > 0 else None + + for i in range(n): + eid = int(event_ids[i]) + pid = int(pid_ms[i]) + + if eid == _WG_GEMM: + starts[pid] = int(timestamps[i]) + ends[pid] = int(end_ts[i]) + xcds[pid] = int(xcc_ids[i]) + elif eid == _WG_GEMM_WAIT: + waits[pid] = int(pid_ns[i]) + elif eid == _WG_SDMA: + if num_sdma > 0: + sdma_starts[pid] = int(timestamps[i]) + sdma_ends[pid] = int(end_ts[i]) + sdma_xcds[pid] = int(xcc_ids[i]) + + result = {"start": starts, "end": ends, "wait": waits, "xcd": xcds, "grid_size": num_tiles, **metadata} + if num_sdma > 0: + result.update( + { + "sdma_start": sdma_starts, + "sdma_end": sdma_ends, + "sdma_xcd": sdma_xcds, + "num_sdma": num_sdma, + } + ) + return result + + +def all_gather_matmul_copy_engine( + shmem, + output_tensor: torch.Tensor, + A_sharded: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + workspace: Optional[FusedWorkspace] = None, + flag_iteration: int = 0, + k_per_flag: int = 4, + staged_a_layout: str = "k_contiguous", + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, + trace: bool = False, + verbose: bool = False, + device_initiated: bool = False, + host_transfer_backend: str = "anvil", +) -> FusedWorkspace: + """ + All-gather + matmul with copy engine orchestrating remote tile transfers. + + Key differences from HBM buffer: + - No fetcher workgroups (only GEMM) + - Host uses SDMA to copy remote tiles (default) OR device WGs initiate SDMA (device_initiated=True) + - GEMM processes local tiles first, then remote tiles + + Args: + staged_a_layout: Buffer layout for gathered A. + "k_contiguous" — (M,K) row-major, K is fast dim. + "m_contiguous" — (M,K) with M as fast dim. + device_initiated: If True, use device-side WGs to initiate SDMA transfers instead of host. + host_transfer_backend: Host-side transfer backend when device_initiated=False. + "anvil" uses host SDMA queue submission. + "hip_memcpy" uses hipMemcpy2DAsync followed by a remote flag signal. + k_per_flag: Retained for call compatibility; ignored by the current per-batch design. + """ + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + assert world_size * K_local == K + assert output_tensor.shape == (M, N) + + if host_transfer_backend not in {"anvil", "hip_memcpy"}: + raise ValueError( + f"Unsupported host_transfer_backend={host_transfer_backend!r}; expected 'anvil' or 'hip_memcpy'" + ) + if host_transfer_backend == "hip_memcpy" and staged_a_layout != "k_contiguous": + raise NotImplementedError( + "host_transfer_backend='hip_memcpy' currently requires staged_a_layout='k_contiguous'" + ) + + if workspace is None: + workspace = all_gather_matmul_copy_engine_preamble( + shmem, A_sharded, B, k_per_flag=k_per_flag, staged_a_layout=staged_a_layout + ) + + selector = workspace.selector + m_tiles_per_batch = workspace.m_tiles_per_batch + + assert M % selector.block_m == 0 + assert K % selector.block_k == 0 + assert K_local % selector.block_k == 0 + + num_k_blocks_local = K_local // selector.block_k + num_m_tiles = M // selector.block_m + num_tiles_n = (N + selector.block_n - 1) // selector.block_n + + # Local K-blocks will be copied via SDMA (host or device initiated) + stride_am, stride_ak = A_sharded.stride() + stride_bk, stride_bn = B.stride() + stride_cm, stride_cn = output_tensor.stride() + stride_sa_m, stride_sa_k = workspace.aux_buffer.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor + stride_bias = 1 + use_bias = False + + # num_m_tiles, num_tiles_n already calculated above + total_tiles = num_m_tiles * num_tiles_n + + # if trace: + # if rank == 0: + # shmem.info( + # "Tracing is not yet supported for the persistent tritonBLAS path here; running without trace capture." + # ) + # trace = False + + # Auto-detect num_sms from device if not specified + num_comm_wgs = selector._hardware.NUM_XCD if device_initiated else 0 + gemm_tiles = total_tiles + + # m_tiles_per_batch was already set above before calling preamble + # Calculate number of batches + num_batches = (num_m_tiles + m_tiles_per_batch - 1) // m_tiles_per_batch + + launch_kwargs = {"matrix_instr_nonkdim": 16} + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages + + # ====================================================================== + # Host orchestration: SDMA copy setup + # ====================================================================== + anvil_lib = shmem.copy_engines + torch.cuda.current_device() # Initialize CUDA context + + # SDMA queues already connected during iris init + if verbose and rank == 0: + shmem.info(f"[Rank {rank}] Copy engines connected, launching kernel...") + shmem.info( + f"Kernel params: num_m_tiles={num_m_tiles}, " + f"num_tiles_n={num_tiles_n}, num_k_blocks_local={num_k_blocks_local}, " + f"group_size_m={selector.group_m}, m_tiles_per_batch={m_tiles_per_batch}" + ) + shmem.info( + f"Pointers: A_sharded=0x{A_sharded.data_ptr():x}, " + f"B=0x{B.data_ptr():x}, C=0x{output_tensor.data_ptr():x}, " + f"bias_ptr=0x{bias_ptr.data_ptr():x}, " + f"staged_a=0x{workspace.aux_buffer.data_ptr():x}, " + f"flags=0x{workspace.locks.data_ptr():x} (n={workspace.locks.numel()})" + ) + + if verbose and rank == 0: + shmem.info( + "Launching kernel: " + f"gemm_tiles={gemm_tiles}, device_initiated={device_initiated}, " + f"host_transfer_backend={host_transfer_backend}, sdma_wgs={num_comm_wgs}" + ) + + tb_block_m = selector.block_m + wait_config = create_wait_config( + wait_buffer=workspace.locks, + expected_buffer=workspace.wait_expected, + expected_inc=world_size, + map_type="block", + block_group_m=m_tiles_per_batch, + block_group_n=num_tiles_n, + ) + + if use_bias: + import warnings + + warnings.warn( + "Bias is not yet supported in the persistent tritonBLAS path for all_gather_matmul_copy_engine. " + "Ignoring bias for this launch." + ) + + # ====================================================================== + # Launch kernel and orchestrate SDMA transfers + # ====================================================================== + + import time + + sdma_start_time = time.perf_counter() + + if device_initiated: + # Device-initiated: keep the known-good split path in production. + # The combined COMM_WGS+GEMM kernel remains in this file as an + # experimental option, but we do not use it by default. + poster_grid = world_size + _batch_poster_kernel[(poster_grid,)]( + A_sharded, + workspace.aux_buffer, + workspace.locks, + flag_iteration, + M, + K_local, + stride_am, + stride_sa_m, + stride_sa_k, + shmem.get_device_context(), + shmem.get_heap_bases(), + shmem.get_copy_engine_ctx(), + rank, + world_size, + selector.block_m, + num_m_tiles, + m_tiles_per_batch, + False, + ) + persistent_matmul_lt( + workspace.aux_buffer, + B, + output_tensor, + selector, + bias=None, + wait_config=wait_config, + ) + else: + # Host-initiated: + # - anvil backend: launch GEMM first, then post batches to overlap SDMA + # - hip_memcpy backend: pre-post all batches before GEMM launch to avoid + # potential deadlock if HIP peer copies do not make forward progress + # while the persistent GEMM is monopolizing the device. + elem_size = A_sharded.element_size() + staged_a_base_addr = workspace.aux_buffer.data_ptr() + flags_base_addr = workspace.locks.data_ptr() + tile_transfer_count = 0 + hip_copy_stream = None + if host_transfer_backend == "hip_memcpy": + hip_copy_stream = getattr(workspace, "host_copy_stream", None) + if hip_copy_stream is None: + hip_copy_stream = hip.create_stream(non_blocking=True) + workspace.host_copy_stream = hip_copy_stream + + def post_host_batch(batch_id: int, m_tile_start: int, num_m_tiles_in_batch: int) -> None: + nonlocal tile_transfer_count + + for dst_rank in range(world_size): + flag_idx = batch_id + flag_addr_local = flags_base_addr + flag_idx * 4 + flag_addr_remote = shmem.translate(flag_addr_local, rank, dst_rank) + + tile = Tile() + tile.pid_m = 0 + tile.pid_n = 0 + tile.block_m = num_m_tiles_in_batch * selector.block_m + tile.block_n = K_local + tile.elem_size = elem_size + tile.src_stride = stride_am * elem_size + # Source is the local shard, so batches only advance in M. + src_offset_bytes = (m_tile_start * selector.block_m * stride_am) * elem_size + tile.data = A_sharded.data_ptr() + src_offset_bytes + + # Destination is this rank's global-K slot inside staged_a. + dst_offset_bytes = ( + m_tile_start * selector.block_m * stride_sa_m + rank * K_local * stride_sa_k + ) * elem_size + dst_ptr_local = staged_a_base_addr + dst_offset_bytes + dst_ptr_remote = shmem.translate(dst_ptr_local, rank, dst_rank) + + if host_transfer_backend == "hip_memcpy": + hip.memcpy_2d_async( + dst_ptr_remote, + stride_sa_m * elem_size, + tile.data, + tile.src_stride, + K_local * elem_size, + num_m_tiles_in_batch * selector.block_m, + stream=hip_copy_stream, + ) + # Preserve the existing readiness semantics: only signal the + # batch once the copy for this destination rank has completed. + hip.stream_synchronize(hip_copy_stream) + anvil_lib.host_atomic_add_32(rank, dst_rank, 0, flag_addr_remote, 1) + else: + anvil_lib.host_put_tile_signal( + rank, + dst_rank, + 0, + tile, + dst_ptr_remote, + stride_sa_m * elem_size, + flag_addr_remote, + 1, + ) + tile_transfer_count += 1 + + if verbose and batch_id == 0 and dst_rank == (rank + 1) % world_size: + shmem.info( + f"[Rank {rank}→{dst_rank}] Signaled batch={batch_id} flag_idx={flag_idx} " + f"({num_m_tiles_in_batch} rows × full local K shard)" + ) + + if verbose and rank == 0: + num_batches_calc = workspace.locks.numel() + shmem.info( + f"[Rank {rank}] Starting SDMA loop (batched M-tile transfers)... " + f"num_tiles_m={num_m_tiles}, num_k_blocks_local={num_k_blocks_local}, " + f"m_tiles_per_batch={m_tiles_per_batch}, tb_block_m={tb_block_m}" + ) + shmem.info(f"[Rank {rank}] Will transfer in {num_batches_calc} batches of {m_tiles_per_batch} M-tiles each") + + # TODO not always faster + # Prime batch 0 before GEMM launch so the first released tile-group can + # start immediately instead of stalling on an empty wait queue. + # TODO issue with time measurement + # first_batch_tiles = min(m_tiles_per_batch, num_m_tiles) + # post_host_batch(0, 0, first_batch_tiles) + + if host_transfer_backend == "hip_memcpy": + # Conservative ordering for HIP peer memcpy: make the gathered A + # fully ready before the persistent GEMM begins waiting on flags. + batch_id = 0 + for m_tile_start in range(0, num_m_tiles, m_tiles_per_batch): + m_tile_end = min(m_tile_start + m_tiles_per_batch, num_m_tiles) + num_m_tiles_in_batch = m_tile_end - m_tile_start + post_host_batch(batch_id, m_tile_start, num_m_tiles_in_batch) + batch_id += 1 + + if verbose and rank == 0: + shmem.info(f"[Rank {rank}] Launching GEMM kernel after pre-posting all HIP memcpy batches...") + + persistent_matmul_lt( + workspace.aux_buffer, + B, + output_tensor, + selector, + bias=None, + wait_config=wait_config, + ) + else: + if verbose and rank == 0: + shmem.info(f"[Rank {rank}] Launching GEMM kernel after pre-posting batch 0...") + + persistent_matmul_lt( + workspace.aux_buffer, + B, + output_tensor, + selector, + bias=None, + wait_config=wait_config, + ) + + # Post the remaining batches while GEMM is already running. + batch_id = 0 + for m_tile_start in range(0, num_m_tiles, m_tiles_per_batch): + m_tile_end = min(m_tile_start + m_tiles_per_batch, num_m_tiles) + num_m_tiles_in_batch = m_tile_end - m_tile_start + post_host_batch(batch_id, m_tile_start, num_m_tiles_in_batch) + batch_id += 1 + + sdma_end_post_time = time.perf_counter() + + if verbose: + # Ensure all SDMA operations complete + for dst_rank in range(world_size): + anvil_lib.host_quiet(rank, dst_rank, 0) + sdma_end_time = time.perf_counter() + + post_ms = (sdma_end_post_time - sdma_start_time) * 1000.0 + quiet_ms = (sdma_end_time - sdma_end_post_time) * 1000.0 + total_ms = (sdma_end_time - sdma_start_time) * 1000.0 + shmem.info( + f"[Rank {rank}] SDMA complete. " + f"Post: {post_ms:.2f}ms, Quiet: {quiet_ms:.2f}ms, Total: {total_ms:.2f}ms, " + f"transfers={tile_transfer_count}" + ) + sample_count = min(8, workspace.locks.numel()) + sample_flags = workspace.locks[:sample_count].cpu().tolist() + shmem.info( + f"[Rank {rank}] Flag sample after SDMA quiet: " + f"expected_inc={world_size}, flags[:{sample_count}]={sample_flags}" + ) + + # ====================================================================== + # Synchronize + # ====================================================================== + if not async_op: + torch.cuda.synchronize() # Wait for kernel completion + shmem.barrier() + + # if trace: + # torch.cuda.synchronize() + # total_tiles = num_m_tiles * num_tiles_n + # workspace.trace_data = _extract_wg_trace( + # shmem, + # grid_size, + # total_tiles, + # num_m_tiles=num_m_tiles, + # num_tiles_n=num_tiles_n, + # ) + + return workspace diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py new file mode 100644 index 000000000..a42cf6f25 --- /dev/null +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -0,0 +1,727 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused All-Gather + GEMM using a local HBM staging buffer with dedicated +fetcher and GEMM workgroups, launched data-parallel. + +Supports configurable staged_a buffer layout (M-contiguous or K-contiguous) +and B layout to match optimal tritonblas conventions (TN, TT, NT, NN). +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from iris.tracing.events import TraceEvent +from .config import FusedConfig +from .workspace import FusedWorkspace + + +# ────────────────────────────────────────────────────────────────────── +# Auto-config: shape-adaptive parameter selection for HBM buffer kernel +# Source: K-021 sweep data (1076+ trials, 7 verified champion shapes) +# ────────────────────────────────────────────────────────────────────── + +# Verified champion configs from IRIS-0018/0019 sweeps + optimize-loop iter3. +# Key: (M, N, K) -> dict of kernel params that beat PyTorch. +_CHAMPION_CONFIGS = { + (262144, 8192, 8192): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=64, + fs=52, + nfs=128, + fsf=304, + ), + (131072, 16384, 16384): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=32, + fs=4, + nfs=64, + fsf=52, + ), + (147456, 28672, 4096): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=16, + fs=59, + nfs=36, + fsf=52, + ), + (229376, 28672, 4096): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=16, + fs=4, + nfs=56, + fsf=52, + ), + (327680, 28672, 4096): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=16, + fs=4, + nfs=32, + fsf=52, + ), + (8192, 8192, 262144): dict( + bm=128, + bn=256, + bk=64, + gm=8, + kpf=32, + fs=4, + nfs=8, + fsf=52, + ), + (16384, 16384, 131072): dict( + bm=128, + bn=256, + bk=64, + gm=16, + kpf=16, + fs=16, + nfs=8, + fsf=52, + ), +} + + +def _auto_config(M: int, N: int, K: int, world_size: int = 8): + """ + Select optimal HBM buffer kernel parameters for a given shape. + + Returns (FusedConfig, k_per_flag, num_fetch_sms, num_fetch_stages, + first_stage_fetch_sms) — ready to pass to the kernel. + + Priority order: + 1. Exact match in champion configs (verified 1.12-1.44x vs PyTorch) + 2. Shape-heuristic derivation from 1076-trial sweep principles + + Heuristics (from K-021 sweep analysis): + - k_per_flag is the #1 knob (52% of perf range). Maximize it. + - bm=256 for M%256==0 and M>=8K; bm=128 otherwise + - bn=256 always (bn=128 is 15-35% worse) + - bk=64 always (bk=128 exceeds 64KB LDS on MI300X) + - num_stages=2 always (num_stages=3 crashes — 98KB LDS needed) + - num_warps=8 always (fewer warps = 22% worse) + - group_size_m: 1 for small M, 24 for large M (L2 locality) + """ + key = (M, N, K) + if key in _CHAMPION_CONFIGS: + c = _CHAMPION_CONFIGS[key] + # Validate kpf for this world_size + num_k_blocks = K // c["bk"] + kpf = c["kpf"] + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + config = FusedConfig( + block_size_m=c["bm"], + block_size_n=c["bn"], + block_size_k=c["bk"], + group_size_m=c["gm"], + ) + return config, kpf, c["fs"], c["nfs"], c["fsf"] + + # Derive from heuristics + num_k_blocks = K // 64 + + # Block sizes + bm = 256 if (M % 256 == 0 and M >= 8192) else 128 + num_m_tiles = M // bm + + # k_per_flag: maximize for throughput + if num_k_blocks >= 512: + kpf = 64 + elif num_k_blocks >= 128: + kpf = 16 + elif num_k_blocks >= 64: + kpf = 8 + else: + kpf = 4 + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + + # num_fetch_sms: scale with M-tiles (more tiles → more fetchers) + if num_m_tiles <= 8: + fs = 4 + elif num_m_tiles <= 32: + fs = 16 + elif num_m_tiles <= 128: + fs = 32 + else: + fs = 52 + + # num_fetch_stages + if num_m_tiles >= 512: + nfs = 4 + elif num_m_tiles >= 64: + nfs = 2 + else: + nfs = 1 + + # group_size_m + gm = 24 if num_m_tiles >= 64 else (8 if num_m_tiles >= 16 else 1) + + config = FusedConfig( + block_size_m=bm, + block_size_n=256, + block_size_k=64, + group_size_m=gm, + ) + return config, kpf, fs, nfs, 64 + + +@triton.jit +def _hbm_buffer_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + staged_a, + flags_ptr, + flag_iteration, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_sa_m, # staged_a stride in M dim + stride_sa_k, # staged_a stride in K dim + stride_bias, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_FETCH_SMS: tl.constexpr, + NUM_M_TILES: tl.constexpr, + NUM_TILES_N: tl.constexpr, + NUM_K_BLOCKS: tl.constexpr, + NUM_K_BLOCKS_LOCAL: tl.constexpr, + K_PER_FLAG: tl.constexpr, + NUM_FLAG_GROUPS_K: tl.constexpr, + TOTAL_GATHER_TILES: tl.constexpr, + BIAS: tl.constexpr, + ALLOW_TF32: tl.constexpr, + NUM_FETCH_STAGES: tl.constexpr, + GEMM_TILES_PER_STAGE: tl.constexpr, + FIRST_STAGE_FETCH_SMS: tl.constexpr, + TRACE: tl.constexpr, +): + pid = tl.program_id(0) + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + zero = tl.program_id(0) * 0 + + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size, tracing=TRACE) + + # Interleaved layout with asymmetric first stage: + # [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... + # P = FIRST_STAGE_FETCH_SMS, F = NUM_FETCH_SMS, G = GEMM_TILES_PER_STAGE + FIRST_STAGE_SIZE: tl.constexpr = FIRST_STAGE_FETCH_SMS + GEMM_TILES_PER_STAGE + REST_STAGE_SIZE: tl.constexpr = NUM_FETCH_SMS + GEMM_TILES_PER_STAGE + M_PER_STAGE: tl.constexpr = (NUM_M_TILES + NUM_FETCH_STAGES - 1) // NUM_FETCH_STAGES + + # Two-phase decode: stage 0 has a different size than subsequent stages + if pid < FIRST_STAGE_SIZE: + my_stage = zero + local_pid = pid + fetch_threshold = zero + FIRST_STAGE_FETCH_SMS + else: + adjusted = pid - FIRST_STAGE_SIZE + my_stage = 1 + adjusted // REST_STAGE_SIZE + local_pid = adjusted % REST_STAGE_SIZE + fetch_threshold = zero + NUM_FETCH_SMS + + if local_pid < fetch_threshold: + # ============================================================== + # FETCHER — stage 0 uses FIRST_STAGE_FETCH_SMS WGs, + # later stages use NUM_FETCH_SMS WGs + # ============================================================== + stage_pid = local_pid + + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().fetch, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) + + src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) + + tiles_per_m_group = NUM_FLAG_GROUPS_K * GROUP_SIZE_M + + for const_stage in range(NUM_FETCH_STAGES): + if my_stage == const_stage: + stage_fetch_sms = FIRST_STAGE_FETCH_SMS if const_stage == 0 else NUM_FETCH_SMS + stage_m_start = const_stage * M_PER_STAGE + stage_m_count = min(M_PER_STAGE, NUM_M_TILES - stage_m_start) + total_fg_stage = NUM_FLAG_GROUPS_K * stage_m_count + + for fg_idx in range(stage_pid, total_fg_stage, stage_fetch_sms): + m_group = fg_idx // tiles_per_m_group + within_group = fg_idx % tiles_per_m_group + # TODO there is a bug if #M-Tiles is not multiple of GROUP_SIZE_M + k_flag_group = within_group // GROUP_SIZE_M + m_in_group = within_group % GROUP_SIZE_M + m_tile = stage_m_start + m_group * GROUP_SIZE_M + m_in_group + m_tile = min(m_tile, NUM_M_TILES - 1) + k_block_start = k_flag_group * K_PER_FLAG + + rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + + for k_off in range(K_PER_FLAG): + k_block_global = k_block_start + k_off + + src_rank_idx = k_block_global // NUM_K_BLOCKS_LOCAL + k_block_local = k_block_global % NUM_K_BLOCKS_LOCAL + + pid_m_t = zero + m_tile + tile_k_t = zero + k_block_local + k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) + + rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + staged_ptrs = ( + staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk.to(tl.int64)[None, :] * stride_sa_k + ) + + for compile_rank in range(world_size): + if src_rank_idx == compile_rank: + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx, hint=(1, BLOCK_SIZE_K)) + tl.store(staged_ptrs, a_tile, cache_modifier=".cg") + + flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group + tl.debug_barrier() + tl.atomic_add(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + + if TRACE: + ctx.tracing.record_event_end(_trace_handle) + + else: + # ============================================================== + # GEMM — gemm_local_id indexes into this stage's M-tile range + # ============================================================== + gemm_local_id = local_pid - fetch_threshold + stage_m_start = my_stage * M_PER_STAGE + + num_pid_in_group = GROUP_SIZE_M * NUM_TILES_N + group_id = gemm_local_id // num_pid_in_group + first_pid_m = stage_m_start + group_id * GROUP_SIZE_M + first_pid_m = min(first_pid_m, NUM_M_TILES - 1) + group_sz = min(NUM_M_TILES - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((gemm_local_id % num_pid_in_group) % group_sz) + pid_n = (gemm_local_id % num_pid_in_group) // group_sz + pid_m = min(pid_m, NUM_M_TILES - 1) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_SIZE_N), BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().compute, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) + + expected_flag_value = flag_iteration + 1 + + for k_fg in range(NUM_FLAG_GROUPS_K): + if TRACE: + _wait_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wait, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=k_fg, + ) + + flag_idx = pid_m * NUM_FLAG_GROUPS_K + k_fg + while tl.atomic_add(flags_ptr + flag_idx, 0, sem="acquire", scope="gpu") < expected_flag_value: + pass + + if TRACE: + ctx.tracing.record_event_end(_wait_handle) + + k_block_base = k_fg * K_PER_FLAG + for k_off in range(K_PER_FLAG): + k_block = k_block_base + k_off + rk = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + + a_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk.to(tl.int64)[None, :] * stride_sa_k + a = tl.load(a_ptrs) + + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(B_ptrs) + + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + if BIAS: + bias_val = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = acc + bias_val[:, None] + + c = acc.to(C.type.element_ty) + stride_cm_i64 = tl.cast(stride_cm, tl.int64) + stride_cn_i64 = tl.cast(stride_cn, tl.int64) + C_ptrs = C + rm.to(tl.int64)[:, None] * stride_cm_i64 + rn.to(tl.int64)[None, :] * stride_cn_i64 + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + tl.store(C_ptrs, c, mask=c_mask, cache_modifier=".wt") + + if TRACE: + ctx.tracing.record_event_end(_trace_handle) + + +# ========================================================================== +# Python API +# ========================================================================== + + +def all_gather_matmul_hbm_buffer_preamble( + ctx, + A_sharded: torch.Tensor, + B: torch.Tensor, + config: Optional[FusedConfig] = None, + k_per_flag: Optional[int] = None, + staged_a_layout: str = "k_contiguous", +) -> FusedWorkspace: + """ + Allocate workspace. + + Args: + staged_a_layout: "k_contiguous" (default, row-major (M,K)) or + "m_contiguous" (col-major, stored as (K,M) transposed). + """ + M, K_local = A_sharded.shape + K, N = B.shape + world_size = ctx.get_num_ranks() + + if config is None: + auto_cfg, auto_kpf, _, _, _ = _auto_config(M, N, K, world_size) + config = auto_cfg + if k_per_flag is None: + k_per_flag = auto_kpf + if k_per_flag is None: + k_per_flag = 8 # Safety default; see K-021 best_configs.json for peak perf + + assert world_size * K_local == K + assert K_local % config.block_size_k == 0 + assert K % config.block_size_k == 0 + assert M % config.block_size_m == 0 + + num_m_tiles = M // config.block_size_m + num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 + num_flag_groups_k = num_k_blocks // k_per_flag + + ws = FusedWorkspace( + operation="all_gather_matmul_hbm_buffer", + shape=(M, N, K), + dtype=A_sharded.dtype, + world_size=world_size, + variant=f"hbm_buffer_{staged_a_layout}", + prepared=True, + ) + + if staged_a_layout == "m_contiguous": + # Allocate (K, M) row-major, .T gives (M, K) with stride_m=1, stride_k=M + storage = ctx.zeros((K, M), dtype=A_sharded.dtype) + ws.aux_buffer = storage.T # (M, K) view, M-contiguous + else: + # Default: (M, K) row-major, stride_m=K, stride_k=1 + ws.aux_buffer = ctx.zeros((M, K), dtype=A_sharded.dtype) + + ws.locks = ctx.zeros((num_m_tiles * num_flag_groups_k,), dtype=torch.int32) + + buffer_mb = M * K * A_sharded.element_size() / (1024**2) + sa_stride_m, sa_stride_k = ws.aux_buffer.stride() + ctx.info( + f"HBM buffer: staged_a=({M},{K}) [{buffer_mb:.1f} MB] " + f"layout={staged_a_layout} strides=({sa_stride_m},{sa_stride_k}), " + f"flags={num_m_tiles}x{num_flag_groups_k}, k_per_flag={k_per_flag}" + ) + + ctx.barrier() + return ws + + +_EID_FETCH = 1024 # TraceEvent().fetch +_EID_COMPUTE = 2048 # TraceEvent().compute +_EID_WAIT = 3072 # TraceEvent().wait + + +def _extract_wg_trace(ctx, grid_size, **metadata): + """Reconstruct per-workgroup trace arrays from DeviceTracing events.""" + import numpy as np + + bufs = ctx.tracing.trace_buffers + n = min(ctx.tracing.trace_counter.item(), ctx.tracing.max_events) + + event_ids = bufs["event_id"][:n].cpu().numpy() + pids = bufs["pid"][:n].cpu().numpy() + timestamps = bufs["timestamp"][:n].cpu().numpy().astype(np.int64) + # Note: despite the field name, "duration_cycles" stores the absolute end timestamp + # (set by record_event_end). The actual duration is end_ts - start_ts. + end_timestamps = bufs["duration_cycles"][:n].cpu().numpy().astype(np.int64) + xcc_ids = bufs["xcc_id"][:n].cpu().numpy().astype(np.int32) + + starts = torch.zeros(grid_size, dtype=torch.int64) + ends = torch.zeros(grid_size, dtype=torch.int64) + waits = torch.zeros(grid_size, dtype=torch.int64) + xcds = torch.zeros(grid_size, dtype=torch.int32) + + for i in range(n): + eid = int(event_ids[i]) + wg = int(pids[i]) + if wg >= grid_size: + continue + if eid == _EID_FETCH or eid == _EID_COMPUTE: + starts[wg] = int(timestamps[i]) + ends[wg] = int(end_timestamps[i]) + xcds[wg] = int(xcc_ids[i]) + elif eid == _EID_WAIT: + waits[wg] += int(end_timestamps[i]) - int(timestamps[i]) + + return {"start": starts, "end": ends, "wait": waits, "xcd": xcds, "grid_size": grid_size, **metadata} + + +def all_gather_matmul_hbm_buffer( + ctx, + output_tensor: torch.Tensor, + A_sharded: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, + flag_iteration: int = 0, + num_fetch_sms: Optional[int] = None, + k_per_flag: Optional[int] = None, + fetch_block_m: Optional[int] = None, + fetch_block_k: Optional[int] = None, + staged_a_layout: str = "k_contiguous", + num_warps: Optional[int] = 8, + num_stages: Optional[int] = 2, + num_fetch_stages: Optional[int] = None, + first_stage_fetch_sms: Optional[int] = None, + trace: bool = False, +) -> FusedWorkspace: + """ + All-gather + matmul with dedicated fetcher/GEMM workgroups. + + When ``config`` is None, uses ``_auto_config()`` to select shape-optimal + parameters from verified sweep data (K-021). This gives up to 1.44× + speedup over PyTorch on champion shapes without any manual tuning. + + Args: + staged_a_layout: Buffer layout for gathered A. + "k_contiguous" — (M,K) row-major, K is fast dim. Matches NN convention. + "m_contiguous" — (M,K) with M as fast dim. Matches TN convention (best for tritonblas). + """ + M, K_local = A_sharded.shape + K, N = B.shape + world_size = ctx.get_num_ranks() + + if config is None: + # Shape-adaptive auto-config from K-021 sweep data + auto_cfg, auto_kpf, auto_fs, auto_nfs, auto_fsf = _auto_config(M, N, K, world_size) + config = auto_cfg + if k_per_flag is None: + k_per_flag = auto_kpf + if num_fetch_sms is None: + num_fetch_sms = auto_fs + if num_fetch_stages is None: + num_fetch_stages = auto_nfs + if first_stage_fetch_sms is None: + first_stage_fetch_sms = auto_fsf + + # Apply defaults for any remaining None values (when config is explicit + # but some params are left at None). + # kpf=8 is the safety default: +4.3% vs kpf=16 on g6 (IRIS-0018, 934 trials) + # and avoids kpf=16 validation failures on 2/8 ranks at M=262144. + # For peak performance on known shapes, use best_configs.json from K-021. + if k_per_flag is None: + k_per_flag = 8 + if num_fetch_sms is None: + num_fetch_sms = 32 + if num_fetch_stages is None: + num_fetch_stages = 1 + if first_stage_fetch_sms is None: + first_stage_fetch_sms = 256 + + rank = ctx.get_rank() + + assert world_size * K_local == K + assert output_tensor.shape == (M, N) + assert M % config.block_size_m == 0 + assert K % config.block_size_k == 0 + assert K_local % config.block_size_k == 0 + + if fetch_block_m is None: + fetch_block_m = config.block_size_m + if fetch_block_k is None: + fetch_block_k = config.block_size_k + + num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 + + if workspace is None: + workspace = all_gather_matmul_hbm_buffer_preamble(ctx, A_sharded, B, config, k_per_flag, staged_a_layout) + + stride_am, stride_ak = A_sharded.stride() + stride_bk, stride_bn = B.stride() + stride_cm, stride_cn = output_tensor.stride() + stride_sa_m, stride_sa_k = workspace.aux_buffer.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor + stride_bias = 1 + use_bias = False + + device = A_sharded.device + num_sms = config.num_sms + if num_sms is None: + props = torch.cuda.get_device_properties(device) + num_sms = props.multi_processor_count + + num_m_tiles = M // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + total_gemm_tiles = num_m_tiles * num_tiles_n + num_k_blocks_local = K_local // config.block_size_k + num_flag_groups_k = num_k_blocks // k_per_flag + total_gather_tiles = num_m_tiles * num_k_blocks + + if num_fetch_sms is None: + num_fetch_sms = max(1, num_sms // 10) + assert 0 < num_fetch_sms + assert num_fetch_stages >= 1 + + # First stage can use more fetcher WGs to fill the first GPU wave + if first_stage_fetch_sms is None: + first_stage_fetch_sms = num_fetch_sms + + # Interleaved layout: [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... + m_per_stage = (num_m_tiles + num_fetch_stages - 1) // num_fetch_stages + gemm_tiles_per_stage = m_per_stage * num_tiles_n + first_stage_size = first_stage_fetch_sms + gemm_tiles_per_stage + rest_stage_size = num_fetch_sms + gemm_tiles_per_stage + total_fetch_wgs = first_stage_fetch_sms + num_fetch_sms * max(0, num_fetch_stages - 1) + grid_size = first_stage_size + rest_stage_size * max(0, num_fetch_stages - 1) + + if trace: + max_trace_events = grid_size * 4 + if not ctx.tracing.enabled: + ctx.tracing.enable(max_events=max_trace_events) + else: + ctx.tracing.reset() + + launch_kwargs = {"matrix_instr_nonkdim": 16} + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages + + _hbm_buffer_all_gather_matmul_kernel[(grid_size,)]( + A_sharded, + B, + output_tensor, + bias_ptr, + workspace.aux_buffer, + workspace.locks, + flag_iteration, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_sa_m, + stride_sa_k, + stride_bias, + ctx.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_fetch_sms, + num_m_tiles, + num_tiles_n, + num_k_blocks, + num_k_blocks_local, + k_per_flag, + num_flag_groups_k, + total_gather_tiles, + use_bias, + config.allow_tf32, + num_fetch_stages, + gemm_tiles_per_stage, + first_stage_fetch_sms, + trace, + **launch_kwargs, + ) + + if not async_op: + ctx.barrier() + + if trace: + torch.cuda.synchronize() + workspace.trace_data = _extract_wg_trace( + ctx, + grid_size, + num_fetch_sms=num_fetch_sms, + num_fetch_stages=num_fetch_stages, + total_fetch_wgs=total_fetch_wgs, + num_m_tiles=num_m_tiles, + num_tiles_n=num_tiles_n, + first_stage_fetch_sms=first_stage_fetch_sms, + first_stage_size=first_stage_size, + rest_stage_size=rest_stage_size, + gemm_tiles_per_stage=gemm_tiles_per_stage, + ) + + return workspace diff --git a/iris/ops/config.py b/iris/ops/config.py index 3ca085c31..530df7816 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -32,7 +32,7 @@ class FusedConfig: CCL Parameters (for operations that need collective communication): all_reduce_variant: All-reduce algorithm variant. Options: "atomic", "ring", - "one_shot", "two_shot", "spinlock". Default: "one_shot". + "one_shot", "two_shot", "spinlock". Default: "two_shot". all_reduce_num_rings: Number of concurrent rings (for ring variant). Default: 1. Example: diff --git a/iris/ops/matmul.py b/iris/ops/matmul.py new file mode 100644 index 000000000..b91a514c9 --- /dev/null +++ b/iris/ops/matmul.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Local GEMM operation using tritonBLAS. + +Each rank has input A (M x K) and computes C = A @ B locally. +Output is local (M x N), not gathered across ranks. +""" + +from typing import Optional +import torch + +# Use tritonBLAS for optimized GEMM +from tritonblas.matmul import persistent_matmul_lt, _make_matmul_selector +from tritonblas.config import matmul_preamble as tritonblas_preamble + +from .config import FusedConfig +from .workspace import FusedWorkspace + + +# Removed custom kernel - now using tritonBLAS's optimized persistent_matmul + + +def matmul_preamble( + shmem, + A: torch.Tensor, + B: torch.Tensor, + config: Optional[FusedConfig] = None, +) -> FusedWorkspace: + """Allocate workspace for local matmul (none needed).""" + if config is None: + config = FusedConfig() + + M, K = A.shape + K2, N = B.shape + world_size = shmem.get_num_ranks() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B has K={K2}" + + # No workspace needed for local matmul + return FusedWorkspace( + operation="matmul", + shape=(M, N, K), + dtype=A.dtype, + world_size=world_size, + prepared=True, + ) + + +def matmul( + shmem, + output_tensor: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, +) -> FusedWorkspace: + """ + Local matrix multiplication using tritonBLAS. + + Computes: output = A @ B + bias (local computation only) + + Each rank computes its own local matmul independently. + + Args: + shmem: Iris shmem context + output_tensor: Output tensor C of shape (M, N) + A: Input matrix A of shape (M, K) + B: Input matrix B of shape (K, N) + bias: Optional bias vector (M,) - broadcast across N dimension + async_op: If False, performs barrier at end + config: Optional FusedConfig for tuning + workspace: Optional pre-allocated workspace + num_warps: Optional number of warps (ignored - tritonBLAS chooses) + num_stages: Optional pipeline stages (ignored - tritonBLAS chooses) + + Returns: + FusedWorkspace object + """ + if config is None: + config = FusedConfig() + + M_local, K = A.shape + K2, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B has K={K2}" + + M = M_local + assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + + # Allocate workspace if not provided + if workspace is None: + workspace = matmul_preamble(shmem, A, B, config) + + # Create tritonBLAS selector to choose optimal block sizes + selector = _make_matmul_selector( + M, + N, + K, + A.dtype, + B.dtype, + output_tensor.dtype, + A.device, + streamk=False, # Use persistent kernel + ) + + # Use tritonBLAS with work-stealing for better performance + use_work_stealing = config.work_stealing if hasattr(config, "work_stealing") else False + tritonblas_config = None + + if use_work_stealing: + # Allocate tritonBLAS work-stealing buffers + tritonblas_config = tritonblas_preamble(selector) + tritonblas_config.reset(streamk=False, work_stealing=True) + + # Call tritonBLAS persistent matmul + # Note: tritonBLAS expects bias as (N,) not (M,), so we need to handle this + if bias is not None: + # iris bias is (M,) - needs to be broadcast across N + # tritonBLAS bias is (N,) - broadcast across M + # For now, warn if bias is used - needs different handling + import warnings + + warnings.warn( + "iris matmul bias (M,) is not directly compatible with tritonBLAS bias (N,). " + "Bias will be ignored in this tritonBLAS integration. " + "Consider adding bias manually after matmul." + ) + bias = None + + persistent_matmul_lt( + A, + B, + output_tensor, + selector, + config=tritonblas_config, + bias=bias, # Will be None for now due to dimension mismatch + work_stealing=use_work_stealing, + ) + + if not async_op: + shmem.barrier() + + return workspace diff --git a/iris/ops/matmul_all_gather.py b/iris/ops/matmul_all_gather.py index ad42ac041..bf34e31b8 100644 --- a/iris/ops/matmul_all_gather.py +++ b/iris/ops/matmul_all_gather.py @@ -17,7 +17,6 @@ import iris import iris.x -from tritonblas.kernels.stages import GemmContext, ScheduleContext, make_tensor_view from .config import FusedConfig from .workspace import FusedWorkspace @@ -49,6 +48,9 @@ def _fused_matmul_all_gather_kernel( GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr, NUM_XCDS: tl.constexpr, + NUM_M_TILES: tl.constexpr, + NUM_TILES_N: tl.constexpr, + NUM_K_BLOCKS: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr, ALLOW_TF32: tl.constexpr, @@ -59,37 +61,56 @@ def _fused_matmul_all_gather_kernel( Computes local GEMM tile and immediately scatters to all ranks. No intermediate buffer needed - direct from registers to remote memory. """ - # ═══════════════════════════════════════════════════════════════════════ - # Create tritonblas views, context, and scheduler for GEMM - # ═══════════════════════════════════════════════════════════════════════ - tensorA = make_tensor_view(A, M_local, K, stride_am, stride_ak) - tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn) - gemm_ctx = GemmContext( - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, - num_sms=NUM_SMS, - num_xcds=NUM_XCDS, - group_size_m=GROUP_SIZE_M, - even_k=EVEN_K, - allow_tf32=ALLOW_TF32, - ) - sched = ScheduleContext(M_local, N, K, gemm_ctx) + pid = tl.program_id(0) # Persistent loop over local tiles using scheduler - start, total, stride = sched.persistent_tile_range() + start = pid + total = NUM_M_TILES * NUM_TILES_N + stride = NUM_SMS for tile_id in range(start, total, stride): - # Get tile coordinates with swizzling from scheduler - out_tile = sched.get_tile_from_idx(tile_id) - - # GEMM using tritonblas stages - acc = gemm_ctx.reduce_axis(tensorA, tensorB, out_tile) - - # Add bias if provided + # Wave-aware tile assignment (similar to hbm_buffer's group-based assignment) + num_pid_in_group = GROUP_SIZE_M * NUM_TILES_N + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + first_pid_m = min(first_pid_m, NUM_M_TILES - 1) + group_sz = min(NUM_M_TILES - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_sz) + pid_n = (tile_id % num_pid_in_group) // group_sz + pid_m = min(pid_m, NUM_M_TILES - 1) + + # M and N tile indices + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Initialize accumulator for this tile (must be inside the persistent loop!) + acc_dtype = tl.int32 if C_gathered.type.element_ty == tl.int8 else tl.float32 + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + for k_block_idx in range(NUM_K_BLOCKS): + # Load A from selected buffer + rk = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + a_ptrs = A + rm.to(tl.int64)[:, None] * stride_am + rk[None, :] * stride_ak + a = tl.load(a_ptrs) + + # Load B at global K position + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(B_ptrs) + + # Accumulate + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # ================================================================== + # Write output + # ================================================================== if BIAS: - rm, _ = out_tile.indices() - bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M_local, other=0.0) - acc = acc + bias_vector[:, None] + bias_val = tl.load(bias_ptr + rm * stride_bias, mask=rm < M_local, other=0.0) + acc = acc + bias_val[:, None] # Convert to output dtype c = acc.to(C_gathered.type.element_ty) @@ -97,7 +118,7 @@ def _fused_matmul_all_gather_kernel( # Create DeviceContext and destination TensorView for all-gather ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) dst_view = iris.x.make_tensor_view(C_gathered, M, N, stride_cm_gathered, stride_cn_gathered) - tile_obj = iris.x.Tile(out_tile.pid_m, out_tile.pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) + tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) # Scatter this tile to all ranks using iris.x.all_gather # dim=0 means scatter along M dimension (rows) @@ -217,6 +238,12 @@ def matmul_all_gather( even_k = K % config.block_size_k == 0 + # Calculate number of tiles + num_k_blocks = (K + config.block_size_k - 1) // config.block_size_k + num_tiles_m = (M_local + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + num_tiles = num_tiles_m * num_tiles_n + # Launch single fused kernel grid = (num_sms,) _fused_matmul_all_gather_kernel[grid]( @@ -244,6 +271,9 @@ def matmul_all_gather( config.group_size_m, num_sms, config.num_xcds, + num_tiles_m, + num_tiles_n, + num_k_blocks, use_bias, even_k, config.allow_tf32, diff --git a/iris/ops/matmul_all_gather_copy_engine.py b/iris/ops/matmul_all_gather_copy_engine.py new file mode 100644 index 000000000..9ad8fb50f --- /dev/null +++ b/iris/ops/matmul_all_gather_copy_engine.py @@ -0,0 +1,489 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused GEMM + All-Gather operation using SDMA (copy engine) for scatter. + +Each rank has a row-sharded input A_local (M_local x K) and computes C_local = A_local @ B. +Then scatters C_local tiles to form the full C (M x N) where M = world_size * M_local. + +This variant uses SDMA hardware for data movement instead of compute shader scatter. +""" + +from typing import Optional +import time +import torch +import triton +import triton.language as tl +import iris + +from .workspace import FusedWorkspace +from tritonblas.matmul import persistent_matmul_lt, create_counter_config +from .tritonblas_launch_wave_schedule import build_launch_wave_plan + + +@triton.jit() +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + +@triton.jit() +def _launch_wave_wait_poster_kernel( + C_local_base, + flags, + completion_signals, + flag_iteration, + wave_tile_counts, + wave_transfer_offsets, + wave_transfer_counts, + transfer_row_offsets, + transfer_col_offsets, + transfer_width_bytes, + transfer_heights, + heap_bases: tl.tensor, + copy_engine_ctx: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + NUM_WAVES: tl.constexpr, + MAX_RECTS_PER_WAVE: tl.constexpr, + SRC_PITCH: tl.constexpr, + DST_PITCH: tl.constexpr, + STRIDE_N: tl.constexpr, +): + dst_rank = tl.program_id(0) + if dst_rank >= world_size or dst_rank == cur_rank: + return + + ptr_dtype = C_local_base.dtype.element_ty + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + elem_size = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32: + elem_size = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64: + elem_size = 8 + else: + elem_size = 4 + + for wave_id in range(NUM_WAVES): + transfer_start = tl.load(wave_transfer_offsets + wave_id) + transfer_count = tl.load(wave_transfer_counts + wave_id) + if transfer_count != 0: + wait_value = (flag_iteration + 1) * tl.load(wave_tile_counts + wave_id) + is_last_wave = wave_id == (NUM_WAVES - 1) + + if is_last_wave: + iris.wait_then_put_signal_rects( + C_local_base, + C_local_base, + cur_rank, + dst_rank, + heap_bases, + copy_engine_ctx, + flags + wave_id, + wait_value, + completion_signals + cur_rank, + 1, + transfer_row_offsets, + transfer_col_offsets, + transfer_width_bytes, + transfer_heights, + transfer_start, + transfer_count, + STRIDE_N * elem_size, + SRC_PITCH * elem_size, + DST_PITCH * elem_size, + MAX_RECTS_PER_WAVE, + ) + else: + iris.wait_then_put_rects( + C_local_base, + C_local_base, + cur_rank, + dst_rank, + heap_bases, + copy_engine_ctx, + flags + wave_id, + wait_value, + transfer_row_offsets, + transfer_col_offsets, + transfer_width_bytes, + transfer_heights, + transfer_start, + transfer_count, + STRIDE_N * elem_size, + SRC_PITCH * elem_size, + DST_PITCH * elem_size, + MAX_RECTS_PER_WAVE, + ) + + +@triton.jit() +def _wait_completion_signals_kernel( + completion_signals, + expected_value, + cur_rank: tl.constexpr, + world_size: tl.constexpr, +): + src_rank = tl.program_id(0) + if src_rank >= world_size or src_rank == cur_rank: + return + while tl.load(completion_signals + src_rank, cache_modifier=".cv", volatile=True) < expected_value: + pass + + +def _selector_wave_size(selector, device: torch.device) -> int: + wave_size = getattr(selector, "_ACTIVE_CU", None) + if wave_size is None or wave_size <= 0: + wave_size = torch.cuda.get_device_properties(device).multi_processor_count + return int(wave_size) + + +def _ensure_tritonblas_launch_wave_workspace( + shmem, + workspace: FusedWorkspace, + selector, + device: torch.device, + m_local: int, + n: int, + world_size: int, + elem_size: int, +): + num_tiles_m = triton.cdiv(m_local, selector.block_m) + num_tiles_n = triton.cdiv(n, selector.block_n) + total_tiles = num_tiles_m * num_tiles_n + wave_size = _selector_wave_size(selector, device) + num_xcds = max(1, int(getattr(selector, "num_sms", 1))) + plan_key = (num_tiles_m, num_tiles_n, selector.group_m, total_tiles, wave_size, num_xcds) + + if getattr(workspace, "launch_wave_plan_key", None) != plan_key: + plan = build_launch_wave_plan( + num_tiles_m=num_tiles_m, + num_tiles_n=num_tiles_n, + group_size_m=selector.group_m, + launch_grid=total_tiles, + wave_size=wave_size, + num_xcds=num_xcds, + ) + workspace.launch_wave_plan = plan + workspace.launch_wave_plan_key = plan_key + workspace.locks = shmem.zeros((plan.num_waves,), dtype=torch.int32) + transfers_by_wave = [[] for _ in range(plan.num_waves)] + for transfer in plan.transfers: + transfers_by_wave[transfer.wave_id].append(transfer) + + wave_transfer_offsets = [] + wave_transfer_counts = [] + transfer_row_offsets = [] + transfer_col_offsets = [] + transfer_width_bytes = [] + transfer_heights = [] + max_rects_per_wave = 0 + running_offset = 0 + + for wave_transfers in transfers_by_wave: + wave_transfer_offsets.append(running_offset) + wave_transfer_counts.append(len(wave_transfers)) + max_rects_per_wave = max(max_rects_per_wave, len(wave_transfers)) + for transfer in wave_transfers: + row_offset = transfer.m_tile_start * selector.block_m + col_offset = transfer.n_tile_start * selector.block_n + batch_height = min(transfer.m_tile_count * selector.block_m, m_local - row_offset) + batch_width = min(transfer.n_tile_count * selector.block_n, n - col_offset) + transfer_row_offsets.append(row_offset) + transfer_col_offsets.append(col_offset) + transfer_width_bytes.append(batch_width * elem_size) + transfer_heights.append(batch_height) + running_offset += len(wave_transfers) + + workspace.wave_transfer_offsets = torch.tensor(wave_transfer_offsets, device=device, dtype=torch.int32) + workspace.wave_transfer_counts = torch.tensor(wave_transfer_counts, device=device, dtype=torch.int32) + workspace.transfer_row_offsets = torch.tensor(transfer_row_offsets, device=device, dtype=torch.int32) + workspace.transfer_col_offsets = torch.tensor(transfer_col_offsets, device=device, dtype=torch.int32) + workspace.transfer_width_bytes = torch.tensor(transfer_width_bytes, device=device, dtype=torch.int32) + workspace.transfer_heights = torch.tensor(transfer_heights, device=device, dtype=torch.int32) + workspace.wave_tile_counts = torch.tensor(plan.wave_tile_counts, device=device, dtype=torch.int32) + workspace.num_tiles_m = num_tiles_m + workspace.num_tiles_n = num_tiles_n + workspace.num_waves = plan.num_waves + workspace.num_batches = plan.num_waves + workspace.num_transfers = len(plan.transfers) + workspace.max_rects_per_wave = max_rects_per_wave + if getattr(workspace, "completion_signals", None) is None or workspace.completion_signals.numel() != world_size: + workspace.completion_signals = shmem.zeros((world_size,), dtype=torch.int32) + return workspace.launch_wave_plan + + +def _auto_m_tiles_per_batch(selector, M_local: int, N: int) -> int: + """Auto-calculate optimal m_tiles_per_batch based on selector and shape.""" + num_tiles_m = (M_local + selector.block_m - 1) // selector.block_m + num_tiles_n = (N + selector.block_n - 1) // selector.block_n + active_cus = getattr(selector, "_ACTIVE_CU", None) + if active_cus is None or active_cus <= 0: + active_cus = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + tiles_per_group = max(1, selector.group_m * num_tiles_n) + groups_per_wave = max(1, int(active_cus) // tiles_per_group) + return max(1, min(num_tiles_m, groups_per_wave * selector.group_m)) + + +def matmul_all_gather_copy_engine_preamble( + shmem, + A: torch.Tensor, + B: torch.Tensor, + m_tiles_per_batch: Optional[int] = None, + selector=None, +) -> FusedWorkspace: + """Allocate workspace for matmul_all_gather_copy_engine including per-batch flags.""" + from tritonblas.matmul import _make_matmul_selector + + M_local, K = A.shape + K2, N = B.shape + world_size = shmem.get_num_ranks() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B has K={K2}" + + M = M_local * world_size + + # Create selector for block size configuration if not provided + if selector is None: + selector = _make_matmul_selector( + M_local, + N, + K, + A.dtype, + B.dtype, + A.dtype, + A.device, + streamk=False, + ) + + # Auto-calculate optimal m_tiles_per_batch if not provided + if m_tiles_per_batch is None: + m_tiles_per_batch = _auto_m_tiles_per_batch(selector, M_local, N) + + # Calculate number of tiles based on selector + num_tiles_m = (M_local + selector.block_m - 1) // selector.block_m + num_tiles_n = (N + selector.block_n - 1) // selector.block_n + + num_tiles = num_tiles_m * num_tiles_n + + # Calculate number of batches + num_batches = (num_tiles_m + m_tiles_per_batch - 1) // m_tiles_per_batch + + ws = FusedWorkspace( + operation="matmul_all_gather_copy_engine", + shape=(M, N, K), + dtype=A.dtype, + world_size=world_size, + prepared=True, + ) + + # Allocate one readiness counter per M-batch. + ws.locks = shmem.zeros((num_batches,), dtype=torch.int32) + ws.completion_signals = shmem.zeros((world_size,), dtype=torch.int32) + + # Store metadata for later use + ws.selector = selector + ws.num_tiles_m = num_tiles_m + ws.num_tiles_n = num_tiles_n + ws.num_batches = num_batches + ws.m_tiles_per_batch = m_tiles_per_batch + + return ws + + +def matmul_all_gather_copy_engine( + shmem, + output_tensor: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + workspace: Optional[FusedWorkspace] = None, + flag_iteration: int = 0, + verbose: bool = False, +) -> FusedWorkspace: + """ + Fused matrix multiplication and all-gather using SDMA (copy engine) for scatter. + + Computes: output = all_gather(A @ B + bias) along M dimension + + Each rank has A of shape (M_local, K) where M_local = M / world_size. + The operation computes C_local = A @ B on each rank and uses SDMA hardware + to scatter the tiles to all ranks (all-gather pattern). + + Args: + shmem: Iris shmem context + output_tensor: Output tensor C of shape (M, N) where M = M_local * world_size + A: Input matrix A of shape (M_local, K) + B: Input matrix B of shape (K, N) + bias: Optional bias vector (M_local,) + async_op: If False, performs barrier at end + workspace: Optional pre-allocated workspace + flag_iteration: Launch generation for cumulative batch counters. + Batch readiness counters are not reset each iteration; + the poster waits for the generation-adjusted target. + verbose: If True, print poster/main/quiet timing breakdown in sync mode + + Returns: + FusedWorkspace object + """ + M_local, K = A.shape + K2, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B has K={K2}" + + M = M_local * world_size + assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + + timing_events = None + cpu_timing = None + if verbose: + current_stream = torch.cuda.current_stream(device=A.device) + timing_events = { + "stream": current_stream, + "poster_start": torch.cuda.Event(enable_timing=True), + "poster_end": torch.cuda.Event(enable_timing=True), + "main_start": torch.cuda.Event(enable_timing=True), + "main_end": torch.cuda.Event(enable_timing=True), + "quiet_start": torch.cuda.Event(enable_timing=True), + "quiet_end": torch.cuda.Event(enable_timing=True), + } + cpu_timing = { + "total_start": time.perf_counter(), + "poster_launch_ms": 0.0, + "main_launch_ms": 0.0, + "quiet_cpu_ms": 0.0, + "sync_wait_ms": 0.0, + } + + # Allocate workspace if not provided + if workspace is None: + workspace = matmul_all_gather_copy_engine_preamble(shmem, A, B) + + stride_cm, stride_cn = output_tensor.stride() + + selector = workspace.selector + + launch_wave_plan = _ensure_tritonblas_launch_wave_workspace( + shmem, + workspace, + selector, + A.device, + M_local, + N, + world_size, + output_tensor.element_size(), + ) + + # Get metadata from workspace after any schedule planning. + num_batches = workspace.num_batches + + if timing_events is not None: + timing_events["poster_start"].record(timing_events["stream"]) + poster_launch_start = time.perf_counter() if cpu_timing is not None else None + poster_grid = (world_size,) + c_local_base = output_tensor[rank * M_local :, :] + _launch_wave_wait_poster_kernel[poster_grid]( + c_local_base, + workspace.locks, + workspace.completion_signals, + flag_iteration, + workspace.wave_tile_counts, + workspace.wave_transfer_offsets, + workspace.wave_transfer_counts, + workspace.transfer_row_offsets, + workspace.transfer_col_offsets, + workspace.transfer_width_bytes, + workspace.transfer_heights, + shmem.get_heap_bases(), + shmem.get_copy_engine_ctx(), + rank, + world_size, + workspace.num_waves, + workspace.max_rects_per_wave, + stride_cm, + stride_cm, + stride_cn, + ) + if cpu_timing is not None: + cpu_timing["poster_launch_ms"] = (time.perf_counter() - poster_launch_start) * 1000.0 + if timing_events is not None: + timing_events["poster_end"].record(timing_events["stream"]) + + # Launch GEMM after poster submission so SDMA can wait autonomously. + if timing_events is not None: + timing_events["main_start"].record(timing_events["stream"]) + main_launch_start = time.perf_counter() if cpu_timing is not None else None + if bias is not None: + import warnings + + warnings.warn( + "Bias is not yet supported in the tritonBLAS SignalView path for " + "matmul_all_gather_copy_engine. Ignoring bias for this launch." + ) + + counter_config = create_counter_config( + workspace.locks, + map_type="launch_wave", + block_group_m=launch_wave_plan.wave_size, + ) + c_local_view = output_tensor[rank * M_local : (rank + 1) * M_local, :] + persistent_matmul_lt( + A, + B, + c_local_view, + selector, + bias=None, + work_stealing=False, + counter_config=counter_config, + ) + if cpu_timing is not None: + cpu_timing["main_launch_ms"] = (time.perf_counter() - main_launch_start) * 1000.0 + if timing_events is not None: + timing_events["main_end"].record(timing_events["stream"]) + + if not async_op: + wait_cpu_start = time.perf_counter() if cpu_timing is not None else None + if timing_events is not None: + timing_events["quiet_start"].record(timing_events["stream"]) + _wait_completion_signals_kernel[(world_size,)]( + workspace.completion_signals, + flag_iteration + 1, + rank, + world_size, + ) + if timing_events is not None: + timing_events["quiet_end"].record(timing_events["stream"]) + sync_wait_start = time.perf_counter() if cpu_timing is not None else None + shmem.barrier() + if cpu_timing is not None: + cpu_timing["sync_wait_ms"] = (time.perf_counter() - sync_wait_start) * 1000.0 + cpu_timing["quiet_cpu_ms"] = (time.perf_counter() - wait_cpu_start) * 1000.0 + + if verbose and rank == 0: + poster_ms = timing_events["poster_start"].elapsed_time(timing_events["poster_end"]) + main_ms = timing_events["main_start"].elapsed_time(timing_events["main_end"]) + quiet_ms = timing_events["quiet_start"].elapsed_time(timing_events["quiet_end"]) + gpu_total_ms = poster_ms + main_ms + quiet_ms + cpu_launch_total_ms = cpu_timing["poster_launch_ms"] + cpu_timing["main_launch_ms"] + cpu_total_ms = (time.perf_counter() - cpu_timing["total_start"]) * 1000.0 + tile_transfer_count = (world_size - 1) * getattr(workspace, "num_transfers", num_batches) + shmem.info( + f"[Rank {rank}] Device copy-engine GPU timing. " + f"Poster: {poster_ms:.2f}ms, Main: {main_ms:.2f}ms, Wait: {quiet_ms:.2f}ms, " + f"GPU total: {gpu_total_ms:.2f}ms" + ) + shmem.info( + f"[Rank {rank}] Device copy-engine CPU timing. " + f"Poster launch: {cpu_timing['poster_launch_ms']:.2f}ms, " + f"Main launch: {cpu_timing['main_launch_ms']:.2f}ms, " + f"Launch total: {cpu_launch_total_ms:.2f}ms, " + f"Wait: {cpu_timing['quiet_cpu_ms']:.2f}ms, " + f"Barrier wait: {cpu_timing['sync_wait_ms']:.2f}ms, " + f"CPU total: {cpu_total_ms:.2f}ms, " + f"transfers={tile_transfer_count}" + ) + + return workspace diff --git a/iris/ops/matmul_all_gather_host_copy_engine.py b/iris/ops/matmul_all_gather_host_copy_engine.py new file mode 100644 index 000000000..1d6786094 --- /dev/null +++ b/iris/ops/matmul_all_gather_host_copy_engine.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused GEMM + All-Gather operation using host-initiated SDMA with POLL packets. + +Each rank has a row-sharded input A_local (M_local x K) and computes C_local = A_local @ B. +The host pre-queues SDMA POLL+COPY packets for scatter, then the device kernel just stores +tiles to local HBM and sets flags to trigger the pre-queued transfers. + +This is more efficient than device-initiated SDMA because: +- SDMA queue setup happens once on host (not per-tile) +- Device kernel is lightweight (store + set flag) +- SDMA hardware automatically performs scatter when flags are set + +This implementation supports two backends: +- Custom Triton kernel (legacy, controlled by use_tritonblas=False) +- tritonBLAS with SignalView (default, use_tritonblas=True) +""" + +from typing import Optional +import torch +import triton +import triton.language as tl + +from .workspace import FusedWorkspace + +# Import tritonBLAS +from tritonblas.matmul import persistent_matmul_lt +from tritonblas.matmul import create_counter_config +from tritonblas.matmul import _make_matmul_selector +from .tritonblas_launch_wave_schedule import build_launch_wave_plan + +# Import Tile class from anvil module +try: + import anvil + + Tile = anvil.Tile +except (ImportError, AttributeError): + Tile = None # Will raise error later if needed + + +@triton.jit() +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + +# Event IDs (must match iris.tracing.events.TraceEvent) +_WG_GEMM = 15 + + +@triton.jit() +def _wait_completion_signals_kernel( + completion_signals, + expected_value, + cur_rank: tl.constexpr, + world_size: tl.constexpr, +): + src_rank = tl.program_id(0) + if src_rank >= world_size or src_rank == cur_rank: + return + # while tl.atomic_add(completion_signals + src_rank, 0, sem="acquire", scope="sys") < expected_value: + while tl.load(completion_signals + src_rank, cache_modifier=".cv", volatile=True) < expected_value: + pass + + +def _extract_wg_trace(shmem, grid_size, num_tiles, sdma_timestamps=None, **metadata): + """Extract per-tile trace data from DeviceTracing events. + + For host-initiated SDMA: + - Each tile generates trace events (not per workgroup) + - SDMA timestamps captured by host via timestamp packets + """ + import numpy as np + + bufs = shmem.tracing.trace_buffers + n = min(shmem.tracing.trace_counter.item(), shmem.tracing.max_events) + + event_ids = bufs["event_id"][:n].cpu().numpy() + pid_ms = bufs["pid_m"][:n].cpu().numpy() # tile_id (not workgroup pid) + timestamps = bufs["timestamp"][:n].cpu().numpy().astype(np.int64) + end_ts = bufs["duration_cycles"][:n].cpu().numpy().astype(np.int64) + xcc_ids = bufs["xcc_id"][:n].cpu().numpy().astype(np.int32) + + # Per-tile traces + starts = torch.zeros(num_tiles, dtype=torch.int64) + ends = torch.zeros(num_tiles, dtype=torch.int64) + waits = torch.zeros(num_tiles, dtype=torch.int64) # Not used but needed for plot + xcds = torch.zeros(num_tiles, dtype=torch.int32) + + for i in range(n): + eid = int(event_ids[i]) + tile_id = int(pid_ms[i]) + + if eid == _WG_GEMM and tile_id < num_tiles: + starts[tile_id] = int(timestamps[i]) + ends[tile_id] = int(end_ts[i]) + xcds[tile_id] = int(xcc_ids[i]) + + result = {"start": starts, "end": ends, "wait": waits, "xcd": xcds, "grid_size": num_tiles, **metadata} + + # Add SDMA timestamps if available (world_size x 2: start/end per rank) + if sdma_timestamps is not None: + result["sdma_timestamps"] = sdma_timestamps.cpu() + + return result + + +def _auto_m_tiles_per_batch(selector, M_local: int, N: int) -> int: + """Auto-calculate optimal m_tiles_per_batch based on selector and shape.""" + num_tiles_m = (M_local + selector.block_m - 1) // selector.block_m + num_tiles_n = (N + selector.block_n - 1) // selector.block_n + active_cus = getattr(selector, "_ACTIVE_CU", None) + if active_cus is None or active_cus <= 0: + import torch + + active_cus = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + tiles_per_group = max(1, selector.group_m * num_tiles_n) + groups_per_wave = max(1, int(active_cus) // tiles_per_group) + return max(1, min(num_tiles_m, groups_per_wave * selector.group_m)) + + +def matmul_all_gather_host_copy_engine_preamble( + shmem, + A: torch.Tensor, + B: torch.Tensor, + m_tiles_per_batch: Optional[int] = None, + trace: bool = False, + selector=None, +) -> FusedWorkspace: + """Allocate workspace for matmul_all_gather_host_copy_engine.""" + M_local, K = A.shape + K2, N = B.shape + world_size = shmem.get_num_ranks() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B has K={K2}" + + M = M_local * world_size + + ws = FusedWorkspace( + operation="matmul_all_gather_host_copy_engine", + shape=(M, N, K), + dtype=A.dtype, + world_size=world_size, + prepared=True, + ) + + if selector is None: + selector = _make_matmul_selector( + M_local, + N, + K, + A.dtype, + B.dtype, + A.dtype, + A.device, + streamk=False, + ) + + # Auto-calculate optimal m_tiles_per_batch if not provided + if m_tiles_per_batch is None: + m_tiles_per_batch = _auto_m_tiles_per_batch(selector, M_local, N) + num_tiles_m = triton.cdiv(M_local, selector.block_m) + num_tiles_n = triton.cdiv(N, selector.block_n) + launch_wave_plan = build_launch_wave_plan( + num_tiles_m=num_tiles_m, + num_tiles_n=num_tiles_n, + group_size_m=selector.group_m, + launch_grid=num_tiles_m * num_tiles_n, + wave_size=selector._ACTIVE_CU, + num_xcds=selector.num_sms, + ) + ws.selector = selector + ws.launch_wave_plan = launch_wave_plan + ws.locks = shmem.zeros((launch_wave_plan.num_waves,), dtype=torch.int32) + ws.num_tiles_m = num_tiles_m + ws.num_tiles_n = num_tiles_n + ws.num_batches = launch_wave_plan.num_waves + ws.num_waves = launch_wave_plan.num_waves + ws.m_tiles_per_batch = m_tiles_per_batch + + ws.completion_signals = shmem.zeros((world_size,), dtype=torch.int32) + + return ws + + +def matmul_all_gather_host_copy_engine( + shmem, + output_tensor: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + workspace: Optional[FusedWorkspace] = None, + flag_iteration: int = 0, + trace: bool = False, + verbose: bool = False, +) -> FusedWorkspace: + """ + Fused matrix multiplication and all-gather using host-initiated SDMA with POLL packets. + + Computes: output = all_gather(A @ B + bias) along M dimension + + Each rank has A of shape (M_local, K) where M_local = M / world_size. + The host pre-queues SDMA POLL+COPY packets for all tiles and ranks. + The device kernel computes tiles, stores to local HBM, then sets flags. + SDMA hardware automatically performs scatter when flags are set. + + Args: + shmem: Iris shmem context + output_tensor: Output tensor C of shape (M, N) where M = M_local * world_size + A: Input matrix A of shape (M_local, K) + B: Input matrix B of shape (K, N) + bias: Optional bias vector (M_local,) + async_op: If False, performs barrier at end + workspace: Optional pre-allocated workspace + + Returns: + FusedWorkspace object + """ + M_local, K = A.shape + K2, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B has K={K2}" + + M = M_local * world_size + assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + + # Allocate workspace if not provided + if workspace is None: + workspace = matmul_all_gather_host_copy_engine_preamble(shmem, A, B, trace=trace) + + stride_cm, stride_cn = output_tensor.stride() + device = A.device + + selector = workspace.selector + launch_wave_plan = getattr(workspace, "launch_wave_plan", None) + num_tiles_m = workspace.num_tiles_m + num_tiles_n = workspace.num_tiles_n + num_batches = workspace.num_batches + if launch_wave_plan is None: + raise ValueError("workspace.launch_wave_plan must be initialized in preamble") + + # ═══════════════════════════════════════════════════════════════════════ + # Device Phase: Compute GEMM + store + set flags + # ═══════════════════════════════════════════════════════════════════════ + # tritonBLAS path with SignalView + selector = workspace.selector + + # Create a view of output_tensor at this rank's position + C_local_view = output_tensor[rank * M_local : (rank + 1) * M_local, :] + + counter_config = create_counter_config( + workspace.locks, + map_type="launch_wave", + block_group_m=launch_wave_plan.wave_size, + ) + + # Use work-stealing if enabled + tritonblas_config = None + + # Warn about bias + if bias is not None: + import warnings + + warnings.warn("Bias is not yet supported in tritonBLAS integration. Consider adding bias manually after GEMM.") + + # Launch tritonBLAS GEMM with SignalView + persistent_matmul_lt( + A, + B, + C_local_view, + selector, + config=tritonblas_config, + bias=None, + work_stealing=False, + counter_config=counter_config, + ) + + # ═══════════════════════════════════════════════════════════════════════ + # Host Phase: Enqueue SDMA POLL+COPY packets for all tiles + # (While kernel is running in parallel on device) + # ═══════════════════════════════════════════════════════════════════════ + import time + + element_size = output_tensor.element_size() + anvil_lib = shmem.copy_engines + + if verbose and rank == 0: + shmem.info( + f"[Rank {rank}] Starting SDMA loop (launch-wave transfers)... " + f"num_m_tiles={num_tiles_m}, num_tiles_n={num_tiles_n}, " + f"wave_size={launch_wave_plan.wave_size}" + ) + shmem.info( + f"[Rank {rank}] Will transfer in {launch_wave_plan.num_waves} waves " + f"across {len(launch_wave_plan.transfers)} rects" + ) + + sdma_start_time = time.perf_counter() + tile_transfer_count = 0 + + # Get block size from workspace selector + block_size_m = workspace.selector.block_m + block_size_n = workspace.selector.block_n + + signal_ptr_local = workspace.completion_signals.data_ptr() + rank * workspace.completion_signals.element_size() + transfers_by_wave = [[] for _ in range(launch_wave_plan.num_waves)] + for transfer in launch_wave_plan.transfers: + transfers_by_wave[transfer.wave_id].append(transfer) + + for wave_id, wave_transfers in enumerate(transfers_by_wave): + if not wave_transfers: + continue + + expected_flag_value = (flag_iteration + 1) * launch_wave_plan.wave_tile_counts[wave_id] + wait_flag_ptr = workspace.locks.data_ptr() + wave_id * workspace.locks.element_size() + is_last_wave = wave_id == (launch_wave_plan.num_waves - 1) + + tiles = [] + dst_ptrs_local = [] + dst_strides = [] + + for transfer in wave_transfers: + m_start = transfer.m_tile_start * block_size_m + n_start = transfer.n_tile_start * block_size_n + batch_height = min(transfer.m_tile_count * block_size_m, M_local - m_start) + batch_width = min(transfer.n_tile_count * block_size_n, N - n_start) + + tile_obj = Tile() + tile_obj.pid_m = 0 + tile_obj.pid_n = 0 + tile_obj.block_m = batch_height + tile_obj.block_n = batch_width + tile_obj.elem_size = element_size + tile_obj.src_stride = stride_cm * element_size + + src_offset = (m_start + rank * M_local) * stride_cm + n_start * stride_cn + tile_obj.data = output_tensor.data_ptr() + src_offset * element_size + dst_offset_local = (m_start + rank * M_local) * stride_cm + n_start * stride_cn + + tiles.append(tile_obj) + dst_ptrs_local.append(output_tensor.data_ptr() + dst_offset_local * element_size) + dst_strides.append(stride_cm * element_size) + + for remote_rank in range(world_size): + if remote_rank == rank: + continue + + dst_ptrs_remote = [shmem.translate(dst_ptr_local, rank, remote_rank) for dst_ptr_local in dst_ptrs_local] + signal_ptr_remote = None + if is_last_wave: + signal_ptr_remote = shmem.translate(signal_ptr_local, rank, remote_rank) + + shmem.put_tiles( + tiles, + dst_rank=remote_rank, + dst_ptrs=dst_ptrs_remote, + dst_strides=dst_strides, + wait_flag=wait_flag_ptr, + wait_value=expected_flag_value, + signal_flag=signal_ptr_remote, + signal_value=1, + async_op=True, + channel=0, + ) + tile_transfer_count += len(wave_transfers) + + if verbose and wave_id < 2 and remote_rank == (rank + 1) % world_size: + shmem.info( + f"[Rank {rank}→{remote_rank}] Queued wave={wave_id} " + f"transfers={len(wave_transfers)} tiles={launch_wave_plan.wave_tile_counts[wave_id]}" + ) + + sdma_end_post_time = time.perf_counter() + + if not async_op: + _wait_completion_signals_kernel[(world_size,)]( + workspace.completion_signals, + flag_iteration + 1, + rank, + world_size, + ) + sdma_end_time = time.perf_counter() + shmem.barrier() + if verbose and rank == 0: + post_ms = (sdma_end_post_time - sdma_start_time) * 1000.0 + quiet_ms = (sdma_end_time - sdma_end_post_time) * 1000.0 + total_ms = (sdma_end_time - sdma_start_time) * 1000.0 + shmem.info( + f"[Rank {rank}] SDMA complete. " + f"Post: {post_ms:.2f}ms, Wait: {quiet_ms:.2f}ms, Total: {total_ms:.2f}ms, " + f"transfers={tile_transfer_count}" + ) + + return workspace diff --git a/iris/ops/tritonblas_launch_wave_schedule.py b/iris/ops/tritonblas_launch_wave_schedule.py new file mode 100644 index 000000000..e52a68438 --- /dev/null +++ b/iris/ops/tritonblas_launch_wave_schedule.py @@ -0,0 +1,301 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Helpers for describing the tritonBLAS launch-wave tile schedule. + +The current tritonBLAS GEMM path launches the full logical grid, applies the +chunked XCD remap to each launched program id, and each launched program +produces one tile. This helper mirrors that launch order and coalesces the +tiles issued by each hardware wave of ``wave_size`` launched programs into one +or more rectangular transfers. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def chiplet_transform_chunked(pid: int, num_workgroups: int, num_xcds: int, chunk_size: int) -> int: + if num_xcds <= 1 or chunk_size <= 0: + return pid + if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): + return pid + + local_pid = pid // num_xcds + chunk_idx = local_pid // chunk_size + pos_in_chunk = local_pid % chunk_size + xcd = pid % num_xcds + return chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk + + +def default_chunk_size(total_tiles: int, group_size_m: int, num_xcds: int) -> int: + chunk_size = group_size_m * group_size_m + if num_xcds > 0: + chunk_size = min(chunk_size, max(1, total_tiles // num_xcds)) + return max(1, chunk_size) + + +def grouped_tile_coords(tile_id: int, num_tiles_m: int, num_tiles_n: int, group_size_m: int) -> tuple[int, int, int]: + num_pid_in_group = group_size_m * num_tiles_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * group_size_m + actual_group_size_m = min(num_tiles_m - first_pid_m, group_size_m) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % actual_group_size_m) + pid_n = (tile_id % num_pid_in_group) // actual_group_size_m + return pid_m, pid_n, group_id + + +@dataclass(frozen=True) +class LaunchWaveTransfer: + wave_id: int + group_id: int + m_tile_start: int + n_tile_start: int + m_tile_count: int + n_tile_count: int + tile_count: int + + +@dataclass(frozen=True) +class LaunchWavePlan: + num_tiles_m: int + num_tiles_n: int + total_tiles: int + launch_grid: int + wave_size: int + num_xcds: int + chunk_size: int + num_waves: int + wave_tile_counts: tuple[int, ...] + transfers: tuple[LaunchWaveTransfer, ...] + + +def _coalesce_group_columns( + wave_id: int, + group_id: int, + first_pid_m: int, + columns: dict[int, set[int]], +) -> list[LaunchWaveTransfer]: + transfers: list[LaunchWaveTransfer] = [] + merged: list[tuple[int, int, int, int]] = [] + + for pid_n in sorted(columns): + local_ms = sorted(columns[pid_n]) + if not local_ms: + continue + + seg_start = local_ms[0] + seg_prev = local_ms[0] + for local_m in local_ms[1:]: + if local_m == seg_prev + 1: + seg_prev = local_m + continue + merged.append((pid_n, seg_start, seg_prev - seg_start + 1, 1)) + seg_start = local_m + seg_prev = local_m + merged.append((pid_n, seg_start, seg_prev - seg_start + 1, 1)) + + for pid_n, local_m_start, m_tile_count, n_tile_count in merged: + if transfers: + prev = transfers[-1] + if ( + prev.group_id == group_id + and prev.n_tile_start + prev.n_tile_count == pid_n + and prev.m_tile_start == first_pid_m + local_m_start + and prev.m_tile_count == m_tile_count + ): + transfers[-1] = LaunchWaveTransfer( + wave_id=prev.wave_id, + group_id=prev.group_id, + m_tile_start=prev.m_tile_start, + n_tile_start=prev.n_tile_start, + m_tile_count=prev.m_tile_count, + n_tile_count=prev.n_tile_count + n_tile_count, + tile_count=prev.tile_count + m_tile_count * n_tile_count, + ) + continue + + transfers.append( + LaunchWaveTransfer( + wave_id=wave_id, + group_id=group_id, + m_tile_start=first_pid_m + local_m_start, + n_tile_start=pid_n, + m_tile_count=m_tile_count, + n_tile_count=n_tile_count, + tile_count=m_tile_count * n_tile_count, + ) + ) + + return transfers + + +def _coalesce_group_rows( + wave_id: int, + group_id: int, + first_pid_m: int, + columns: dict[int, set[int]], +) -> list[LaunchWaveTransfer]: + rows: dict[int, set[int]] = {} + for pid_n, local_ms in columns.items(): + for local_m in local_ms: + rows.setdefault(local_m, set()).add(pid_n) + + transfers: list[LaunchWaveTransfer] = [] + merged: list[tuple[int, int, int, int]] = [] + + for local_m in sorted(rows): + ns = sorted(rows[local_m]) + if not ns: + continue + + seg_start = ns[0] + seg_prev = ns[0] + for pid_n in ns[1:]: + if pid_n == seg_prev + 1: + seg_prev = pid_n + continue + merged.append((local_m, seg_start, 1, seg_prev - seg_start + 1)) + seg_start = pid_n + seg_prev = pid_n + merged.append((local_m, seg_start, 1, seg_prev - seg_start + 1)) + + for local_m_start, n_tile_start, m_tile_count, n_tile_count in merged: + if transfers: + prev = transfers[-1] + if ( + prev.group_id == group_id + and prev.m_tile_start + prev.m_tile_count == first_pid_m + local_m_start + and prev.n_tile_start == n_tile_start + and prev.n_tile_count == n_tile_count + ): + transfers[-1] = LaunchWaveTransfer( + wave_id=prev.wave_id, + group_id=prev.group_id, + m_tile_start=prev.m_tile_start, + n_tile_start=prev.n_tile_start, + m_tile_count=prev.m_tile_count + m_tile_count, + n_tile_count=prev.n_tile_count, + tile_count=prev.tile_count + m_tile_count * n_tile_count, + ) + continue + + transfers.append( + LaunchWaveTransfer( + wave_id=wave_id, + group_id=group_id, + m_tile_start=first_pid_m + local_m_start, + n_tile_start=n_tile_start, + m_tile_count=m_tile_count, + n_tile_count=n_tile_count, + tile_count=m_tile_count * n_tile_count, + ) + ) + + return transfers + + +def build_launch_wave_plan( + num_tiles_m: int, + num_tiles_n: int, + group_size_m: int, + launch_grid: int, + wave_size: int, + num_xcds: int, + chunk_size: int | None = None, + merge_order: str = "column", +) -> LaunchWavePlan: + total_tiles = num_tiles_m * num_tiles_n + if launch_grid <= 0: + raise ValueError("launch_grid must be positive") + if wave_size <= 0: + raise ValueError("wave_size must be positive") + if total_tiles <= 0: + raise ValueError("tile grid must be non-empty") + if merge_order not in {"column", "row"}: + raise ValueError("merge_order must be 'column' or 'row'") + + if chunk_size is None: + chunk_size = default_chunk_size(total_tiles, group_size_m, num_xcds) + + launch_grid = max(launch_grid, total_tiles) + num_waves = ceil_div(launch_grid, wave_size) + wave_tile_counts: list[int] = [] + transfers: list[LaunchWaveTransfer] = [] + + launched_tile_ids = [ + chiplet_transform_chunked(pid, launch_grid, num_xcds, chunk_size) for pid in range(launch_grid) + ] + + for wave_id in range(num_waves): + groups: dict[int, tuple[int, dict[int, set[int]]]] = {} + tiles_in_wave = 0 + + wave_program_start = wave_id * wave_size + wave_program_end = min(wave_program_start + wave_size, launch_grid) + for launch_pid in range(wave_program_start, wave_program_end): + tile_id = launched_tile_ids[launch_pid] + if tile_id >= total_tiles: + continue + + pid_m, pid_n, group_id = grouped_tile_coords(tile_id, num_tiles_m, num_tiles_n, group_size_m) + first_pid_m = group_id * group_size_m + columns = groups.setdefault(group_id, (first_pid_m, {}))[1] + local_m = pid_m - first_pid_m + columns.setdefault(pid_n, set()).add(local_m) + tiles_in_wave += 1 + + wave_tile_counts.append(tiles_in_wave) + if tiles_in_wave == 0: + continue + + for group_id in sorted(groups): + first_pid_m, columns = groups[group_id] + if merge_order == "row": + transfers.extend(_coalesce_group_rows(wave_id, group_id, first_pid_m, columns)) + else: + transfers.extend(_coalesce_group_columns(wave_id, group_id, first_pid_m, columns)) + + return LaunchWavePlan( + num_tiles_m=num_tiles_m, + num_tiles_n=num_tiles_n, + total_tiles=total_tiles, + launch_grid=launch_grid, + wave_size=wave_size, + num_xcds=num_xcds, + chunk_size=chunk_size, + num_waves=num_waves, + wave_tile_counts=tuple(wave_tile_counts), + transfers=tuple(transfers), + ) + + +def build_launch_wave_plan_for_shape( + m: int, + n: int, + block_m: int, + block_n: int, + group_size_m: int, + launch_grid: int, + wave_size: int, + num_xcds: int, + chunk_size: int | None = None, + merge_order: str = "column", +) -> LaunchWavePlan: + num_tiles_m = ceil_div(m, block_m) + num_tiles_n = ceil_div(n, block_n) + return build_launch_wave_plan( + num_tiles_m=num_tiles_m, + num_tiles_n=num_tiles_n, + group_size_m=group_size_m, + launch_grid=launch_grid, + wave_size=wave_size, + num_xcds=num_xcds, + chunk_size=chunk_size, + merge_order=merge_order, + ) diff --git a/iris/ops/workspace.py b/iris/ops/workspace.py index a9c7cb616..855e29862 100644 --- a/iris/ops/workspace.py +++ b/iris/ops/workspace.py @@ -42,6 +42,20 @@ class FusedWorkspace: aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results locks: Optional[torch.Tensor] = None # Synchronization primitives + # Push variant workspace + a_inbox: Optional[torch.Tensor] = None # (world_size, M, K_local) inbox buffer + signal_flags: Optional[torch.Tensor] = None # (world_size, world_size, m_tiles, k_tiles) + completion_signals: Optional[torch.Tensor] = None # (world_size,) sender-completion slots + + # Launch-wave copy-engine metadata + wave_tile_counts: Optional[torch.Tensor] = None + wave_transfer_offsets: Optional[torch.Tensor] = None + wave_transfer_counts: Optional[torch.Tensor] = None + transfer_row_offsets: Optional[torch.Tensor] = None + transfer_col_offsets: Optional[torch.Tensor] = None + transfer_width_bytes: Optional[torch.Tensor] = None + transfer_heights: Optional[torch.Tensor] = None + prepared: bool = False def matches( @@ -82,4 +96,14 @@ def clear(self): """Free all allocated buffers.""" self.aux_buffer = None self.locks = None + self.a_inbox = None + self.signal_flags = None + self.completion_signals = None + self.wave_tile_counts = None + self.wave_transfer_offsets = None + self.wave_transfer_counts = None + self.transfer_row_offsets = None + self.transfer_col_offsets = None + self.transfer_width_bytes = None + self.transfer_heights = None self.prepared = False diff --git a/iris/tracing/core.py b/iris/tracing/core.py index 317fc0bbf..57c007625 100644 --- a/iris/tracing/core.py +++ b/iris/tracing/core.py @@ -208,7 +208,7 @@ def export(self, filename="trace.json", merge=False): "traceEvents": trace_events, "displayTimeUnit": "ns", "metadata": { - "schema_version": "1.1", + "schema_version": "1.2", "num_events": num_events, "rank": self.iris.cur_rank, "world_size": self.iris.num_ranks, @@ -285,7 +285,7 @@ def export(self, filename="trace.json", merge=False): "traceEvents": all_events, "displayTimeUnit": "ns", "metadata": { - "schema_version": "1.1", + "schema_version": "1.2", "total_events": len(all_events), "max_events": self.max_events, "time_unit": "cycles (s_memrealtime @ 100MHz)", diff --git a/iris/tracing/events.py b/iris/tracing/events.py index 4838c09d6..9ff7ef98f 100644 --- a/iris/tracing/events.py +++ b/iris/tracing/events.py @@ -2,6 +2,12 @@ Trace event type IDs and Triton-side enumeration. EVENT_NAMES and TraceEvent must stay in sync: same IDs for the same operations. + +Event ID ranges: + 0–1023 iris ops (data movement, atomics) + 1024–2047 user data movement (fetch/prefetch) + 2048–3071 user compute (compute, reduce) + 3072–4095 synchronization (wait, barrier) """ import triton @@ -12,6 +18,7 @@ # Event type IDs to names mapping (used for export / display). # Keep in sync with TraceEvent below. EVENT_NAMES = { + # iris ops (0–1023) 0: "load", 1: "store", 2: "get", @@ -26,45 +33,58 @@ 11: "atomic_or", 12: "atomic_min", 13: "atomic_max", + # User data movement (1024–2047) + 1024: "fetch", + # User compute (2048–3071) + 2048: "compute", + 2049: "reduce", + # Synchronization (3072–4095) + 3072: "wait", + 3073: "barrier", } @aggregate class TraceEvent: """ - Trace event type enumeration for iris remote memory operations. + Trace event type enumeration for iris operations and kernel instrumentation. + + Event ID ranges: + 0–1023 iris ops (data movement, atomics) + 1024–2047 user data movement (fetch/prefetch) + 2048–3071 user compute (compute, reduce) + 3072–4095 synchronization (wait, barrier) Usage: >>> ctx.record_event(event_id=TraceEvent().put, target_rank=1, address=ptr) Available event types: - Data Movement: + iris ops (0–1023): - load (0): Remote load operation - store (1): Remote store operation - get (2): Remote read (pull from remote to local) - put (3): Remote write (push from local to remote) - copy (4): Peer-to-peer copy between ranks + - atomic_add (5) .. atomic_max (13): Atomic operations + + User data movement (1024–2047): + - fetch (1024): Prefetching / staging data - Atomic Operations: - - atomic_add (5): Atomic addition - - atomic_sub (6): Atomic subtraction - - atomic_cas (7): Atomic compare-and-swap - - atomic_xchg (8): Atomic exchange - - atomic_xor (9): Atomic XOR - - atomic_and (10): Atomic AND - - atomic_or (11): Atomic OR - - atomic_min (12): Atomic minimum - - atomic_max (13): Atomic maximum + User compute (2048–3071): + - compute (2048): Kernel compute phase (GEMM, FFT, etc.) + - reduce (2049): Reduction operation + + Synchronization (3072–4095): + - wait (3072): Stalled on a dependency + - barrier (3073): Synchronization point """ - # Data movement operations + # iris ops (0–1023) load: tl.constexpr store: tl.constexpr get: tl.constexpr put: tl.constexpr copy: tl.constexpr - - # Atomic operations atomic_add: tl.constexpr atomic_sub: tl.constexpr atomic_cas: tl.constexpr @@ -75,16 +95,30 @@ class TraceEvent: atomic_min: tl.constexpr atomic_max: tl.constexpr + # Workgroup-level profiling events + wg_fetch: tl.constexpr + wg_gemm: tl.constexpr + wg_gemm_wait: tl.constexpr + wg_sdma: tl.constexpr + # User data movement (1024–2047) + fetch: tl.constexpr + + # User compute (2048–3071) + compute: tl.constexpr + reduce: tl.constexpr + + # Synchronization (3072–4095) + wait: tl.constexpr + barrier: tl.constexpr + @triton.constexpr_function def __init__(self): - # Data movement + # iris ops (0–1023) self.load = tl.constexpr(0) self.store = tl.constexpr(1) self.get = tl.constexpr(2) self.put = tl.constexpr(3) self.copy = tl.constexpr(4) - - # Atomics self.atomic_add = tl.constexpr(5) self.atomic_sub = tl.constexpr(6) self.atomic_cas = tl.constexpr(7) @@ -94,3 +128,19 @@ def __init__(self): self.atomic_or = tl.constexpr(11) self.atomic_min = tl.constexpr(12) self.atomic_max = tl.constexpr(13) + + # Workgroup-level profiling + self.wg_fetch = tl.constexpr(14) + self.wg_gemm = tl.constexpr(15) + self.wg_gemm_wait = tl.constexpr(16) + self.wg_sdma = tl.constexpr(17) + # User data movement (1024–2047) + self.fetch = tl.constexpr(1024) + + # User compute (2048–3071) + self.compute = tl.constexpr(2048) + self.reduce = tl.constexpr(2049) + + # Synchronization (3072–4095) + self.wait = tl.constexpr(3072) + self.barrier = tl.constexpr(3073) diff --git a/iris/x/core.py b/iris/x/core.py index fee50918e..198c46e1a 100644 --- a/iris/x/core.py +++ b/iris/x/core.py @@ -78,7 +78,7 @@ def tile_ptr(ptr, M, N, stride_m, stride_n, pid_m, pid_n, BLOCK_SIZE_M: tl.const iris.load/iris.store for remote access. """ rm, rn, mask = tile_layout(pid_m, pid_n, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N) - offset = rm[:, None] * stride_m + rn[None, :] * stride_n + offset = rm.to(tl.int64)[:, None] * stride_m.to(tl.int64) + rn.to(tl.int64)[None, :] * stride_n.to(tl.int64) tile_ptr = ptr + offset tile_ptr = tl.multiple_of(tile_ptr, (BLOCK_SIZE_M, BLOCK_SIZE_N)) return tile_ptr, mask @@ -99,7 +99,9 @@ def offset_ptr(ptr, stride_m, stride_n, offset_m, offset_n): Returns: New pointer with offset applied """ - return ptr + offset_m * stride_m + offset_n * stride_n + offset_m_t = offset_m + 0 * stride_m + offset_n_t = offset_n + 0 * stride_n + return ptr + offset_m_t.to(tl.int64) * stride_m.to(tl.int64) + offset_n_t.to(tl.int64) * stride_n.to(tl.int64) @aggregate @@ -360,7 +362,9 @@ def tile_ptr_from_indices(self, rm, rn, block_m: tl.constexpr, block_n: tl.const mask = (rm[:, None] < self.M) & (rn[None, :] < self.N) # Compute pointer offsets - offset = rm[:, None] * self.stride_m + rn[None, :] * self.stride_n + offset = rm.to(tl.int64)[:, None] * self.stride_m.to(tl.int64) + rn.to(tl.int64)[None, :] * self.stride_n.to( + tl.int64 + ) tile_ptr = self.ptr + offset tile_ptr = tl.multiple_of(tile_ptr, (block_m, block_n)) diff --git a/iris/x/gather.py b/iris/x/gather.py index ca8bd4f9c..4e2b10cc9 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -24,6 +24,7 @@ def gather( src_view: TensorView, source_rank: tl.constexpr, ctx: DeviceContext, + hint: tl.constexpr = None, ): """ Tile-level gather from a specific rank. @@ -37,6 +38,9 @@ def gather( src_view: TensorView for source tensor on source_rank. source_rank: Specific rank to load from (constexpr). ctx: DeviceContext with rank, world_size, and heap_bases. + hint: Vectorization hint passed to tl.multiple_of / tl.max_contiguous on + the translated pointer. Use a scalar (e.g. 16) or a tuple + (e.g. (1, 16)) to indicate alignment. Defaults to None (no hint). Returns: Loaded tile data as a tensor. @@ -61,6 +65,7 @@ def gather( source_rank, # from_rank (source rank) ctx.heap_bases, mask=mask, + hint=hint, ) return tile_data diff --git a/test.py b/test.py new file mode 100644 index 000000000..dab152afc --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ +import sys + +sys.path.append("./iris/experimental") + +import my_module as anvil + +print("Get isntance") + +instance = anvil.AnvilLib.get_instance() +print("initialize") +instance.init() + +print("Connect 0 to 1") + +instance.connect(0, 1, 1) + +queue = instance.get_sdma_queue(0, 1, 0) + +# handle = queue.device_handle() + +handle = anvil.get_handle_as_tensor(queue) diff --git a/tests/ops/test_all_gather_matmul.py b/tests/ops/test_all_gather_matmul.py index 193505011..00841d729 100644 --- a/tests/ops/test_all_gather_matmul.py +++ b/tests/ops/test_all_gather_matmul.py @@ -1,18 +1,64 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ -Tests for fused all_gather + matmul operation. +Tests for fused all_gather + matmul operations. Each rank has A_sharded (M x K_local), B is replicated. The operation gathers A from all ranks and computes C = A_gathered @ B. +Covers both the baseline pull kernel and the HBM-buffered kernel. """ import pytest import torch import torch.distributed as dist +import tritonblas import iris +import os +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) +from iris.ops.config import FusedConfig + + +def _param_shapes(): + if "IRIS_TEST_M" in os.environ: + return [ + ( + int(os.environ["IRIS_TEST_M"]), + int(os.environ["IRIS_TEST_K_LOCAL"]), + int(os.environ["IRIS_TEST_N"]), + ) + ] + return [ + (128, 32, 64), + (256, 64, 128), + ] + + +def _heap_size() -> int: + return int(os.environ.get("IRIS_TEST_HEAP_SIZE", 1 << 34)) + + +def _make_reference(rank, world_size, M, K_local, N, dtype): + """Build a torch reference output for all_gather + matmul.""" + device = f"cuda:{rank}" + K = K_local * world_size + + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=device) + + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=device) + + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=device) for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) + ref_output = torch.matmul(A_gathered_ref, B) + torch.cuda.synchronize() + return A_sharded, B, ref_output @pytest.mark.parametrize( @@ -23,86 +69,233 @@ ) @pytest.mark.parametrize( "M,K_local,N", + _param_shapes(), +) +def test_all_gather_matmul_baseline(dtype, atol, rtol, M, K_local, N): + """Test baseline all_gather_matmul against torch all_gather + matmul.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = _heap_size() + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + K = K_local * world_size + + min_block_size = 32 + if M < min_block_size or K_local < min_block_size or N < min_block_size: + pytest.skip(f"Problem too small for min block size {min_block_size}") + + A_sharded, B, ref_output = _make_reference(rank, world_size, M, K_local, N, dtype) + device = f"cuda:{rank}" + + A_sharded_shmem = ctx.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = ctx.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + output = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + + config = ( + FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + if M <= 256 or K_local <= 64 or N <= 128 + else FusedConfig() + ) + + assert M >= config.block_size_m + assert K_local >= config.block_size_k + assert N >= config.block_size_n + + ctx.ops.all_gather_matmul(output, A_sharded_shmem, B_shmem, config=config) + + torch.cuda.synchronize() + ctx.barrier() + + max_diff = (output - ref_output).abs().max().item() + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol}" + ) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", [ - (128, 32, 64), - (256, 64, 128), + (torch.float16, 1e-2, 1e-2), ], ) -def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): - """Test all_gather_matmul against torch all_gather + matmul.""" +@pytest.mark.parametrize( + "M,K_local,N", + _param_shapes(), +) +def test_tritonblas_rccl_all_gather_matmul(dtype, atol, rtol, M, K_local, N): + """Test RCCL all_gather + tritonBLAS matmul against torch reference.""" if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") - heap_size = 2**33 - shmem = iris.iris(heap_size) - rank = shmem.get_rank() - world_size = shmem.get_num_ranks() - - K = K_local * world_size # Full K dimension + heap_size = _heap_size() + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + device = f"cuda:{rank}" - # Skip if problem size is too small for world_size or block sizes - # With default or custom configs, we need at least one tile - min_block_size = 32 # Smallest block size we use - if M < min_block_size: - pytest.skip(f"M={M} too small (need >= {min_block_size})") - if K_local < min_block_size: - pytest.skip(f"K_local={K_local} too small (need >= {min_block_size})") - if N < min_block_size: - pytest.skip(f"N={N} too small (need >= {min_block_size})") + K = K_local * world_size + A_sharded, B, ref_output = _make_reference(rank, world_size, M, K_local, N, dtype) - # Seed for reproducibility - different seed per rank for A_sharded - torch.manual_seed(42 + rank) - A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + A_gathered_parts = [torch.empty((M, K_local), dtype=dtype, device=device) for _ in range(world_size)] + A_gathered = torch.empty((M, K), dtype=dtype, device=device) + output = ctx.zeros((M, N), dtype=dtype) + selector = tritonblas.OrigamiMatmulSelector( + M, + N, + K, + A_gathered.dtype, + B.dtype, + output.dtype, + A_gathered.device, + ) + config = tritonblas.matmul_preamble(selector) - # B must be identical on all ranks - torch.manual_seed(123) - B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + dist.all_gather(A_gathered_parts, A_sharded) + A_gathered = torch.cat(A_gathered_parts, dim=1) + tritonblas.matmul_lt(A_gathered, B, output, selector, config) - # Reference: torch all_gather + matmul - A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] - dist.all_gather(A_gathered_list, A_sharded) - A_gathered_ref = torch.cat(A_gathered_list, dim=1) # (M, K) - ref_output = torch.matmul(A_gathered_ref, B) torch.cuda.synchronize() - # Create shmem tensors directly - A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + max_diff = (output - ref_output).abs().max().item() + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} (tritonblas+rccl, M={M}, K_local={K_local}, N={N})" + ) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + _param_shapes(), +) +@pytest.mark.parametrize( + "staged_a_layout", + [ + "k_contiguous", + "m_contiguous", + ], +) +def test_all_gather_matmul_hbm_buffer(dtype, atol, rtol, M, K_local, N, staged_a_layout): + """Test all_gather_matmul_hbm_buffer against torch all_gather + matmul.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = _heap_size() + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + K = K_local * world_size + + A_sharded, B, ref_output = _make_reference(rank, world_size, M, K_local, N, dtype) + + A_sharded_shmem = ctx.zeros((M, K_local), dtype=dtype) A_sharded_shmem.copy_(A_sharded) - B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem = ctx.zeros((K, N), dtype=dtype) B_shmem.copy_(B) - output = shmem.zeros((M, N), dtype=dtype) - - shmem.barrier() + output = ctx.zeros((M, N), dtype=dtype) - # Run fused all_gather + matmul using shmem.ops API - from iris.ops.config import FusedConfig + ctx.barrier() - # Use appropriate block sizes based on problem size - # For small problems, use smaller blocks - if M <= 256 or K_local <= 64 or N <= 128: - config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) - else: - config = FusedConfig() + config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) - # Validate config against problem size - assert M >= config.block_size_m, f"M ({M}) must be >= block_size_m ({config.block_size_m})" - assert K_local >= config.block_size_k, f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k})" - assert N >= config.block_size_n, f"N ({N}) must be >= block_size_n ({config.block_size_n})" + workspace = all_gather_matmul_hbm_buffer_preamble( + ctx, A_sharded_shmem, B_shmem, config=config, staged_a_layout=staged_a_layout + ) - shmem.ops.all_gather_matmul(output, A_sharded_shmem, B_shmem, config=config) + all_gather_matmul_hbm_buffer( + ctx, + output, + A_sharded_shmem, + B_shmem, + config=config, + workspace=workspace, + staged_a_layout=staged_a_layout, + trace=False, + ) torch.cuda.synchronize() - shmem.barrier() + ctx.barrier() max_diff = (output - ref_output).abs().max().item() + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} " + f"(staged_a_layout={staged_a_layout}, M={M}, K_local={K_local}, N={N})" + ) + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + _param_shapes(), +) +def test_all_gather_matmul_hbm_buffer_with_bias(dtype, atol, rtol, M, K_local, N): + """Test all_gather_matmul_hbm_buffer with a bias vector.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = _heap_size() + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + K = K_local * world_size + + A_sharded, B, ref_output_no_bias = _make_reference(rank, world_size, M, K_local, N, dtype) + device = f"cuda:{rank}" + + torch.manual_seed(77) + bias = torch.randn(M, dtype=dtype, device=device) + ref_output = ref_output_no_bias + bias[:, None] + + A_sharded_shmem = ctx.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = ctx.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + bias_shmem = ctx.zeros((M,), dtype=dtype) + bias_shmem.copy_(bias) + output = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + + config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + + all_gather_matmul_hbm_buffer( + ctx, + output, + A_sharded_shmem, + B_shmem, + bias=bias_shmem, + config=config, + trace=False, + ) + + torch.cuda.synchronize() + ctx.barrier() + + max_diff = (output - ref_output).abs().max().item() assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( - f"Rank {rank}: Max diff {max_diff}, expected < {atol}" + f"Rank {rank}: Max diff {max_diff}, expected < {atol} (with bias)" ) if __name__ == "__main__": - # For quick debugging import sys if not dist.is_initialized(): @@ -111,7 +304,4 @@ def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): rank = dist.get_rank() torch.cuda.set_device(rank) - - print(f"[Rank {rank}] Testing all_gather_matmul...") - test_all_gather_matmul(torch.float16, 128, 32, 64) - print(f"[Rank {rank}] ✓ Test passed!") + print(f"[Rank {rank}] Tests in this file require pytest + torchrun. See tests/run_tests_distributed.py") diff --git a/tests/ops/test_all_gather_matmul_copy_engine.py b/tests/ops/test_all_gather_matmul_copy_engine.py new file mode 100644 index 000000000..bc01514d0 --- /dev/null +++ b/tests/ops/test_all_gather_matmul_copy_engine.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for all_gather_matmul_copy_engine. + +Each rank owns A_sharded (M, K_local), gathers the K dimension across ranks, +and computes C = A_gathered @ B. This file exercises both the host-initiated +and device-initiated copy-engine paths against a torch reference. +""" + +import pytest +import torch +import torch.distributed as dist + +import iris +import os +from iris.ops.all_gather_matmul_copy_engine import ( + all_gather_matmul_copy_engine, + all_gather_matmul_copy_engine_preamble, +) +from tritonblas.matmul import _make_matmul_selector + + +def _param_shapes(): + if "IRIS_TEST_M" in os.environ: + return [ + ( + int(os.environ["IRIS_TEST_M"]), + int(os.environ["IRIS_TEST_K_LOCAL"]), + int(os.environ["IRIS_TEST_N"]), + ) + ] + return [(256, 128, 256)] + + +def _device_initiated_modes(): + mode = os.environ.get("IRIS_TEST_COPY_ENGINE_MODE") + if mode == "host": + return [False] + if mode == "device": + return [True] + return [False, True] + + +def _host_transfer_backends(): + backend = os.environ.get("IRIS_TEST_HOST_TRANSFER_BACKEND") + if backend: + return [backend] + return ["anvil"] + + +def _heap_size() -> int: + return int(os.environ.get("IRIS_TEST_HEAP_SIZE", 1 << 34)) + + +def _make_reference(rank, world_size, M, K_local, N, dtype): + """Build a torch reference output for all_gather + matmul.""" + device = f"cuda:{rank}" + K = K_local * world_size + + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=device) + + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=device) + + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=device) for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) + ref_output = torch.matmul(A_gathered_ref, B) + torch.cuda.synchronize() + return A_sharded, B, ref_output + + +def _make_selector(M, N, K, dtype, device): + return _make_matmul_selector( + M, + N, + K, + dtype, + dtype, + dtype, + device, + streamk=False, + ) + + +@pytest.mark.parametrize("dtype, atol, rtol", [(torch.float16, 5e-2, 5e-2)]) +@pytest.mark.parametrize("device_initiated", _device_initiated_modes()) +@pytest.mark.parametrize("host_transfer_backend", _host_transfer_backends()) +@pytest.mark.parametrize("M,K_local,N", _param_shapes()) +def test_all_gather_matmul_copy_engine(dtype, atol, rtol, device_initiated, host_transfer_backend, M, K_local, N): + """Test all_gather_matmul_copy_engine against torch all_gather + matmul.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = _heap_size() + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + K = K_local * world_size + + A_sharded, B, ref_output = _make_reference(rank, world_size, M, K_local, N, dtype) + selector = _make_selector(M, N, K, dtype, B.device) + + if M % selector.block_m != 0: + pytest.skip(f"M={M} must be divisible by block_m={selector.block_m}") + if K % selector.block_k != 0: + pytest.skip(f"K={K} must be divisible by block_k={selector.block_k}") + if K_local % selector.block_k != 0: + pytest.skip(f"K_local={K_local} must be divisible by block_k={selector.block_k}") + + A_sharded_shmem = ctx.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = ctx.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + output = ctx.zeros((M, N), dtype=dtype) + + workspace = all_gather_matmul_copy_engine_preamble( + ctx, + A_sharded_shmem, + B_shmem, + selector=selector, + k_per_flag=4, + ) + + ctx.barrier() + + all_gather_matmul_copy_engine( + ctx, + output, + A_sharded_shmem, + B_shmem, + workspace=workspace, + k_per_flag=4, + device_initiated=device_initiated, + host_transfer_backend=host_transfer_backend, + trace=False, + ) + + torch.cuda.synchronize() + ctx.barrier() + + max_diff = (output - ref_output).abs().max().item() + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} " + f"(device_initiated={device_initiated}, host_transfer_backend={host_transfer_backend}, " + f"M={M}, K_local={K_local}, N={N})" + ) + + +if __name__ == "__main__": + import sys + + if not dist.is_initialized(): + print("Run with: torchrun --nproc_per_node=2 tests/ops/test_all_gather_matmul_copy_engine.py") + sys.exit(1) + + rank = dist.get_rank() + torch.cuda.set_device(rank) + print(f"[Rank {rank}] Tests in this file require pytest + torchrun. See tests/run_tests_distributed.py") diff --git a/tests/ops/test_matmul_all_gather.py b/tests/ops/test_matmul_all_gather.py index ad2f37c4d..4e921cf6b 100644 --- a/tests/ops/test_matmul_all_gather.py +++ b/tests/ops/test_matmul_all_gather.py @@ -11,7 +11,140 @@ import pytest import torch import torch.distributed as dist +import tritonblas import iris +import os + + +def _param_shapes(): + if "IRIS_TEST_M" in os.environ: + return [ + ( + int(os.environ["IRIS_TEST_M"]), + int(os.environ["IRIS_TEST_N"]), + int(os.environ["IRIS_TEST_K"]), + ) + ] + return [ + (64, 64, 32), + (512, 256, 512), + (1024, 2048, 1024), + ] + + +def _heap_size() -> int: + return int(os.environ.get("IRIS_TEST_HEAP_SIZE", 1 << 34)) + + +def _full_validation_threshold_bytes() -> int: + return int(os.environ.get("IRIS_TEST_FULL_VALIDATION_THRESHOLD_BYTES", 2 << 30)) + + +def _should_stream_validation(rows_per_rank: int, n: int, dtype: torch.dtype) -> bool: + element_size = torch.tensor([], dtype=dtype).element_size() + local_ref_bytes = rows_per_rank * n * element_size + return local_ref_bytes > _full_validation_threshold_bytes() + + +def _validation_rows_per_chunk(rows_per_rank: int, n: int, dtype: torch.dtype) -> int: + if "IRIS_TEST_VALIDATION_ROWS_PER_CHUNK" in os.environ: + return max(1, min(rows_per_rank, int(os.environ["IRIS_TEST_VALIDATION_ROWS_PER_CHUNK"]))) + element_size = torch.tensor([], dtype=dtype).element_size() + target_bytes = 2 << 30 + rows_per_chunk = max(1, target_bytes // max(1, n * element_size)) + return max(1, min(rows_per_rank, rows_per_chunk)) + + +def _validation_cols_per_chunk(rows_per_chunk: int, n: int, dtype: torch.dtype) -> int: + if "IRIS_TEST_VALIDATION_COLS_PER_CHUNK" in os.environ: + return max(1, min(n, int(os.environ["IRIS_TEST_VALIDATION_COLS_PER_CHUNK"]))) + element_size = torch.tensor([], dtype=dtype).element_size() + target_bytes = 128 << 20 + cols_per_chunk = max(1, target_bytes // max(1, rows_per_chunk * element_size * 4)) + return max(1, min(n, cols_per_chunk)) + + +def _assert_close_chunked(output_chunk, ref_chunk, atol, rtol, src_rank, row_start): + cols_per_chunk = _validation_cols_per_chunk(output_chunk.shape[0], output_chunk.shape[1], output_chunk.dtype) + + for col_start in range(0, output_chunk.shape[1], cols_per_chunk): + col_end = min(col_start + cols_per_chunk, output_chunk.shape[1]) + output_slice = output_chunk[:, col_start:col_end] + ref_slice = ref_chunk[:, col_start:col_end] + abs_diff = torch.abs(output_slice - ref_slice) + tolerance = atol + rtol * torch.abs(ref_slice) + mismatch = abs_diff > tolerance + + if torch.any(mismatch): + mismatch_idx = torch.nonzero(mismatch, as_tuple=False)[0] + local_row = int(mismatch_idx[0].item()) + local_col = int(mismatch_idx[1].item()) + global_row = row_start + local_row + global_col = col_start + local_col + max_diff = torch.max(abs_diff).item() + output_val = output_slice[local_row, local_col].item() + ref_val = ref_slice[local_row, local_col].item() + pytest.fail( + f"Mismatch in gathered rows from src_rank={src_rank} at row={global_row}, col={global_col}: " + f"output={output_val}, ref={ref_val}, max_diff={max_diff}, expected within atol={atol}, rtol={rtol}\n" + f"Rank validation failed for shmem.ops.matmul_all_gather" + ) + + +def _assert_gathered_rows_match_dense(output, local_ref, rank, world_size, atol, rtol): + recv_chunk = None + rows_per_rank = local_ref.shape[0] + + for src_rank in range(world_size): + if rank == src_rank: + ref_chunk = local_ref + else: + if recv_chunk is None: + recv_chunk = torch.empty_like(local_ref) + ref_chunk = recv_chunk + + dist.broadcast(ref_chunk, src=src_rank) + row_start = src_rank * rows_per_rank + row_end = row_start + rows_per_rank + output_chunk = output[row_start:row_end] + + if not torch.allclose(output_chunk, ref_chunk, atol=atol, rtol=rtol): + max_diff = torch.max(torch.abs(output_chunk - ref_chunk)).item() + pytest.fail( + f"Max difference in gathered rows from src_rank={src_rank}: {max_diff}, expected < {atol}\n" + f"Rank {rank}: shmem.ops.matmul_all_gather output doesn't match reference" + ) + + +def _assert_gathered_rows_match_streamed(output, A_local, B, rank, world_size, atol, rtol): + rows_per_rank = A_local.shape[0] + rows_per_chunk = _validation_rows_per_chunk(rows_per_rank, output.shape[1], output.dtype) + + for src_rank in range(world_size): + for local_row_start in range(0, rows_per_rank, rows_per_chunk): + local_row_end = min(local_row_start + rows_per_chunk, rows_per_rank) + chunk_rows = local_row_end - local_row_start + + if rank == src_rank: + ref_chunk = torch.matmul(A_local[local_row_start:local_row_end], B) + else: + ref_chunk = torch.empty((chunk_rows, output.shape[1]), dtype=output.dtype, device=output.device) + + dist.broadcast(ref_chunk, src=src_rank) + global_row_start = src_rank * rows_per_rank + local_row_start + global_row_end = global_row_start + chunk_rows + output_chunk = output[global_row_start:global_row_end] + _assert_close_chunked(output_chunk, ref_chunk, atol, rtol, src_rank, global_row_start) + + +def _assert_gathered_rows_match(output, A_local, B, rank, world_size, atol, rtol): + if _should_stream_validation(A_local.shape[0], output.shape[1], output.dtype): + _assert_gathered_rows_match_streamed(output, A_local, B, rank, world_size, atol, rtol) + return + + local_ref = torch.matmul(A_local, B) + torch.cuda.synchronize() + _assert_gathered_rows_match_dense(output, local_ref, rank, world_size, atol, rtol) @pytest.mark.parametrize( @@ -24,18 +157,14 @@ ) @pytest.mark.parametrize( "M, N, K", - [ - (64, 64, 32), - (512, 256, 512), - (1024, 2048, 1024), - ], + _param_shapes(), ) def test_matmul_all_gather(dtype, atol, rtol, M, N, K): """Test matmul_all_gather using shmem.ops API with proper config.""" if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") - heap_size = 2**33 # 8GB + heap_size = _heap_size() shmem = iris.iris(heap_size) rank = shmem.get_rank() world_size = shmem.get_num_ranks() @@ -61,15 +190,6 @@ def test_matmul_all_gather(dtype, atol, rtol, M, N, K): B = shmem.randn((K, N), dtype=dtype) output = shmem.zeros((M, N), dtype=dtype) - # Reference: compute local GEMM, then all-gather along M dimension - A_ref = A_local.clone() - B_ref = B.clone() - C_local_ref = torch.matmul(A_ref, B_ref) - C_gathered_list = [torch.zeros(M_local, N, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] - dist.all_gather(C_gathered_list, C_local_ref) - pytorch_output = torch.cat(C_gathered_list, dim=0) # Concatenate along M dimension - torch.cuda.synchronize() - shmem.barrier() # Use appropriate block sizes based on problem size @@ -102,18 +222,77 @@ def test_matmul_all_gather(dtype, atol, rtol, M, N, K): torch.cuda.synchronize() shmem.barrier() - max_diff = torch.abs(output - pytorch_output).max().item() - - assert torch.allclose(output, pytorch_output, atol=atol, rtol=rtol), ( - f"Max difference: {max_diff}, expected < {atol}\n" - f"Rank {rank}: shmem.ops.matmul_all_gather output doesn't match reference" - ) + _assert_gathered_rows_match(output, A_local, B, rank, world_size, atol, rtol) if rank == 0: print(f"✓ matmul_all_gather test passed: {dtype}, M={M}, N={N}, K={K}") shmem.barrier() + del output + del B + del A_local del shmem import gc gc.collect() + torch.cuda.empty_cache() + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 0.5, 0.01), + (torch.bfloat16, 0.5, 0.01), + ], +) +@pytest.mark.parametrize( + "M, N, K", + _param_shapes(), +) +def test_tritonblas_rccl_matmul_all_gather(dtype, atol, rtol, M, N, K): + """Test tritonBLAS matmul + RCCL all_gather against a dense local reference.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = _heap_size() + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + if M % world_size != 0: + pytest.skip(f"M={M} not divisible by world_size={world_size}") + + M_local = M // world_size + min_block_size = 32 + if M_local < min_block_size: + pytest.skip(f"M_local={M_local} too small for world_size={world_size} (need >= {min_block_size})") + if K < min_block_size: + pytest.skip(f"K={K} too small (need >= {min_block_size})") + if N < min_block_size: + pytest.skip(f"N={N} too small (need >= {min_block_size})") + + torch.manual_seed(123 + rank) + A_local = shmem.randn((M_local, K), dtype=dtype) + torch.manual_seed(456) + B = shmem.randn((K, N), dtype=dtype) + C_local = shmem.zeros((M_local, N), dtype=dtype) + output = shmem.zeros((M, N), dtype=dtype) + selector = tritonblas.OrigamiMatmulSelector( + M_local, + N, + K, + A_local.dtype, + B.dtype, + C_local.dtype, + A_local.device, + ) + config = tritonblas.matmul_preamble(selector) + + shmem.barrier() + tritonblas.matmul_lt(A_local, B, C_local, selector, config) + dist.all_gather_into_tensor(output, C_local) + + torch.cuda.synchronize() + shmem.barrier() + + _assert_gathered_rows_match(output, A_local, B, rank, world_size, atol, rtol) diff --git a/tests/ops/test_matmul_all_gather_copy_engine.py b/tests/ops/test_matmul_all_gather_copy_engine.py new file mode 100644 index 000000000..2ccca1190 --- /dev/null +++ b/tests/ops/test_matmul_all_gather_copy_engine.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for matmul_all_gather_copy_engine. + +Each rank computes C_local = A_local @ B and the copy engine scatters the +result tiles so every rank observes the gathered output C. +""" + +import pytest +import torch +import torch.distributed as dist + +import iris +import os +from iris.ops.matmul_all_gather_host_copy_engine import ( + matmul_all_gather_host_copy_engine, + matmul_all_gather_host_copy_engine_preamble, +) +from iris.ops.matmul_all_gather_copy_engine import ( + matmul_all_gather_copy_engine, + matmul_all_gather_copy_engine_preamble, +) +from tritonblas.matmul import _make_matmul_selector + + +def _param_shapes(): + if "IRIS_TEST_M" in os.environ: + return [ + ( + int(os.environ["IRIS_TEST_M"]), + int(os.environ["IRIS_TEST_N"]), + int(os.environ["IRIS_TEST_K"]), + ) + ] + return [(1024, 256, 256)] + + +def _copy_engine_modes(): + mode = os.environ.get("IRIS_TEST_COPY_ENGINE_MODE") + if mode in {"host", "device"}: + return [mode] + return ["host", "device"] + + +def _heap_size() -> int: + return int(os.environ.get("IRIS_TEST_HEAP_SIZE", 1 << 34)) + + +def _full_validation_threshold_bytes() -> int: + return int(os.environ.get("IRIS_TEST_FULL_VALIDATION_THRESHOLD_BYTES", 2 << 30)) + + +def _should_stream_validation(rows_per_rank: int, n: int, dtype: torch.dtype) -> bool: + element_size = torch.tensor([], dtype=dtype).element_size() + local_ref_bytes = rows_per_rank * n * element_size + return local_ref_bytes > _full_validation_threshold_bytes() + + +def _validation_rows_per_chunk(rows_per_rank: int, n: int, dtype: torch.dtype) -> int: + if "IRIS_TEST_VALIDATION_ROWS_PER_CHUNK" in os.environ: + return max(1, min(rows_per_rank, int(os.environ["IRIS_TEST_VALIDATION_ROWS_PER_CHUNK"]))) + element_size = torch.tensor([], dtype=dtype).element_size() + target_bytes = 2 << 30 + rows_per_chunk = max(1, target_bytes // max(1, n * element_size)) + return max(1, min(rows_per_rank, rows_per_chunk)) + + +def _validation_cols_per_chunk(rows_per_chunk: int, n: int, dtype: torch.dtype) -> int: + if "IRIS_TEST_VALIDATION_COLS_PER_CHUNK" in os.environ: + return max(1, min(n, int(os.environ["IRIS_TEST_VALIDATION_COLS_PER_CHUNK"]))) + element_size = torch.tensor([], dtype=dtype).element_size() + target_bytes = 128 << 20 + cols_per_chunk = max(1, target_bytes // max(1, rows_per_chunk * element_size * 4)) + return max(1, min(n, cols_per_chunk)) + + +def _assert_close_chunked(output_chunk, ref_chunk, atol, rtol, copy_engine_mode, src_rank, row_start): + cols_per_chunk = _validation_cols_per_chunk(output_chunk.shape[0], output_chunk.shape[1], output_chunk.dtype) + + for col_start in range(0, output_chunk.shape[1], cols_per_chunk): + col_end = min(col_start + cols_per_chunk, output_chunk.shape[1]) + output_slice = output_chunk[:, col_start:col_end] + ref_slice = ref_chunk[:, col_start:col_end] + abs_diff = torch.abs(output_slice - ref_slice) + tolerance = atol + rtol * torch.abs(ref_slice) + mismatch = abs_diff > tolerance + + if torch.any(mismatch): + mismatch_idx = torch.nonzero(mismatch, as_tuple=False)[0] + local_row = int(mismatch_idx[0].item()) + local_col = int(mismatch_idx[1].item()) + global_row = row_start + local_row + global_col = col_start + local_col + max_diff = torch.max(abs_diff).item() + output_val = output_slice[local_row, local_col].item() + ref_val = ref_slice[local_row, local_col].item() + pytest.fail( + f"Mismatch in gathered rows from src_rank={src_rank} at row={global_row}, col={global_col}: " + f"output={output_val}, ref={ref_val}, max_diff={max_diff}, expected within atol={atol}, rtol={rtol}\n" + f"Rank validation failed for matmul_all_gather_copy_engine (mode={copy_engine_mode})" + ) + + +def _assert_gathered_rows_match_dense(output, local_ref, rank, world_size, atol, rtol, copy_engine_mode): + recv_chunk = None + rows_per_rank = local_ref.shape[0] + + for src_rank in range(world_size): + if rank == src_rank: + ref_chunk = local_ref + else: + if recv_chunk is None: + recv_chunk = torch.empty_like(local_ref) + ref_chunk = recv_chunk + + dist.broadcast(ref_chunk, src=src_rank) + row_start = src_rank * rows_per_rank + row_end = row_start + rows_per_rank + output_chunk = output[row_start:row_end] + + if not torch.allclose(output_chunk, ref_chunk, atol=atol, rtol=rtol): + max_diff = torch.max(torch.abs(output_chunk - ref_chunk)).item() + pytest.fail( + f"Max difference in gathered rows from src_rank={src_rank}: {max_diff}, expected < {atol}\n" + f"Rank validation failed for matmul_all_gather_copy_engine (mode={copy_engine_mode})" + ) + + +def _assert_gathered_rows_match_streamed(output, A_local, B, rank, world_size, atol, rtol, copy_engine_mode): + rows_per_rank = A_local.shape[0] + rows_per_chunk = _validation_rows_per_chunk(rows_per_rank, output.shape[1], output.dtype) + + for src_rank in range(world_size): + for local_row_start in range(0, rows_per_rank, rows_per_chunk): + local_row_end = min(local_row_start + rows_per_chunk, rows_per_rank) + chunk_rows = local_row_end - local_row_start + + if rank == src_rank: + ref_chunk = torch.matmul(A_local[local_row_start:local_row_end], B) + else: + ref_chunk = torch.empty((chunk_rows, output.shape[1]), dtype=output.dtype, device=output.device) + + dist.broadcast(ref_chunk, src=src_rank) + global_row_start = src_rank * rows_per_rank + local_row_start + global_row_end = global_row_start + chunk_rows + output_chunk = output[global_row_start:global_row_end] + _assert_close_chunked( + output_chunk, + ref_chunk, + atol, + rtol, + copy_engine_mode, + src_rank, + global_row_start, + ) + + +def _assert_gathered_rows_match(output, A_local, B, rank, world_size, atol, rtol, copy_engine_mode): + if _should_stream_validation(A_local.shape[0], output.shape[1], output.dtype): + _assert_gathered_rows_match_streamed(output, A_local, B, rank, world_size, atol, rtol, copy_engine_mode) + return + + local_ref = torch.matmul(A_local, B) + torch.cuda.synchronize() + _assert_gathered_rows_match_dense(output, local_ref, rank, world_size, atol, rtol, copy_engine_mode) + + +def _make_selector(M_local, N, K, dtype, device): + return _make_matmul_selector( + M_local, + N, + K, + dtype, + dtype, + dtype, + device, + streamk=False, + ) + + +@pytest.mark.parametrize("dtype, atol, rtol", [(torch.float16, 5e-2, 5e-2)]) +@pytest.mark.parametrize("copy_engine_mode", _copy_engine_modes()) +@pytest.mark.parametrize("M,N,K", _param_shapes()) +def test_matmul_all_gather_copy_engine(dtype, atol, rtol, copy_engine_mode, M, N, K): + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = _heap_size() + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + device = torch.device(f"cuda:{rank}") + + if M % world_size != 0: + pytest.skip(f"M={M} not divisible by world_size={world_size}") + + M_local = M // world_size + selector = _make_selector(M_local, N, K, dtype, device) + + if M_local % selector.block_m != 0: + pytest.skip(f"M_local={M_local} must be divisible by block_size_m={selector.block_m}") + if K % selector.block_k != 0: + pytest.skip(f"K={K} must be divisible by block_size_k={selector.block_k}") + + A_local = shmem.randn((M_local, K), dtype=dtype) + B = shmem.randn((K, N), dtype=dtype) + output = shmem.zeros((M, N), dtype=dtype) + + if copy_engine_mode == "device": + workspace = matmul_all_gather_copy_engine_preamble( + shmem, + A_local, + B, + selector=selector, + ) + else: + workspace = matmul_all_gather_host_copy_engine_preamble( + shmem, + A_local, + B, + trace=False, + selector=selector, + ) + + shmem.barrier() + + if copy_engine_mode == "device": + matmul_all_gather_copy_engine( + shmem, + output, + A_local, + B, + workspace=workspace, + ) + else: + matmul_all_gather_host_copy_engine( + shmem, + output, + A_local, + B, + workspace=workspace, + trace=False, + ) + + torch.cuda.synchronize() + shmem.barrier() + + _assert_gathered_rows_match(output, A_local, B, rank, world_size, atol, rtol, copy_engine_mode) + + shmem.barrier() + del output + del B + del A_local + del shmem + import gc + + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + import sys + + if not dist.is_initialized(): + print("Run with: torchrun --nproc_per_node=2 tests/ops/test_matmul_all_gather_copy_engine.py") + sys.exit(1) + + rank = dist.get_rank() + torch.cuda.set_device(rank) + print(f"[Rank {rank}] Tests in this file require pytest + torchrun. See tests/run_tests_distributed.py")