|
| 1 | +# Copyright 2024 The Flax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import builtins |
| 18 | +import dataclasses |
| 19 | +import typing as tp |
| 20 | + |
| 21 | +from flax.typing import Key, PathParts |
| 22 | + |
| 23 | + |
| 24 | +if tp.TYPE_CHECKING: |
| 25 | + ellipsis = builtins.ellipsis |
| 26 | +else: |
| 27 | + ellipsis = tp.Any |
| 28 | + |
| 29 | +Predicate = tp.Callable[[PathParts, tp.Any], bool] |
| 30 | + |
| 31 | +FilterLiteral = tp.Union[type, str, Predicate, bool, ellipsis, None] |
| 32 | +Filter = tp.Union[FilterLiteral, tuple['Filter', ...], list['Filter']] |
| 33 | + |
| 34 | + |
| 35 | + |
| 36 | + |
| 37 | + |
| 38 | +def to_predicate(filter_: Filter) -> Predicate: |
| 39 | + """Converts a Filter to a predicate function.""" |
| 40 | + |
| 41 | + if isinstance(filter_, str): |
| 42 | + return WithTag(filter_) |
| 43 | + elif isinstance(filter_, type): |
| 44 | + return OfType(filter_) |
| 45 | + elif isinstance(filter_, bool): |
| 46 | + if filter_: |
| 47 | + return Everything() |
| 48 | + else: |
| 49 | + return Nothing() |
| 50 | + elif filter_ is Ellipsis: |
| 51 | + return Everything() |
| 52 | + elif filter_ is None: |
| 53 | + return Nothing() |
| 54 | + elif callable(filter_): |
| 55 | + return filter_ |
| 56 | + elif isinstance(filter_, (list, tuple)): |
| 57 | + return Any(*filter_) |
| 58 | + else: |
| 59 | + raise TypeError(f'Invalid collection filter: {filter_:!r}. ') |
| 60 | + |
| 61 | + |
| 62 | +def filters_to_predicates( |
| 63 | + filters: tp.Sequence[Filter], |
| 64 | +) -> tuple[Predicate, ...]: |
| 65 | + for i, filter_ in enumerate(filters): |
| 66 | + if filter_ in (..., True) and i != len(filters) - 1: |
| 67 | + remaining_filters = filters[i + 1 :] |
| 68 | + if not all(f in (..., True) for f in remaining_filters): |
| 69 | + raise ValueError( |
| 70 | + '`...` or `True` can only be used as the last filters, ' |
| 71 | + f'got {filter_} it at index {i}.' |
| 72 | + ) |
| 73 | + return tuple(map(to_predicate, filters)) |
| 74 | + |
| 75 | + |
| 76 | +class HasTag(tp.Protocol): |
| 77 | + tag: str |
| 78 | + |
| 79 | + |
| 80 | +def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]: |
| 81 | + return hasattr(x, 'tag') |
| 82 | + |
| 83 | + |
| 84 | +@dataclasses.dataclass(frozen=True) |
| 85 | +class WithTag: |
| 86 | + tag: str |
| 87 | + |
| 88 | + def __call__(self, path: PathParts, x: tp.Any): |
| 89 | + return _has_tag(x) and x.tag == self.tag |
| 90 | + |
| 91 | + def __repr__(self): |
| 92 | + return f'WithTag({self.tag!r})' |
| 93 | + |
| 94 | + |
| 95 | +@dataclasses.dataclass(frozen=True) |
| 96 | +class PathContains: |
| 97 | + key: Key |
| 98 | + |
| 99 | + def __call__(self, path: PathParts, x: tp.Any): |
| 100 | + return self.key in path |
| 101 | + |
| 102 | + def __repr__(self): |
| 103 | + return f'PathContains({self.key!r})' |
| 104 | + |
| 105 | + |
| 106 | +class PathIn: |
| 107 | + |
| 108 | + def __init__(self, *paths: PathParts): |
| 109 | + self.paths = frozenset(paths) |
| 110 | + |
| 111 | + def __call__(self, path: PathParts, x: tp.Any): |
| 112 | + return path in self.paths |
| 113 | + |
| 114 | + def __repr__(self): |
| 115 | + paths_repr = ','.join(map(repr, self.paths)) |
| 116 | + return f'PathIn({paths_repr})' |
| 117 | + |
| 118 | + def __eq__(self, other): |
| 119 | + return isinstance(other, PathIn) and self.paths == other.paths |
| 120 | + |
| 121 | + def __hash__(self): |
| 122 | + return hash(self.paths) |
| 123 | + |
| 124 | + |
| 125 | +@dataclasses.dataclass(frozen=True) |
| 126 | +class OfType: |
| 127 | + type: type |
| 128 | + |
| 129 | + def __call__(self, path: PathParts, x: tp.Any): |
| 130 | + return isinstance(x, self.type) or ( |
| 131 | + hasattr(x, 'type') and issubclass(x.type, self.type) |
| 132 | + ) |
| 133 | + |
| 134 | + def __repr__(self): |
| 135 | + return f'OfType({self.type!r})' |
| 136 | + |
| 137 | + |
| 138 | +class Any: |
| 139 | + |
| 140 | + def __init__(self, *filters: Filter): |
| 141 | + self.predicates = tuple( |
| 142 | + to_predicate(collection_filter) for collection_filter in filters |
| 143 | + ) |
| 144 | + |
| 145 | + def __call__(self, path: PathParts, x: tp.Any) -> bool: |
| 146 | + return any(predicate(path, x) for predicate in self.predicates) |
| 147 | + |
| 148 | + def __repr__(self): |
| 149 | + return f'Any({", ".join(map(repr, self.predicates))})' |
| 150 | + |
| 151 | + def __eq__(self, other): |
| 152 | + return isinstance(other, Any) and self.predicates == other.predicates |
| 153 | + |
| 154 | + def __hash__(self): |
| 155 | + return hash(self.predicates) |
| 156 | + |
| 157 | + |
| 158 | +class All: |
| 159 | + |
| 160 | + def __init__(self, *filters: Filter): |
| 161 | + self.predicates = tuple( |
| 162 | + to_predicate(collection_filter) for collection_filter in filters |
| 163 | + ) |
| 164 | + |
| 165 | + def __call__(self, path: PathParts, x: tp.Any) -> bool: |
| 166 | + return all(predicate(path, x) for predicate in self.predicates) |
| 167 | + |
| 168 | + def __repr__(self): |
| 169 | + return f'All({", ".join(map(repr, self.predicates))})' |
| 170 | + |
| 171 | + def __eq__(self, other): |
| 172 | + return isinstance(other, All) and self.predicates == other.predicates |
| 173 | + |
| 174 | + def __hash__(self): |
| 175 | + return hash(self.predicates) |
| 176 | + |
| 177 | + |
| 178 | +class Not: |
| 179 | + |
| 180 | + def __init__(self, collection_filter: Filter, /): |
| 181 | + self.predicate = to_predicate(collection_filter) |
| 182 | + |
| 183 | + def __call__(self, path: PathParts, x: tp.Any) -> bool: |
| 184 | + return not self.predicate(path, x) |
| 185 | + |
| 186 | + def __repr__(self): |
| 187 | + return f'Not({self.predicate!r})' |
| 188 | + |
| 189 | + def __eq__(self, other): |
| 190 | + return isinstance(other, Not) and self.predicate == other.predicate |
| 191 | + |
| 192 | + def __hash__(self): |
| 193 | + return hash(self.predicate) |
| 194 | + |
| 195 | + |
| 196 | +class Everything: |
| 197 | + |
| 198 | + def __call__(self, path: PathParts, x: tp.Any) -> bool: |
| 199 | + return True |
| 200 | + |
| 201 | + def __repr__(self): |
| 202 | + return 'Everything()' |
| 203 | + |
| 204 | + def __eq__(self, other): |
| 205 | + return isinstance(other, Everything) |
| 206 | + |
| 207 | + def __hash__(self): |
| 208 | + return hash(Everything) |
| 209 | + |
| 210 | + |
| 211 | +class Nothing: |
| 212 | + |
| 213 | + def __call__(self, path: PathParts, x: tp.Any) -> bool: |
| 214 | + return False |
| 215 | + |
| 216 | + def __repr__(self): |
| 217 | + return 'Nothing()' |
| 218 | + |
| 219 | + def __eq__(self, other): |
| 220 | + return isinstance(other, Nothing) |
| 221 | + |
| 222 | + def __hash__(self): |
| 223 | + return hash(Nothing) |
0 commit comments