Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None:
# 1. be removed
# 2. have a reference implementation added to ref_implementations.py
_WARN_ONLY = {
"cadence::_softmax_f32_f32",
"cadence::quantized_softmax.per_tensor",
"cadence::quantized_softmax",
"cadence::quantized_w8a32_gru",
Expand Down Expand Up @@ -640,10 +639,10 @@ def register_fake(
"int sampling_ratio, bool aligned) -> (Tensor out)"
)
lib.define(
"_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)"
"_softmax_f32_f32(Tensor self, int dim, bool? half_to_float = None) -> (Tensor out)"
)
lib.define(
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float = None, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
Expand Down Expand Up @@ -2652,12 +2651,13 @@ def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_meta(

@register_fake("cadence::_softmax_f32_f32")
def softmax_f32_f32_meta(
self: torch.Tensor,
input_tensor: torch.Tensor,
dim: int,
dtype: torch.dtype,
half_to_float: Optional[bool] = None,
) -> torch.Tensor:
return self.new_empty(self.size(), dtype=self.dtype)
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
assert not half_to_float, "half_to_float is not supported"
return input_tensor.new_empty(input_tensor.size(), dtype=torch.float32)


@register_fake("cadence::quantized_softmax")
Expand Down
11 changes: 11 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,3 +1979,14 @@ def linalg_svd(
assert compute_uv
U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver)
return U.contiguous(), S.contiguous(), Vh.contiguous()


@impl_tracked(m, "_softmax_f32_f32")
def softmax_f32_f32(
input_tensor: torch.Tensor,
dim: int,
half_to_float: bool | None = None,
) -> torch.Tensor:
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
assert not half_to_float, "half_to_float is not supported"
return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32)
9 changes: 9 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2885,3 +2885,12 @@ def test_quantized_layer_norm(self) -> None:
output_scale,
output_zero_point,
)

def test_softmax_f32_f32(self) -> None:
# Just a wrapper around torch.nn.functional.softmax, so just ensure that it runs
input_tensor = torch.tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32
)
output = torch.ops.cadence._softmax_f32_f32(input_tensor, dim=1)
self.assertEqual(output.dtype, torch.float32)
self.assertEqual(output.shape, input_tensor.shape)
Loading