diff --git a/mx/elemwise_ops.py b/mx/elemwise_ops.py index 32d9e9b..fb9468b 100644 --- a/mx/elemwise_ops.py +++ b/mx/elemwise_ops.py @@ -41,7 +41,9 @@ def _safe_rshift(x, bits, exp): if exp is None: return x / (2**bits) else: - return x / (2**bits) * (2 ** exp) + out = x / (2**bits) * (2 ** exp) + out[torch.isnan(out)] = 0. + return out def _round_mantissa(A, bits, round, clamp=False): @@ -162,7 +164,7 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', if not custom_cuda: out[A == float("Inf")] = float("Inf") out[A == -float("Inf")] = -float("Inf") - out[A == float("NaN")] = float("NaN") + out[torch.isnan(A)] = float("NaN") if A_is_sparse: output = torch.sparse_coo_tensor(sparse_A.indices(), output,