Skip to content

Commit

Permalink
Wrap the FeedForward layers inside Einsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707468281
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Dec 18, 2024
1 parent c4cc65b commit bb23e4f
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 1 deletion.
5 changes: 4 additions & 1 deletion kauldron/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"""Optimizers etc."""

# pylint: disable=g-importing-member

from kauldron.optim._freeze import partial_updates
from kauldron.optim._masks import exclude
from kauldron.optim._masks import select
from kauldron.optim.combine import named_chain
from kauldron.optim.transform import decay_to_init
# pylint: enable=g-importing-memberfrom
54 changes: 54 additions & 0 deletions kauldron/optim/_freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Freeze utils."""

from collections.abc import Callable
import functools
from typing import Any

import jax
import optax

_PyTree = Any


def partial_updates(
optimizer: optax.GradientTransformation,
mask: _PyTree | Callable[[_PyTree], _PyTree],
) -> optax.GradientTransformation:
"""Applies the optimizer to a subset of the parameters.
Args:
optimizer: The optimizer to use.
mask: A tree or callable returning a tree of bools to apply the optimizer
to.
Returns:
The wrapped optimizer.
"""

return optax.multi_transform(
{
'train': optimizer,
'freeze': optax.set_to_zero(),
},
functools.partial(_make_labels, mask=mask),
)


def _make_labels(tree, mask):
if callable(mask):
mask = mask(tree)
return jax.tree.map(lambda x: 'train' if x else 'freeze', mask)
58 changes: 58 additions & 0 deletions kauldron/optim/_freeze_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
from kauldron import kd
import optax


def test_partial_updates():
optimizer = kd.optim.partial_updates(
optax.adam(learning_rate=1e-3),
mask=kd.optim.select('lora'),
)

params = {
'a': {
'lora': {
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
}
},
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
}

assert kd.optim._freeze._make_labels(params, kd.optim.select('lora')) == {
'a': {
'lora': {
'x': 'train',
'y': 'train',
}
},
'x': 'freeze',
'y': 'freeze',
}

# TODO(epot): Could check the state params is empty for frozen params.
optimizer.init({
'a': {
'lora': {
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
}
},
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
})
149 changes: 149 additions & 0 deletions kauldron/optim/_masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Masks utils."""

from collections.abc import Callable, Sequence
import re
from typing import Any

import jax

_PyTree = Any


# Improvements:
# * Could add `exclude=` kwargs, similar to `glob()`.


def select(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]:
r"""Create a mask which selects only the sub-pytree matching the pattern.
* `xx` will match all `{'xx': ...}` dict anywhere inside the tree. Note that
the match is strict, so `xx` will NOT match `{'xxyy': }`
* `xx.yy` will match `{'xx': {'yy': ...}}` dict
* Regex are supported, when using regex, make sure to escape `.` (e.g.
`xx\.yy[0-9]+`)
Example:
```python
mask_fn = kg.optim.select("lora")
mask_fn({
'layer0': {
'lora': {
'a': jnp.zeros(),
'b': jnp.zeros(),
},
'weights': jnp.zeros(),
'bias': jnp.zeros(),
}
}) == {
'layer0': {
'lora': {
'a': True,
'b': True,
},
'weights': False,
'bias': False,
}
}
```
Args:
pattern: The pattern to include. Everything else will be `False`.
Returns:
The optax mask factory.
"""

# Convert the pattern to a regex.
if isinstance(pattern, str):
pattern = [pattern]

pattern_regexes = [_make_regex(p) for p in pattern]

def _path_match_pattern(path: jax.tree_util.KeyPath) -> bool:
path_str = ".".join(_jax_key_entry_to_str(p) for p in path)
return any(bool(p.search(path_str)) for p in pattern_regexes)

def _make_mask(tree: _PyTree) -> _PyTree:
# TODO(epot): Replace by `jax.tree.flatten_with_path` once Colab is updated
leaves_with_path, treedef = jax.tree_util.tree_flatten_with_path(tree)

# Parse each leaves
leaves = []
for path, _ in leaves_with_path:
leaves.append(_path_match_pattern(path))

