Skip to content

Commit

Permalink
Wrap the FeedForward layers inside Einsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707468281
  • Loading branch information
Conchylicultor authored and The gemma Authors committed Dec 18, 2024
1 parent af38d6e commit 1db24de
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
3 changes: 2 additions & 1 deletion gemma/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
class Einsum(nn.Module):
"""Einsum is a convenience module for parameterized tensor multiplication."""
shape: tuple[int, ...]
weight_name: str = 'w'

@nn.compact
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
w = self.param('w', nn.initializers.normal(), self.shape)
w = self.param(self.weight_name, nn.initializers.normal(), self.shape)
return jnp.einsum(eqn, x, w)


Expand Down
37 changes: 18 additions & 19 deletions gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,32 +248,31 @@ def __call__(self, x):
# Some versions use an alternate parameter ordering that
# transposes hidden_dim and features.
if self.transpose_gating_einsum:
w_gating = self.param(
'gating_einsum',
nn.initializers.normal(),
((2, self.hidden_dim, self.features)),
eq = '...F,NHF->...,NH'
gating = layers.Einsum(
shape=(2, self.hidden_dim, self.features),
weight_name='gating_einsum',
)
w_gating = w_gating.transpose((0, 2, 1))
else:
w_gating = self.param(
'gating_einsum',
nn.initializers.normal(),
((2, self.features, self.hidden_dim)),
eq = '...F,NFH->...NH'
gating = layers.Einsum(
shape=(2, self.features, self.hidden_dim),
weight_name='gating_einsum',
)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)

# Up projection
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
# Use the same scope for backwards compatibility with existing checkpoints
# created before using `layers.Einsum` here.
nn.share_scope(self, gating)
gate = gating(eq, x)
activations = nn.gelu(gate[..., 0, :]) * gate[..., 1, :]

# Down projection
w_linear = self.param(
'linear',
nn.initializers.zeros_init(),
(self.hidden_dim, self.features),
linear = layers.Einsum(
shape=(self.hidden_dim, self.features),
weight_name='linear',
)
outputs = jnp.dot(activations, w_linear)
nn.share_scope(self, linear)
outputs = linear('...H,HF->...F', activations)

return outputs

Expand Down

0 comments on commit 1db24de

Please sign in to comment.