From 4388843e3cb75d9404d66df86b7ae12ada05aa0c Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Fri, 11 Apr 2025 15:46:37 +0200 Subject: [PATCH] Arm backend: Convert asserts to raise errors in op_avg_pool2d Asserts are converted to proper raises to ensure graph integrity. Signed-off-by: Sebastian Larsson Change-Id: I3cbb9a9d8b7aeae4d374e04087efa86e63054ae7 --- backends/arm/operators/op_avg_pool2d.py | 34 +++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 73a6713633..727fd52dfd 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -85,8 +85,12 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - input_tensor = inputs[0] - assert input_tensor.dtype == ts.DType.INT8 + supported_dtypes = [ts.DType.INT8] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) accumulator_type = ts.DType.INT32 @@ -118,9 +122,12 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - assert ( - inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 - ), "Only FP32 and INT8 supported" + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) if inputs[0].dtype == ts.DType.INT8: super().define_node(node, tosa_graph, inputs, output) @@ -205,8 +212,12 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - input_tensor = inputs[0] - assert input_tensor.dtype == ts.DType.INT8 + supported_dtypes = [ts.DType.INT8] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) accumulator_type = ts.DType.INT32 @@ -241,9 +252,12 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - assert ( - inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 - ), "Only FP32 and INT8 supported" + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) if inputs[0].dtype == ts.DType.INT8: super().define_node(node, tosa_graph, inputs, output)