Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 19, 2024
1 parent 8592cee commit 0face65
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 8 deletions.
14 changes: 14 additions & 0 deletions src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,20 @@ def forward(
# shape: (batch_size, seq_len, d_model)
return self.w_out(att)

@torch.no_grad()
def normalize_matrices(self):
"""
Normalize the weights in all matrices. This should be called after each optimizer step, which
the :class:`~olmo_core.train.callbacks.MatrixNormalizerCallback` will handle for you.
"""
self._normalize_matrix(self.w_q.weight)
self._normalize_matrix(self.w_k.weight)
self._normalize_matrix(self.w_v.weight)
self._normalize_matrix(self.w_out.weight, dim=0)

def _normalize_matrix(self, w: torch.Tensor, dim: int = -1):
w.copy_(l2_normalize(w, dim=dim))


class FusedAttention(nn.Module):
"""
Expand Down
14 changes: 14 additions & 0 deletions src/olmo_core/nn/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..config import Config, DType, StrEnum
from ..exceptions import OLMoConfigurationError
from .functional import l2_normalize

__all__ = ["FeedForwardConfig", "FeedForwardType", "FeedForward", "NormalizedFeedForward"]

Expand Down Expand Up @@ -130,3 +131,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
sw1 = self.sw1 * ((self.sw_init_value / self.sw_init_scaling) * self.sqrt_d_model)
sw3 = self.sw3 * (self.sw_init_value / self.sw_init_scaling)
return self.w2(F.silu(sw1 * self.w1(x)) * (sw3 * self.w3(x)))

@torch.no_grad()
def normalize_matrices(self):
"""
Normalize the weights in all matrices. This should be called after each optimizer step, which
the :class:`~olmo_core.train.callbacks.MatrixNormalizerCallback` will handle for you.
"""
self._normalize_matrix(self.w1.weight)
self._normalize_matrix(self.w2.weight, dim=0)
self._normalize_matrix(self.w3.weight)

def _normalize_matrix(self, w: torch.Tensor, dim: int = -1):
w.copy_(l2_normalize(w, dim=dim))
15 changes: 15 additions & 0 deletions src/olmo_core/nn/transformer/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,21 @@ def forward(

return h

@torch.no_grad()
def normalize_matrices(self):
"""
Normalize the weights in all matrices. This should be called after each optimizer step, which
the :class:`~olmo_core.train.callbacks.MatrixNormalizerCallback` will handle for you.
"""
if hasattr(self.attention, "normalize_matrices"):
self.attention.normalize_matrices()

if hasattr(self.feed_forward, "normalize_matrices"):
self.feed_forward.normalize_matrices()

def _normalize_matrix(self, w: torch.Tensor, dim: int = -1):
w.copy_(l2_normalize(w, dim=dim))


class MoETransformerBlock(TransformerBlockBase):
"""
Expand Down
12 changes: 4 additions & 8 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,14 +496,8 @@ def normalize_matrices(self):
self._normalize_matrix(self.embeddings.weight)

for block in self.blocks:
self._normalize_matrix(block.attention.w_q.weight)
self._normalize_matrix(block.attention.w_k.weight)
self._normalize_matrix(block.attention.w_v.weight)
self._normalize_matrix(block.attention.w_out.weight, dim=0)

self._normalize_matrix(block.feed_forward.w1.weight)
self._normalize_matrix(block.feed_forward.w2.weight, dim=0)
self._normalize_matrix(block.feed_forward.w3.weight)
if hasattr(block, "normalize_matrices"):
block.normalize_matrices()

if self.w_out is not None:
self._normalize_matrix(self.w_out.weight)
Expand All @@ -518,9 +512,11 @@ def forward(
max_doc_lens: Optional[Sequence[int]] = None,
) -> torch.Tensor:
out = super().forward(input_ids, doc_lens=doc_lens, max_doc_lens=max_doc_lens)

if self.w_out is not None:
sz = self.sz * (self.sz_init_value / self.sz_init_scaling)
out = sz * out

return out

def apply_compile(self):
Expand Down

0 comments on commit 0face65

Please sign in to comment.