Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def _validate_ref_impl_exists() -> None:
_WARN_ONLY = {
"cadence::quantized_softmax.per_tensor",
"cadence::quantized_softmax",
"cadence::quantized_w8a32_gru",
}

ref_impls = get_registered_ref_implementations()
Expand Down Expand Up @@ -2753,7 +2752,7 @@ def quantized_w8a32_gru_meta(
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> torch.Tensor:
return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype)
return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32)


# Validate that all meta kernels have reference implementations
Expand Down
64 changes: 64 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,70 @@ def quantized_w8a32_linear(
return output


@impl_tracked(m, "quantized_w8a32_gru")
def quantized_w8a32_gru(
inputs: torch.Tensor,
hidden: torch.Tensor,
weights_inputs: torch.Tensor,
w_i_scale: float,
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> torch.Tensor:
assert weights_inputs.dtype == torch.int8
assert weights_hidden.dtype == torch.int8
assert bias_inputs.dtype == torch.int8
assert bias_hidden.dtype == torch.int8
assert inputs.dtype == torch.float32
assert hidden.dtype == torch.float32

if len(hidden.shape) > 2:
raise ValueError("Hidden state must be 2D or 1D")

if len(hidden.shape) == 2 and hidden.shape[0] != 1:
raise ValueError("Leading dimension of hidden state must be 1")

original_hidden_shape = hidden.shape
hidden = hidden.view(-1)

hidden_dim = hidden.shape[0]
if (hidden_dim % 4) != 0:
raise ValueError(
"Hidden dimension must be a multiple of 4 for HiFi SIMD operations"
)

dequant_weights_inputs = weights_inputs.float() * w_i_scale
dequant_weights_hidden = weights_hidden.float() * w_h_scale

# C++ implementation averages the two bias scales
avg_bias_scale = (b_i_scale + b_h_scale) / 2
dequant_bias_inputs = bias_inputs.float() * avg_bias_scale
dequant_bias_hidden = bias_hidden.float() * avg_bias_scale

gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs)
gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden)

i_r, i_z, i_n = gi.chunk(3, -1)
h_r, h_z, h_n = gh.chunk(3, -1)

reset_gate = torch.sigmoid(i_r + h_r)
update_gate = torch.sigmoid(i_z + h_z)
new_gate = torch.tanh(i_n + reset_gate * h_n)

new_hidden = (1 - update_gate) * new_gate + update_gate * hidden

if new_hidden.shape[0] != 1:
raise ValueError("Leading dimension of hidden state must be 1")

assert new_hidden.shape == original_hidden_shape

new_hidden = new_hidden.view(-1)
return torch.stack([new_hidden, new_hidden], dim=0)


