Skip to content

Commit

Permalink
use vector norm
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 19, 2024
1 parent 825eb26 commit 8592cee
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/olmo_core/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@


def l2_normalize(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
# NOTE: could also use F.normalize(), but that doesn't work with DTensor at the moment
return x / torch.linalg.norm(x, dim=dim, keepdim=True, dtype=torch.float32).type_as(x)
# NOTE: could also use F.normalize(), but that doesn't work with DTensor at the moment.
return x / torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32).type_as(x)

0 comments on commit 8592cee

Please sign in to comment.