diff --git a/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/csrc/BatchRulesBinaryOps.cpp index 2f5e9ae305..7367f43a77 100644 --- a/functorch/csrc/BatchRulesBinaryOps.cpp +++ b/functorch/csrc/BatchRulesBinaryOps.cpp @@ -183,6 +183,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { UNARY_POINTWISE(clamp_max); POINTWISE_BOXED(clamp_max_); + UNARY_POINTWISE(clip); + POINTWISE_BOXED(clip.Tensor); + // Commented out so we have a test op // BINARY_SCALAR_2(copysign, Tensor, Scalar); BINARY_SCALAR_2(div, Tensor, Scalar); diff --git a/functorch/csrc/BatchingRegistrations.cpp b/functorch/csrc/BatchingRegistrations.cpp index 02f6cef3b0..1b92543a70 100644 --- a/functorch/csrc/BatchingRegistrations.cpp +++ b/functorch/csrc/BatchingRegistrations.cpp @@ -888,7 +888,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("view_as", native::view_as); // composite wrt autograd m.impl("addmm", addmm_batching_rule); - // clamp operations // unary pointwise, out-of-place, no additional arguments. #define TO_BATCHING_RULE(name, ...) \