|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
| 5 | +import io |
5 | 6 | import math
|
6 |
| -from collections.abc import Generator, Iterable |
| 7 | +import pickle |
| 8 | +import types |
| 9 | +from collections.abc import Callable, Generator, Iterable |
| 10 | +from functools import wraps |
7 | 11 | from types import ModuleType
|
8 |
| -from typing import TYPE_CHECKING, cast |
| 12 | +from typing import ( |
| 13 | + TYPE_CHECKING, |
| 14 | + Any, |
| 15 | + ClassVar, |
| 16 | + Generic, |
| 17 | + Literal, |
| 18 | + ParamSpec, |
| 19 | + TypeAlias, |
| 20 | + TypeVar, |
| 21 | + cast, |
| 22 | +) |
9 | 23 |
|
10 | 24 | from . import _compat
|
11 | 25 | from ._compat import (
|
|
19 | 33 | from ._typing import Array
|
20 | 34 |
|
21 | 35 | if TYPE_CHECKING: # pragma: no cover
|
22 |
| - # TODO import from typing (requires Python >=3.13) |
23 |
| - from typing_extensions import TypeIs |
| 36 | + # TODO import from typing (requires Python >=3.12 and >=3.13) |
| 37 | + from typing_extensions import TypeIs, override |
| 38 | +else: |
| 39 | + |
| 40 | + def override(func): |
| 41 | + return func |
| 42 | + |
| 43 | + |
| 44 | +P = ParamSpec("P") |
| 45 | +T = TypeVar("T") |
24 | 46 |
|
25 | 47 |
|
26 | 48 | __all__ = [
|
|
29 | 51 | "eager_shape",
|
30 | 52 | "in1d",
|
31 | 53 | "is_python_scalar",
|
| 54 | + "jax_autojit", |
32 | 55 | "mean",
|
33 | 56 | "meta_namespace",
|
| 57 | + "pickle_flatten", |
| 58 | + "pickle_unflatten", |
34 | 59 | ]
|
35 | 60 |
|
36 | 61 |
|
@@ -306,3 +331,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
|
306 | 331 | out["boolean indexing"] = True
|
307 | 332 | out["data-dependent shapes"] = True
|
308 | 333 | return out
|
| 334 | + |
| 335 | + |
| 336 | +_BASIC_PICKLED_TYPES = frozenset(( |
| 337 | + bool, int, float, complex, str, bytes, bytearray, |
| 338 | + list, tuple, dict, set, frozenset, range, slice, |
| 339 | + types.NoneType, types.EllipsisType, |
| 340 | +)) # fmt: skip |
| 341 | +_BASIC_REST_TYPES = frozenset(( |
| 342 | + type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType |
| 343 | +)) # fmt: skip |
| 344 | + |
| 345 | +FlattenRest: TypeAlias = tuple[object, ...] |
| 346 | + |
| 347 | + |
| 348 | +def pickle_flatten( |
| 349 | + obj: object, cls: type[T] | tuple[type[T], ...] |
| 350 | +) -> tuple[list[T], FlattenRest]: |
| 351 | + """ |
| 352 | + Use the pickle machinery to extract objects out of an arbitrary container. |
| 353 | +
|
| 354 | + Unlike regular ``pickle.dumps``, this function always succeeds. |
| 355 | +
|
| 356 | + Parameters |
| 357 | + ---------- |
| 358 | + obj : object |
| 359 | + The object to pickle. |
| 360 | + cls : type | tuple[type, ...] |
| 361 | + One or multiple classes to extract from the object. |
| 362 | + The instances of these classes inside ``obj`` will not be pickled. |
| 363 | +
|
| 364 | + Returns |
| 365 | + ------- |
| 366 | + instances : list[cls] |
| 367 | + All instances of ``cls`` found inside ``obj`` (not pickled). |
| 368 | + rest |
| 369 | + Opaque object containing the pickled bytes plus all other objects where |
| 370 | + ``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised. |
| 371 | + These are unpickleable objects, types, modules, and functions. |
| 372 | +
|
| 373 | + This object is *typically* hashable save for fairly exotic objects |
| 374 | + that are neither pickleable nor hashable. |
| 375 | +
|
| 376 | + This object is pickleable if everything except ``instances`` was pickleable |
| 377 | + in the input object. |
| 378 | +
|
| 379 | + See Also |
| 380 | + -------- |
| 381 | + pickle_unflatten : Reverse function. |
| 382 | +
|
| 383 | + Examples |
| 384 | + -------- |
| 385 | + >>> class A: |
| 386 | + ... def __repr__(self): |
| 387 | + ... return "<A>" |
| 388 | + >>> class NS: |
| 389 | + ... def __repr__(self): |
| 390 | + ... return "<NS>" |
| 391 | + ... def __reduce__(self): |
| 392 | + ... assert False, "not serializable" |
| 393 | + >>> obj = {1: A(), 2: [A(), NS(), A()]} |
| 394 | + >>> instances, rest = pickle_flatten(obj, A) |
| 395 | + >>> instances |
| 396 | + [<A>, <A>, <A>] |
| 397 | + >>> pickle_unflatten(instances, rest) |
| 398 | + {1: <A>, 2: [<A>, <NS>, <A>]} |
| 399 | +
|
| 400 | + This can be also used to swap inner objects; the only constraint is that |
| 401 | + the number of objects in and out must be the same: |
| 402 | +
|
| 403 | + >>> pickle_unflatten(["foo", "bar", "baz"], rest) |
| 404 | + {1: "foo", 2: ["bar", <NS>, "baz"]} |
| 405 | + """ |
| 406 | + instances: list[T] = [] |
| 407 | + rest: list[object] = [] |
| 408 | + |
| 409 | + class Pickler(pickle.Pickler): # numpydoc ignore=GL08 |
| 410 | + """ |
| 411 | + Use the `pickle.Pickler.persistent_id` hook to extract objects. |
| 412 | + """ |
| 413 | + |
| 414 | + @override |
| 415 | + def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08 |
| 416 | + if isinstance(obj, cls): |
| 417 | + instances.append(obj) # type: ignore[arg-type] |
| 418 | + return 0 |
| 419 | + |
| 420 | + typ_ = type(obj) |
| 421 | + if typ_ in _BASIC_PICKLED_TYPES: # No subclasses! |
| 422 | + # If obj is a collection, recursively descend inside it |
| 423 | + return None |
| 424 | + if typ_ in _BASIC_REST_TYPES: |
| 425 | + rest.append(obj) |
| 426 | + return 1 |
| 427 | + |
| 428 | + try: |
| 429 | + # Note: a class that defines __slots__ without defining __getstate__ |
| 430 | + # cannot be pickled with __reduce__(), but can with __reduce_ex__(5) |
| 431 | + _ = obj.__reduce_ex__(5) |
| 432 | + except Exception: # pylint: disable=broad-exception-caught |
| 433 | + rest.append(obj) |
| 434 | + return 1 |
| 435 | + |
| 436 | + # Object can be pickled. Let the Pickler recursively descend inside it. |
| 437 | + return None |
| 438 | + |
| 439 | + f = io.BytesIO() |
| 440 | + p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL) |
| 441 | + p.dump(obj) |
| 442 | + return instances, (f.getvalue(), *rest) |
| 443 | + |
| 444 | + |
| 445 | +def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any] |
| 446 | + """ |
| 447 | + Reverse of ``pickle_flatten``. |
| 448 | +
|
| 449 | + Parameters |
| 450 | + ---------- |
| 451 | + instances : Iterable |
| 452 | + Inner objects to be reinserted into the flattened container. |
| 453 | + rest : FlattenRest |
| 454 | + Extra bits, as returned by ``pickle_flatten``. |
| 455 | +
|
| 456 | + Returns |
| 457 | + ------- |
| 458 | + object |
| 459 | + The outer object originally passed to ``pickle_flatten`` after a |
| 460 | + pickle->unpickle round-trip. |
| 461 | +
|
| 462 | + See Also |
| 463 | + -------- |
| 464 | + pickle_flatten : Serializing function. |
| 465 | + pickle.loads : Standard unpickle function. |
| 466 | +
|
| 467 | + Notes |
| 468 | + ----- |
| 469 | + The `instances` iterable must yield at least the same number of elements as the ones |
| 470 | + returned by ``pickle_without``, but the elements do not need to be the same objects |
| 471 | + or even the same types of objects. Excess elements, if any, will be left untouched. |
| 472 | + """ |
| 473 | + iters = iter(instances), iter(rest) |
| 474 | + pik = cast(bytes, next(iters[1])) |
| 475 | + |
| 476 | + class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08 |
| 477 | + """Mirror of the overridden Pickler in pickle_flatten.""" |
| 478 | + |
| 479 | + @override |
| 480 | + def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08 |
| 481 | + try: |
| 482 | + return next(iters[pid]) |
| 483 | + except StopIteration as e: |
| 484 | + msg = "Not enough objects to unpickle" |
| 485 | + raise ValueError(msg) from e |
| 486 | + |
| 487 | + f = io.BytesIO(pik) |
| 488 | + return Unpickler(f).load() |
| 489 | + |
| 490 | + |
| 491 | +class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01 |
| 492 | + """ |
| 493 | + Helper of :func:`jax_autojit`. |
| 494 | +
|
| 495 | + Wrap arbitrary inputs and outputs of the jitted function and |
| 496 | + convert them to/from PyTrees. |
| 497 | + """ |
| 498 | + |
| 499 | + obj: T |
| 500 | + _registered: ClassVar[bool] = False |
| 501 | + __slots__: tuple[str, ...] = ("obj",) |
| 502 | + |
| 503 | + def __init__(self, obj: T) -> None: # numpydoc ignore=GL08 |
| 504 | + self._register() |
| 505 | + self.obj = obj |
| 506 | + |
| 507 | + @classmethod |
| 508 | + def _register(cls): # numpydoc ignore=SS06 |
| 509 | + """ |
| 510 | + Register upon first use instead of at import time, to avoid |
| 511 | + globally importing JAX. |
| 512 | + """ |
| 513 | + if not cls._registered: |
| 514 | + import jax |
| 515 | + |
| 516 | + jax.tree_util.register_pytree_node( |
| 517 | + cls, |
| 518 | + lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType] |
| 519 | + lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType] |
| 520 | + ) |
| 521 | + cls._registered = True |
| 522 | + |
| 523 | + |
| 524 | +def jax_autojit( |
| 525 | + func: Callable[P, T], |
| 526 | +) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03 |
| 527 | + """ |
| 528 | + Wrap `func` with ``jax.jit``, with the following differences: |
| 529 | +
|
| 530 | + - Python scalar arguments and return values are not automatically converted to |
| 531 | + ``jax.Array`` objects. |
| 532 | + - All non-array arguments are automatically treated as static. |
| 533 | + Unlike ``jax.jit``, static arguments must be either hashable or serializable with |
| 534 | + ``pickle``. |
| 535 | + - Unlike ``jax.jit``, non-array arguments and return values are not limited to |
| 536 | + tuple/list/dict, but can be any object serializable with ``pickle``. |
| 537 | + - Automatically descend into non-array arguments and find ``jax.Array`` objects |
| 538 | + inside them, then rebuild the arguments when entering `func`, swapping the JAX |
| 539 | + concrete arrays with tracer objects. |
| 540 | + - Automatically descend into non-array return values and find ``jax.Array`` objects |
| 541 | + inside them, then rebuild them downstream of exiting the JIT, swapping the JAX |
| 542 | + tracer objects with concrete arrays. |
| 543 | +
|
| 544 | + See Also |
| 545 | + -------- |
| 546 | + jax.jit : JAX JIT compilation function. |
| 547 | + """ |
| 548 | + import jax |
| 549 | + |
| 550 | + @jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator] |
| 551 | + def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08 |
| 552 | + wargs: _AutoJITWrapper[Any], |
| 553 | + ) -> _AutoJITWrapper[T]: |
| 554 | + args, kwargs = wargs.obj |
| 555 | + res = func(*args, **kwargs) # pyright: ignore[reportCallIssue] |
| 556 | + return _AutoJITWrapper(res) |
| 557 | + |
| 558 | + @wraps(func) |
| 559 | + def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 |
| 560 | + wargs = _AutoJITWrapper((args, kwargs)) |
| 561 | + return inner(wargs).obj |
| 562 | + |
| 563 | + return outer |
0 commit comments