diff --git a/mx/simd_ops.py b/mx/simd_ops.py index cd9a553..712a4b2 100644 --- a/mx/simd_ops.py +++ b/mx/simd_ops.py @@ -18,6 +18,8 @@ simd_reduce_mean y = x.mean(dim) simd_norm y = (x**2).sum().sqrt() """ +from typing import Any, Optional, TYPE_CHECKING, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -540,7 +542,14 @@ def simd_reduce_mean(in1, dim=None, keepdim=False, mx_specs=None): return SIMDMul.apply(s, 1/denom, mx_specs) -def simd_norm(in1, keepdim=False, mx_specs=None): +def simd_norm( + in1, + p: Optional[Union[float, str]] = "fro", + dim=None, + keepdim=False, + dtype=None, + mx_specs=None, +): """ Computes Frobenius norm sqrt(sum(in1**2)), same as torch.linalg.norm(in1) with no other args """ mx_assert_test(mx_specs)