Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions python/flydsl/compiler/jit_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,38 @@ def __get_c_pointers__(self):
def __cache_signature__(self):
return (type(self), self._orig_dtype) + self.tensor_adaptor.get_cache_signature()

@classmethod
def lean_cache_signature(cls, t):
"""Cache signature identical to ``cls(t).__cache_signature__()`` but
derived directly from ``dtype``/``shape``/``strides`` — WITHOUT the
DLPack export + ``DLTensorAdaptor`` construction that ``__init__`` does.

Used on the JIT cache-key fast path (cache hit), where only the key is
needed; the full adaptor is built lazily on a cache miss when the kernel
is actually compiled/launched. Must stay byte-identical to
``__cache_signature__()`` for the default construction
(``use_32bit_stride=False``, ``dynamic_layout=True``) — guarded by
``tests/unit/test_lean_cache_signature.py``.

``get_cache_signature()`` returns
``(elem_bytes, use_32bit_stride, dynamic_dims, unit_stride_marker)``:
all dims dynamic (-1) and a 1 at the first unit-stride axis
(``_pick_unit_stride_axis``), the rest -1.
"""
strides = tuple(int(s) for s in t.stride())
rank = len(strides)
unit_axis = next((i for i, s in enumerate(strides) if s == 1), None)
if unit_axis is None:
raise RuntimeError("tensor has no axis with stride == 1; layout-dynamic memref requires one")
Comment on lines +328 to +332
return (
cls,
t.dtype,
t.element_size(),
False,
(-1,) * rank,
tuple(1 if i == unit_axis else -1 for i in range(rank)),
)

