Skip to content

Commit a33e91f

Browse files
authored
Arm backend: Add support for sigmoid and tanh int16x8 (#15101)
Adds support for sigmoid and tanh to arm backend for int16x8 support Removes unnecessary test_sigmoid_16bit as these are now covered by test_sigmoid
1 parent 6f5ce2d commit a33e91f

File tree

5 files changed

+20
-211
lines changed

5 files changed

+20
-211
lines changed

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def is_node_supported( # noqa: C901
114114
return False
115115

116116
if node.target in self.target_ops_i8:
117-
if dtype not in (torch.int8,):
117+
if dtype not in (torch.int8, torch.int16):
118118
self.reporter.report_reject(
119-
node, f"Unsupported dtype {dtype} (Supports i8)."
119+
node, f"Unsupported dtype {dtype} (Supports i8, i16)."
120120
)
121121
return False
122122

backends/arm/quantizer/arm_quantizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def get_symmetric_a16w8_quantization_config(
161161
is_dynamic: bool = False,
162162
weight_qmin: int = -127,
163163
weight_qmax: int = 127,
164+
epsilon: float = 2**-12,
164165
):
165166
"""
166167
16A8W quantization config: 16-bit activations, 8-bit weights.
@@ -174,11 +175,12 @@ def get_symmetric_a16w8_quantization_config(
174175
is_dynamic: Whether to use dynamic quantization
175176
weight_qmin: Minimum quantization value for weights
176177
weight_qmax: Maximum quantization value for weights
178+
epsilon: Value used to pad observed [qmin, qmax] before initial zero point and scale calculation
177179
178180
Returns:
179181
QuantizationConfig with 16-bit activations and 8-bit weights
180182
"""
181-
extra_args: Dict[str, Any] = {"eps": 2**-12}
183+
extra_args: Dict[str, Any] = {"eps": epsilon}
182184

183185
# Setup observer/fake-quant for 16-bit activations
184186
if is_qat:

backends/arm/test/ops/test_sigmoid.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"zeros": lambda: torch.zeros(10, 10, 10, 10),
3535
"ones": lambda: torch.ones(10, 10, 10),
3636
"rand": lambda: torch.rand(10, 10) - 0.5,
37+
"rand_4d": lambda: torch.rand(1, 1, 5, 10),
3738
"randn_pos": lambda: torch.randn(10) + 10,
3839
"randn_neg": lambda: torch.randn(10) - 10,
3940
"ramp": lambda: torch.arange(-16, 16, 0.2),
@@ -269,22 +270,23 @@ def get_symmetric_a16w8_sigmoid_quantizer(per_channel_quantization=False):
269270
}
270271

271272
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
273+
274+
# Use a smaller episilon value to not greatly inflate [qmin, qmax]
272275
quantizer.set_global(
273-
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
276+
get_symmetric_a16w8_quantization_config(
277+
is_per_channel=per_channel_quantization, epsilon=2**-16
278+
)
274279
)
275280

276281
return Quantize(
277282
quantizer,
278283
get_symmetric_a16w8_quantization_config(
279-
is_per_channel=per_channel_quantization
284+
is_per_channel=per_channel_quantization, epsilon=2**-16
280285
),
281286
)
282287

283288

284289
@common.parametrize("test_data", test_data_suite)
285-
@pytest.mark.xfail(
286-
reason="missing int16 sigmoid ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13974"
287-
)
288290
def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor):
289291
"""Test sigmoid operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
290292
per_channel_quantization = False
@@ -311,7 +313,7 @@ def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor):
311313
@common.parametrize("test_data", test_data_suite)
312314
@common.XfailIfNoCorstone300
313315
@pytest.mark.xfail(
314-
reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations"
316+
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
315317
)
316318
def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor):
317319
"""Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
@@ -337,9 +339,6 @@ def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor):
337339

338340
@common.parametrize("test_data", test_data_suite)
339341
@common.XfailIfNoCorstone320
340-
@pytest.mark.xfail(
341-
reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations"
342-
)
343342
def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor):
344343
"""Test sigmoid operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
345344
per_channel_quantization = False

backends/arm/test/ops/test_sigmoid_16bit.py

Lines changed: 0 additions & 190 deletions
This file was deleted.

backends/arm/test/ops/test_tanh.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,23 @@ def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False):
121121
}
122122

123123
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124+
125+
# Use a smaller episilon value to not greatly inflate [qmin, qmax]
124126
quantizer.set_global(
125-
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
127+
get_symmetric_a16w8_quantization_config(
128+
is_per_channel=per_channel_quantization, epsilon=2**-16
129+
)
126130
)
127131

128132
return Quantize(
129133
quantizer,
130134
get_symmetric_a16w8_quantization_config(
131-
is_per_channel=per_channel_quantization
135+
is_per_channel=per_channel_quantization, epsilon=2**-16
132136
),
133137
)
134138

135139

136140
@common.parametrize("test_data", test_data_suite)
137-
@pytest.mark.xfail(
138-
reason="missing int16 tanh ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13975"
139-
)
140141
def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
141142
"""Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
142143
per_channel_quantization = False
@@ -163,7 +164,7 @@ def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
163164
@common.parametrize("test_data", test_data_suite)
164165
@common.XfailIfNoCorstone300
165166
@pytest.mark.xfail(
166-
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
167+
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
167168
)
168169
def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor):
169170
"""Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
@@ -189,9 +190,6 @@ def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor):
189190

190191
@common.parametrize("test_data", test_data_suite)
191192
@common.XfailIfNoCorstone320
192-
@pytest.mark.xfail(
193-
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
194-
)
195193
def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor):
196194
"""Test tanh operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
197195
per_channel_quantization = False

0 commit comments

Comments
 (0)