Skip to content

Commit 7bd16d2

Browse files
Arm backend: Simplify testing of int inputs
INT-test pipelines always checked that quantized models produced quantize and dequantize nodes, and that these are removed in to_edge. This meant that when you add a test with inputs that do not need to be quantized, e.g. integer and boolean inputs, you would have to pop those stages from the test pipeline. This patch removes the need for popping those stages by making sure that they are only added if at least one input is in floating point. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I51220b719dfd19e3a4c109e23de544fea374333c
1 parent 2aeee9b commit 7bd16d2

20 files changed

+158
-202
lines changed

backends/arm/scripts/parse_test_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
"multihead_attention.default",
2020
"adaptive_avg_pool2d.default",
2121
"bitwise_right_shift.Tensor",
22+
"bitwise_right_shift.Scalar",
2223
"bitwise_left_shift.Tensor",
24+
"bitwise_left_shift.Scalar",
2325
"native_group_norm.default",
2426
"silu.default",
2527
"sdpa.default",

backends/arm/test/models/test_nn_functional.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ def test_nn_functional_INT(test_data):
110110
)
111111
pipeline.pop_stage("check.aten")
112112
pipeline.pop_stage("check_count.exir")
113-
pipeline.pop_stage("check.quant_nodes")
114-
pipeline.pop_stage("check_not.quant_nodes")
113+
if pipeline.has_stage("check.quant_nodes"):
114+
pipeline.pop_stage("check.quant_nodes")
115+
if pipeline.has_stage("check_not.quant_nodes"):
116+
pipeline.pop_stage("check_not.quant_nodes")
115117
try:
116118
pipeline.run()
117119
except RuntimeError as e:

backends/arm/test/models/test_nn_modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ def test_nn_Modules_INT(test_data):
147147
)
148148
pipeline.pop_stage("check.aten")
149149
pipeline.pop_stage("check_count.exir")
150-
pipeline.pop_stage("check.quant_nodes")
151-
pipeline.pop_stage("check_not.quant_nodes")
150+
if pipeline.has_stage("check.quant_nodes"):
151+
pipeline.pop_stage("check.quant_nodes")
152+
if pipeline.has_stage("check_not.quant_nodes"):
153+
pipeline.pop_stage("check_not.quant_nodes")
152154
try:
153155
pipeline.run()
154156
except RuntimeError as e:

backends/arm/test/ops/test_any.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ def test_any_tosa_INT(test_data: input_t1):
149149
rtol=0,
150150
qtol=0,
151151
)
152-
pipeline.pop_stage("quantize")
153-
pipeline.pop_stage("check.quant_nodes")
154152
pipeline.run()
155153

156154

@@ -181,8 +179,6 @@ def test_any_u85_INT(test_data: input_t1):
181179
rtol=0,
182180
qtol=0,
183181
)
184-
pipeline.pop_stage("quantize")
185-
pipeline.pop_stage("check.quant_nodes")
186182
pipeline.run()
187183

188184

@@ -211,6 +207,4 @@ def test_any_vgf_INT(test_data: input_t1):
211207
op.exir_op,
212208
tosa_version="TOSA-1.0+INT",
213209
)
214-
pipeline.pop_stage("quantize")
215-
pipeline.pop_stage("check.quant_nodes")
216210
pipeline.run()

backends/arm/test/ops/test_arange.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def test_arange_start_step_tosa_INT(test_data: test_data_t):
9898
ArangeAdd.aten_op,
9999
ArangeAdd.exir_op,
100100
)
101-
pipeline.pop_stage("check.quant_nodes")
102101
pipeline.run()
103102

104103

@@ -111,7 +110,6 @@ def test_arange_start_step_u55_INT(test_data: test_data_t):
111110
input_data(),
112111
ArangeAdd.aten_op,
113112
)
114-
pipeline.pop_stage("check.quant_nodes")
115113
pipeline.run()
116114

117115

@@ -124,7 +122,6 @@ def test_arange_start_step_u85_INT(test_data: test_data_t):
124122
input_data(),
125123
ArangeAdd.aten_op,
126124
)
127-
pipeline.pop_stage("check.quant_nodes")
128125
pipeline.run()
129126

130127

backends/arm/test/ops/test_bitwise.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
109109

110110

111111
class AndScalar(BitwiseBinaryScalar):
112-
aten_op = "torch.ops.aten.bitwise_and.Scalar"
113112
# Tensor because it gets converted from Scalar -> Tensor in lowering
113+
aten_op = "torch.ops.aten.bitwise_and.Tensor"
114114
exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor"
115115
exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Scalar"
116116

@@ -119,8 +119,8 @@ def forward(self, tensor: torch.Tensor, scalar: int):
119119

120120

121121
class XorScalar(BitwiseBinaryScalar):
122-
aten_op = "torch.ops.aten.bitwise_xor.Scalar"
123122
# Tensor because it gets converted from Scalar -> Tensor in lowering
123+
aten_op = "torch.ops.aten.bitwise_xor.Tensor"
124124
exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Tensor"
125125
exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Scalar"
126126

