Skip to content

Commit fd935e0

Browse files
committed
add experimental support for mutable array
1 parent 7600ad1 commit fd935e0

File tree

7 files changed

+1476
-0
lines changed

7 files changed

+1476
-0
lines changed

flax/experimental/nx/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .statelib import init as init
16+
from .statelib import merge as merge
17+
from .filterlib import PathContains as PathContains
18+
from .statelib import split as split
19+
from .variablelib import Variable as Variable
20+
from .variablelib import Param as Param
21+
from .variablelib import BatchStat as BatchStat
22+
from .variablelib import Cache as Cache
23+
from .statelib import TreeDef as TreeDef
24+
from .objectlib import Object as Object

flax/experimental/nx/filterlib.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import builtins
18+
import dataclasses
19+
import typing as tp
20+
21+
from flax.typing import Key, PathParts
22+
23+
24+
if tp.TYPE_CHECKING:
25+
ellipsis = builtins.ellipsis
26+
else:
27+
ellipsis = tp.Any
28+
29+
Predicate = tp.Callable[[PathParts, tp.Any], bool]
30+
31+
FilterLiteral = tp.Union[type, str, Predicate, bool, ellipsis, None]
32+
Filter = tp.Union[FilterLiteral, tuple['Filter', ...], list['Filter']]
33+
34+
35+
36+
37+
38+
def to_predicate(filter_: Filter) -> Predicate:
39+
"""Converts a Filter to a predicate function."""
40+
41+
if isinstance(filter_, str):
42+
return WithTag(filter_)
43+
elif isinstance(filter_, type):
44+
return OfType(filter_)
45+
elif isinstance(filter_, bool):
46+
if filter_:
47+
return Everything()
48+
else:
49+
return Nothing()
50+
elif filter_ is Ellipsis:
51+
return Everything()
52+
elif filter_ is None:
53+
return Nothing()
54+
elif callable(filter_):
55+
return filter_
56+
elif isinstance(filter_, (list, tuple)):
57+
return Any(*filter_)
58+
else:
59+
raise TypeError(f'Invalid collection filter: {filter_:!r}. ')
60+
61+
62+
def filters_to_predicates(
63+
filters: tp.Sequence[Filter],
64+
) -> tuple[Predicate, ...]:
65+
for i, filter_ in enumerate(filters):
66+
if filter_ in (..., True) and i != len(filters) - 1:
67+
remaining_filters = filters[i + 1 :]
68+
if not all(f in (..., True) for f in remaining_filters):
69+
raise ValueError(
70+
'`...` or `True` can only be used as the last filters, '
71+
f'got {filter_} it at index {i}.'
72+
)
73+
return tuple(map(to_predicate, filters))
74+
75+
76+
class HasTag(tp.Protocol):
77+
tag: str
78+
79+
80+
def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]:
81+
return hasattr(x, 'tag')
82+
83+
84+
@dataclasses.dataclass(frozen=True)
85+
class WithTag:
86+
tag: str
87+
88+
def __call__(self, path: PathParts, x: tp.Any):
89+
return _has_tag(x) and x.tag == self.tag
90+
91+
def __repr__(self):
92+
return f'WithTag({self.tag!r})'
93+
94+
95+
@dataclasses.dataclass(frozen=True)
96+
class PathContains:
97+
key: Key
98+
99+
def __call__(self, path: PathParts, x: tp.Any):
100+
return self.key in path
101+
102+
def __repr__(self):
103+
return f'PathContains({self.key!r})'
104+
105+
106+
class PathIn:
107+
108+
def __init__(self, *paths: PathParts):
109+
self.paths = frozenset(paths)
110+
111+
def __call__(self, path: PathParts, x: tp.Any):
112+
return path in self.paths
113+
114+
def __repr__(self):
115+
paths_repr = ','.join(map(repr, self.paths))
116+
return f'PathIn({paths_repr})'
117+
118+
def __eq__(self, other):
119+
return isinstance(other, PathIn) and self.paths == other.paths
120+
121+
def __hash__(self):
122+
return hash(self.paths)
123+
124+
125+
@dataclasses.dataclass(frozen=True)
126+
class OfType:
127+
type: type
128+
129+
def __call__(self, path: PathParts, x: tp.Any):
130+
return isinstance(x, self.type) or (
131+
hasattr(x, 'type') and issubclass(x.type, self.type)
132+
)
133+
134+
def __repr__(self):
135+
return f'OfType({self.type!r})'
136+
137+
138+
class Any:
139+
140+
def __init__(self, *filters: Filter):
141+
self.predicates = tuple(
142+
to_predicate(collection_filter) for collection_filter in filters
143+
)
144+
145+
def __call__(self, path: PathParts, x: tp.Any) -> bool:
146+
return any(predicate(path, x) for predicate in self.predicates)
147+
148+
def __repr__(self):
149+
return f'Any({", ".join(map(repr, self.predicates))})'
150+
151+
def __eq__(self, other):
152+
return isinstance(other, Any) and self.predicates == other.predicates
153+
154+
def __hash__(self):
155+
return hash(self.predicates)
156+
157+
158+
class All:
159+
160+
def __init__(self, *filters: Filter):
161+
self.predicates = tuple(
162+
to_predicate(collection_filter) for collection_filter in filters
163+
)
164+
165+
def __call__(self, path: PathParts, x: tp.Any) -> bool:
166+
return all(predicate(path, x) for predicate in self.predicates)
167+
168+
def __repr__(self):
169+
return f'All({", ".join(map(repr, self.predicates))})'
170+
171+
def __eq__(self, other):
172+
return isinstance(other, All) and self.predicates == other.predicates
173+
174+
def __hash__(self):
175+
return hash(self.predicates)
176+
177+
178+
class Not:
179+
180+
def __init__(self, collection_filter: Filter, /):
181+
self.predicate = to_predicate(collection_filter)
182+
183+
def __call__(self, path: PathParts, x: tp.Any) -> bool:
184+
return not self.predicate(path, x)
185+
186+
def __repr__(self):
187+
return f'Not({self.predicate!r})'
188+
189+
def __eq__(self, other):
190+
return isinstance(other, Not) and self.predicate == other.predicate
191+
192+
def __hash__(self):
193+
return hash(self.predicate)
194+
195+
196+
class Everything:
197+
198+
def __call__(self, path: PathParts, x: tp.Any) -> bool:
199+
return True
200+
201+
def __repr__(self):
202+
return 'Everything()'
203+
204+
def __eq__(self, other):
205+
return isinstance(other, Everything)
206+
207+
def __hash__(self):
208+
return hash(Everything)
209+
210+
211+
class Nothing:
212+
213+
def __call__(self, path: PathParts, x: tp.Any) -> bool:
214+
return False
215+
216+
def __repr__(self):
217+
return 'Nothing()'
218+
219+
def __eq__(self, other):
220+
return isinstance(other, Nothing)
221+
222+
def __hash__(self):
223+
return hash(Nothing)

0 commit comments

Comments
 (0)