# Restore the tree structure.
return jax.tree.unflatten(treedef, leaves)

return _make_mask


def exclude(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]:
"""Create a mask which selects all nodes except the ones matching the pattern.
This is the inverse of `select()`.
Example:
```python
optax.masked(
optax.set_to_zero(),
kd.optim.exclude("lora"), # Only `lora` weights are trained.
)
```
Args:
pattern: The pattern to exclude. See `select()` for more details.
Returns:
The optax mask factory.
"""
make_select_mask = select(pattern)

def _make_mask(tree: _PyTree) -> _PyTree:
# Invert the select mask.
tree = make_select_mask(tree)
return jax.tree.map(lambda x: not x, tree)

return _make_mask


_REGEX_SPECIAL_CHARS = set("()[]?+*^$|\\")


def _make_regex(pattern: str) -> re.Pattern[str]:
# Auto-detect regex and forward them as-is.
if any(c in _REGEX_SPECIAL_CHARS for c in pattern):
pass
else: # Otherwise, escape special characters (`.`).
pattern = re.escape(pattern)

pattern = rf"(?:^|\.){pattern}(?:$|\.)"
return re.compile(pattern)


def _jax_key_entry_to_str(
jax_key_entry: jax.tree_util.KeyEntry,
) -> str:
"""Convert a JaxKeyEntry into a valid `kontext.Path` element."""
match jax_key_entry:
case jax.tree_util.DictKey(key):
return key
case _:
raise TypeError(f"Unknown key entry type {type(jax_key_entry)}")
119 changes: 119 additions & 0 deletions kauldron/optim/_masks_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kauldron import kd
from kauldron.optim import _masks


def test_select():
# Check the regex is restricted to the exact path.
assert kd.optim.select("lora")({
"lora": 0,
"notlora": 0,
"lora.more": 0,
"loranot.more": 0,
"notlora.more": 0,
"more.lora": 0,
"more.notlora": 0,
"more.lora.more": 0,
"more.notlora.more": 0,
}) == {
"lora": True,
"notlora": False,
"lora.more": True,
"loranot.more": False,
"notlora.more": False,
"more.lora": True,
"more.notlora": False,
"more.lora.more": True,
"more.notlora.more": False,
}

# Exclude returns the opossite mask.
assert kd.optim.exclude("lora")({
"lora": 0,
"notlora": 0,
"lora.more": 0,
"loranot.more": 0,
"notlora.more": 0,
"more.lora": 0,
"more.notlora": 0,
"more.lora.more": 0,
"more.notlora.more": 0,
}) == {
"lora": False,
"notlora": True,
"lora.more": False,
"loranot.more": True,
"notlora.more": True,
"more.lora": False,
"more.notlora": True,
"more.lora.more": False,
"more.notlora.more": True,
}

# Test that a `.` in the path is properly escaped.
assert kd.optim.select("lora.more")({
"lora": 0,
"loraxmore": 0,
"lora.more": 0,
"more.loraxmore.more": 0,
"more.lora.more.more": 0,
}) == {
"lora": False,
"loraxmore": False,
"lora.more": True,
"more.loraxmore.more": False,
"more.lora.more.more": True,
}

# Test that the select works on nested tree
assert kd.optim.select("lora.more")({
"lora": {
"more": {
"x": 0,
"y": 0,
},
"notmore": 0,
},
"y": {"lora": {"more": 0}},
"z": 0,
}) == {
"lora": {
"more": {
"x": True,
"y": True,
},
"notmore": False,
},
"y": {"lora": {"more": True}},
"z": False,
}

# Tests that regex are properly escaped
assert kd.optim.select("lora[0-9]+")({
"lora00": 0,
"lora1": 0,
"lora1x": 0,
"lora1": 0,
"xx.lora": 0,
"xx.lora3.aa": 0,
}) == {
"lora00": True,
"lora1": True,
"lora1x": False,
"lora1": True,
"xx.lora": False,
"xx.lora3.aa": True,
}

0 comments on commit bb23e4f

Please sign in to comment.