Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 64 additions & 19 deletions .claude/skills/kernel-trace-analysis/scripts/hotspot_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,33 @@ def print_source_detail(hotspot, source_cache, context=3):
print(f" stall={fmt_cycles(inst.stall_cycles):>7} type={inst.stall_type:<12} {inst.asm}")


def read_kernel_metadata(dispatch_dir):
def read_kernel_metadata(dispatch_dir, kernel_filter=""):
"""Read authoritative resource counts from ``out_kernel_trace.csv`` if present.

The ATT ``code.json`` only contains the (possibly single-CU, possibly
vgpr-form) disassembly, so it cannot reveal accum_vgpr / SGPR / LDS /
workgroup size. The kernel-trace CSV carries the real launch metadata.
Searches the dispatch dir and its parent (staging often copies the CSV
next to the ui_output_agent_* dir). Returns {} if not found.

Row selection priority:
1. ``kernel_filter`` substring matched against Kernel_Name, optionally
narrowed by Dispatch_Id when the dir name encodes ``dispatch_<id>``
(rocprofv3 ``ui_output_agent_*_dispatch_<id>`` layout). Dispatch_Id
matching avoids false matches when a PyTorch reference kernel shares
the same name substring.
2. Bidirectional name heuristic against the directory basename (legacy
path for timestamped dirs like ``20240101_120000_pa_decode_kernel``).
"""
candidates = []
for base in (dispatch_dir, os.path.dirname(os.path.abspath(dispatch_dir))):
candidates += glob.glob(os.path.join(base, "*kernel_trace*.csv"))

dir_name = os.path.basename(os.path.abspath(dispatch_dir))
# Extract the dispatch id from rocprofv3's ui_output_agent_<N>_dispatch_<id> layout.
_dispatch_id_m = re.search(r"dispatch_(\d+)$", dir_name)
dispatch_id = _dispatch_id_m.group(1) if _dispatch_id_m else None

for path in candidates:
try:
with open(path) as f:
Expand All @@ -258,24 +273,40 @@ def read_kernel_metadata(dispatch_dir):
continue
if not rows or "Accum_VGPR_Count" not in rows[0]:
continue
# Pick the row whose kernel matches the dispatch dir name. The dir is
# usually staged as "<timestamp>_<short_kernel_name>" while the CSV
# Kernel_Name has a trailing index (e.g. dir ".._pa_decode_ps_kernel"
# vs kernel "pa_decode_ps_kernel_0"), so match bidirectionally on the
# timestamp-stripped short name.
dir_name = os.path.basename(os.path.abspath(dispatch_dir))
short = re.sub(r"^\d{8}_\d{6}_", "", dir_name) # strip YYYYMMDD_HHMMSS_

def _matches(kn):
if not kn:
return False
return kn in dir_name or short in kn or kn.startswith(short) or short.startswith(kn)

has_dispatch_col = "Dispatch_Id" in rows[0]

chosen = None
for r in rows:
if _matches(r.get("Kernel_Name", "")):
chosen = r
break
if kernel_filter:
# Explicit filter: kernel name substring, narrowed by Dispatch_Id when available.
can_disambiguate = bool(dispatch_id and has_dispatch_col)
matches = [r for r in rows if kernel_filter in r.get("Kernel_Name", "")]
if can_disambiguate:
matches = [r for r in matches if str(r.get("Dispatch_Id", "")).strip() == dispatch_id]
if matches:
chosen = matches[0]
if not can_disambiguate and len(matches) > 1:
# First-substring-wins: no dispatch id available to pick between same-named rows.
print(
f" warning: --kernel '{kernel_filter}' matched {len(matches)} rows in "
f"{os.path.basename(path)} with no dispatch id to disambiguate; using the "
"first match (pass a more specific --kernel)"
)
else:
# Legacy heuristic: bidirectional substring match against the dir basename.
# Works for timestamped dirs like ``20240101_120000_pa_decode_kernel``.
short = re.sub(r"^\d{8}_\d{6}_", "", dir_name) # strip YYYYMMDD_HHMMSS_

def _matches(kn):
if not kn:
return False
return kn in dir_name or short in kn or kn.startswith(short) or short.startswith(kn)

for r in rows:
if _matches(r.get("Kernel_Name", "")):
chosen = r
break

if chosen is None:
continue # no matching row in this CSV — try the next candidate

