Skip to content

Commit f04eaec

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Cadence ops: Support quantized gru (#15209)
Summary: As titled. Differential Revision: D84855084
1 parent efccca4 commit f04eaec

File tree

3 files changed

+250
-2
lines changed

3 files changed

+250
-2
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def _validate_ref_impl_exists() -> None:
5555
_WARN_ONLY = {
5656
"cadence::quantized_softmax.per_tensor",
5757
"cadence::quantized_softmax",
58-
"cadence::quantized_w8a32_gru",
5958
}
6059

6160
ref_impls = get_registered_ref_implementations()
@@ -2753,7 +2752,7 @@ def quantized_w8a32_gru_meta(
27532752
bias_hidden: torch.Tensor,
27542753
b_h_scale: float,
27552754
) -> torch.Tensor:
2756-
return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype)
2755+
return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32)
27572756

27582757

27592758
# Validate that all meta kernels have reference implementations

backends/cadence/aot/ref_implementations.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,70 @@ def quantized_w8a32_linear(
985985
return output
986986

987987

988+
@impl_tracked(m, "quantized_w8a32_gru")
989+
def quantized_w8a32_gru(
990+
inputs: torch.Tensor,
991+
hidden: torch.Tensor,
992+
weights_inputs: torch.Tensor,
993+
w_i_scale: float,
994+
weights_hidden: torch.Tensor,
995+
w_h_scale: float,
996+
bias_inputs: torch.Tensor,
997+
b_i_scale: float,
998+
bias_hidden: torch.Tensor,
999+
b_h_scale: float,
1000+
) -> torch.Tensor:
1001+
assert weights_inputs.dtype == torch.int8
1002+
assert weights_hidden.dtype == torch.int8
1003+
assert bias_inputs.dtype == torch.int8
1004+
assert bias_hidden.dtype == torch.int8
1005+
assert inputs.dtype == torch.float32
1006+
assert hidden.dtype == torch.float32
1007+
1008+
if len(hidden.shape) > 2:
1009+
raise ValueError("Hidden state must be 2D or 1D")
1010+
1011+
if len(hidden.shape) == 2 and hidden.shape[0] != 1:
1012+
raise ValueError("Leading dimension of hidden state must be 1")
1013+
1014+
original_hidden_shape = hidden.shape
1015+
hidden = hidden.view(-1)
1016+
1017+
hidden_dim = hidden.shape[0]
1018+
if (hidden_dim % 4) != 0:
1019+
raise ValueError(
1020+
"Hidden dimension must be a multiple of 4 for HiFi SIMD operations"
1021+
)
1022+
1023+
dequant_weights_inputs = weights_inputs.float() * w_i_scale
1024+
dequant_weights_hidden = weights_hidden.float() * w_h_scale
1025+
1026+
# C++ implementation averages the two bias scales
1027+
avg_bias_scale = (b_i_scale + b_h_scale) / 2
1028+
dequant_bias_inputs = bias_inputs.float() * avg_bias_scale
1029+
dequant_bias_hidden = bias_hidden.float() * avg_bias_scale
1030+
1031+
gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs)
1032+
gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden)
1033+
1034+
i_r, i_z, i_n = gi.chunk(3, -1)
1035+
h_r, h_z, h_n = gh.chunk(3, -1)
1036+
1037+
reset_gate = torch.sigmoid(i_r + h_r)
1038+
update_gate = torch.sigmoid(i_z + h_z)
1039+
new_gate = torch.tanh(i_n + reset_gate * h_n)
1040+
1041+
new_hidden = (1 - update_gate) * new_gate + update_gate * hidden
1042+
1043+
if new_hidden.shape[0] != 1:
1044+
raise ValueError("Leading dimension of hidden state must be 1")
1045+
1046+
assert new_hidden.shape == original_hidden_shape
1047+
1048+
new_hidden = new_hidden.view(-1)
1049+
return torch.stack([new_hidden, new_hidden], dim=0)
1050+
1051+
9881052
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
9891053
def quantized_conv2d_nhwc_per_tensor(
9901054
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2894,3 +2894,188 @@ def test_softmax_f32_f32(self) -> None:
28942894
output = torch.ops.cadence._softmax_f32_f32(input_tensor, dim=1)
28952895
self.assertEqual(output.dtype, torch.float32)
28962896
self.assertEqual(output.shape, input_tensor.shape)
2897+
2898+
@expand(
2899+
[
2900+
(
2901+
"basic_hidden_dim_4",
2902+
torch.tensor([[1.0, 2.0]], dtype=torch.float32), # inputs: 1x2
2903+
torch.tensor(
2904+
[[0.5, 0.5, 0.5, 0.5]], dtype=torch.float32
2905+
), # hidden: 1x4
2906+
torch.ones(
2907+
(12, 2), dtype=torch.int8
2908+
), # weights_inputs: 12x2 (3*4 x input_dim=2)
2909+
0.1, # w_i_scale
2910+
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4)
2911+
0.1, # w_h_scale
2912+
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
2913+
0.1, # b_i_scale
2914+
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
2915+
0.1, # b_h_scale
2916+
),
2917+
(
2918+
"invalid_batch_size_2",
2919+
torch.tensor(
2920+
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=torch.float32
2921+
), # inputs: 2x3
2922+
torch.tensor(
2923+
[[0.5, 0.5, 0.5, 0.5], [0.3, 0.3, 0.3, 0.3]], dtype=torch.float32
2924+
), # hidden: 2x4
2925+
torch.ones((12, 3), dtype=torch.int8), # weights_inputs: 12x3
2926+
0.1, # w_i_scale
2927+
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4
2928+
0.1, # w_h_scale
2929+
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
2930+
0.1, # b_i_scale
2931+
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
2932+
0.1, # b_h_scale
2933+
),
2934+
(
2935+
"non_zero_biases",
2936+
torch.tensor([[1.0, 1.0]], dtype=torch.float32), # inputs: 1x2
2937+
torch.zeros((1, 4), dtype=torch.float32), # hidden: 1x4
2938+
torch.ones((12, 2), dtype=torch.int8), # weights_inputs: 12x2
2939+
0.2, # w_i_scale
2940+
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4
2941+
0.1, # w_h_scale
2942+
torch.tensor(
2943+
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
2944+
), # bias_inputs: 12
2945+
0.1, # b_i_scale
2946+
torch.tensor(
2947+
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
2948+
), # bias_hidden: 12
2949+
0.1, # b_h_scale
2950+
),
2951+
(
2952+
"negative_weights",
2953+
torch.tensor([[1.0, -1.0]], dtype=torch.float32), # inputs: 1x2
2954+
torch.tensor(
2955+
[[0.5, -0.5, 0.5, -0.5]], dtype=torch.float32
2956+
), # hidden: 1x4
2957+
torch.tensor(
2958+
[[1, -1], [-1, 1]] * 6, dtype=torch.int8
2959+
), # weights_inputs: 12x2 (alternating pattern)
2960+
0.1, # w_i_scale
2961+
torch.tensor(
2962+
[[1, -1, 1, -1], [-1, 1, -1, 1]] * 6, dtype=torch.int8
2963+
), # weights_hidden: 12x4 (alternating pattern)
2964+
0.1, # w_h_scale
2965+
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
2966+
0.1, # b_i_scale
2967+
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
2968+
0.1, # b_h_scale
2969+
),
2970+
(
2971+
"hidden_dim_8",
2972+
torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # inputs: 1x3
2973+
torch.tensor(
2974+
[[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]], dtype=torch.float32
2975+
), # hidden: 1x8
2976+
torch.ones((24, 3), dtype=torch.int8), # weights_inputs: 24x3 (3*8 x 3)
2977+
0.1, # w_i_scale
2978+
torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8)
2979+
0.1, # w_h_scale
2980+
torch.zeros(24, dtype=torch.int8), # bias_inputs: 24
2981+
0.1, # b_i_scale
2982+
torch.zeros(24, dtype=torch.int8), # bias_hidden: 24
2983+
0.1, # b_h_scale
2984+
),
2985+
]
2986+
)
2987+
def test_quantized_w8a32_gru(
2988+
self,
2989+
name: str,
2990+
inputs: torch.Tensor,
2991+
hidden: torch.Tensor,
2992+
weights_inputs: torch.Tensor,
2993+
w_i_scale: float,
2994+
weights_hidden: torch.Tensor,
2995+
w_h_scale: float,
2996+
bias_inputs: torch.Tensor,
2997+
b_i_scale: float,
2998+
bias_hidden: torch.Tensor,
2999+
b_h_scale: float,
3000+
) -> None:
3001+
3002+
if name == "invalid_batch_size_2":
3003+
with self.assertRaises(ValueError) as context:
3004+
torch.ops.cadence.quantized_w8a32_gru(
3005+
inputs,
3006+
hidden,
3007+
weights_inputs,
3008+
w_i_scale,
3009+
weights_hidden,
3010+
w_h_scale,
3011+
bias_inputs,
3012+
b_i_scale,
3013+
bias_hidden,
3014+
b_h_scale,
3015+
)
3016+
self.assertIn(
3017+
"Leading dimension of hidden state must be 1", str(context.exception)
3018+
)
3019+
return
3020+
3021+
output = torch.ops.cadence.quantized_w8a32_gru(
3022+
inputs,
3023+
hidden,
3024+
weights_inputs,
3025+
w_i_scale,
3026+
weights_hidden,
3027+
w_h_scale,
3028+
bias_inputs,
3029+
b_i_scale,
3030+
bias_hidden,
3031+
b_h_scale,
3032+
)
3033+
3034+
# Verify output properties
3035+
self.assertEqual(
3036+
output.dtype,
3037+
torch.float32,
3038+
f"Output dtype should be float32 in {name}",
3039+
)
3040+
self.assertEqual(
3041+
output.shape,
3042+
(2, hidden.shape[-1]),
3043+
f"Output shape should match {(2, hidden.shape[-1])} in {name}",
3044+
)
3045+
assert isinstance(output, torch.Tensor)
3046+
3047+
# Verify output is bounded: GRU hidden state is a convex combination of
3048+
# tanh([-1,1]) and previous hidden([-1,1]), so output should be in [-1,1]
3049+
self.assertTrue(
3050+
torch.all(output >= -1.0) and torch.all(output <= 1.0),
3051+
f"Output values should be in [-1.1, 1.1] in {name}. Got min={output.min():.4f}, max={output.max():.4f}",
3052+
)
3053+
3054+
def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
3055+
# Test that non-multiple of 4 hidden dimension raises error
3056+
inputs = torch.tensor([[1.0, 2.0]], dtype=torch.float32) # 1x2
3057+
hidden = torch.tensor(
3058+
[[0.5, 0.5, 0.5]], dtype=torch.float32
3059+
) # 1x3 (not divisible by 4)
3060+
weights_inputs = torch.zeros((9, 2), dtype=torch.int8) # 9x2
3061+
weights_hidden = torch.zeros((9, 3), dtype=torch.int8) # 9x3
3062+
bias_inputs = torch.zeros(9, dtype=torch.int8)
3063+
bias_hidden = torch.zeros(9, dtype=torch.int8)
3064+
3065+
with self.assertRaises(ValueError) as context:
3066+
torch.ops.cadence.quantized_w8a32_gru(
3067+
inputs,
3068+
hidden,
3069+
weights_inputs,
3070+
0.1,
3071+
weights_hidden,
3072+
0.1,
3073+
bias_inputs,
3074+
0.1,
3075+
bias_hidden,
3076+
0.1,
3077+
)
3078+
3079+
self.assertIn(
3080+
"Hidden dimension must be a multiple of 4", str(context.exception)
3081+
)

0 commit comments

Comments
 (0)