@@ -129,8 +129,8 @@ def forward(self, tensor: torch.Tensor, scalar: int):
129129

130130

131131
class OrScalar(BitwiseBinaryScalar):
132-
aten_op = "torch.ops.aten.bitwise_or.Scalar"
133132
# Tensor because it gets converted from Scalar -> Tensor in lowering
133+
aten_op = "torch.ops.aten.bitwise_or.Tensor"
134134
exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Tensor"
135135
exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Scalar"
136136

@@ -174,8 +174,6 @@ def test_bitwise_and_tensor_tosa_INT(test_data: input_t2):
174174
rtol=0,
175175
qtol=0,
176176
)
177-
pipeline.pop_stage("quantize")
178-
pipeline.pop_stage("check.quant_nodes")
179177
pipeline.run()
180178

181179

@@ -190,8 +188,6 @@ def test_bitwise_and_scalar_tosa_INT(test_data: input_t2):
190188
rtol=0,
191189
qtol=0,
192190
)
193-
pipeline.pop_stage("quantize")
194-
pipeline.pop_stage("check.quant_nodes")
195191
pipeline.run()
196192

197193

@@ -239,8 +235,6 @@ def test_bitwise_and_scalar_u85_INT(test_data: input_t2):
239235
rtol=0,
240236
qtol=0,
241237
)
242-
pipeline.pop_stage("quantize")
243-
pipeline.pop_stage("check.quant_nodes")
244238
pipeline.run()
245239

246240

@@ -256,8 +250,6 @@ def test_bitwise_and_tensor_u85_INT(test_data: input_t2):
256250
rtol=0,
257251
qtol=0,
258252
)
259-
pipeline.pop_stage("quantize")
260-
pipeline.pop_stage("check.quant_nodes")
261253
pipeline.run()
262254

263255

@@ -296,8 +288,6 @@ def test_bitwise_and_tensor_vgf_INT(test_data: input_t2):
296288
qtol=0,
297289
tosa_version="TOSA-1.0+INT",
298290
)
299-
pipeline.pop_stage("quantize")
300-
pipeline.pop_stage("check.quant_nodes")
301291
pipeline.run()
302292

303293

@@ -314,8 +304,6 @@ def test_bitwise_and_scalar_vgf_INT(test_data: input_t2):
314304
qtol=0,
315305
tosa_version="TOSA-1.0+INT",
316306
)
317-
pipeline.pop_stage("quantize")
318-
pipeline.pop_stage("check.quant_nodes")
319307
pipeline.run()
320308

321309

@@ -355,8 +343,6 @@ def test_bitwise_xor_tensor_tosa_INT(test_data: input_t2):
355343
rtol=0,
356344
qtol=0,
357345
)
358-
pipeline.pop_stage("quantize")
359-
pipeline.pop_stage("check.quant_nodes")
360346
pipeline.run()
361347

362348

@@ -371,8 +357,6 @@ def test_bitwise_xor_scalar_tosa_INT(test_data: input_t2):
371357
rtol=0,
372358
qtol=0,
373359
)
374-
pipeline.pop_stage("quantize")
375-
pipeline.pop_stage("check.quant_nodes")
376360
pipeline.run()
377361

378362

@@ -420,8 +404,6 @@ def test_bitwise_xor_tensor_u85_INT(test_data: input_t2):
420404
rtol=0,
421405
qtol=0,
422406
)
423-
pipeline.pop_stage("quantize")
424-
pipeline.pop_stage("check.quant_nodes")
425407
pipeline.run()
426408

427409

@@ -437,8 +419,6 @@ def test_bitwise_xor_scalar_u85_INT(test_data: input_t2):
437419
rtol=0,
438420
qtol=0,
439421
)
440-
pipeline.pop_stage("quantize")
441-
pipeline.pop_stage("check.quant_nodes")
442422
pipeline.run()
443423

444424

@@ -477,8 +457,6 @@ def test_bitwise_xor_tensor_vgf_INT(test_data: input_t2):
477457
qtol=0,
478458
tosa_version="TOSA-1.0+INT",
479459
)
480-
pipeline.pop_stage("quantize")
481-
pipeline.pop_stage("check.quant_nodes")
482460
pipeline.run()
483461

484462

@@ -495,8 +473,6 @@ def test_bitwise_xor_scalar_vgf_INT(test_data: input_t2):
495473
qtol=0,
496474
tosa_version="TOSA-1.0+INT",
497475
)
498-
pipeline.pop_stage("quantize")
499-
pipeline.pop_stage("check.quant_nodes")
500476
pipeline.run()
501477

502478

@@ -536,8 +512,6 @@ def test_bitwise_or_tensor_tosa_INT(test_data: input_t2):
536512
rtol=0,
537513
qtol=0,
538514
)
539-
pipeline.pop_stage("quantize")
540-
pipeline.pop_stage("check.quant_nodes")
541515
pipeline.run()
542516

