perf(jit): lean cache-key fast path for @flyc.jit launches#669
perf(jit): lean cache-key fast path for @flyc.jit launches#669fsx950223 wants to merge 3 commits into
Conversation
Every @flyc.jit launch rebuilds its cache key to probe _call_state_cache
(polymorphic dispatch — unavoidable). Profiling the softmax launcher showed
~50us of that ~108us host overhead was _resolve_and_make_cache_key, dominated
by two avoidable costs:
1. building a full TensorAdaptor per tensor (DLPack export) just to read a
cache signature — ~6.6us/tensor;
2. isinstance(arg, JitArgument) where JitArgument is a runtime_checkable
Protocol → structural check ~2.2us/arg.
Add a lean fast probe that produces a BYTE-IDENTICAL key without either cost:
- TensorAdaptor.lean_cache_signature(t): derives the signature directly from
dtype/shape/strides (no DLPack); the full adaptor is built only on a cache
miss when actually compiling.
- JitFunction._fast_cache_key + _build_fast_key_plan: per-param plan computed
once; uses lean tensor sig and a cheap hasattr("__cache_signature__") instead
of the Protocol isinstance; does NOT mutate bound args.
- __call__ probes with the lean key first; the cache-hit path passes raw args
straight to CallState (whose slot extractors already accept raw tensors). The
miss path is unchanged (full resolve + compile), and its key is identical to
the lean key so the CallState it caches is found by the fast probe.
Measured (softmax 8192^2, MI308X): key build 52us -> 8.6us (6x); full __call__
on a launch-bound shape 108us -> 50us. Helps the polymorphic path (varying
shapes, where flyc.compile can't be used); GPU-bound kernels were already
hiding this behind execution.
tests/unit/test_lean_cache_signature.py asserts lean_cache_signature ==
TensorAdaptor.__cache_signature__ and _fast_cache_key == _build_full_cache_key
across dtypes/ranks/strides (incl. fp8, transposed, unit-size, non-pow2).
Verified test_softmax.py and test_pa.py pass.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR optimizes the @flyc.jit launch hot path by adding a lean cache-key probe that avoids per-launch TensorAdaptor/DLPack construction and avoids runtime_checkable Protocol structural isinstance checks, while ensuring the fast key is byte-identical to the full cache key.
Changes:
- Add
TensorAdaptor.lean_cache_signature(t)to derive tensor cache signatures from dtype/shape/strides without DLPack. - Add
JitFunction._fast_cache_key+ a lazily-built per-signature plan, and probe_call_state_cachewith the lean key before doing full resolution/compilation. - Add GPU unit tests validating
lean_cache_signatureand_fast_cache_keymatch the full-path signatures/keys.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tests/unit/test_lean_cache_signature.py | Adds regression tests ensuring lean tensor signatures and fast cache keys match the full path. |
| python/flydsl/compiler/jit_function.py | Introduces fast cache-key planning/building and uses it for the cache-hit probe in __call__. |
| python/flydsl/compiler/jit_argument.py | Adds TensorAdaptor.lean_cache_signature to avoid DLPack export on cache-hit probing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Arg already a JitArgument (cheap duck-check == the Protocol isinstance). | ||
| if hasattr(arg, "__cache_signature__"): | ||
| parts.append((name, arg.__cache_signature__())) | ||
| continue |
| strides = tuple(int(s) for s in t.stride()) | ||
| rank = len(strides) | ||
| unit_axis = next((i for i, s in enumerate(strides) if s == 1), None) | ||
| if unit_axis is None: | ||
| raise RuntimeError("tensor has no axis with stride == 1; layout-dynamic memref requires one") |
…k comment
Boundary review of the lean cache-key fast path. Verified
lean_cache_signature stays byte-identical to TensorAdaptor.__cache_signature__
across the layouts where the framework and DLPack stride views can disagree
(DLPack coerces unit/zero-size strides) — the lean path follows the framework
view via _pick_unit_stride_axis, matching the full path:
broadcast (stride 0), 3d broadcast, leading/mid/all size-1 dims, trailing
size-1 (ambiguous unit axis), permuted, channels_last, strided rows,
zero-size dims, fp8 transposed.
Also add test_lean_and_full_reject_no_unit_stride_consistently: a tensor with
no stride-1 axis cannot be a layout-dynamic memref; both the full path
(TensorAdaptor.__init__) and the lean path must reject it, so the fast probe
never silently dispatches a tensor the full path would refuse.
Fix the _fast_cache_key comment that claimed hasattr("__cache_signature__") is
equivalent to the runtime_checkable isinstance(arg, JitArgument): the Protocol
also probes __get_ir_types__/__get_c_pointers__. They select the same branch
only because every __cache_signature__ implementer in the codebase is a
complete JitArgument — state that instead of claiming exact equivalence.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Follow-up to the lean cache-key fast path. After the key build dropped to
~8.6us, profiling the softmax launcher (256x512, MI308X) showed the next
host-overhead lever was _check_globals_drift at ~7.7us — 16% of the ~50us
cache-hit __call__ — run on every launch.
The cost is re-summarizing every captured global via _snapshot_global_value
each call. For a typical kernel most captured "globals" are modules, imported
helper functions, and enum/classes (softmax captures 9 refs: 7 are
module/function/type, 1 EnumType, only BLOCK_THREADS is a scalar). These are
summarized purely by id() — and the callable branch additionally builds
repr(val) every call.
For such identity-stable values an `is` comparison against the baseline object
is exactly equivalent to (and far cheaper than) re-summarizing the snapshot:
the snapshot's only discriminant is the identity. Holding the baseline object
also keeps it alive, eliminating the id-reuse aliasing that comparing stored
id() snapshots is theoretically prone to.
- _is_identity_stable(val): mirrors the id()-based ("callable"/"obj") branches
of _snapshot_global_value(stable=False); False for scalars and builtin
containers (those stay on the by-value path so in-place mutation is caught).
- _snapshot_refs_for_drift(refs): drift baseline now maps (name, mod) ->
(snapshot, fastref), fastref being the live object for identity-stable values
else None.
- _check_globals_drift: short-circuits when fastref is not None and the global
is still bound to it; otherwise falls through to the unchanged full compare.
Measured: _check_globals_drift 7.7us -> 1.57us (5x); full __call__ ~49.6us ->
~40.5us (min-of-8). Drift semantics unchanged: rebinding a scalar OR an
identity-stable function/module still raises, in-place container mutation still
raises, no false positives after restore.
test_drift_identity_stable_fastpath_and_container_mutation covers the is-hit,
is-miss-rebind, and in-place-container-mutation paths; the per-owner-cls drift
test is updated to the new baseline shape via _snapshot_refs_for_drift.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
I’m not fully convinced this is the right layer to optimize. Passing raw |
What's the purpose of using dlpack here? |
Summary
Every
@flyc.jitlaunch rebuilds its cache key to probe_call_state_cache(polymorphic dispatch — unavoidable). Profiling the softmax launcher showed ~50us of host overhead per cache-hit launch. This PR drives that down in three steps.1. Lean cache-key fast path (
34b5f48)_resolve_and_make_cache_keywas dominated by two avoidable costs:TensorAdaptorper tensor (DLPack export) just to read a cache signature — ~6.6us/tensor;isinstance(arg, JitArgument)whereJitArgumentis aruntime_checkableProtocol → structural check ~2.2us/arg.Added a lean fast probe that produces a byte-identical key without either cost:
TensorAdaptor.lean_cache_signature(t)— derives the signature directly from dtype/shape/strides (no DLPack); the full adaptor is built only on a cache miss when actually compiling.JitFunction._fast_cache_key+_build_fast_key_plan— per-param plan computed once; lean tensor sig + a cheaphasattr("__cache_signature__")instead of the Protocolisinstance; does NOT mutate bound args.__call__probes with the lean key first; the cache-hit path passes raw args straight toCallState. The miss path is unchanged and its key is identical, so the cachedCallStateis found by the fast probe.2. Boundary hardening (
e511c9f)Verified
lean_cache_signaturestays byte-identical toTensorAdaptor.__cache_signature__across the layouts where the framework and DLPack stride views can disagree (DLPack coerces unit/zero-size strides) — the lean path follows the framework view via_pick_unit_stride_axis, same as the full path:Plus
test_lean_and_full_reject_no_unit_stride_consistently: a tensor with no stride-1 axis is rejected by both paths, so the fast probe never silently dispatches a tensor the full path would refuse. Also corrected the_fast_cache_keycomment that overclaimedhasattr(...)≡ the 3-methodisinstance(arg, JitArgument).3. Drift-check optimization (
e33eca2)After the key build dropped, the next lever was
_check_globals_drift(~7.7us, 16% of the cache-hit__call__), which re-summarizes every captured global per launch. Most captured "globals" are modules / imported helpers / classes (softmax: 7 of 9 refs), summarized purely byid()— and the callable branch buildsrepr(val)every call.For these identity-stable values an
iscompare against the baseline object is exactly equivalent to re-summarizing (the snapshot's only discriminant is identity), and holding the object alive also removes id-reuse aliasing. Scalars and containers stay on the by-value path so in-place mutation is still caught.Measurements (softmax 8192² / 256², MI308X)
_check_globals_drift__call__(cache hit)Helps the polymorphic launch path (varying shapes, where
flyc.compilecan't be used); GPU-bound kernels were already hiding this behind execution.Tests
tests/unit/test_lean_cache_signature.py— lean sig ≡__cache_signature__and_fast_cache_key≡_build_full_cache_keyacross dtypes/ranks/strides incl. all the boundary layouts above.tests/unit/test_jit_cache_key_completeness.py—test_drift_identity_stable_fastpath_and_container_mutationcovers theis-hit,is-miss-rebind, and in-place-container-mutation paths; existing drift tests updated to the new baseline shape.test_softmax.pypasses. (Note:test_tensor_cache_signature.pyfailures in my local run are a pre-existing stale-_mlir-binary mismatch, unrelated to these Python-only changes.)🤖 Generated with Claude Code