Expand Down Expand Up @@ -457,7 +488,10 @@ def print_reg_pressure(reg_info):
print_header("Register Pressure & Occupancy")
print(f" Architecture: {reg_info['arch']}")
if not reg_info["has_meta"]:
print(" (no kernel_trace CSV found — accum/LDS/SGPR estimated from ISA only)")
print(
" (kernel_trace CSV not matched — accum/LDS/SGPR estimated from ISA only; "
"pass --kernel <name_substr> to enable CSV metadata lookup)"
)
if reg_info["is_vgpr_form"]:
print(f" arch_vgpr: {reg_info['arch_vgpr']} (MFMA vgpr-form: accumulators in arch file, no AGPR)")
else:
Expand Down Expand Up @@ -496,6 +530,17 @@ def main():
"--detail", action="store_true", help="Show source snippet + instruction breakdown under each source hotspot"
)
parser.add_argument("--context", type=int, default=3, help="Source lines of context around hotspot (default: 3)")
parser.add_argument(
"--kernel",
default="",
metavar="SUBSTR",
help="Kernel name substring for CSV metadata lookup "
"(e.g. 'pa_mqa_logits_fp4_kernel_0'). "
"Required when the dispatch dir name does not encode the kernel name, "
"as with rocprofv3 ui_output_agent_*_dispatch_<id> directories. "
"Combined with the dispatch id from the dir name when a Dispatch_Id "
"column is present in the CSV.",
)
args = parser.parse_args()

if not os.path.isdir(args.dispatch_dir):
Expand All @@ -515,7 +560,7 @@ def main():
print(f" Total cycles: {fmt_cycles(total_cycles)}")
print(f" Total stalls: {fmt_cycles(total_stall)} ({100*total_stall/total_cycles:.1f}% of total cycles)")

meta = read_kernel_metadata(args.dispatch_dir)
meta = read_kernel_metadata(args.dispatch_dir, kernel_filter=args.kernel)
reg_info = detect_arch_and_reg_pressure(instructions, meta)
print_reg_pressure(reg_info)

Expand Down
151 changes: 151 additions & 0 deletions tests/unit/test_hotspot_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""Unit tests for the kernel-trace-analysis hotspot_analyzer CSV row selection.

The analyzer reads authoritative VGPR/SGPR/LDS/occupancy data from
``*_kernel_trace.csv`` and must pick the right row for the dispatch under
analysis. Row selection is plain string/CSV matching and is the part most
prone to silent mis-selection, so it is covered here:

- legacy dir-name heuristic (timestamped dirs) still matches
- ``ui_output_agent_*_dispatch_*`` dirs return {} without ``--kernel``
- ``--kernel`` + ``Dispatch_Id`` selects the correct row
- ``--kernel`` without a ``Dispatch_Id`` column falls back to name match
- argparse wires ``--kernel`` through to ``read_kernel_metadata``
"""

import csv
import importlib.util
import os
import sys
from pathlib import Path

import pytest

pytestmark = [pytest.mark.l0_backend_agnostic]

_REPO_ROOT = Path(__file__).resolve().parents[2]
_SCRIPT = _REPO_ROOT / ".claude" / "skills" / "kernel-trace-analysis" / "scripts" / "hotspot_analyzer.py"

_SPEC = importlib.util.spec_from_file_location("hotspot_analyzer", _SCRIPT)
hotspot_analyzer = importlib.util.module_from_spec(_SPEC)
_SPEC.loader.exec_module(hotspot_analyzer)


# Minimal column set: the header must contain "Accum_VGPR_Count" for the CSV to
# be recognized as a kernel-trace file, plus the fields read_kernel_metadata returns.
_BASE_ROW = {
"VGPR_Count": "100",
"Accum_VGPR_Count": "0",
"SGPR_Count": "50",
"LDS_Block_Size": "4096",
"Workgroup_Size_X": "256",
"Workgroup_Size_Y": "1",
"Workgroup_Size_Z": "1",
}


def _write_csv(dispatch_dir, rows):
"""Write an out_kernel_trace.csv into dispatch_dir with the given rows."""
os.makedirs(dispatch_dir, exist_ok=True)
path = os.path.join(dispatch_dir, "out_kernel_trace.csv")
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
w.writeheader()
w.writerows(rows)
return path


def test_legacy_timestamp_heuristic_still_matches(tmp_path):
# Timestamped dir name vs trailing-index Kernel_Name -> bidirectional substring match.
d = str(tmp_path / "20240101_120000_pa_decode_kernel")
_write_csv(d, [{**_BASE_ROW, "Kernel_Name": "pa_decode_kernel_0", "VGPR_Count": "111"}])

meta = hotspot_analyzer.read_kernel_metadata(d)

assert meta and meta["csv_vgpr"] == 111


def test_ui_output_dir_without_kernel_filter_returns_empty(tmp_path):
# ui_output_agent_*_dispatch_* dir carries no kernel name, so the legacy
# heuristic cannot match -> {} (the bug this PR addresses).
d = str(tmp_path / "ui_output_agent_15249_dispatch_223")
_write_csv(d, [{**_BASE_ROW, "Kernel_Name": "pa_mqa_logits_fp4_kernel_0"}])

assert hotspot_analyzer.read_kernel_metadata(d) == {}


def test_kernel_filter_with_dispatch_id_selects_correct_row(tmp_path):
# Two rows share the name substring; Dispatch_Id from the dir name disambiguates.
d = str(tmp_path / "ui_output_agent_15249_dispatch_223")
_write_csv(
d,
[
{**_BASE_ROW, "Kernel_Name": "pa_mqa_logits_fp4_kernel_0", "Dispatch_Id": "999", "VGPR_Count": "11"},
{**_BASE_ROW, "Kernel_Name": "pa_mqa_logits_fp4_kernel_0", "Dispatch_Id": "223", "VGPR_Count": "22"},
],
)

meta = hotspot_analyzer.read_kernel_metadata(d, kernel_filter="pa_mqa_logits_fp4_kernel")

assert meta["csv_vgpr"] == 22


def test_kernel_filter_without_dispatch_column_falls_back_to_name(tmp_path):
# No Dispatch_Id column -> name-only substring match.
d = str(tmp_path / "ui_output_agent_15249_dispatch_223")
_write_csv(d, [{**_BASE_ROW, "Kernel_Name": "pa_mqa_logits_fp4_kernel_0", "VGPR_Count": "77"}])

meta = hotspot_analyzer.read_kernel_metadata(d, kernel_filter="pa_mqa_logits_fp4")

assert meta["csv_vgpr"] == 77


def test_ambiguous_match_without_dispatch_id_warns_and_picks_first(tmp_path, capsys):
# Dir has no dispatch_<id> suffix, so even with a Dispatch_Id column there is
# nothing to disambiguate -> first match wins, with a warning.
d = str(tmp_path / "plain_dir")
_write_csv(
d,
[
{**_BASE_ROW, "Kernel_Name": "some_kernel_0", "Dispatch_Id": "1", "VGPR_Count": "11"},
{**_BASE_ROW, "Kernel_Name": "some_kernel_1", "Dispatch_Id": "2", "VGPR_Count": "22"},
],
)

meta = hotspot_analyzer.read_kernel_metadata(d, kernel_filter="some_kernel")
out = capsys.readouterr().out

assert meta["csv_vgpr"] == 11
assert "matched 2 rows" in out and "warning" in out


def test_argparse_wires_kernel_through_to_read_kernel_metadata(tmp_path, monkeypatch):
# End-to-end: --kernel on the command line reaches read_kernel_metadata.
d = tmp_path / "ui_output_agent_1_dispatch_5"
d.mkdir()

captured = {}

def fake_read(dispatch_dir, kernel_filter=""):
captured["kernel_filter"] = kernel_filter
return {}

class _FakeInst:
stall_cycles = 1
total_cycles = 2

monkeypatch.setattr(hotspot_analyzer, "read_kernel_metadata", fake_read)
monkeypatch.setattr(hotspot_analyzer, "load_instructions", lambda _d: [_FakeInst()])
monkeypatch.setattr(hotspot_analyzer, "aggregate_by_source", lambda _i: [])
monkeypatch.setattr(hotspot_analyzer, "load_source_map", lambda _d: {})
monkeypatch.setattr(hotspot_analyzer, "detect_arch_and_reg_pressure", lambda _i, _m: {})
monkeypatch.setattr(hotspot_analyzer, "print_reg_pressure", lambda _r: None)
monkeypatch.setattr(hotspot_analyzer, "print_stall_type_summary", lambda _i, _t: None)
monkeypatch.setattr(hotspot_analyzer, "print_source_hotspots", lambda *a, **k: None)
monkeypatch.setattr(hotspot_analyzer, "print_asm_hotspots", lambda *a, **k: None)
monkeypatch.setattr(sys, "argv", ["hotspot_analyzer.py", str(d), "--kernel", "my_kernel_substr"])

assert hotspot_analyzer.main() == 0
assert captured["kernel_filter"] == "my_kernel_substr"
Loading