543517

@@ -552,8 +526,6 @@ def test_bitwise_or_scalar_tosa_INT(test_data: input_t2):
552526
rtol=0,
553527
qtol=0,
554528
)
555-
pipeline.pop_stage("quantize")
556-
pipeline.pop_stage("check.quant_nodes")
557529
pipeline.run()
558530

559531

@@ -601,8 +573,6 @@ def test_bitwise_or_tensor_u85_INT(test_data: input_t2):
601573
rtol=0,
602574
qtol=0,
603575
)
604-
pipeline.pop_stage("quantize")
605-
pipeline.pop_stage("check.quant_nodes")
606576
pipeline.run()
607577

608578

@@ -618,8 +588,6 @@ def test_bitwise_or_scalar_u85_INT(test_data: input_t2):
618588
rtol=0,
619589
qtol=0,
620590
)
621-
pipeline.pop_stage("quantize")
622-
pipeline.pop_stage("check.quant_nodes")
623591
pipeline.run()
624592

625593

@@ -658,8 +626,6 @@ def test_bitwise_or_tensor_vgf_INT(test_data: input_t2):
658626
qtol=0,
659627
tosa_version="TOSA-1.0+INT",
660628
)
661-
pipeline.pop_stage("quantize")
662-
pipeline.pop_stage("check.quant_nodes")
663629
pipeline.run()
664630

665631

@@ -676,8 +642,6 @@ def test_bitwise_or_scalar_vgf_INT(test_data: input_t2):
676642
qtol=0,
677643
tosa_version="TOSA-1.0+INT",
678644
)
679-
pipeline.pop_stage("quantize")
680-
pipeline.pop_stage("check.quant_nodes")
681645
pipeline.run()
682646

683647

backends/arm/test/ops/test_bitwise_not.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ def test_bitwise_not_tosa_INT(test_data: Tuple):
6060
aten_op=aten_op,
6161
exir_op=exir_op,
6262
)
63-
pipeline.pop_stage("quantize")
64-
pipeline.pop_stage("check.quant_nodes")
6563
pipeline.run()
6664

6765

@@ -87,8 +85,6 @@ def test_bitwise_not_u85_INT(test_data: Tuple):
8785
aten_ops=aten_op,
8886
exir_ops=exir_op,
8987
)
90-
pipeline.pop_stage("quantize")
91-
pipeline.pop_stage("check.quant_nodes")
9288
pipeline.run()
9389

9490

@@ -115,6 +111,4 @@ def test_bitwise_not_vgf_INT(test_data: Tuple):
115111
exir_op,
116112
tosa_version="TOSA-1.0+INT",
117113
)
118-
pipeline.pop_stage("quantize")
119-
pipeline.pop_stage("check.quant_nodes")
120114
pipeline.run()

backends/arm/test/ops/test_eye.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def test_eye_tosa_INT(test_data: test_data_t):
6868
input_data(),
6969
EyeAdd.aten_op,
7070
)
71-
pipeline.pop_stage("check.quant_nodes")
71+
if pipeline.has_stage("check.quant_nodes"):
72+
pipeline.pop_stage("check.quant_nodes")
7273
pipeline.run()
7374

7475

@@ -82,7 +83,8 @@ def test_eye_u55_INT(test_data: test_data_t):
8283
EyeAdd.aten_op,
8384
use_to_edge_transform_and_lower=True,
8485
)
85-
pipeline.pop_stage("check.quant_nodes")
86+
if pipeline.has_stage("check.quant_nodes"):
87+
pipeline.pop_stage("check.quant_nodes")
8688
pipeline.run()
8789

8890

@@ -96,7 +98,8 @@ def test_eye_u85_INT(test_data: test_data_t):
9698
EyeAdd.aten_op,
9799
use_to_edge_transform_and_lower=True,
98100
)
99-
pipeline.pop_stage("check.quant_nodes")
101+
if pipeline.has_stage("check.quant_nodes"):
102+
pipeline.pop_stage("check.quant_nodes")
100103
pipeline.run()
101104

102105

@@ -132,7 +135,8 @@ def test_eye_vgf_INT(test_data: test_data_t):
132135
EyeAdd.aten_op,
133136
tosa_version="TOSA-1.0+INT",
134137
)
135-
pipeline.pop_stage("check.quant_nodes")
138+
if pipeline.has_stage("check.quant_nodes"):
139+
pipeline.pop_stage("check.quant_nodes")
136140
pipeline.run()
137141

138142

backends/arm/test/ops/test_full.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def test_full_like_tosa_INT(test_data: Tuple):
117117
aten_op=[],
118118
exir_op=exir_op,
119119
)
120-
pipeline.pop_stage("check.quant_nodes")
121120
pipeline.run()
122121

123122

0 commit comments

Comments
 (0)