From bd254bd03ce163af834e3113b81578e030f13492 Mon Sep 17 00:00:00 2001 From: Lee Jae Heon Date: Thu, 4 Sep 2025 18:10:50 +0900 Subject: [PATCH] simd_norm input parameters update current input parameters for simd_norm are not consistent with those of 'torch.norm'(refer to https://docs.pytorch.org/docs/stable/generated/torch.norm.html) thus I changed it accordingly * note that my pytorch does not require 'out' as input paramters so I didn't add it (return torch.norm(self, p, dim, keepdim, dtype=dtype)) --- mx/simd_ops.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/mx/simd_ops.py b/mx/simd_ops.py index cd9a553..712a4b2 100644 --- a/mx/simd_ops.py +++ b/mx/simd_ops.py @@ -18,6 +18,8 @@ simd_reduce_mean y = x.mean(dim) simd_norm y = (x**2).sum().sqrt() """ +from typing import Any, Optional, TYPE_CHECKING, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -540,7 +542,14 @@ def simd_reduce_mean(in1, dim=None, keepdim=False, mx_specs=None): return SIMDMul.apply(s, 1/denom, mx_specs) -def simd_norm(in1, keepdim=False, mx_specs=None): +def simd_norm( + in1, + p: Optional[Union[float, str]] = "fro", + dim=None, + keepdim=False, + dtype=None, + mx_specs=None, +): """ Computes Frobenius norm sqrt(sum(in1**2)), same as torch.linalg.norm(in1) with no other args """ mx_assert_test(mx_specs)