@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
def quantized_conv2d_nhwc_per_tensor(
input_tensor: torch.Tensor,
Expand Down
185 changes: 185 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2894,3 +2894,188 @@ def test_softmax_f32_f32(self) -> None:
output = torch.ops.cadence._softmax_f32_f32(input_tensor, dim=1)
self.assertEqual(output.dtype, torch.float32)
self.assertEqual(output.shape, input_tensor.shape)

@expand(
[
(
"basic_hidden_dim_4",
torch.tensor([[1.0, 2.0]], dtype=torch.float32), # inputs: 1x2
torch.tensor(
[[0.5, 0.5, 0.5, 0.5]], dtype=torch.float32
), # hidden: 1x4
torch.ones(
(12, 2), dtype=torch.int8
), # weights_inputs: 12x2 (3*4 x input_dim=2)
0.1, # w_i_scale
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4)
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"invalid_batch_size_2",
torch.tensor(
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=torch.float32
), # inputs: 2x3
torch.tensor(
[[0.5, 0.5, 0.5, 0.5], [0.3, 0.3, 0.3, 0.3]], dtype=torch.float32
), # hidden: 2x4
torch.ones((12, 3), dtype=torch.int8), # weights_inputs: 12x3
0.1, # w_i_scale
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"non_zero_biases",
torch.tensor([[1.0, 1.0]], dtype=torch.float32), # inputs: 1x2
torch.zeros((1, 4), dtype=torch.float32), # hidden: 1x4
torch.ones((12, 2), dtype=torch.int8), # weights_inputs: 12x2
0.2, # w_i_scale
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4
0.1, # w_h_scale
torch.tensor(
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
), # bias_inputs: 12
0.1, # b_i_scale
torch.tensor(
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
), # bias_hidden: 12
0.1, # b_h_scale
),
(
"negative_weights",
torch.tensor([[1.0, -1.0]], dtype=torch.float32), # inputs: 1x2
torch.tensor(
[[0.5, -0.5, 0.5, -0.5]], dtype=torch.float32
), # hidden: 1x4
torch.tensor(
[[1, -1], [-1, 1]] * 6, dtype=torch.int8
), # weights_inputs: 12x2 (alternating pattern)
0.1, # w_i_scale
torch.tensor(
[[1, -1, 1, -1], [-1, 1, -1, 1]] * 6, dtype=torch.int8
), # weights_hidden: 12x4 (alternating pattern)
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"hidden_dim_8",
torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # inputs: 1x3
torch.tensor(
[[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]], dtype=torch.float32
), # hidden: 1x8
torch.ones((24, 3), dtype=torch.int8), # weights_inputs: 24x3 (3*8 x 3)
0.1, # w_i_scale
torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8)
0.1, # w_h_scale
torch.zeros(24, dtype=torch.int8), # bias_inputs: 24
0.1, # b_i_scale
torch.zeros(24, dtype=torch.int8), # bias_hidden: 24
0.1, # b_h_scale
),
]
)
def test_quantized_w8a32_gru(
self,
name: str,
inputs: torch.Tensor,
hidden: torch.Tensor,
weights_inputs: torch.Tensor,
w_i_scale: float,
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> None:

if name == "invalid_batch_size_2":
with self.assertRaises(ValueError) as context:
torch.ops.cadence.quantized_w8a32_gru(
inputs,
hidden,
weights_inputs,
w_i_scale,
weights_hidden,
w_h_scale,
bias_inputs,
b_i_scale,
bias_hidden,
b_h_scale,
)
self.assertIn(
"Leading dimension of hidden state must be 1", str(context.exception)
)
return

output = torch.ops.cadence.quantized_w8a32_gru(
inputs,
hidden,
weights_inputs,
w_i_scale,
weights_hidden,
w_h_scale,
bias_inputs,
b_i_scale,
bias_hidden,
b_h_scale,
)

# Verify output properties
self.assertEqual(
output.dtype,
torch.float32,
f"Output dtype should be float32 in {name}",
)
self.assertEqual(
output.shape,
(2, hidden.shape[-1]),
f"Output shape should match {(2, hidden.shape[-1])} in {name}",
)
assert isinstance(output, torch.Tensor)

# Verify output is bounded: GRU hidden state is a convex combination of
# tanh([-1,1]) and previous hidden([-1,1]), so output should be in [-1,1]
self.assertTrue(
torch.all(output >= -1.0) and torch.all(output <= 1.0),
f"Output values should be in [-1.1, 1.1] in {name}. Got min={output.min():.4f}, max={output.max():.4f}",
)

def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
# Test that non-multiple of 4 hidden dimension raises error
inputs = torch.tensor([[1.0, 2.0]], dtype=torch.float32) # 1x2
hidden = torch.tensor(
[[0.5, 0.5, 0.5]], dtype=torch.float32
) # 1x3 (not divisible by 4)
weights_inputs = torch.zeros((9, 2), dtype=torch.int8) # 9x2
weights_hidden = torch.zeros((9, 3), dtype=torch.int8) # 9x3
bias_inputs = torch.zeros(9, dtype=torch.int8)
bias_hidden = torch.zeros(9, dtype=torch.int8)

with self.assertRaises(ValueError) as context:
torch.ops.cadence.quantized_w8a32_gru(
inputs,
hidden,
weights_inputs,
0.1,
weights_hidden,
0.1,
bias_inputs,
0.1,
bias_hidden,
0.1,
)

self.assertIn(
"Hidden dimension must be a multiple of 4", str(context.exception)
)
Loading