def _mark_layout_dynamic(self, leading_dim: int, divisibility: int):
# Always pass a concrete axis index down. The DLPack stride view that
# the backend sees can disagree with the framework view for tensors
Expand Down
169 changes: 158 additions & 11 deletions python/flydsl/compiler/jit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from ..utils import env, log
from .ast_rewriter import ASTRewriter
from .backends import compile_backend_name, get_backend
from .jit_argument import convert_to_jit_arguments, is_type_param_annotation, resolve_signature
from .jit_argument import (
TensorAdaptor,
convert_to_jit_arguments,
is_type_param_annotation,
resolve_signature,
)
from .jit_executor import CompiledArtifact
from .kernel_function import (
CompilationContext,
Expand Down Expand Up @@ -279,6 +284,45 @@ def _snapshot_refs(refs: List[Tuple[str, str, dict]], *, stable: bool) -> Dict[T
return out


def _is_identity_stable(val) -> bool:
"""True for values that ``_snapshot_global_value(stable=False)`` summarizes
purely by ``id()`` — i.e. the ``"callable"`` and opaque ``"obj"`` branches.

For these, an ``is`` comparison against the baseline object is exactly
equivalent to re-summarizing and comparing the snapshot (the snapshot's only
discriminant is the identity), so the drift check can skip the resummary —
including the per-call ``repr(val)`` that the callable branch builds. Scalars
and builtin containers are summarized *by value* (to catch in-place mutation),
so they stay on the full-compare path.
"""
if isinstance(val, (int, float, bool, str, bytes, type(None))):
return False
if isinstance(val, (tuple, list, set, frozenset, dict)):
return False
return True


def _snapshot_refs_for_drift(refs: List[Tuple[str, str, dict]]) -> Dict[Tuple[str, str], Any]:
"""Baseline for in-process drift detection: ``(name, mod) -> (snapshot, fastref)``.

``snapshot`` is the ``stable=False`` summary (as in :func:`_snapshot_refs`);
``fastref`` is the live baseline object for identity-stable values (modules,
functions, classes, opaque objects) or ``None`` otherwise. Holding the
reference lets :meth:`JitFunction._check_globals_drift` short-circuit with an
``is`` check — and, because the object is kept alive, eliminates the id-reuse
aliasing that comparing stored ``id()`` snapshots is theoretically prone to.
"""
out: Dict[Tuple[str, str], Any] = {}
for name, mod_name, var_dict in refs:
if name in var_dict:
val = var_dict[name]
out[(name, mod_name)] = (
_snapshot_global_value(val, stable=False),
val if _is_identity_stable(val) else None,
)
return out


class FlyDSLCompileError(RuntimeError):
"""Raised when an MLIR pass pipeline fails.

Expand Down Expand Up @@ -1241,6 +1285,7 @@ def __init__(self, func: Callable, compile_hints: Optional[dict] = None):
self._manager_owner_cls = None
self.cache_manager = None
self._call_state_cache = {} # cache_key -> CallState
self._fast_key_plan = None # lazily-built per-param key-extraction plan
self._sig = None # lazy: set on first call
self._has_self_param = False # lazy: set in _ensure_sig
self._backend_target = None # lazy: GPUTarget resolved once in _ensure_sig
Expand Down Expand Up @@ -1282,8 +1327,14 @@ def _check_globals_drift(self, owner_cls=None) -> None:
baseline = self._used_global_vals[owner_cls]
for name, mod_name, var_dict in self._get_global_refs(owner_cls):
key = (name, mod_name)
old = baseline.get(key, _NOT_IN_BASELINE)
if old is _NOT_IN_BASELINE:
entry = baseline.get(key, _NOT_IN_BASELINE)
if entry is _NOT_IN_BASELINE:
continue
old, fastref = entry
# Identity-stable global (module/function/class/opaque object) that is
# still bound to the same object: nothing to re-summarize. Equivalent
# to ``new == old`` for these, but skips the resummary + repr() build.
if fastref is not None and var_dict.get(name) is fastref:
continue
new = _snapshot_global_value(var_dict[name], stable=False) if name in var_dict else None
if new != old:
Expand Down Expand Up @@ -1420,6 +1471,93 @@ def _build_full_cache_key(self, bound_arguments, *, owner_cls=None, bound_self=N
cache_key = (("_self_type_", type(bound_self)),) + cache_key
return cache_key

def _build_fast_key_plan(self):
"""Per-parameter classification, computed once from the signature.

Each entry is ``(kind, name)`` where ``kind`` mirrors the branch
``_resolve_and_make_cache_key`` would take based on the *annotation*:
``"cx"`` (Constexpr), ``"tp"`` (Type[...] param), ``"ann"`` (annotation
is a JitArgument subclass, e.g. ``Stream``/``Int32``), or ``"reg"``
(registry-resolved, e.g. tensors). The remaining per-call decision
(arg already a JitArgument? tensor → lean signature?) is made cheaply in
``_fast_cache_key`` without the runtime_checkable-Protocol ``isinstance``.
"""
plan = []
empty = inspect.Parameter.empty
for name, param in self._sig.parameters.items():
ann = param.annotation
if ann is not empty and Constexpr.is_constexpr_annotation(ann):
plan.append(("cx", name))
elif ann is not empty and is_type_param_annotation(ann):
plan.append(("tp", name))
elif isinstance(ann, type) and issubclass(ann, JitArgument):
plan.append(("ann", name, ann))
else:
plan.append(("reg", name))
return plan

def _fast_cache_key(self, bound_arguments, *, owner_cls=None, bound_self=None):
"""Lean equivalent of ``_build_full_cache_key`` for the cache-hit probe.

Produces a **byte-identical** key but avoids the two dominant per-launch
costs measured on the hot path: (1) building a full ``TensorAdaptor``
(DLPack export) per tensor just to read a signature — replaced by
``TensorAdaptor.lean_cache_signature``; (2) the ``runtime_checkable``
``isinstance(arg, JitArgument)`` Protocol structural check — replaced by
a cheap ``hasattr(arg, "__cache_signature__")``. Unlike
``_resolve_and_make_cache_key`` it does NOT mutate ``bound_arguments``;
the full resolution still runs on a cache miss.
"""
from .jit_argument import JitArgumentRegistry

if self._fast_key_plan is None:
self._fast_key_plan = self._build_fast_key_plan()

parts = list(self._globals_key_prefix(owner_cls))
parts.append(("_env_", _cache_invalidating_env_values()))
parts.append(("_target_", self._backend_target))
if self.compile_hints:
parts.append(("_hints_", tuple(sorted((k, str(v)) for k, v in self.compile_hints.items()))))

for entry in self._fast_key_plan:
kind = entry[0]
name = entry[1]
arg = bound_arguments[name]
if kind == "cx":
parts.append((name, Constexpr.value_signature(arg)))
continue
if kind == "tp":
parts.append((name, arg))
continue
# Arg already a JitArgument: duck-check for the only method the key
# needs, instead of the full runtime_checkable-Protocol isinstance
# (which also probes __get_ir_types__/__get_c_pointers__). Every type
# that implements __cache_signature__ in the codebase is a complete
# JitArgument, so this selects the same branch as the full path while
# skipping the ~2.2us structural check.
if hasattr(arg, "__cache_signature__"):
parts.append((name, arg.__cache_signature__()))
continue
if kind == "ann":
parts.append((name, cache_signature(entry[2](arg))))
continue
# registry-resolved: tensors take the no-DLPack lean signature.
ctor, _ = JitArgumentRegistry.get(type(arg))
if ctor is None:
raise TypeError(
f"{name}: {type(arg).__name__} is neither a JitArgument nor has a registered "
f"constructor; cannot derive cache signature."
)
if ctor is TensorAdaptor:
parts.append((name, TensorAdaptor.lean_cache_signature(arg)))
else:
parts.append((name, cache_signature(ctor(arg))))

cache_key = tuple(parts)
if bound_self is not None:
cache_key = (("_self_type_", type(bound_self)),) + cache_key
return cache_key

@staticmethod
def _cache_key_to_str(cache_key) -> str:
"""Convert tuple cache key to string for disk cache."""
Expand All @@ -1442,29 +1580,38 @@ def __call__(self, *args, **kwargs):
# snapshot the used globals on first compile (per owner_cls) and RAISE on
# any later change.
if owner_cls not in self._used_global_vals:
self._used_global_vals[owner_cls] = _snapshot_refs(self._get_global_refs(owner_cls), stable=False)
self._used_global_vals[owner_cls] = _snapshot_refs_for_drift(self._get_global_refs(owner_cls))
else:
self._check_globals_drift(owner_cls)

sig = self._sig
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

cache_key = self._build_full_cache_key(bound.arguments, owner_cls=owner_cls, bound_self=bound_self)

args_tuple = tuple(bound.arguments.values())

# Compile/runtime pairing at JIT entry (not in CompiledArtifact / ExecutionEngine init).
from ..runtime.device_runtime import ensure_compile_runtime_pairing_from_env

ensure_compile_runtime_pairing_from_env(compile_backend_name())

# Fast path: reuse pre-built CallState (no ctypes alloc, no DLPack)
call_state = self._call_state_cache.get(cache_key)
# Fast path: probe with a LEAN cache key (byte-identical to the full
# key, but no per-tensor DLPack/TensorAdaptor and no runtime_checkable
# Protocol isinstance — see _fast_cache_key). CallState's slot
# extractors accept raw tensors, so the TensorAdaptor wrapping is only
# needed on a miss; pass the raw bound args straight through.
fast_key = self._fast_cache_key(bound.arguments, owner_cls=owner_cls, bound_self=bound_self)
call_state = self._call_state_cache.get(fast_key)
if call_state is not None:
if env.compile.compile_only:
return None
return call_state(args_tuple)
return call_state(tuple(bound.arguments.values()))

# Miss: do the full resolution (mutates bound.arguments into JitArgument
# instances for the compile path below). The resulting key is identical
# to fast_key, so the CallState cached here is found by the fast probe on
# subsequent calls.
cache_key = self._build_full_cache_key(bound.arguments, owner_cls=owner_cls, bound_self=bound_self)

args_tuple = tuple(bound.arguments.values())

# Normal path: check in-process cache first, then optional disk cache.
# In run_only mode the disk cache is read regardless of enable_cache, since
Expand Down
43 changes: 40 additions & 3 deletions tests/unit/test_jit_cache_key_completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_drift_baseline_is_per_owner_cls():
against its own baseline, not the first owner's. Otherwise a global seen only
under the second owner is skipped (not in the first owner's baseline) and a
later mutation silently reuses the memoized key segment instead of raising."""
from flydsl.compiler.jit_function import _snapshot_refs
from flydsl.compiler.jit_function import _snapshot_refs_for_drift

@flyc.jit
def launch(A: fx.Tensor):
Expand All @@ -249,11 +249,48 @@ class B:

launch._global_refs_cache[A] = [("FOO", "m", g)]
launch._global_refs_cache[B] = [("BAR", "m", g)]
launch._used_global_vals[A] = _snapshot_refs(launch._global_refs_cache[A], stable=False)
launch._used_global_vals[B] = _snapshot_refs(launch._global_refs_cache[B], stable=False)
launch._used_global_vals[A] = _snapshot_refs_for_drift(launch._global_refs_cache[A])
launch._used_global_vals[B] = _snapshot_refs_for_drift(launch._global_refs_cache[B])

g["BAR"] = 2 # mutate a global seen only under owner B

launch._check_globals_drift(A) # A's baseline (FOO) unchanged → no raise
with pytest.raises(RuntimeError, match="BAR"):
launch._check_globals_drift(B) # B's own baseline catches the BAR drift


def test_drift_identity_stable_fastpath_and_container_mutation():
"""The drift check short-circuits identity-stable globals (modules / functions
/ classes) with an ``is`` compare, but must still: (1) raise when such a global
is rebound to a different object, and (2) catch *in-place* mutation of a
by-value container, which keeps its identity and so must not be fast-pathed."""
from flydsl.compiler.jit_function import _snapshot_refs_for_drift

@flyc.jit
def launch(A: fx.Tensor):
return A

def helper_a():
return 1

def helper_b():
return 2

g = {"__name__": "m", "fn": helper_a, "lst": [1, 2]}
launch._global_refs_cache[None] = [("fn", "m", g), ("lst", "m", g)]
launch._used_global_vals[None] = _snapshot_refs_for_drift(launch._global_refs_cache[None])

# Identity-stable + unchanged → fast ``is`` hit, no raise.
launch._check_globals_drift(None)

# Identity-stable rebound to a different object → ``is`` misses, full compare raises.
g["fn"] = helper_b
with pytest.raises(RuntimeError, match="fn"):
launch._check_globals_drift(None)
g["fn"] = helper_a
launch._check_globals_drift(None) # restored → no raise

# In-place mutation of a by-value container (same identity) must still be caught.
g["lst"].append(3)
with pytest.raises(RuntimeError, match="lst"):
launch._check_globals_drift(None)
Loading
Loading