-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Wrap the FeedForward layers inside
Einsum
PiperOrigin-RevId: 707468281
- Loading branch information
1 parent
c4cc65b
commit bb23e4f
Showing
5 changed files
with
384 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2024 The kauldron Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Freeze utils.""" | ||
|
||
from collections.abc import Callable | ||
import functools | ||
from typing import Any | ||
|
||
import jax | ||
import optax | ||
|
||
_PyTree = Any | ||
|
||
|
||
def partial_updates( | ||
optimizer: optax.GradientTransformation, | ||
mask: _PyTree | Callable[[_PyTree], _PyTree], | ||
) -> optax.GradientTransformation: | ||
"""Applies the optimizer to a subset of the parameters. | ||
Args: | ||
optimizer: The optimizer to use. | ||
mask: A tree or callable returning a tree of bools to apply the optimizer | ||
to. | ||
Returns: | ||
The wrapped optimizer. | ||
""" | ||
|
||
return optax.multi_transform( | ||
{ | ||
'train': optimizer, | ||
'freeze': optax.set_to_zero(), | ||
}, | ||
functools.partial(_make_labels, mask=mask), | ||
) | ||
|
||
|
||
def _make_labels(tree, mask): | ||
if callable(mask): | ||
mask = mask(tree) | ||
return jax.tree.map(lambda x: 'train' if x else 'freeze', mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2024 The kauldron Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import jax.numpy as jnp | ||
from kauldron import kd | ||
import optax | ||
|
||
|
||
def test_partial_updates(): | ||
optimizer = kd.optim.partial_updates( | ||
optax.adam(learning_rate=1e-3), | ||
mask=kd.optim.select('lora'), | ||
) | ||
|
||
params = { | ||
'a': { | ||
'lora': { | ||
'x': jnp.zeros((2,)), | ||
'y': jnp.zeros((2,)), | ||
} | ||
}, | ||
'x': jnp.zeros((2,)), | ||
'y': jnp.zeros((2,)), | ||
} | ||
|
||
assert kd.optim._freeze._make_labels(params, kd.optim.select('lora')) == { | ||
'a': { | ||
'lora': { | ||
'x': 'train', | ||
'y': 'train', | ||
} | ||
}, | ||
'x': 'freeze', | ||
'y': 'freeze', | ||
} | ||
|
||
# TODO(epot): Could check the state params is empty for frozen params. | ||
optimizer.init({ | ||
'a': { | ||
'lora': { | ||
'x': jnp.zeros((2,)), | ||
'y': jnp.zeros((2,)), | ||
} | ||
}, | ||
'x': jnp.zeros((2,)), | ||
'y': jnp.zeros((2,)), | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright 2024 The kauldron Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Masks utils.""" | ||
|
||
from collections.abc import Callable, Sequence | ||
import re | ||
from typing import Any | ||
|
||
import jax | ||
|
||
_PyTree = Any | ||
|
||
|
||
# Improvements: | ||
# * Could add `exclude=` kwargs, similar to `glob()`. | ||
|
||
|
||
def select(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]: | ||
r"""Create a mask which selects only the sub-pytree matching the pattern. | ||
* `xx` will match all `{'xx': ...}` dict anywhere inside the tree. Note that | ||
the match is strict, so `xx` will NOT match `{'xxyy': }` | ||
* `xx.yy` will match `{'xx': {'yy': ...}}` dict | ||
* Regex are supported, when using regex, make sure to escape `.` (e.g. | ||
`xx\.yy[0-9]+`) | ||
Example: | ||
```python | ||
mask_fn = kg.optim.select("lora") | ||
mask_fn({ | ||
'layer0': { | ||
'lora': { | ||
'a': jnp.zeros(), | ||
'b': jnp.zeros(), | ||
}, | ||
'weights': jnp.zeros(), | ||
'bias': jnp.zeros(), | ||
} | ||
}) == { | ||
'layer0': { | ||
'lora': { | ||
'a': True, | ||
'b': True, | ||
}, | ||
'weights': False, | ||
'bias': False, | ||
} | ||
} | ||
``` | ||
Args: | ||
pattern: The pattern to include. Everything else will be `False`. | ||
Returns: | ||
The optax mask factory. | ||
""" | ||
|
||
# Convert the pattern to a regex. | ||
if isinstance(pattern, str): | ||
pattern = [pattern] | ||
|
||
pattern_regexes = [_make_regex(p) for p in pattern] | ||
|
||
def _path_match_pattern(path: jax.tree_util.KeyPath) -> bool: | ||
path_str = ".".join(_jax_key_entry_to_str(p) for p in path) | ||
return any(bool(p.search(path_str)) for p in pattern_regexes) | ||
|
||
def _make_mask(tree: _PyTree) -> _PyTree: | ||
# TODO(epot): Replace by `jax.tree.flatten_with_path` once Colab is updated | ||
leaves_with_path, treedef = jax.tree_util.tree_flatten_with_path(tree) | ||
|
||
# Parse each leaves | ||
leaves = [] | ||
for path, _ in leaves_with_path: | ||
leaves.append(_path_match_pattern(path)) | ||
|
||
# Restore the tree structure. | ||
return jax.tree.unflatten(treedef, leaves) | ||
|
||
return _make_mask | ||
|
||
|
||
def exclude(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]: | ||
"""Create a mask which selects all nodes except the ones matching the pattern. | ||
This is the inverse of `select()`. | ||
Example: | ||
```python | ||
optax.masked( | ||
optax.set_to_zero(), | ||
kd.optim.exclude("lora"), # Only `lora` weights are trained. | ||
) | ||
``` | ||
Args: | ||
pattern: The pattern to exclude. See `select()` for more details. | ||
Returns: | ||
The optax mask factory. | ||
""" | ||
make_select_mask = select(pattern) | ||
|
||
def _make_mask(tree: _PyTree) -> _PyTree: | ||
# Invert the select mask. | ||
tree = make_select_mask(tree) | ||
return jax.tree.map(lambda x: not x, tree) | ||
|
||
return _make_mask | ||
|
||
|
||
_REGEX_SPECIAL_CHARS = set("()[]?+*^$|\\") | ||
|
||
|
||
def _make_regex(pattern: str) -> re.Pattern[str]: | ||
# Auto-detect regex and forward them as-is. | ||
if any(c in _REGEX_SPECIAL_CHARS for c in pattern): | ||
pass | ||
else: # Otherwise, escape special characters (`.`). | ||
pattern = re.escape(pattern) | ||
|
||
pattern = rf"(?:^|\.){pattern}(?:$|\.)" | ||
return re.compile(pattern) | ||
|
||
|
||
def _jax_key_entry_to_str( | ||
jax_key_entry: jax.tree_util.KeyEntry, | ||
) -> str: | ||
"""Convert a JaxKeyEntry into a valid `kontext.Path` element.""" | ||
match jax_key_entry: | ||
case jax.tree_util.DictKey(key): | ||
return key | ||
case _: | ||
raise TypeError(f"Unknown key entry type {type(jax_key_entry)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright 2024 The kauldron Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from kauldron import kd | ||
from kauldron.optim import _masks | ||
|
||
|
||
def test_select(): | ||
# Check the regex is restricted to the exact path. | ||
assert kd.optim.select("lora")({ | ||
"lora": 0, | ||
"notlora": 0, | ||
"lora.more": 0, | ||
"loranot.more": 0, | ||
"notlora.more": 0, | ||
"more.lora": 0, | ||
"more.notlora": 0, | ||
"more.lora.more": 0, | ||
"more.notlora.more": 0, | ||
}) == { | ||
"lora": True, | ||
"notlora": False, | ||
"lora.more": True, | ||
"loranot.more": False, | ||
"notlora.more": False, | ||
"more.lora": True, | ||
"more.notlora": False, | ||
"more.lora.more": True, | ||
"more.notlora.more": False, | ||
} | ||
|
||
# Exclude returns the opossite mask. | ||
assert kd.optim.exclude("lora")({ | ||
"lora": 0, | ||
"notlora": 0, | ||
"lora.more": 0, | ||
"loranot.more": 0, | ||
"notlora.more": 0, | ||
"more.lora": 0, | ||
"more.notlora": 0, | ||
"more.lora.more": 0, | ||
"more.notlora.more": 0, | ||
}) == { | ||
"lora": False, | ||
"notlora": True, | ||
"lora.more": False, | ||
"loranot.more": True, | ||
"notlora.more": True, | ||
"more.lora": False, | ||
"more.notlora": True, | ||
"more.lora.more": False, | ||
"more.notlora.more": True, | ||
} | ||
|
||
# Test that a `.` in the path is properly escaped. | ||
assert kd.optim.select("lora.more")({ | ||
"lora": 0, | ||
"loraxmore": 0, | ||
"lora.more": 0, | ||
"more.loraxmore.more": 0, | ||
"more.lora.more.more": 0, | ||
}) == { | ||
"lora": False, | ||
"loraxmore": False, | ||
"lora.more": True, | ||
"more.loraxmore.more": False, | ||
"more.lora.more.more": True, | ||
} | ||
|
||
# Test that the select works on nested tree | ||
assert kd.optim.select("lora.more")({ | ||
"lora": { | ||
"more": { | ||
"x": 0, | ||
"y": 0, | ||
}, | ||
"notmore": 0, | ||
}, | ||
"y": {"lora": {"more": 0}}, | ||
"z": 0, | ||
}) == { | ||
"lora": { | ||
"more": { | ||
"x": True, | ||
"y": True, | ||
}, | ||
"notmore": False, | ||
}, | ||
"y": {"lora": {"more": True}}, | ||
"z": False, | ||
} | ||
|
||
# Tests that regex are properly escaped | ||
assert kd.optim.select("lora[0-9]+")({ | ||
"lora00": 0, | ||
"lora1": 0, | ||
"lora1x": 0, | ||
"lora1": 0, | ||
"xx.lora": 0, | ||
"xx.lora3.aa": 0, | ||
}) == { | ||
"lora00": True, | ||
"lora1": True, | ||
"lora1x": False, | ||
"lora1": True, | ||
"xx.lora": False, | ||
"xx.lora3.aa": True, | ||
} |