Skip to content

Add Metal T.gemm_v2 using simdgroup_multiply_accumulate#1869

Open
oraluben wants to merge 2 commits intotile-ai:mainfrom
oraluben:metal-gemm
Open

Add Metal T.gemm_v2 using simdgroup_multiply_accumulate#1869
oraluben wants to merge 2 commits intotile-ai:mainfrom
oraluben:metal-gemm

Conversation

@oraluben
Copy link
Collaborator

@oraluben oraluben commented Feb 23, 2026

Metal GEMM v2: Direct Simdgroup Matrix Operations Without Fragment Layout

Summary

This PR adds a new GEMM implementation for Metal targets that uses simdgroup_matrix operations directly, bypassing TileLang's thread <-> register fragment layout system entirely. The accumulator lives in simdgroup registers throughout the entire k-loop, eliminating the costly shared memory round-trips that the previous approach required.


Background: Simdgroup Matrix Operations in Metal

Metal provides simdgroup matrix operations as the only public API for accessing matrix acceleration hardware (analogous to NVIDIA's Tensor Cores). Unlike CUDA, which exposes multiple abstraction levels (wmma, mma.sync PTX, wgmma), Metal offers a single, opaque interface through the simdgroup_matrix<T, Rows, Cols> type:

  • simdgroup_load(mat, ptr, stride) — cooperative load from device/threadgroup memory
  • simdgroup_store(mat, ptr, stride) — cooperative store to device/threadgroup memory
  • simdgroup_multiply_accumulate(D, A, B, C) — fused multiply-accumulate: D = A × B + C

These operate on 8×8 matrix tiles and are executed collectively by a simdgroup (32 threads, equivalent to a CUDA warp). The key distinction from CUDA is that the hardware-internal data distribution across threads is completely opaque — there is no documented per-thread register layout, and no lower-level PTX-like alternative exists. This makes simdgroup operations the sole mechanism for matrix-accelerated computation on Apple GPUs.


Main Changes

New GEMM backend for Metal (GemmMetal)

  • Adds GemmInst::kMetalExp / GemmInst.METAL as a new GEMM instruction type
  • Implements GemmMetal class that lowers T.gemm / T.gemm_v2 to Metal simdgroup intrinsics
  • Data flow: shared → simdgroup_load → simdgroup_multiply_accumulate → simdgroup_store → shared/global
  • No intermediate fragment buffers, no shared memory scratch for layout conversion

MPSIntrinEmitter (metal_macro_generator.py)

  • Emits TIR macros for simdgroup_load, simdgroup_multiply_accumulate, and simdgroup_store
  • Manages multi-tile mapping: each simdgroup holds multiple 8×8 matrix blocks, indexed by (warp_m, warp_n) and tile position (i, j)
  • Unified simdgroup_copy method for both load and store directions
  • Common _parse_buffer_2d utility for Buffer/BufferRegion handling

Infrastructure changes

  • metal.simdgroup scope recognized as a local buffer type in IsFragmentBuffer, is_local_buffer, and decouple_type_cast
  • CopyInst::kMetalSIMDGroup added for future lowering from T.copy to simdgroup_{load,store}
  • Metal warp partition uses kMPerWarp=8 (matching 8×8 simdgroup matrix size)
  • Fix in parallel.cc: guard as<Fragment>() cast for non-fragment layouts to avoid crash

Why No Layout Inference Is Needed

TileLang's layout system maps individual threads to individual register elements — this is essential for CUDA's wmma/mma instructions where software must precisely control which thread holds which data fragment, and ldmatrix/stmatrix instructions require knowing the per-thread address pattern.

Metal's simdgroup matrix operations are fundamentally different: they are opaque, collective operations where the entire simdgroup (32 threads) cooperates and the hardware decides the internal data distribution. The programmer never needs to know (and cannot control) which thread holds which element of an 8×8 matrix. Therefore:

Operation CUDA (needs layout) Metal (no layout needed)
Load from shared ldmatrix — each thread provides an address based on its lane position simdgroup_load(mat, ptr, stride) — hardware handles everything
Multiply-accumulate mma — thread <-> register mapping must match instruction encoding simdgroup_multiply_accumulate(D, A, B, C) — hardware handles everything
Store to memory stmatrix — inverse of load layout simdgroup_store(mat, ptr, stride) — hardware handles everything

By using simdgroup operations directly, we also eliminate the shared memory round-trip that the previous implementation required to convert between the opaque simdgroup layout and TileLang's fragment layout on every k-iteration.


Why Only shared × shared -> shared/global GEMM

This PR implements only the smem × smem = smem/global (SS) case because:

  1. Metal simdgroup_load reads from device/threadgroup memory — it cannot read from registers or other opaque sources. Both A and B operands must come from addressable memory (shared or global), making the SS pattern the natural match.

  2. simdgroup_store can write to both shared and global memory — the result can go directly to global memory without an intermediate shared memory buffer, which is what the frontend exploits:

    T.gemm_v2(A_shared, B_shared, C[by * block_M, bx * block_N])  # C is global
  3. No register↔register GEMM — unlike CUDA where fragments can be loaded from registers and composed in various ways, Metal simdgroup operations always read from memory. Fragment-scope (register-level) operands would require materializing to shared memory first, which is exactly what the SS pattern already handles.

  4. Global operands go through shared memory anyway — the standard TileLang pattern T.copy(A_global, A_shared) followed by T.gemm(A_shared, B_shared, C) already stages global data through shared memory, so supporting global × global directly adds no benefit.


Test Coverage

  • test_metal_gemm_v2.py: End-to-end correctness tests on Metal hardware (requires macOS + MPS), comparing against torch.matmul
  • test_metal_gemm_v2_linux.py: Codegen-only tests verifiable on any platform, checking that generated Metal shader source contains expected simdgroup operations

Summary by CodeRabbit

  • New Features

    • Added Metal GPU backend support for matrix-multiply with SIMD-group optimizations and a way to retrieve Metal kernel source.
  • Tests

    • Added end-to-end Metal GEMM verification tests and Metal shader codegen tests.
  • Bug Fixes

    • Improved layout inference to avoid unsafe fragment access.
    • Recognize additional Metal-specific buffer scope for SIMD group ops.
  • Dependencies

    • Tightened apache-tvm-ffi constraints for macOS compatibility.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 23, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 696f703 and 0976055.

📒 Files selected for processing (1)
  • requirements.txt
🚧 Files skipped from review as they are similar to previous changes (1)
  • requirements.txt

📝 Walkthrough

Walkthrough

Adds Metal/MPS backend support: new GemmMetal implementation, Metal-specific intrinsics emitter, simdgroup copy/store/load support, Metal-target GEMM dispatch and enums, Metal-focused tests, and Darwin-specific apache-tvm-ffi version constraint.

Changes

Cohort / File(s) Summary
Dependency Management
pyproject.toml, requirements.txt
Tightened apache-tvm-ffi spec: changed >=0.1.2 to >0.1.2 and added a Darwin-specific upper bound apache-tvm-ffi<0.1.8; platform_system == 'Darwin'.
Copy Operation Extensions
src/op/copy.cc, src/op/copy.h
Added Metal SIMD-group store support: new CopyInst::kMetalSIMDGroup, CheckSIMDGroupStore(Target) API, and LowerSIMDGroupStore lowering path and dispatch.
GEMM Core Changes
src/op/gemm.cc, src/op/gemm.h, src/op/gemm_py.cc
Introduced kMetalExp/Metal GemmInst, conditional warp-partition adjustment for Metal (kMPerWarp=8), and Metal branch in GemmPyNode::getGemmInst.
Layout & Buffer Utilities
src/op/parallel.cc, src/op/utils.h, tilelang/transform/decouple_type_cast.py, tilelang/utils/language.py
Made fragment layout access optional (safety checks); recognized metal.simdgroup as fragment/local buffer via is_metal_simdgroup and updated relevant helpers.
Metal Intrinsics Emitter
tilelang/intrinsics/metal_macro_generator.py
Added MPSIntrinEmitter implementing simdgroup load/store, ldmatrix-like loads, mma accumulate, warp/thread indexing, and helpers for Buffer/BufferRegion handling.
GEMM Backend & Dispatch
tilelang/tileop/gemm/inst.py, tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_metal.py
Added METAL GemmInst, is_metal() predicate, exported GemmMetal, and implemented GemmMetal lowering that uses MPSIntrinEmitter and metal.simdgroup buffers.
JIT / Adapter
tilelang/jit/adapter/torch/metal.py
Added MetalKernelAdapter.get_kernel_source() to expose kernel source string.
Tests
testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_gemm_v2_linux.py
Added device-side and codegen tests validating Metal GEMM v2 lowering, simdgroup ops presence, and numerical correctness against PyTorch (MPS) where applicable.

Sequence Diagram

sequenceDiagram
    participant Host as Host/PyTorch
    participant Dispatch as GemmPy Dispatch
    participant Backend as GemmMetal
    participant Emitter as MPSIntrinEmitter
    participant Metal as Metal Runtime

    Host->>Dispatch: Request GEMM (Target=Metal)
    Dispatch->>Backend: Route to GemmMetal (is_metal)
    Backend->>Backend: Compute warp partitions / alloc buffers
    Backend->>Emitter: init with tiling & dtypes
    loop per K block
        Backend->>Emitter: ldmatrix_a / ldmatrix_b
        Emitter->>Metal: simdgroup_load / simdgroup_multiply_accumulate
        Emitter->>Backend: return partial accumulators
    end
    Backend->>Emitter: simdgroup_store C_simd -> C_buf
    Backend->>Host: Return lowered kernel / kernel source
    Host->>Metal: Execute kernel on device
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 In simdgroup rows my whiskers weave,

Threads hop, matrices spin and cleave,
MPS hums, warps align in tune,
TileLang hops under Metal's moon,
A tiny rabbit cheers — gemm's new bloom.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding a Metal GEMM v2 implementation using simdgroup_multiply_accumulate intrinsics, which is the primary focus of this substantial changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/transform/decouple_type_cast.py (1)

92-95: ⚠️ Potential issue | 🟡 Minor

Error message doesn't mention metal.simdgroup as a known scope.

Since is_local_buffer now accepts metal.simdgroup, the error message should list it as a valid scope for completeness.

Suggested fix
         raise ValueError(
             f"Unknown buffer scope '{buffer.scope()}' for buffer '{buffer.name}'. "
-            f"Expected one of: local, local.fragment, local.var, global, shared, shared.dyn"
+            f"Expected one of: local, local.fragment, local.var, metal.simdgroup, global, shared, shared.dyn"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 92 - 95, The
ValueError raised in decouple_type_cast.py lists valid buffer scopes but omits
"metal.simdgroup"; update the error message (the raise ValueError that
references buffer.scope() and buffer.name) to include "metal.simdgroup" in the
Expected one of: list so it matches the scopes accepted by is_local_buffer/other
checks and clearly communicates valid scopes.
🧹 Nitpick comments (7)
src/op/copy.cc (1)

639-645: Method name CheckSIMDGroupStore is misleading — it matches any simdgroup↔simdgroup copy.

Both src and dst are checked for "metal.simdgroup" scope, meaning this matches loads and stores (or more accurately, simdgroup-to-simdgroup transfers). Compare with CheckLDSMCopy (shared→fragment) and CheckSTSMCopy (fragment→shared) which are directional. Consider renaming to CheckSIMDGroupCopy (and correspondingly LowerSIMDGroupCopy) for consistency and clarity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 639 - 645, The method CheckSIMDGroupStore
incorrectly implies a directional store but actually matches any
simdgroup↔simdgroup transfer; rename CheckSIMDGroupStore to CheckSIMDGroupCopy
(and rename its lowering helper LowerSIMDGroupStore to LowerSIMDGroupCopy or
similarly) and update all call sites to use the new names so the intent matches
the implementation; keep the same body (checking src.scope() ==
"metal.simdgroup" && dst.scope() == "metal.simdgroup") but change identifiers
for consistency with CheckLDSMCopy/CheckSTSMCopy naming conventions.
src/op/gemm_py.cc (1)

136-137: Metal branch placement is correct; consider a minimal allowMetal() guard for future-proofing

The ordering (TCGEN5MMA → WGMMA → CDNA → CUDA → Metal → fallback) is correct — none of the earlier guards can fire on a Metal target. The change is logically sound.

For consistency with the other paths (which all have dedicated allow* functions that validate shape, dtype, and scope constraints before returning their instruction type), a lightweight allowMetal() gate would help catch unsupported Metal configurations early rather than deferring to downstream errors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm_py.cc` around lines 136 - 137, Add a lightweight allowMetal()
guard before returning GemmInst::kMetalExp: implement allowMetal() similar to
the other allow* helpers (validate shapes, dtypes, and memory scope constraints
expected by the Metal path) and call it in the branch that currently checks
TargetIsMetal(target); only return GemmInst::kMetalExp when allowMetal() returns
true, otherwise fall back to the existing fallback/error path so unsupported
Metal configs are rejected early. Reference: the TargetIsMetal(target) branch
and GemmInst::kMetalExp; follow the validation pattern used by functions like
allowCuda()/allowCdna()/allowWgmma() to ensure consistency.
src/op/gemm.h (1)

42-42: kMetalExp naming doesn't align with Python's METAL; enum values should be explicit

Two related concerns:

  1. The C++ enumerator is named kMetalExp (Exp = experimental), but the Python counterpart in tilelang/tileop/gemm/inst.py uses the plain name METAL = 4 with no "experimental" qualifier. This asymmetry makes it unclear whether "experimental" is a meaningful status or just a stale suffix.

  2. The C++ enum assigns values implicitly (sequential 0–4). The Python IntEnum assigns values explicitly (0–4). These are currently in sync, but inserting a new entry between existing ones would silently break the C++↔Python ABI. Adding explicit values in C++ is a cheap safeguard:

♻️ Suggested: add explicit values and align naming
-enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA, kMetalExp };
+enum class GemmInst : uint8_t {
+  kMMA = 0,
+  kWGMMA = 1,
+  kTCGEN5MMA = 2,
+  kMFMA = 3,
+  kMetal = 4,  // align with Python GemmInst.METAL
+};
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm.h` at line 42, The GemmInst enum currently uses an asymmetrical
name and implicit values; change the enumerator kMetalExp to match the Python
name (e.g., kMETAL) and make all enum members have explicit integer values
matching the Python IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0,
kWGMMA = 1, kTCGEN5MMA = 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type
uint8_t, and update any usages of GemmInst::kMetalExp to the new symbol to
preserve ABI alignment with tilelang/tileop/gemm/inst.py.
testing/python/metal/test_metal_gemm_v2.py (3)

91-93: Redundant torch.mps.is_available() guard silently swallows test output when Metal is absent.

The individual test functions already carry @tilelang.testing.requires_metal, which skips them if Metal is unavailable. Wrapping tilelang.testing.main() in an additional torch.mps.is_available() check means that when run on a non-Metal machine, no tests are registered at all — no "skipped" output, no indication tests exist. Compare with test_metal_gemm_v2_linux.py which calls tilelang.testing.main() unconditionally.

🔧 Suggested fix
 if __name__ == "__main__":
-    if torch.mps.is_available():
-        tilelang.testing.main()
+    tilelang.testing.main()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 91 - 93, Remove the
redundant torch.mps.is_available() guard in the __main__ block so that
tilelang.testing.main() is called unconditionally; the tests themselves use the
`@tilelang.testing.requires_metal` decorator (in this file's test functions) to
mark/skips tests when Metal is absent, so update the __main__ section (the block
checking __name__ == "__main__") to simply invoke tilelang.testing.main()
without wrapping it in torch.mps.is_available().

80-83: Consider putting requires_metal as the outermost decorator for conventional ordering.

Placing @tilelang.testing.requires_metal inside @pytest.mark.xfail works correctly (skip propagates), but the conventional pattern is to put guard/skip decorators outermost so the intent is immediately obvious at the call site.

🎨 Suggested reorder
-@pytest.mark.xfail(reason="TODO: codegen not support float16x8")
 `@tilelang.testing.requires_metal`
+@pytest.mark.xfail(reason="TODO: codegen not support float16x8")
 def test_gemm_v2_large():
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 80 - 83, Move the
decorator order so the platform guard is outermost: swap the two decorators on
test_gemm_v2_large so `@tilelang.testing.requires_metal` appears above
`@pytest.mark.xfail`; update the decorators on the test_gemm_v2_large function
accordingly to keep the same behavior but follow conventional ordering.

88-88: Add a comment explaining atol=1.0 for the large-K test.

For K=1024 with float16 inputs, accumulated rounding error per element can approach K × ε_fp16 ≈ 1024 × 9.7e-4 ≈ 1.0, so the tolerance is intentional. A brief inline comment would prevent future readers from tightening it incorrectly.

💬 Suggested change
-    assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0)
+    # atol=1.0: with K=1024 fp16 inputs, accumulated rounding error ≈ K × ε_fp16 ≈ 1.0
+    assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` at line 88, Add a brief inline
comment next to the call to assert_gemm_v2(1024, 1024, 1024, 16, 16, 16,
atol=1.0) explaining that atol=1.0 is intentional because for K=1024 with
float16 inputs the accumulated rounding error per element can approach K *
ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute tolerance is required for
this large-K test to avoid false failures.
testing/python/metal/test_metal_gemm_v2_linux.py (1)

50-53: Misleading variable name and redundant target specification.

Two minor issues:

  1. Line 50 sets tvm.target.Target("metal") as a context manager and passes target="metal" to tilelang.lower on line 51. The context manager is redundant; the explicit target= arg alone is sufficient (and matches the pattern in similar test files).
  2. The return value of tilelang.lower is named artifact, but kernel_source is a property on the JIT kernel object (as shown in tilelang/jit/kernel.py), not on a raw artifact. Naming it kernel would be more accurate.
🔧 Suggested cleanup
-    with tvm.transform.PassContext(), tvm.target.Target("metal"):
-        artifact = tilelang.lower(func, target="metal")
-
-    src_code = artifact.kernel_source
+    with tvm.transform.PassContext():
+        kernel = tilelang.lower(func, target="metal")
+
+    src_code = kernel.kernel_source
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 50 - 53,
Remove the redundant PassContext target context and rename the returned value
from tilelang.lower to reflect it's a JIT kernel: call tilelang.lower with the
explicit target argument only (remove the tvm.transform.PassContext(),
tvm.target.Target("metal") context manager) and rename the variable from
artifact to kernel so you access kernel.kernel_source (i.e., locate the
tilelang.lower call and the subsequent kernel_source access and update the
target usage and variable name accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pyproject.toml`:
- Around line 33-35: The pyproject lower-bound for "apache-tvm-ffi" is too
permissive (">0.1.2") and allows Darwin installs of versions 0.1.3–0.1.5 that
contain the memory regression; update the constraint(s) for "apache-tvm-ffi" to
match requirements.txt by using a >=0.1.6 lower bound (e.g., change ">0.1.2" to
">=0.1.6" and ensure the Darwin-specific constraint "apache-tvm-ffi<0.1.8;
platform_system == 'Darwin'" remains consistent with that lower bound).

In `@src/op/copy.cc`:
- Around line 685-686: GetCopyInst can return CopyInst::kMetalSIMDGroup but
Lower() lacks a matching case, causing a fatal crash; add a case for
CopyInst::kMetalSIMDGroup in Lower() (in the function Lower) before the kNormal
branch that implements the appropriate lowering behavior for SIMD-group metal
stores (mirror the pattern used for other Metal-specific cases in Lower(),
dispatching to the SIMD-group store lowering path and returning the resultant
statement), ensuring the new case references CopyInst::kMetalSIMDGroup so
execution no longer falls through to LOG(FATAL).

In `@src/op/copy.h`:
- Around line 232-235: The Doxygen comment for CheckSIMDGroupStore is
incorrect/copy-pasted from CheckTMemStore; update the comment to describe that
CheckSIMDGroupStore(Target target) determines whether the Metal SIMD-group
(SIMD-group/warp/wavefront) store instruction is supported on the given Target
(i.e., checks for Metal SIMD group store capability), not tensor memory store
support; reference CheckSIMDGroupStore and, for comparison, CheckTMemStore to
ensure the wording distinguishes SIMD group store support from tensor memory
store support.
- Around line 274-277: The CopyNode::Lower() switch must handle
CopyInst::kMetalSIMDGroup: add a case in Lower() that calls
LowerSIMDGroupStore(T, analyzer) when GetCopyInst(...) returns kMetalSIMDGroup
(this path is produced when CheckSIMDGroupStore(target) is true) to avoid the
LOG(FATAL) fallthrough; then implement the missing LowerSIMDGroupStore(const
LowerArgs &T, arith::Analyzer *analyzer) in copy.cc mirroring the pattern of
existing store lowerers (use the same argument handling and memory access
lowering used by other Metal-specific store methods, perform bounds/stride
checks as done in related Lower*Store functions, and produce the appropriate
Stmt), ensuring the declaration in copy.h matches the definition.

In `@src/op/gemm.cc`:
- Around line 150-154: Comments referencing a hardcoded "16" are now stale
because kMPerWarp is target-dependent; update the comments near kMPerWarp,
TargetIsMetal, and the logic that mentions "m_warp*16" and "Each warp needs at
least 16 elements in M" to reflect the runtime variable (e.g., refer to "m_warp
* kMPerWarp" and "kMPerWarp elements") so they accurately describe behavior on
Metal and other targets.

In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 70-71: The test test_metal_gemm_v2_larger currently calls
assert_metal_gemm_v2_codegen with parameters known to fail at runtime; update
the test to either mark it as an expected failure or explicitly verify
codegen-only success: add `@pytest.mark.xfail`(reason="TODO: codegen not support
float16x8") above test_metal_gemm_v2_larger to match the runtime test, or
replace the single assert_metal_gemm_v2_codegen call with a two-step check that
runs the codegen path (tilelang.lower/codegen) and asserts it succeeds while
keeping execution separately marked xfail; reference test_metal_gemm_v2_larger
and assert_metal_gemm_v2_codegen when making the change.
- Around line 22-34: The codegen test's matmul_gemm_v2 kernel is structurally
different from the runtime test; update the matmul_gemm_v2 in the Linux codegen
test so its symbols match the runtime test: change C_local to use
T.alloc_shared(..., scope="shared") (instead of T.alloc_fragment), remove/adjust
coalesced_width=2 on T.copy calls to match the runtime test's copies, and use
T.gemm_v2 (or make the runtime test use T.gemm if you choose that canonical API)
so the GEMM operator is identical; ensure these changes are applied to the
matmul_gemm_v2 definition so the codegen pre-flight validates the same kernel
that the runtime test executes.
- Line 32: The test is calling the old lowering path via T.gemm instead of the
new Metal path; update the call to use T.gemm_v2 so the codegen test exercises
the gemm_v2 lowering (replace the invocation T.gemm(A_shared, B_shared, C_local)
with T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.

In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 43-44: The code computes self.warp_rows and self.warp_cols via
integer division of warp_row_tiles//micro_size_x and
warp_col_tiles//micro_size_y which will silently truncate if inputs are not
divisible by the micro sizes (8); add validation in the same initializer or
before these assignments (e.g., in the MetalMacroGenerator constructor or method
that sets warp_row_tiles/warp_col_tiles) that raises a clear error if
warp_row_tiles % micro_size_x != 0 or warp_col_tiles % micro_size_y != 0,
mentioning the offending values, and only then compute self.warp_rows and
self.warp_cols as the integer quotient.

In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-58: The method get_kernel_source currently claims to return str
but may return None and ignores the kernel_only flag; change its signature to ->
str | None (or keep -> str but assert/raise if kernel_global_source is None) and
implement the kernel_only branch: if kernel_only is True return
self.kernel_global_source, otherwise return the full Metal source (compose or
return the attribute that holds the complete module/source such as
self.metal_source or self.full_source); ensure you reference get_kernel_source
and kernel_global_source and either assert kernel_global_source is not None
before returning a str or update callers/types to accept Optional[str].

In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 22-23: The int() cast on potentially symbolic shapes self.M and
self.N will fail at runtime for PrimExpr; update the computation of
warp_row_tiles and warp_col_tiles (currently int(self.M // m_warp) and
int(self.N // n_warp)) to preserve symbolic expressions instead of forcing
Python ints—either remove the int() and keep self.M // m_warp and self.N //
n_warp, or use tir.floordiv/tvm.tir.floordiv to produce a PrimExpr;
alternatively, if a concrete int is required, guard with an isinstance check for
tir.IntImm before casting. Ensure you change both warp_row_tiles and
warp_col_tiles and keep references to m_warp and n_warp.

---

Outside diff comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 92-95: The ValueError raised in decouple_type_cast.py lists valid
buffer scopes but omits "metal.simdgroup"; update the error message (the raise
ValueError that references buffer.scope() and buffer.name) to include
"metal.simdgroup" in the Expected one of: list so it matches the scopes accepted
by is_local_buffer/other checks and clearly communicates valid scopes.

---

Nitpick comments:
In `@src/op/copy.cc`:
- Around line 639-645: The method CheckSIMDGroupStore incorrectly implies a
directional store but actually matches any simdgroup↔simdgroup transfer; rename
CheckSIMDGroupStore to CheckSIMDGroupCopy (and rename its lowering helper
LowerSIMDGroupStore to LowerSIMDGroupCopy or similarly) and update all call
sites to use the new names so the intent matches the implementation; keep the
same body (checking src.scope() == "metal.simdgroup" && dst.scope() ==
"metal.simdgroup") but change identifiers for consistency with
CheckLDSMCopy/CheckSTSMCopy naming conventions.

In `@src/op/gemm_py.cc`:
- Around line 136-137: Add a lightweight allowMetal() guard before returning
GemmInst::kMetalExp: implement allowMetal() similar to the other allow* helpers
(validate shapes, dtypes, and memory scope constraints expected by the Metal
path) and call it in the branch that currently checks TargetIsMetal(target);
only return GemmInst::kMetalExp when allowMetal() returns true, otherwise fall
back to the existing fallback/error path so unsupported Metal configs are
rejected early. Reference: the TargetIsMetal(target) branch and
GemmInst::kMetalExp; follow the validation pattern used by functions like
allowCuda()/allowCdna()/allowWgmma() to ensure consistency.

In `@src/op/gemm.h`:
- Line 42: The GemmInst enum currently uses an asymmetrical name and implicit
values; change the enumerator kMetalExp to match the Python name (e.g., kMETAL)
and make all enum members have explicit integer values matching the Python
IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0, kWGMMA = 1, kTCGEN5MMA
= 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type uint8_t, and update any
usages of GemmInst::kMetalExp to the new symbol to preserve ABI alignment with
tilelang/tileop/gemm/inst.py.

In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 50-53: Remove the redundant PassContext target context and rename
the returned value from tilelang.lower to reflect it's a JIT kernel: call
tilelang.lower with the explicit target argument only (remove the
tvm.transform.PassContext(), tvm.target.Target("metal") context manager) and
rename the variable from artifact to kernel so you access kernel.kernel_source
(i.e., locate the tilelang.lower call and the subsequent kernel_source access
and update the target usage and variable name accordingly).

In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 91-93: Remove the redundant torch.mps.is_available() guard in the
__main__ block so that tilelang.testing.main() is called unconditionally; the
tests themselves use the `@tilelang.testing.requires_metal` decorator (in this
file's test functions) to mark/skips tests when Metal is absent, so update the
__main__ section (the block checking __name__ == "__main__") to simply invoke
tilelang.testing.main() without wrapping it in torch.mps.is_available().
- Around line 80-83: Move the decorator order so the platform guard is
outermost: swap the two decorators on test_gemm_v2_large so
`@tilelang.testing.requires_metal` appears above `@pytest.mark.xfail`; update the
decorators on the test_gemm_v2_large function accordingly to keep the same
behavior but follow conventional ordering.
- Line 88: Add a brief inline comment next to the call to assert_gemm_v2(1024,
1024, 1024, 16, 16, 16, atol=1.0) explaining that atol=1.0 is intentional
because for K=1024 with float16 inputs the accumulated rounding error per
element can approach K * ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute
tolerance is required for this large-K test to avoid false failures.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f25954 and 696f703.

📒 Files selected for processing (18)
  • pyproject.toml
  • requirements.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/gemm_py.cc
  • src/op/parallel.cc
  • src/op/utils.h
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/utils/language.py

Comment on lines +33 to +35
"apache-tvm-ffi~=0.1.0,>0.1.2",
# apache/tvm-ffi#464
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Lower-bound inconsistency with requirements.txt may expose the memory bug on Darwin.

requirements.txt keeps a general lower bound of >=0.1.6 (the fix for tilelang#1502), while pyproject.toml uses >0.1.2. On Darwin a user installing the wheel from PyPI sees:

File Effective Darwin range
requirements.txt 0.1.6 – 0.1.7
pyproject.toml 0.1.3 – 0.1.7

Versions 0.1.3–0.1.5 are still selectable from pyproject.toml and still carry the memory regression. Unless broad compatibility is the explicit goal (the existing comment suggests it is), consider aligning the pyproject.toml lower bound with requirements.txt:

🔧 Proposed alignment
-    "apache-tvm-ffi~=0.1.0,>0.1.2",
+    "apache-tvm-ffi~=0.1.0,>=0.1.6",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"apache-tvm-ffi~=0.1.0,>0.1.2",
# apache/tvm-ffi#464
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
"apache-tvm-ffi~=0.1.0,>=0.1.6",
# apache/tvm-ffi#464
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` around lines 33 - 35, The pyproject lower-bound for
"apache-tvm-ffi" is too permissive (">0.1.2") and allows Darwin installs of
versions 0.1.3–0.1.5 that contain the memory regression; update the
constraint(s) for "apache-tvm-ffi" to match requirements.txt by using a >=0.1.6
lower bound (e.g., change ">0.1.2" to ">=0.1.6" and ensure the Darwin-specific
constraint "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" remains
consistent with that lower bound).

Comment on lines +685 to +686
} else if (CheckSIMDGroupStore(target)) {
return CopyInst::kMetalSIMDGroup;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Dispatch to kMetalSIMDGroup added in GetCopyInst but Lower() has no corresponding case — runtime crash.

GetCopyInst can now return CopyInst::kMetalSIMDGroup, but Lower() (line 694–726) does not handle this value. Execution will fall through to LOG(FATAL) << "Unsupported copy inst" at line 724, crashing at runtime.

Add a case in Lower() before the kNormal case:

🐛 Proposed fix in Lower() (around line 721)
   } else if (copy_inst == CopyInst::kNormal) {
     return LowerNormalCopy(T, analyzer);
+  } else if (copy_inst == CopyInst::kMetalSIMDGroup) {
+    return LowerSIMDGroupStore(T, analyzer);
   } else {
     LOG(FATAL) << "Unsupported copy inst " << static_cast<int>(copy_inst);
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 685 - 686, GetCopyInst can return
CopyInst::kMetalSIMDGroup but Lower() lacks a matching case, causing a fatal
crash; add a case for CopyInst::kMetalSIMDGroup in Lower() (in the function
Lower) before the kNormal branch that implements the appropriate lowering
behavior for SIMD-group metal stores (mirror the pattern used for other
Metal-specific cases in Lower(), dispatching to the SIMD-group store lowering
path and returning the resultant statement), ensuring the new case references
CopyInst::kMetalSIMDGroup so execution no longer falls through to LOG(FATAL).

Comment on lines +232 to +235
/*!
* \brief Check if tensor memory store is supported.
*/
bool CheckSIMDGroupStore(Target target) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Doxygen comment is copy-pasted from CheckTMemStore.

The brief says "Check if tensor memory store is supported" but this method checks for Metal SIMD group store support. Should be updated to reflect its actual purpose.

📝 Proposed fix
   /*!
-   * \brief Check if tensor memory store is supported.
+   * \brief Check if Metal SIMD group store is supported.
    */
   bool CheckSIMDGroupStore(Target target) const;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/*!
* \brief Check if tensor memory store is supported.
*/
bool CheckSIMDGroupStore(Target target) const;
/*!
* \brief Check if Metal SIMD group store is supported.
*/
bool CheckSIMDGroupStore(Target target) const;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.h` around lines 232 - 235, The Doxygen comment for
CheckSIMDGroupStore is incorrect/copy-pasted from CheckTMemStore; update the
comment to describe that CheckSIMDGroupStore(Target target) determines whether
the Metal SIMD-group (SIMD-group/warp/wavefront) store instruction is supported
on the given Target (i.e., checks for Metal SIMD group store capability), not
tensor memory store support; reference CheckSIMDGroupStore and, for comparison,
CheckTMemStore to ensure the wording distinguishes SIMD group store support from
tensor memory store support.

Comment on lines +274 to +277
/*!
* \brief Generate lowering for simdgroup store.
*/
Stmt LowerSIMDGroupStore(const LowerArgs &T, arith::Analyzer *analyzer) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

rg -n 'LowerSIMDGroupStore' --type=cpp

Repository: tile-ai/tilelang

Length of output: 158


🏁 Script executed:

cat -n src/op/copy.cc | head -750 | tail -100

Repository: tile-ai/tilelang

Length of output: 5249


🏁 Script executed:

rg -n 'kMetalSIMDGroup' --type=cpp

Repository: tile-ai/tilelang

Length of output: 237


🏁 Script executed:

rg -A 10 'LowerSIMDGroupStore' --type=cpp src/op/copy.cc

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

sed -n '50,60p' src/op/copy.h

Repository: tile-ai/tilelang

Length of output: 300


🏁 Script executed:

rg -B 5 -A 10 'kMetalSIMDGroup' src/op/copy.h

Repository: tile-ai/tilelang

Length of output: 1072


Add dispatch case for kMetalSIMDGroup in CopyNode::Lower() and implement LowerSIMDGroupStore().

GetCopyInst() can return CopyInst::kMetalSIMDGroup when CheckSIMDGroupStore(target) is true, but Lower() (lines 694–726) has no case to handle it. The method will fall through to LOG(FATAL) at line 724. Additionally, LowerSIMDGroupStore() is declared in copy.h:277 but has no implementation in copy.cc.

You need to:

  1. Add a dispatch case in Lower() for kMetalSIMDGroup that calls LowerSIMDGroupStore()
  2. Implement LowerSIMDGroupStore() in copy.cc
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.h` around lines 274 - 277, The CopyNode::Lower() switch must
handle CopyInst::kMetalSIMDGroup: add a case in Lower() that calls
LowerSIMDGroupStore(T, analyzer) when GetCopyInst(...) returns kMetalSIMDGroup
(this path is produced when CheckSIMDGroupStore(target) is true) to avoid the
LOG(FATAL) fallthrough; then implement the missing LowerSIMDGroupStore(const
LowerArgs &T, arith::Analyzer *analyzer) in copy.cc mirroring the pattern of
existing store lowerers (use the same argument handling and memory access
lowering used by other Metal-specific store methods, perform bounds/stride
checks as done in related Lower*Store functions, and produce the appropriate
Stmt), ensuring the declaration in copy.h matches the definition.

Comment on lines +150 to +154
int kMPerWarp = 16; // Rows processed by a single warp
if (TargetIsMetal(target)) {
kMPerWarp = 8;
}
int kNPerWarp = 8; // Columns processed by a single warp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Stale comments hardcode "16" after kMPerWarp became target-dependent

The arithmetic correctly reads the runtime kMPerWarp variable everywhere, but two nearby comments still reference the old constant "16" and will now mislead on Metal:

  • Line 248: // If M cannot be evenly divided by m_warp*16, try to split remaining warps
  • Line 278: // Each warp needs at least 16 elements in M
📝 Suggested comment fixes
-    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
-    // to N
+    // If M cannot be evenly divided by m_warp*kMPerWarp, try to split remaining warps
+    // to N
-    int max_m_warps = M / kMPerWarp; // Each warp needs at least 16 elements in M
+    int max_m_warps = M / kMPerWarp; // Each warp needs at least kMPerWarp elements in M
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm.cc` around lines 150 - 154, Comments referencing a hardcoded "16"
are now stale because kMPerWarp is target-dependent; update the comments near
kMPerWarp, TargetIsMetal, and the logic that mentions "m_warp*16" and "Each warp
needs at least 16 elements in M" to reflect the runtime variable (e.g., refer to
"m_warp * kMPerWarp" and "kMPerWarp elements") so they accurately describe
behavior on Metal and other targets.

T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)

T.gemm(A_shared, B_shared, C_local)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's check the test file in question to see what's actually being called
find . -name "test_metal_gemm_v2_linux.py" -type f | head -1 | xargs cat -n

Repository: tile-ai/tilelang

Length of output: 3353


🏁 Script executed:

# Search for gemm and gemm_v2 definitions in the language
rg -n "def gemm\b|def gemm_v2\b" --type py tilelang/language/ -A 5

Repository: tile-ai/tilelang

Length of output: 743


🏁 Script executed:

# Check what tl.tileop operations are defined and used
rg -n "tl\.tileop\.gemm" --type py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 1314


🏁 Script executed:

# Check runtime test mentioned in review (test_metal_gemm_v2.py)
find . -name "test_metal_gemm_v2.py" -type f | head -1 | xargs cat -n 2>/dev/null || echo "File not found"

Repository: tile-ai/tilelang

Length of output: 3454


🏁 Script executed:

# Search for T.gemm and T.gemm_v2 usage patterns
rg -n "T\.gemm\(|T\.gemm_v2\(" --type py | head -20

Repository: tile-ai/tilelang

Length of output: 2497


Replace T.gemm with T.gemm_v2 on line 32.

T.gemm and T.gemm_v2 are distinct operations with different lowering paths (tl.tileop.gemm vs. tl.tileop.gemm_py). This test is meant to verify Metal codegen for the new gemm_v2 path, but it calls T.gemm instead. The assertions for simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store may be validating against the wrong lowering backend and could pass for the wrong reason or miss actual Metal simdgroup codegen issues.

The runtime test (test_metal_gemm_v2.py line 34) correctly uses T.gemm_v2; align this codegen test with it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` at line 32, The test is
calling the old lowering path via T.gemm instead of the new Metal path; update
the call to use T.gemm_v2 so the codegen test exercises the gemm_v2 lowering
(replace the invocation T.gemm(A_shared, B_shared, C_local) with
T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.

Comment on lines +70 to +71
def test_metal_gemm_v2_larger():
assert_metal_gemm_v2_codegen(128, 128, 128, 32, 32, 32, dtype=T.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

test_metal_gemm_v2_larger may silently pass despite a known codegen limitation.

The runtime test test_gemm_v2_large in test_metal_gemm_v2.py tests the same parameters (128, 128, 128, 32, 32, 32, dtype=T.float16) and is marked @pytest.mark.xfail(reason="TODO: codegen not support float16x8"). If the codegen for this block configuration is genuinely broken, the Linux codegen test should also fail (and be expected to fail) for consistency. Either:

  • Add @pytest.mark.xfail(reason="TODO: codegen not support float16x8") here as well, or
  • Verify that the codegen-only path (tilelang.lower without execution) succeeds where execution fails.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 70 - 71, The
test test_metal_gemm_v2_larger currently calls assert_metal_gemm_v2_codegen with
parameters known to fail at runtime; update the test to either mark it as an
expected failure or explicitly verify codegen-only success: add
`@pytest.mark.xfail`(reason="TODO: codegen not support float16x8") above
test_metal_gemm_v2_larger to match the runtime test, or replace the single
assert_metal_gemm_v2_codegen call with a two-step check that runs the codegen
path (tilelang.lower/codegen) and asserts it succeeds while keeping execution
separately marked xfail; reference test_metal_gemm_v2_larger and
assert_metal_gemm_v2_codegen when making the change.

Comment on lines +43 to +44
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add validation that warp_row_tiles and warp_col_tiles are divisible by 8.

If warp_row_tiles or warp_col_tiles aren't multiples of micro_size_x/micro_size_y (8), the integer division would silently truncate, leading to incorrect tiling.

Suggested validation
         # Number of 8x8 tiles per warp
+        assert warp_row_tiles % self.micro_size_x == 0, (
+            f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})"
+        )
+        assert warp_col_tiles % self.micro_size_y == 0, (
+            f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})"
+        )
         self.warp_rows = warp_row_tiles // self.micro_size_x
         self.warp_cols = warp_col_tiles // self.micro_size_y
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
# Number of 8x8 tiles per warp
assert warp_row_tiles % self.micro_size_x == 0, (
f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})"
)
assert warp_col_tiles % self.micro_size_y == 0, (
f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})"
)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/metal_macro_generator.py` around lines 43 - 44, The code
computes self.warp_rows and self.warp_cols via integer division of
warp_row_tiles//micro_size_x and warp_col_tiles//micro_size_y which will
silently truncate if inputs are not divisible by the micro sizes (8); add
validation in the same initializer or before these assignments (e.g., in the
MetalMacroGenerator constructor or method that sets
warp_row_tiles/warp_col_tiles) that raises a clear error if warp_row_tiles %
micro_size_x != 0 or warp_col_tiles % micro_size_y != 0, mentioning the
offending values, and only then compute self.warp_rows and self.warp_cols as the
integer quotient.

Comment on lines +56 to +58
def get_kernel_source(self, kernel_only: bool = True) -> str:
return self.kernel_global_source

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the return type annotation and honour the kernel_only flag.

Two issues here:

  1. Return type mismatchkernel_global_source is declared str | None (Line 23), so the method can return None, contradicting the -> str annotation. This will cause silent type errors for callers.

  2. Unused kernel_only parameter — Every peer adapter branches on this flag (base.py, nvrtc/adapter.py, cython/adapter.py). Silently ignoring it here means get_kernel_source(kernel_only=False) behaves identically to kernel_only=True, breaking the expected contract.

🛠️ Proposed fix
-    def get_kernel_source(self, kernel_only: bool = True) -> str:
-        return self.kernel_global_source
+    def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+        # Metal has a single unified source; kernel_only has no distinct meaning here.
+        return self.kernel_global_source

If a non-None guarantee is truly required at call sites, add an explicit assertion:

-    def get_kernel_source(self, kernel_only: bool = True) -> str:
-        return self.kernel_global_source
+    def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+        assert self.kernel_global_source is not None, "kernel_global_source is not available"
+        return self.kernel_global_source
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 56-56: Unused method argument: kernel_only

(ARG002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/torch/metal.py` around lines 56 - 58, The method
get_kernel_source currently claims to return str but may return None and ignores
the kernel_only flag; change its signature to -> str | None (or keep -> str but
assert/raise if kernel_global_source is None) and implement the kernel_only
branch: if kernel_only is True return self.kernel_global_source, otherwise
return the full Metal source (compose or return the attribute that holds the
complete module/source such as self.metal_source or self.full_source); ensure
you reference get_kernel_source and kernel_global_source and either assert
kernel_global_source is not None before returning a str or update callers/types
to accept Optional[str].

Comment on lines +22 to +23
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

int() cast on potentially symbolic self.M / self.N.

self.M and self.N originate from buffer shapes and may be tir.IntImm or symbolic PrimExpr. If symbolic, int(self.M // m_warp) will raise at runtime. Other GEMM backends (e.g., GemmMMA) handle this similarly, so this is likely fine for Metal's concrete-size use case, but worth noting.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_metal.py` around lines 22 - 23, The int() cast on
potentially symbolic shapes self.M and self.N will fail at runtime for PrimExpr;
update the computation of warp_row_tiles and warp_col_tiles (currently
int(self.M // m_warp) and int(self.N // n_warp)) to preserve symbolic
expressions instead of forcing Python ints—either remove the int() and keep
self.M // m_warp and self.N // n_warp, or use tir.floordiv/tvm.tir.floordiv to
produce a PrimExpr; alternatively, if a concrete int is required, guard with an
isinstance check for tir.IntImm before casting. Ensure you change both
warp_row_tiles and warp_col_tiles and keep references to m_warp and n_warp.

@oraluben oraluben requested a review from LeiWang1999 February 23, 2026 11:30
@LeiWang1999
Copy link
Member

why we need to introduce storage scope metal.simdgroup instead of just using local.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants