Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA support for Gemma #66

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gemma/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
class Einsum(nn.Module):
"""Einsum is a convenience module for parameterized tensor multiplication."""
shape: tuple[int, ...]
weight_name: str = 'w'

@nn.compact
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
w = self.param('w', nn.initializers.normal(), self.shape)
w = self.param(self.weight_name, nn.initializers.normal(), self.shape)
return jnp.einsum(eqn, x, w)


Expand Down
37 changes: 18 additions & 19 deletions gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,32 +248,31 @@ def __call__(self, x):
# Some versions use an alternate parameter ordering that
# transposes hidden_dim and features.
if self.transpose_gating_einsum:
w_gating = self.param(
'gating_einsum',
nn.initializers.normal(),
((2, self.hidden_dim, self.features)),
eq = '...F,NHF->...NH'
gating = layers.Einsum(
shape=(2, self.hidden_dim, self.features),
weight_name='gating_einsum',
)
w_gating = w_gating.transpose((0, 2, 1))
else:
w_gating = self.param(
'gating_einsum',
nn.initializers.normal(),
((2, self.features, self.hidden_dim)),
eq = '...F,NFH->...NH'
gating = layers.Einsum(
shape=(2, self.features, self.hidden_dim),
weight_name='gating_einsum',
)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)

# Up projection
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
# Use the same scope for backwards compatibility with existing checkpoints
# created before using `layers.Einsum` here.
nn.share_scope(self, gating)
gate = gating(eq, x)
activations = nn.gelu(gate[..., 0, :]) * gate[..., 1, :]

# Down projection
w_linear = self.param(
'linear',
nn.initializers.zeros_init(),
(self.hidden_dim, self.features),
linear = layers.Einsum(
shape=(self.hidden_dim, self.features),
weight_name='linear',
)
outputs = jnp.dot(activations, w_linear)
nn.share_scope(self, linear)
outputs = linear('...H,HF->...F', activations)

return outputs

Expand Down