Add Metal T.gemm_v2 using simdgroup_multiply_accumulate#1869
Add Metal T.gemm_v2 using simdgroup_multiply_accumulate#1869oraluben wants to merge 2 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds Metal/MPS backend support: new GemmMetal implementation, Metal-specific intrinsics emitter, simdgroup copy/store/load support, Metal-target GEMM dispatch and enums, Metal-focused tests, and Darwin-specific apache-tvm-ffi version constraint. Changes
Sequence DiagramsequenceDiagram
participant Host as Host/PyTorch
participant Dispatch as GemmPy Dispatch
participant Backend as GemmMetal
participant Emitter as MPSIntrinEmitter
participant Metal as Metal Runtime
Host->>Dispatch: Request GEMM (Target=Metal)
Dispatch->>Backend: Route to GemmMetal (is_metal)
Backend->>Backend: Compute warp partitions / alloc buffers
Backend->>Emitter: init with tiling & dtypes
loop per K block
Backend->>Emitter: ldmatrix_a / ldmatrix_b
Emitter->>Metal: simdgroup_load / simdgroup_multiply_accumulate
Emitter->>Backend: return partial accumulators
end
Backend->>Emitter: simdgroup_store C_simd -> C_buf
Backend->>Host: Return lowered kernel / kernel source
Host->>Metal: Execute kernel on device
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/transform/decouple_type_cast.py (1)
92-95:⚠️ Potential issue | 🟡 MinorError message doesn't mention
metal.simdgroupas a known scope.Since
is_local_buffernow acceptsmetal.simdgroup, the error message should list it as a valid scope for completeness.Suggested fix
raise ValueError( f"Unknown buffer scope '{buffer.scope()}' for buffer '{buffer.name}'. " - f"Expected one of: local, local.fragment, local.var, global, shared, shared.dyn" + f"Expected one of: local, local.fragment, local.var, metal.simdgroup, global, shared, shared.dyn" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 92 - 95, The ValueError raised in decouple_type_cast.py lists valid buffer scopes but omits "metal.simdgroup"; update the error message (the raise ValueError that references buffer.scope() and buffer.name) to include "metal.simdgroup" in the Expected one of: list so it matches the scopes accepted by is_local_buffer/other checks and clearly communicates valid scopes.
🧹 Nitpick comments (7)
src/op/copy.cc (1)
639-645: Method nameCheckSIMDGroupStoreis misleading — it matches any simdgroup↔simdgroup copy.Both
srcanddstare checked for"metal.simdgroup"scope, meaning this matches loads and stores (or more accurately, simdgroup-to-simdgroup transfers). Compare withCheckLDSMCopy(shared→fragment) andCheckSTSMCopy(fragment→shared) which are directional. Consider renaming toCheckSIMDGroupCopy(and correspondinglyLowerSIMDGroupCopy) for consistency and clarity.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 639 - 645, The method CheckSIMDGroupStore incorrectly implies a directional store but actually matches any simdgroup↔simdgroup transfer; rename CheckSIMDGroupStore to CheckSIMDGroupCopy (and rename its lowering helper LowerSIMDGroupStore to LowerSIMDGroupCopy or similarly) and update all call sites to use the new names so the intent matches the implementation; keep the same body (checking src.scope() == "metal.simdgroup" && dst.scope() == "metal.simdgroup") but change identifiers for consistency with CheckLDSMCopy/CheckSTSMCopy naming conventions.src/op/gemm_py.cc (1)
136-137: Metal branch placement is correct; consider a minimalallowMetal()guard for future-proofingThe ordering (TCGEN5MMA → WGMMA → CDNA → CUDA → Metal → fallback) is correct — none of the earlier guards can fire on a Metal target. The change is logically sound.
For consistency with the other paths (which all have dedicated
allow*functions that validate shape, dtype, and scope constraints before returning their instruction type), a lightweightallowMetal()gate would help catch unsupported Metal configurations early rather than deferring to downstream errors.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm_py.cc` around lines 136 - 137, Add a lightweight allowMetal() guard before returning GemmInst::kMetalExp: implement allowMetal() similar to the other allow* helpers (validate shapes, dtypes, and memory scope constraints expected by the Metal path) and call it in the branch that currently checks TargetIsMetal(target); only return GemmInst::kMetalExp when allowMetal() returns true, otherwise fall back to the existing fallback/error path so unsupported Metal configs are rejected early. Reference: the TargetIsMetal(target) branch and GemmInst::kMetalExp; follow the validation pattern used by functions like allowCuda()/allowCdna()/allowWgmma() to ensure consistency.src/op/gemm.h (1)
42-42:kMetalExpnaming doesn't align with Python'sMETAL; enum values should be explicitTwo related concerns:
The C++ enumerator is named
kMetalExp(Exp = experimental), but the Python counterpart intilelang/tileop/gemm/inst.pyuses the plain nameMETAL = 4with no "experimental" qualifier. This asymmetry makes it unclear whether "experimental" is a meaningful status or just a stale suffix.The C++ enum assigns values implicitly (sequential 0–4). The Python
IntEnumassigns values explicitly (0–4). These are currently in sync, but inserting a new entry between existing ones would silently break the C++↔Python ABI. Adding explicit values in C++ is a cheap safeguard:♻️ Suggested: add explicit values and align naming
-enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA, kMetalExp }; +enum class GemmInst : uint8_t { + kMMA = 0, + kWGMMA = 1, + kTCGEN5MMA = 2, + kMFMA = 3, + kMetal = 4, // align with Python GemmInst.METAL +};🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm.h` at line 42, The GemmInst enum currently uses an asymmetrical name and implicit values; change the enumerator kMetalExp to match the Python name (e.g., kMETAL) and make all enum members have explicit integer values matching the Python IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0, kWGMMA = 1, kTCGEN5MMA = 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type uint8_t, and update any usages of GemmInst::kMetalExp to the new symbol to preserve ABI alignment with tilelang/tileop/gemm/inst.py.testing/python/metal/test_metal_gemm_v2.py (3)
91-93: Redundanttorch.mps.is_available()guard silently swallows test output when Metal is absent.The individual test functions already carry
@tilelang.testing.requires_metal, which skips them if Metal is unavailable. Wrappingtilelang.testing.main()in an additionaltorch.mps.is_available()check means that when run on a non-Metal machine, no tests are registered at all — no "skipped" output, no indication tests exist. Compare withtest_metal_gemm_v2_linux.pywhich callstilelang.testing.main()unconditionally.🔧 Suggested fix
if __name__ == "__main__": - if torch.mps.is_available(): - tilelang.testing.main() + tilelang.testing.main()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` around lines 91 - 93, Remove the redundant torch.mps.is_available() guard in the __main__ block so that tilelang.testing.main() is called unconditionally; the tests themselves use the `@tilelang.testing.requires_metal` decorator (in this file's test functions) to mark/skips tests when Metal is absent, so update the __main__ section (the block checking __name__ == "__main__") to simply invoke tilelang.testing.main() without wrapping it in torch.mps.is_available().
80-83: Consider puttingrequires_metalas the outermost decorator for conventional ordering.Placing
@tilelang.testing.requires_metalinside@pytest.mark.xfailworks correctly (skip propagates), but the conventional pattern is to put guard/skip decorators outermost so the intent is immediately obvious at the call site.🎨 Suggested reorder
-@pytest.mark.xfail(reason="TODO: codegen not support float16x8") `@tilelang.testing.requires_metal` +@pytest.mark.xfail(reason="TODO: codegen not support float16x8") def test_gemm_v2_large():🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` around lines 80 - 83, Move the decorator order so the platform guard is outermost: swap the two decorators on test_gemm_v2_large so `@tilelang.testing.requires_metal` appears above `@pytest.mark.xfail`; update the decorators on the test_gemm_v2_large function accordingly to keep the same behavior but follow conventional ordering.
88-88: Add a comment explainingatol=1.0for the large-K test.For K=1024 with float16 inputs, accumulated rounding error per element can approach
K × ε_fp16 ≈ 1024 × 9.7e-4 ≈ 1.0, so the tolerance is intentional. A brief inline comment would prevent future readers from tightening it incorrectly.💬 Suggested change
- assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0) + # atol=1.0: with K=1024 fp16 inputs, accumulated rounding error ≈ K × ε_fp16 ≈ 1.0 + assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` at line 88, Add a brief inline comment next to the call to assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0) explaining that atol=1.0 is intentional because for K=1024 with float16 inputs the accumulated rounding error per element can approach K * ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute tolerance is required for this large-K test to avoid false failures.testing/python/metal/test_metal_gemm_v2_linux.py (1)
50-53: Misleading variable name and redundant target specification.Two minor issues:
- Line 50 sets
tvm.target.Target("metal")as a context manager and passestarget="metal"totilelang.loweron line 51. The context manager is redundant; the explicittarget=arg alone is sufficient (and matches the pattern in similar test files).- The return value of
tilelang.loweris namedartifact, butkernel_sourceis a property on the JIT kernel object (as shown intilelang/jit/kernel.py), not on a raw artifact. Naming itkernelwould be more accurate.🔧 Suggested cleanup
- with tvm.transform.PassContext(), tvm.target.Target("metal"): - artifact = tilelang.lower(func, target="metal") - - src_code = artifact.kernel_source + with tvm.transform.PassContext(): + kernel = tilelang.lower(func, target="metal") + + src_code = kernel.kernel_source🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 50 - 53, Remove the redundant PassContext target context and rename the returned value from tilelang.lower to reflect it's a JIT kernel: call tilelang.lower with the explicit target argument only (remove the tvm.transform.PassContext(), tvm.target.Target("metal") context manager) and rename the variable from artifact to kernel so you access kernel.kernel_source (i.e., locate the tilelang.lower call and the subsequent kernel_source access and update the target usage and variable name accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pyproject.toml`:
- Around line 33-35: The pyproject lower-bound for "apache-tvm-ffi" is too
permissive (">0.1.2") and allows Darwin installs of versions 0.1.3–0.1.5 that
contain the memory regression; update the constraint(s) for "apache-tvm-ffi" to
match requirements.txt by using a >=0.1.6 lower bound (e.g., change ">0.1.2" to
">=0.1.6" and ensure the Darwin-specific constraint "apache-tvm-ffi<0.1.8;
platform_system == 'Darwin'" remains consistent with that lower bound).
In `@src/op/copy.cc`:
- Around line 685-686: GetCopyInst can return CopyInst::kMetalSIMDGroup but
Lower() lacks a matching case, causing a fatal crash; add a case for
CopyInst::kMetalSIMDGroup in Lower() (in the function Lower) before the kNormal
branch that implements the appropriate lowering behavior for SIMD-group metal
stores (mirror the pattern used for other Metal-specific cases in Lower(),
dispatching to the SIMD-group store lowering path and returning the resultant
statement), ensuring the new case references CopyInst::kMetalSIMDGroup so
execution no longer falls through to LOG(FATAL).
In `@src/op/copy.h`:
- Around line 232-235: The Doxygen comment for CheckSIMDGroupStore is
incorrect/copy-pasted from CheckTMemStore; update the comment to describe that
CheckSIMDGroupStore(Target target) determines whether the Metal SIMD-group
(SIMD-group/warp/wavefront) store instruction is supported on the given Target
(i.e., checks for Metal SIMD group store capability), not tensor memory store
support; reference CheckSIMDGroupStore and, for comparison, CheckTMemStore to
ensure the wording distinguishes SIMD group store support from tensor memory
store support.
- Around line 274-277: The CopyNode::Lower() switch must handle
CopyInst::kMetalSIMDGroup: add a case in Lower() that calls
LowerSIMDGroupStore(T, analyzer) when GetCopyInst(...) returns kMetalSIMDGroup
(this path is produced when CheckSIMDGroupStore(target) is true) to avoid the
LOG(FATAL) fallthrough; then implement the missing LowerSIMDGroupStore(const
LowerArgs &T, arith::Analyzer *analyzer) in copy.cc mirroring the pattern of
existing store lowerers (use the same argument handling and memory access
lowering used by other Metal-specific store methods, perform bounds/stride
checks as done in related Lower*Store functions, and produce the appropriate
Stmt), ensuring the declaration in copy.h matches the definition.
In `@src/op/gemm.cc`:
- Around line 150-154: Comments referencing a hardcoded "16" are now stale
because kMPerWarp is target-dependent; update the comments near kMPerWarp,
TargetIsMetal, and the logic that mentions "m_warp*16" and "Each warp needs at
least 16 elements in M" to reflect the runtime variable (e.g., refer to "m_warp
* kMPerWarp" and "kMPerWarp elements") so they accurately describe behavior on
Metal and other targets.
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 70-71: The test test_metal_gemm_v2_larger currently calls
assert_metal_gemm_v2_codegen with parameters known to fail at runtime; update
the test to either mark it as an expected failure or explicitly verify
codegen-only success: add `@pytest.mark.xfail`(reason="TODO: codegen not support
float16x8") above test_metal_gemm_v2_larger to match the runtime test, or
replace the single assert_metal_gemm_v2_codegen call with a two-step check that
runs the codegen path (tilelang.lower/codegen) and asserts it succeeds while
keeping execution separately marked xfail; reference test_metal_gemm_v2_larger
and assert_metal_gemm_v2_codegen when making the change.
- Around line 22-34: The codegen test's matmul_gemm_v2 kernel is structurally
different from the runtime test; update the matmul_gemm_v2 in the Linux codegen
test so its symbols match the runtime test: change C_local to use
T.alloc_shared(..., scope="shared") (instead of T.alloc_fragment), remove/adjust
coalesced_width=2 on T.copy calls to match the runtime test's copies, and use
T.gemm_v2 (or make the runtime test use T.gemm if you choose that canonical API)
so the GEMM operator is identical; ensure these changes are applied to the
matmul_gemm_v2 definition so the codegen pre-flight validates the same kernel
that the runtime test executes.
- Line 32: The test is calling the old lowering path via T.gemm instead of the
new Metal path; update the call to use T.gemm_v2 so the codegen test exercises
the gemm_v2 lowering (replace the invocation T.gemm(A_shared, B_shared, C_local)
with T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.
In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 43-44: The code computes self.warp_rows and self.warp_cols via
integer division of warp_row_tiles//micro_size_x and
warp_col_tiles//micro_size_y which will silently truncate if inputs are not
divisible by the micro sizes (8); add validation in the same initializer or
before these assignments (e.g., in the MetalMacroGenerator constructor or method
that sets warp_row_tiles/warp_col_tiles) that raises a clear error if
warp_row_tiles % micro_size_x != 0 or warp_col_tiles % micro_size_y != 0,
mentioning the offending values, and only then compute self.warp_rows and
self.warp_cols as the integer quotient.
In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-58: The method get_kernel_source currently claims to return str
but may return None and ignores the kernel_only flag; change its signature to ->
str | None (or keep -> str but assert/raise if kernel_global_source is None) and
implement the kernel_only branch: if kernel_only is True return
self.kernel_global_source, otherwise return the full Metal source (compose or
return the attribute that holds the complete module/source such as
self.metal_source or self.full_source); ensure you reference get_kernel_source
and kernel_global_source and either assert kernel_global_source is not None
before returning a str or update callers/types to accept Optional[str].
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 22-23: The int() cast on potentially symbolic shapes self.M and
self.N will fail at runtime for PrimExpr; update the computation of
warp_row_tiles and warp_col_tiles (currently int(self.M // m_warp) and
int(self.N // n_warp)) to preserve symbolic expressions instead of forcing
Python ints—either remove the int() and keep self.M // m_warp and self.N //
n_warp, or use tir.floordiv/tvm.tir.floordiv to produce a PrimExpr;
alternatively, if a concrete int is required, guard with an isinstance check for
tir.IntImm before casting. Ensure you change both warp_row_tiles and
warp_col_tiles and keep references to m_warp and n_warp.
---
Outside diff comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 92-95: The ValueError raised in decouple_type_cast.py lists valid
buffer scopes but omits "metal.simdgroup"; update the error message (the raise
ValueError that references buffer.scope() and buffer.name) to include
"metal.simdgroup" in the Expected one of: list so it matches the scopes accepted
by is_local_buffer/other checks and clearly communicates valid scopes.
---
Nitpick comments:
In `@src/op/copy.cc`:
- Around line 639-645: The method CheckSIMDGroupStore incorrectly implies a
directional store but actually matches any simdgroup↔simdgroup transfer; rename
CheckSIMDGroupStore to CheckSIMDGroupCopy (and rename its lowering helper
LowerSIMDGroupStore to LowerSIMDGroupCopy or similarly) and update all call
sites to use the new names so the intent matches the implementation; keep the
same body (checking src.scope() == "metal.simdgroup" && dst.scope() ==
"metal.simdgroup") but change identifiers for consistency with
CheckLDSMCopy/CheckSTSMCopy naming conventions.
In `@src/op/gemm_py.cc`:
- Around line 136-137: Add a lightweight allowMetal() guard before returning
GemmInst::kMetalExp: implement allowMetal() similar to the other allow* helpers
(validate shapes, dtypes, and memory scope constraints expected by the Metal
path) and call it in the branch that currently checks TargetIsMetal(target);
only return GemmInst::kMetalExp when allowMetal() returns true, otherwise fall
back to the existing fallback/error path so unsupported Metal configs are
rejected early. Reference: the TargetIsMetal(target) branch and
GemmInst::kMetalExp; follow the validation pattern used by functions like
allowCuda()/allowCdna()/allowWgmma() to ensure consistency.
In `@src/op/gemm.h`:
- Line 42: The GemmInst enum currently uses an asymmetrical name and implicit
values; change the enumerator kMetalExp to match the Python name (e.g., kMETAL)
and make all enum members have explicit integer values matching the Python
IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0, kWGMMA = 1, kTCGEN5MMA
= 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type uint8_t, and update any
usages of GemmInst::kMetalExp to the new symbol to preserve ABI alignment with
tilelang/tileop/gemm/inst.py.
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 50-53: Remove the redundant PassContext target context and rename
the returned value from tilelang.lower to reflect it's a JIT kernel: call
tilelang.lower with the explicit target argument only (remove the
tvm.transform.PassContext(), tvm.target.Target("metal") context manager) and
rename the variable from artifact to kernel so you access kernel.kernel_source
(i.e., locate the tilelang.lower call and the subsequent kernel_source access
and update the target usage and variable name accordingly).
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 91-93: Remove the redundant torch.mps.is_available() guard in the
__main__ block so that tilelang.testing.main() is called unconditionally; the
tests themselves use the `@tilelang.testing.requires_metal` decorator (in this
file's test functions) to mark/skips tests when Metal is absent, so update the
__main__ section (the block checking __name__ == "__main__") to simply invoke
tilelang.testing.main() without wrapping it in torch.mps.is_available().
- Around line 80-83: Move the decorator order so the platform guard is
outermost: swap the two decorators on test_gemm_v2_large so
`@tilelang.testing.requires_metal` appears above `@pytest.mark.xfail`; update the
decorators on the test_gemm_v2_large function accordingly to keep the same
behavior but follow conventional ordering.
- Line 88: Add a brief inline comment next to the call to assert_gemm_v2(1024,
1024, 1024, 16, 16, 16, atol=1.0) explaining that atol=1.0 is intentional
because for K=1024 with float16 inputs the accumulated rounding error per
element can approach K * ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute
tolerance is required for this large-K test to avoid false failures.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
pyproject.tomlrequirements.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/gemm.ccsrc/op/gemm.hsrc/op/gemm_py.ccsrc/op/parallel.ccsrc/op/utils.htesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/transform/decouple_type_cast.pytilelang/utils/language.py
| "apache-tvm-ffi~=0.1.0,>0.1.2", | ||
| # apache/tvm-ffi#464 | ||
| "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'", |
There was a problem hiding this comment.
Lower-bound inconsistency with requirements.txt may expose the memory bug on Darwin.
requirements.txt keeps a general lower bound of >=0.1.6 (the fix for tilelang#1502), while pyproject.toml uses >0.1.2. On Darwin a user installing the wheel from PyPI sees:
| File | Effective Darwin range |
|---|---|
requirements.txt |
0.1.6 – 0.1.7 |
pyproject.toml |
0.1.3 – 0.1.7 |
Versions 0.1.3–0.1.5 are still selectable from pyproject.toml and still carry the memory regression. Unless broad compatibility is the explicit goal (the existing comment suggests it is), consider aligning the pyproject.toml lower bound with requirements.txt:
🔧 Proposed alignment
- "apache-tvm-ffi~=0.1.0,>0.1.2",
+ "apache-tvm-ffi~=0.1.0,>=0.1.6",📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "apache-tvm-ffi~=0.1.0,>0.1.2", | |
| # apache/tvm-ffi#464 | |
| "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'", | |
| "apache-tvm-ffi~=0.1.0,>=0.1.6", | |
| # apache/tvm-ffi#464 | |
| "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'", |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pyproject.toml` around lines 33 - 35, The pyproject lower-bound for
"apache-tvm-ffi" is too permissive (">0.1.2") and allows Darwin installs of
versions 0.1.3–0.1.5 that contain the memory regression; update the
constraint(s) for "apache-tvm-ffi" to match requirements.txt by using a >=0.1.6
lower bound (e.g., change ">0.1.2" to ">=0.1.6" and ensure the Darwin-specific
constraint "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" remains
consistent with that lower bound).
| } else if (CheckSIMDGroupStore(target)) { | ||
| return CopyInst::kMetalSIMDGroup; |
There was a problem hiding this comment.
Dispatch to kMetalSIMDGroup added in GetCopyInst but Lower() has no corresponding case — runtime crash.
GetCopyInst can now return CopyInst::kMetalSIMDGroup, but Lower() (line 694–726) does not handle this value. Execution will fall through to LOG(FATAL) << "Unsupported copy inst" at line 724, crashing at runtime.
Add a case in Lower() before the kNormal case:
🐛 Proposed fix in Lower() (around line 721)
} else if (copy_inst == CopyInst::kNormal) {
return LowerNormalCopy(T, analyzer);
+ } else if (copy_inst == CopyInst::kMetalSIMDGroup) {
+ return LowerSIMDGroupStore(T, analyzer);
} else {
LOG(FATAL) << "Unsupported copy inst " << static_cast<int>(copy_inst);
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/copy.cc` around lines 685 - 686, GetCopyInst can return
CopyInst::kMetalSIMDGroup but Lower() lacks a matching case, causing a fatal
crash; add a case for CopyInst::kMetalSIMDGroup in Lower() (in the function
Lower) before the kNormal branch that implements the appropriate lowering
behavior for SIMD-group metal stores (mirror the pattern used for other
Metal-specific cases in Lower(), dispatching to the SIMD-group store lowering
path and returning the resultant statement), ensuring the new case references
CopyInst::kMetalSIMDGroup so execution no longer falls through to LOG(FATAL).
| /*! | ||
| * \brief Check if tensor memory store is supported. | ||
| */ | ||
| bool CheckSIMDGroupStore(Target target) const; |
There was a problem hiding this comment.
Doxygen comment is copy-pasted from CheckTMemStore.
The brief says "Check if tensor memory store is supported" but this method checks for Metal SIMD group store support. Should be updated to reflect its actual purpose.
📝 Proposed fix
/*!
- * \brief Check if tensor memory store is supported.
+ * \brief Check if Metal SIMD group store is supported.
*/
bool CheckSIMDGroupStore(Target target) const;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| /*! | |
| * \brief Check if tensor memory store is supported. | |
| */ | |
| bool CheckSIMDGroupStore(Target target) const; | |
| /*! | |
| * \brief Check if Metal SIMD group store is supported. | |
| */ | |
| bool CheckSIMDGroupStore(Target target) const; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/copy.h` around lines 232 - 235, The Doxygen comment for
CheckSIMDGroupStore is incorrect/copy-pasted from CheckTMemStore; update the
comment to describe that CheckSIMDGroupStore(Target target) determines whether
the Metal SIMD-group (SIMD-group/warp/wavefront) store instruction is supported
on the given Target (i.e., checks for Metal SIMD group store capability), not
tensor memory store support; reference CheckSIMDGroupStore and, for comparison,
CheckTMemStore to ensure the wording distinguishes SIMD group store support from
tensor memory store support.
| /*! | ||
| * \brief Generate lowering for simdgroup store. | ||
| */ | ||
| Stmt LowerSIMDGroupStore(const LowerArgs &T, arith::Analyzer *analyzer) const; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
rg -n 'LowerSIMDGroupStore' --type=cppRepository: tile-ai/tilelang
Length of output: 158
🏁 Script executed:
cat -n src/op/copy.cc | head -750 | tail -100Repository: tile-ai/tilelang
Length of output: 5249
🏁 Script executed:
rg -n 'kMetalSIMDGroup' --type=cppRepository: tile-ai/tilelang
Length of output: 237
🏁 Script executed:
rg -A 10 'LowerSIMDGroupStore' --type=cpp src/op/copy.ccRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
sed -n '50,60p' src/op/copy.hRepository: tile-ai/tilelang
Length of output: 300
🏁 Script executed:
rg -B 5 -A 10 'kMetalSIMDGroup' src/op/copy.hRepository: tile-ai/tilelang
Length of output: 1072
Add dispatch case for kMetalSIMDGroup in CopyNode::Lower() and implement LowerSIMDGroupStore().
GetCopyInst() can return CopyInst::kMetalSIMDGroup when CheckSIMDGroupStore(target) is true, but Lower() (lines 694–726) has no case to handle it. The method will fall through to LOG(FATAL) at line 724. Additionally, LowerSIMDGroupStore() is declared in copy.h:277 but has no implementation in copy.cc.
You need to:
- Add a dispatch case in
Lower()forkMetalSIMDGroupthat callsLowerSIMDGroupStore() - Implement
LowerSIMDGroupStore()in copy.cc
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/copy.h` around lines 274 - 277, The CopyNode::Lower() switch must
handle CopyInst::kMetalSIMDGroup: add a case in Lower() that calls
LowerSIMDGroupStore(T, analyzer) when GetCopyInst(...) returns kMetalSIMDGroup
(this path is produced when CheckSIMDGroupStore(target) is true) to avoid the
LOG(FATAL) fallthrough; then implement the missing LowerSIMDGroupStore(const
LowerArgs &T, arith::Analyzer *analyzer) in copy.cc mirroring the pattern of
existing store lowerers (use the same argument handling and memory access
lowering used by other Metal-specific store methods, perform bounds/stride
checks as done in related Lower*Store functions, and produce the appropriate
Stmt), ensuring the declaration in copy.h matches the definition.
| int kMPerWarp = 16; // Rows processed by a single warp | ||
| if (TargetIsMetal(target)) { | ||
| kMPerWarp = 8; | ||
| } | ||
| int kNPerWarp = 8; // Columns processed by a single warp |
There was a problem hiding this comment.
Stale comments hardcode "16" after kMPerWarp became target-dependent
The arithmetic correctly reads the runtime kMPerWarp variable everywhere, but two nearby comments still reference the old constant "16" and will now mislead on Metal:
- Line 248:
// If M cannot be evenly divided by m_warp*16, try to split remaining warps - Line 278:
// Each warp needs at least 16 elements in M
📝 Suggested comment fixes
- // If M cannot be evenly divided by m_warp*16, try to split remaining warps
- // to N
+ // If M cannot be evenly divided by m_warp*kMPerWarp, try to split remaining warps
+ // to N- int max_m_warps = M / kMPerWarp; // Each warp needs at least 16 elements in M
+ int max_m_warps = M / kMPerWarp; // Each warp needs at least kMPerWarp elements in M🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/gemm.cc` around lines 150 - 154, Comments referencing a hardcoded "16"
are now stale because kMPerWarp is target-dependent; update the comments near
kMPerWarp, TargetIsMetal, and the logic that mentions "m_warp*16" and "Each warp
needs at least 16 elements in M" to reflect the runtime variable (e.g., refer to
"m_warp * kMPerWarp" and "kMPerWarp elements") so they accurately describe
behavior on Metal and other targets.
| T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) | ||
| T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) | ||
|
|
||
| T.gemm(A_shared, B_shared, C_local) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's check the test file in question to see what's actually being called
find . -name "test_metal_gemm_v2_linux.py" -type f | head -1 | xargs cat -nRepository: tile-ai/tilelang
Length of output: 3353
🏁 Script executed:
# Search for gemm and gemm_v2 definitions in the language
rg -n "def gemm\b|def gemm_v2\b" --type py tilelang/language/ -A 5Repository: tile-ai/tilelang
Length of output: 743
🏁 Script executed:
# Check what tl.tileop operations are defined and used
rg -n "tl\.tileop\.gemm" --type py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 1314
🏁 Script executed:
# Check runtime test mentioned in review (test_metal_gemm_v2.py)
find . -name "test_metal_gemm_v2.py" -type f | head -1 | xargs cat -n 2>/dev/null || echo "File not found"Repository: tile-ai/tilelang
Length of output: 3454
🏁 Script executed:
# Search for T.gemm and T.gemm_v2 usage patterns
rg -n "T\.gemm\(|T\.gemm_v2\(" --type py | head -20Repository: tile-ai/tilelang
Length of output: 2497
Replace T.gemm with T.gemm_v2 on line 32.
T.gemm and T.gemm_v2 are distinct operations with different lowering paths (tl.tileop.gemm vs. tl.tileop.gemm_py). This test is meant to verify Metal codegen for the new gemm_v2 path, but it calls T.gemm instead. The assertions for simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store may be validating against the wrong lowering backend and could pass for the wrong reason or miss actual Metal simdgroup codegen issues.
The runtime test (test_metal_gemm_v2.py line 34) correctly uses T.gemm_v2; align this codegen test with it.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_gemm_v2_linux.py` at line 32, The test is
calling the old lowering path via T.gemm instead of the new Metal path; update
the call to use T.gemm_v2 so the codegen test exercises the gemm_v2 lowering
(replace the invocation T.gemm(A_shared, B_shared, C_local) with
T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.
| def test_metal_gemm_v2_larger(): | ||
| assert_metal_gemm_v2_codegen(128, 128, 128, 32, 32, 32, dtype=T.float16) |
There was a problem hiding this comment.
test_metal_gemm_v2_larger may silently pass despite a known codegen limitation.
The runtime test test_gemm_v2_large in test_metal_gemm_v2.py tests the same parameters (128, 128, 128, 32, 32, 32, dtype=T.float16) and is marked @pytest.mark.xfail(reason="TODO: codegen not support float16x8"). If the codegen for this block configuration is genuinely broken, the Linux codegen test should also fail (and be expected to fail) for consistency. Either:
- Add
@pytest.mark.xfail(reason="TODO: codegen not support float16x8")here as well, or - Verify that the codegen-only path (
tilelang.lowerwithout execution) succeeds where execution fails.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 70 - 71, The
test test_metal_gemm_v2_larger currently calls assert_metal_gemm_v2_codegen with
parameters known to fail at runtime; update the test to either mark it as an
expected failure or explicitly verify codegen-only success: add
`@pytest.mark.xfail`(reason="TODO: codegen not support float16x8") above
test_metal_gemm_v2_larger to match the runtime test, or replace the single
assert_metal_gemm_v2_codegen call with a two-step check that runs the codegen
path (tilelang.lower/codegen) and asserts it succeeds while keeping execution
separately marked xfail; reference test_metal_gemm_v2_larger and
assert_metal_gemm_v2_codegen when making the change.
| self.warp_rows = warp_row_tiles // self.micro_size_x | ||
| self.warp_cols = warp_col_tiles // self.micro_size_y |
There was a problem hiding this comment.
Add validation that warp_row_tiles and warp_col_tiles are divisible by 8.
If warp_row_tiles or warp_col_tiles aren't multiples of micro_size_x/micro_size_y (8), the integer division would silently truncate, leading to incorrect tiling.
Suggested validation
# Number of 8x8 tiles per warp
+ assert warp_row_tiles % self.micro_size_x == 0, (
+ f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})"
+ )
+ assert warp_col_tiles % self.micro_size_y == 0, (
+ f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})"
+ )
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.warp_rows = warp_row_tiles // self.micro_size_x | |
| self.warp_cols = warp_col_tiles // self.micro_size_y | |
| # Number of 8x8 tiles per warp | |
| assert warp_row_tiles % self.micro_size_x == 0, ( | |
| f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})" | |
| ) | |
| assert warp_col_tiles % self.micro_size_y == 0, ( | |
| f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})" | |
| ) | |
| self.warp_rows = warp_row_tiles // self.micro_size_x | |
| self.warp_cols = warp_col_tiles // self.micro_size_y |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/intrinsics/metal_macro_generator.py` around lines 43 - 44, The code
computes self.warp_rows and self.warp_cols via integer division of
warp_row_tiles//micro_size_x and warp_col_tiles//micro_size_y which will
silently truncate if inputs are not divisible by the micro sizes (8); add
validation in the same initializer or before these assignments (e.g., in the
MetalMacroGenerator constructor or method that sets
warp_row_tiles/warp_col_tiles) that raises a clear error if warp_row_tiles %
micro_size_x != 0 or warp_col_tiles % micro_size_y != 0, mentioning the
offending values, and only then compute self.warp_rows and self.warp_cols as the
integer quotient.
| def get_kernel_source(self, kernel_only: bool = True) -> str: | ||
| return self.kernel_global_source | ||
|
|
There was a problem hiding this comment.
Fix the return type annotation and honour the kernel_only flag.
Two issues here:
-
Return type mismatch —
kernel_global_sourceis declaredstr | None(Line 23), so the method can returnNone, contradicting the-> strannotation. This will cause silent type errors for callers. -
Unused
kernel_onlyparameter — Every peer adapter branches on this flag (base.py,nvrtc/adapter.py,cython/adapter.py). Silently ignoring it here meansget_kernel_source(kernel_only=False)behaves identically tokernel_only=True, breaking the expected contract.
🛠️ Proposed fix
- def get_kernel_source(self, kernel_only: bool = True) -> str:
- return self.kernel_global_source
+ def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+ # Metal has a single unified source; kernel_only has no distinct meaning here.
+ return self.kernel_global_sourceIf a non-None guarantee is truly required at call sites, add an explicit assertion:
- def get_kernel_source(self, kernel_only: bool = True) -> str:
- return self.kernel_global_source
+ def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+ assert self.kernel_global_source is not None, "kernel_global_source is not available"
+ return self.kernel_global_source🧰 Tools
🪛 Ruff (0.15.1)
[warning] 56-56: Unused method argument: kernel_only
(ARG002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/jit/adapter/torch/metal.py` around lines 56 - 58, The method
get_kernel_source currently claims to return str but may return None and ignores
the kernel_only flag; change its signature to -> str | None (or keep -> str but
assert/raise if kernel_global_source is None) and implement the kernel_only
branch: if kernel_only is True return self.kernel_global_source, otherwise
return the full Metal source (compose or return the attribute that holds the
complete module/source such as self.metal_source or self.full_source); ensure
you reference get_kernel_source and kernel_global_source and either assert
kernel_global_source is not None before returning a str or update callers/types
to accept Optional[str].
| warp_row_tiles = int(self.M // m_warp) | ||
| warp_col_tiles = int(self.N // n_warp) |
There was a problem hiding this comment.
int() cast on potentially symbolic self.M / self.N.
self.M and self.N originate from buffer shapes and may be tir.IntImm or symbolic PrimExpr. If symbolic, int(self.M // m_warp) will raise at runtime. Other GEMM backends (e.g., GemmMMA) handle this similarly, so this is likely fine for Metal's concrete-size use case, but worth noting.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/gemm/gemm_metal.py` around lines 22 - 23, The int() cast on
potentially symbolic shapes self.M and self.N will fail at runtime for PrimExpr;
update the computation of warp_row_tiles and warp_col_tiles (currently
int(self.M // m_warp) and int(self.N // n_warp)) to preserve symbolic
expressions instead of forcing Python ints—either remove the int() and keep
self.M // m_warp and self.N // n_warp, or use tir.floordiv/tvm.tir.floordiv to
produce a PrimExpr; alternatively, if a concrete int is required, guard with an
isinstance check for tir.IntImm before casting. Ensure you change both
warp_row_tiles and warp_col_tiles and keep references to m_warp and n_warp.
|
why we need to introduce storage scope |
Metal GEMM v2: Direct Simdgroup Matrix Operations Without Fragment Layout
Summary
This PR adds a new GEMM implementation for Metal targets that uses
simdgroup_matrixoperations directly, bypassing TileLang's thread <-> register fragment layout system entirely. The accumulator lives in simdgroup registers throughout the entire k-loop, eliminating the costly shared memory round-trips that the previous approach required.Background: Simdgroup Matrix Operations in Metal
Metal provides simdgroup matrix operations as the only public API for accessing matrix acceleration hardware (analogous to NVIDIA's Tensor Cores). Unlike CUDA, which exposes multiple abstraction levels (
wmma,mma.syncPTX,wgmma), Metal offers a single, opaque interface through thesimdgroup_matrix<T, Rows, Cols>type:simdgroup_load(mat, ptr, stride)— cooperative load from device/threadgroup memorysimdgroup_store(mat, ptr, stride)— cooperative store to device/threadgroup memorysimdgroup_multiply_accumulate(D, A, B, C)— fused multiply-accumulate: D = A × B + CThese operate on 8×8 matrix tiles and are executed collectively by a simdgroup (32 threads, equivalent to a CUDA warp). The key distinction from CUDA is that the hardware-internal data distribution across threads is completely opaque — there is no documented per-thread register layout, and no lower-level PTX-like alternative exists. This makes simdgroup operations the sole mechanism for matrix-accelerated computation on Apple GPUs.
Main Changes
New GEMM backend for Metal (
GemmMetal)GemmInst::kMetalExp/GemmInst.METALas a new GEMM instruction typeGemmMetalclass that lowersT.gemm/T.gemm_v2to Metal simdgroup intrinsicsshared → simdgroup_load → simdgroup_multiply_accumulate → simdgroup_store → shared/globalMPSIntrinEmitter(metal_macro_generator.py)simdgroup_load,simdgroup_multiply_accumulate, andsimdgroup_store(warp_m, warp_n)and tile position(i, j)simdgroup_copymethod for both load and store directions_parse_buffer_2dutility for Buffer/BufferRegion handlingInfrastructure changes
metal.simdgroupscope recognized as a local buffer type inIsFragmentBuffer,is_local_buffer, anddecouple_type_castCopyInst::kMetalSIMDGroupadded for future lowering fromT.copytosimdgroup_{load,store}kMPerWarp=8(matching 8×8 simdgroup matrix size)parallel.cc: guardas<Fragment>()cast for non-fragment layouts to avoid crashWhy No Layout Inference Is Needed
TileLang's layout system maps individual threads to individual register elements — this is essential for CUDA's
wmma/mmainstructions where software must precisely control which thread holds which data fragment, andldmatrix/stmatrixinstructions require knowing the per-thread address pattern.Metal's simdgroup matrix operations are fundamentally different: they are opaque, collective operations where the entire simdgroup (32 threads) cooperates and the hardware decides the internal data distribution. The programmer never needs to know (and cannot control) which thread holds which element of an 8×8 matrix. Therefore:
ldmatrix— each thread provides an address based on its lane positionsimdgroup_load(mat, ptr, stride)— hardware handles everythingmma— thread <-> register mapping must match instruction encodingsimdgroup_multiply_accumulate(D, A, B, C)— hardware handles everythingstmatrix— inverse of load layoutsimdgroup_store(mat, ptr, stride)— hardware handles everythingBy using simdgroup operations directly, we also eliminate the shared memory round-trip that the previous implementation required to convert between the opaque simdgroup layout and TileLang's fragment layout on every k-iteration.
Why Only
shared × shared -> shared/globalGEMMThis PR implements only the
smem × smem = smem/global(SS) case because:Metal
simdgroup_loadreads from device/threadgroup memory — it cannot read from registers or other opaque sources. Both A and B operands must come from addressable memory (shared or global), making the SS pattern the natural match.simdgroup_storecan write to both shared and global memory — the result can go directly to global memory without an intermediate shared memory buffer, which is what the frontend exploits:No register↔register GEMM — unlike CUDA where fragments can be loaded from registers and composed in various ways, Metal simdgroup operations always read from memory. Fragment-scope (register-level) operands would require materializing to shared memory first, which is exactly what the SS pattern already handles.
Global operands go through shared memory anyway — the standard TileLang pattern
T.copy(A_global, A_shared)followed byT.gemm(A_shared, B_shared, C)already stages global data through shared memory, so supportingglobal × globaldirectly adds no benefit.Test Coverage
test_metal_gemm_v2.py: End-to-end correctness tests on Metal hardware (requires macOS + MPS), comparing againsttorch.matmultest_metal_gemm_v2_linux.py: Codegen-only tests verifiable on any platform, checking that generated Metal shader source contains expected simdgroup operationsSummary by CodeRabbit
New Features
Tests
Bug Fixes
Dependencies