Skip to content

Commit 712dd9e

Browse files
committed
ENH: jax_autojit
1 parent 4425d14 commit 712dd9e

File tree

8 files changed

+628
-104
lines changed

8 files changed

+628
-104
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

+259-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,24 @@
22

33
from __future__ import annotations
44

5+
import io
56
import math
6-
from collections.abc import Generator, Iterable
7+
import pickle
8+
import types
9+
from collections.abc import Callable, Generator, Iterable
10+
from functools import wraps
711
from types import ModuleType
8-
from typing import TYPE_CHECKING, cast
12+
from typing import (
13+
TYPE_CHECKING,
14+
Any,
15+
ClassVar,
16+
Generic,
17+
Literal,
18+
ParamSpec,
19+
TypeAlias,
20+
TypeVar,
21+
cast,
22+
)
923

1024
from . import _compat
1125
from ._compat import (
@@ -19,8 +33,16 @@
1933
from ._typing import Array
2034

2135
if TYPE_CHECKING: # pragma: no cover
22-
# TODO import from typing (requires Python >=3.13)
23-
from typing_extensions import TypeIs
36+
# TODO import from typing (requires Python >=3.12 and >=3.13)
37+
from typing_extensions import TypeIs, override
38+
else:
39+
40+
def override(func):
41+
return func
42+
43+
44+
P = ParamSpec("P")
45+
T = TypeVar("T")
2446

2547

2648
__all__ = [
@@ -29,8 +51,11 @@
2951
"eager_shape",
3052
"in1d",
3153
"is_python_scalar",
54+
"jax_autojit",
3255
"mean",
3356
"meta_namespace",
57+
"pickle_flatten",
58+
"pickle_unflatten",
3459
]
3560

3661

@@ -306,3 +331,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
306331
out["boolean indexing"] = True
307332
out["data-dependent shapes"] = True
308333
return out
334+
335+
336+
_BASIC_PICKLED_TYPES = frozenset((
337+
bool, int, float, complex, str, bytes, bytearray,
338+
list, tuple, dict, set, frozenset, range, slice,
339+
types.NoneType, types.EllipsisType,
340+
)) # fmt: skip
341+
_BASIC_REST_TYPES = frozenset((
342+
type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType
343+
)) # fmt: skip
344+
345+
FlattenRest: TypeAlias = tuple[object, ...]
346+
347+
348+
def pickle_flatten(
349+
obj: object, cls: type[T] | tuple[type[T], ...]
350+
) -> tuple[list[T], FlattenRest]:
351+
"""
352+
Use the pickle machinery to extract objects out of an arbitrary container.
353+
354+
Unlike regular ``pickle.dumps``, this function always succeeds.
355+
356+
Parameters
357+
----------
358+
obj : object
359+
The object to pickle.
360+
cls : type | tuple[type, ...]
361+
One or multiple classes to extract from the object.
362+
The instances of these classes inside ``obj`` will not be pickled.
363+
364+
Returns
365+
-------
366+
instances : list[cls]
367+
All instances of ``cls`` found inside ``obj`` (not pickled).
368+
rest
369+
Opaque object containing the pickled bytes plus all other objects where
370+
``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised.
371+
These are unpickleable objects, types, modules, and functions.
372+
373+
This object is *typically* hashable save for fairly exotic objects
374+
that are neither pickleable nor hashable.
375+
376+
This object is pickleable if everything except ``instances`` was pickleable
377+
in the input object.
378+
379+
See Also
380+
--------
381+
pickle_unflatten : Reverse function.
382+
383+
Examples
384+
--------
385+
>>> class A:
386+
... def __repr__(self):
387+
... return "<A>"
388+
>>> class NS:
389+
... def __repr__(self):
390+
... return "<NS>"
391+
... def __reduce__(self):
392+
... assert False, "not serializable"
393+
>>> obj = {1: A(), 2: [A(), NS(), A()]}
394+
>>> instances, rest = pickle_flatten(obj, A)
395+
>>> instances
396+
[<A>, <A>, <A>]
397+
>>> pickle_unflatten(instances, rest)
398+
{1: <A>, 2: [<A>, <NS>, <A>]}
399+
400+
This can be also used to swap inner objects; the only constraint is that
401+
the number of objects in and out must be the same:
402+
403+
>>> pickle_unflatten(["foo", "bar", "baz"], rest)
404+
{1: "foo", 2: ["bar", <NS>, "baz"]}
405+
"""
406+
instances: list[T] = []
407+
rest: list[object] = []
408+
409+
class Pickler(pickle.Pickler): # numpydoc ignore=GL08
410+
"""
411+
Use the `pickle.Pickler.persistent_id` hook to extract objects.
412+
"""
413+
414+
@override
415+
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
416+
if isinstance(obj, cls):
417+
instances.append(obj) # type: ignore[arg-type]
418+
return 0
419+
420+
typ_ = type(obj)
421+
if typ_ in _BASIC_PICKLED_TYPES: # No subclasses!
422+
# If obj is a collection, recursively descend inside it
423+
return None
424+
if typ_ in _BASIC_REST_TYPES:
425+
rest.append(obj)
426+
return 1
427+
428+
try:
429+
# Note: a class that defines __slots__ without defining __getstate__
430+
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
431+
_ = obj.__reduce_ex__(5)
432+
except Exception: # pylint: disable=broad-exception-caught
433+
rest.append(obj)
434+
return 1
435+
436+
# Object can be pickled. Let the Pickler recursively descend inside it.
437+
return None
438+
439+
f = io.BytesIO()
440+
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL)
441+
p.dump(obj)
442+
return instances, (f.getvalue(), *rest)
443+
444+
445+
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
446+
"""
447+
Reverse of ``pickle_flatten``.
448+
449+
Parameters
450+
----------
451+
instances : Iterable
452+
Inner objects to be reinserted into the flattened container.
453+
rest : FlattenRest
454+
Extra bits, as returned by ``pickle_flatten``.
455+
456+
Returns
457+
-------
458+
object
459+
The outer object originally passed to ``pickle_flatten`` after a
460+
pickle->unpickle round-trip.
461+
462+
See Also
463+
--------
464+
pickle_flatten : Serializing function.
465+
pickle.loads : Standard unpickle function.
466+
467+
Notes
468+
-----
469+
The `instances` iterable must yield at least the same number of elements as the ones
470+
returned by ``pickle_without``, but the elements do not need to be the same objects
471+
or even the same types of objects. Excess elements, if any, will be left untouched.
472+
"""
473+
iters = iter(instances), iter(rest)
474+
pik = cast(bytes, next(iters[1]))
475+
476+
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08
477+
"""Mirror of the overridden Pickler in pickle_flatten."""
478+
479+
@override
480+
def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
481+
try:
482+
return next(iters[pid])
483+
except StopIteration as e:
484+
msg = "Not enough objects to unpickle"
485+
raise ValueError(msg) from e
486+
487+
f = io.BytesIO(pik)
488+
return Unpickler(f).load()
489+
490+
491+
class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
492+
"""
493+
Helper of :func:`jax_autojit`.
494+
495+
Wrap arbitrary inputs and outputs of the jitted function and
496+
convert them to/from PyTrees.
497+
"""
498+
499+
obj: T
500+
_registered: ClassVar[bool] = False
501+
__slots__: tuple[str, ...] = ("obj",)
502+
503+
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
504+
self._register()
505+
self.obj = obj
506+
507+
@classmethod
508+
def _register(cls): # numpydoc ignore=SS06
509+
"""
510+
Register upon first use instead of at import time, to avoid
511+
globally importing JAX.
512+
"""
513+
if not cls._registered:
514+
import jax
515+
516+
jax.tree_util.register_pytree_node(
517+
cls,
518+
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
519+
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
520+
)
521+
cls._registered = True
522+
523+
524+
def jax_autojit(
525+
func: Callable[P, T],
526+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
527+
"""
528+
Wrap `func` with ``jax.jit``, with the following differences:
529+
530+
- Python scalar arguments and return values are not automatically converted to
531+
``jax.Array`` objects.
532+
- All non-array arguments are automatically treated as static.
533+
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
534+
``pickle``.
535+
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
536+
tuple/list/dict, but can be any object serializable with ``pickle``.
537+
- Automatically descend into non-array arguments and find ``jax.Array`` objects
538+
inside them, then rebuild the arguments when entering `func`, swapping the JAX
539+
concrete arrays with tracer objects.
540+
- Automatically descend into non-array return values and find ``jax.Array`` objects
541+
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
542+
tracer objects with concrete arrays.
543+
544+
See Also
545+
--------
546+
jax.jit : JAX JIT compilation function.
547+
"""
548+
import jax
549+
550+
@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
551+
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
552+
wargs: _AutoJITWrapper[Any],
553+
) -> _AutoJITWrapper[T]:
554+
args, kwargs = wargs.obj
555+
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
556+
return _AutoJITWrapper(res)
557+
558+
@wraps(func)
559+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
560+
wargs = _AutoJITWrapper((args, kwargs))
561+
return inner(wargs).obj
562+
563+
return outer

0 commit comments

Comments
 (0)