Skip to content

Commit e9abb0a

Browse files
ConchylicultorThe kauldron Authors
authored andcommitted
Wrap the FeedForward layers inside Einsum
PiperOrigin-RevId: 707468281
1 parent c4cc65b commit e9abb0a

File tree

5 files changed

+384
-1
lines changed

5 files changed

+384
-1
lines changed

kauldron/optim/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"""Optimizers etc."""
1616

1717
# pylint: disable=g-importing-member
18+
19+
from kauldron.optim._freeze import partial_updates
20+
from kauldron.optim._masks import exclude
21+
from kauldron.optim._masks import select
1822
from kauldron.optim.combine import named_chain
1923
from kauldron.optim.transform import decay_to_init
20-
# pylint: enable=g-importing-memberfrom

kauldron/optim/_freeze.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2024 The kauldron Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Freeze utils."""
16+
17+
from collections.abc import Callable
18+
import functools
19+
from typing import Any
20+
21+
import jax
22+
import optax
23+
24+
_PyTree = Any
25+
26+
27+
def partial_updates(
28+
optimizer: optax.GradientTransformation,
29+
mask: _PyTree | Callable[[_PyTree], _PyTree],
30+
) -> optax.GradientTransformation:
31+
"""Applies the optimizer to a subset of the parameters.
32+
33+
Args:
34+
optimizer: The optimizer to use.
35+
mask: A tree or callable returning a tree of bools to apply the optimizer
36+
to.
37+
38+
Returns:
39+
The wrapped optimizer.
40+
"""
41+
42+
return optax.multi_transform(
43+
{
44+
'train': optimizer,
45+
'freeze': optax.set_to_zero(),
46+
},
47+
functools.partial(_make_labels, mask=mask),
48+
)
49+
50+
51+
def _make_labels(tree, mask):
52+
if callable(mask):
53+
mask = mask(tree)
54+
return jax.tree.map(lambda x: 'train' if x else 'freeze', mask)

kauldron/optim/_freeze_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2024 The kauldron Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax.numpy as jnp
16+
from kauldron import kd
17+
import optax
18+
19+
20+
def test_partial_updates():
21+
optimizer = kd.optim.partial_updates(
22+
optax.adam(learning_rate=1e-3),
23+
mask=kd.optim.select('lora'),
24+
)
25+
26+
params = {
27+
'a': {
28+
'lora': {
29+
'x': jnp.zeros((2,)),
30+
'y': jnp.zeros((2,)),
31+
}
32+
},
33+
'x': jnp.zeros((2,)),
34+
'y': jnp.zeros((2,)),
35+
}
36+
37+
assert kd.optim._freeze._make_labels(params, kd.optim.select('lora')) == {
38+
'a': {
39+
'lora': {
40+
'x': 'train',
41+
'y': 'train',
42+
}
43+
},
44+
'x': 'freeze',
45+
'y': 'freeze',
46+
}
47+
48+
# TODO(epot): Could check the state params is empty for frozen params.
49+
optimizer.init({
50+
'a': {
51+
'lora': {
52+
'x': jnp.zeros((2,)),
53+
'y': jnp.zeros((2,)),
54+
}
55+
},
56+
'x': jnp.zeros((2,)),
57+
'y': jnp.zeros((2,)),
58+
})

kauldron/optim/_masks.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2024 The kauldron Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Masks utils."""
16+
17+
from collections.abc import Callable, Sequence
18+
import re
19+
from typing import Any
20+
21+
import jax
22+
23+
_PyTree = Any
24+
25+
26+
# Improvements:
27+
# * Could add `exclude=` kwargs, similar to `glob()`.
28+
29+
30+
def select(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]:
31+
r"""Create a mask which selects only the sub-pytree matching the pattern.
32+
33+
* `xx` will match all `{'xx': ...}` dict anywhere inside the tree. Note that
34+
the match is strict, so `xx` will NOT match `{'xxyy': }`
35+
* `xx.yy` will match `{'xx': {'yy': ...}}` dict
36+
* Regex are supported, when using regex, make sure to escape `.` (e.g.
37+
`xx\.yy[0-9]+`)
38+
39+
Example:
40+
41+
```python
42+
mask_fn = kg.optim.select("lora")
43+
44+
mask_fn({
45+
'layer0': {
46+
'lora': {
47+
'a': jnp.zeros(),
48+
'b': jnp.zeros(),
49+
},
50+
'weights': jnp.zeros(),
51+
'bias': jnp.zeros(),
52+
}
53+
}) == {
54+
'layer0': {
55+
'lora': {
56+
'a': True,
57+
'b': True,
58+
},
59+
'weights': False,
60+
'bias': False,
61+
}
62+
}
63+
```
64+
65+
Args:
66+
pattern: The pattern to include. Everything else will be `False`.
67+
68+
Returns:
69+
The optax mask factory.
70+
"""
71+
72+
# Convert the pattern to a regex.
73+
if isinstance(pattern, str):
74+
pattern = [pattern]
75+
76+
pattern_regexes = [_make_regex(p) for p in pattern]
77+
78+
def _path_match_pattern(path: jax.tree_util.KeyPath) -> bool:
79+
path_str = ".".join(_jax_key_entry_to_str(p) for p in path)
80+
return any(bool(p.search(path_str)) for p in pattern_regexes)
81+
82+
def _make_mask(tree: _PyTree) -> _PyTree:
83+
# TODO(epot): Replace by `jax.tree.flatten_with_path` once Colab is updated
84+
leaves_with_path, treedef = jax.tree_util.tree_flatten_with_path(tree)
85+
86+
# Parse each leaves
87+
leaves = []
88+
for path, _ in leaves_with_path:
89+
leaves.append(_path_match_pattern(path))
90+
91+
# Restore the tree structure.
92+
return jax.tree.unflatten(treedef, leaves)
93+
94+
return _make_mask
95+
96+
97+
def exclude(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]:
98+
"""Create a mask which selects all nodes except the ones matching the pattern.
99+
100+
This is the inverse of `select()`.
101+
102+
Example:
103+
104+
```python
105+
optax.masked(
106+
optax.set_to_zero(),
107+
kd.optim.exclude("lora"), # Only `lora` weights are trained.
108+
)
109+
```
110+
111+
Args:
112+
pattern: The pattern to exclude. See `select()` for more details.
113+
114+
Returns:
115+
The optax mask factory.
116+
"""
117+
make_select_mask = select(pattern)
118+
119+
def _make_mask(tree: _PyTree) -> _PyTree:
120+
# Invert the select mask.
121+
tree = make_select_mask(tree)
122+
return jax.tree.map(lambda x: not x, tree)
123+
124+
return _make_mask
125+
126+
127+
_REGEX_SPECIAL_CHARS = set("()[]?+*^$|\\")
128+
129+
130+
def _make_regex(pattern: str) -> re.Pattern[str]:
131+
# Auto-detect regex and forward them as-is.
132+
if any(c in _REGEX_SPECIAL_CHARS for c in pattern):
133+
pass
134+
else: # Otherwise, escape special characters (`.`).
135+
pattern = re.escape(pattern)
136+
137+
pattern = rf"(?:^|\.){pattern}(?:$|\.)"
138+
return re.compile(pattern)
139+
140+
141+
def _jax_key_entry_to_str(
142+
jax_key_entry: jax.tree_util.KeyEntry,
143+
) -> str:
144+
"""Convert a JaxKeyEntry into a valid `kontext.Path` element."""
145+
match jax_key_entry:
146+
case jax.tree_util.DictKey(key):
147+
return key
148+
case _:
149+
raise TypeError(f"Unknown key entry type {type(jax_key_entry)}")

kauldron/optim/_masks_test.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2024 The kauldron Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from kauldron import kd
16+
from kauldron.optim import _masks
17+
18+
19+
def test_select():
20+
# Check the regex is restricted to the exact path.
21+
assert kd.optim.select("lora")({
22+
"lora": 0,
23+
"notlora": 0,
24+
"lora.more": 0,
25+
"loranot.more": 0,
26+
"notlora.more": 0,
27+
"more.lora": 0,
28+
"more.notlora": 0,
29+
"more.lora.more": 0,
30+
"more.notlora.more": 0,
31+
}) == {
32+
"lora": True,
33+
"notlora": False,
34+
"lora.more": True,
35+
"loranot.more": False,
36+
"notlora.more": False,
37+
"more.lora": True,
38+
"more.notlora": False,
39+
"more.lora.more": True,
40+
"more.notlora.more": False,
41+
}
42+
43+
# Exclude returns the opossite mask.
44+
assert kd.optim.exclude("lora")({
45+
"lora": 0,
46+
"notlora": 0,
47+
"lora.more": 0,
48+
"loranot.more": 0,
49+
"notlora.more": 0,
50+
"more.lora": 0,
51+
"more.notlora": 0,
52+
"more.lora.more": 0,
53+
"more.notlora.more": 0,
54+
}) == {
55+
"lora": False,
56+
"notlora": True,
57+
"lora.more": False,
58+
"loranot.more": True,
59+
"notlora.more": True,
60+
"more.lora": False,
61+
"more.notlora": True,
62+
"more.lora.more": False,
63+
"more.notlora.more": True,
64+
}
65+
66+
# Test that a `.` in the path is properly escaped.
67+
assert kd.optim.select("lora.more")({
68+
"lora": 0,
69+
"loraxmore": 0,
70+
"lora.more": 0,
71+
"more.loraxmore.more": 0,
72+
"more.lora.more.more": 0,
73+
}) == {
74+
"lora": False,
75+
"loraxmore": False,
76+
"lora.more": True,
77+
"more.loraxmore.more": False,
78+
"more.lora.more.more": True,
79+
}
80+
81+
# Test that the select works on nested tree
82+
assert kd.optim.select("lora.more")({
83+
"lora": {
84+
"more": {
85+
"x": 0,
86+
"y": 0,
87+
},
88+
"notmore": 0,
89+
},
90+
"y": {"lora": {"more": 0}},
91+
"z": 0,
92+
}) == {
93+
"lora": {
94+
"more": {
95+
"x": True,
96+
"y": True,
97+
},
98+
"notmore": False,
99+
},
100+
"y": {"lora": {"more": True}},
101+
"z": False,
102+
}
103+
104+
# Tests that regex are properly escaped
105+
assert kd.optim.select("lora[0-9]+")({
106+
"lora00": 0,
107+
"lora1": 0,
108+
"lora1x": 0,
109+
"lora1": 0,
110+
"xx.lora": 0,
111+
"xx.lora3.aa": 0,
112+
}) == {
113+
"lora00": True,
114+
"lora1": True,
115+
"lora1x": False,
116+
"lora1": True,
117+
"xx.lora": False,
118+
"xx.lora3.aa": True,
119+
}

0 commit comments

Comments
 (0)