diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index eea9792efcf..71cf718c9a8 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -129,6 +129,43 @@ inline void validate_quantization_params( "Single quant Output"); } +inline bool is_channels_last_tensor(const Tensor& tensor) { + if (tensor.dim() != 4) { + return false; + } + + // When channels or spatial dims are 1 the layout information is ambiguous. + if (tensor.size(1) == 1 || (tensor.size(2) == 1 && tensor.size(3) == 1)) { + return true; + } + + constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { + 0, 2, 3, 1}; + executorch::aten::ArrayRef + channels_last_order(kChannelsLastDimOrder, 4); + + return tensor.dim_order() == channels_last_order; +} + +inline bool is_channel_broadcast(const Tensor& tensor1, const Tensor& tensor2) { + if (tensor1.dim() != tensor2.dim()) { + return false; + } + + if (tensor1.dim() != 4) { + return false; + } + + if (tensor1.size(1) != tensor2.size(1)) { + return false; + } + + const bool tensor1_channels_only = tensor1.numel() == tensor1.size(1); + const bool tensor2_channels_only = tensor2.numel() == tensor2.size(1); + + return tensor1_channels_only || tensor2_channels_only; +} + // Refer to CMSIS-NN 'arm_nn_requantize' implementation for details: // https://github.com/ARM-software/CMSIS-NN/blob/main/Include/arm_nnsupportfunctions.h#L1625 // multiplier: Range {ARM_NN_Q31_MIN + 1, Q32_MAX} diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index ddc4b4bb869..019ab4cfb58 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -33,7 +33,14 @@ Tensor& quantized_add_out( const Scalar& output_shift, Tensor& out) { // Validate tensor types and dim order - validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out); + bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8); + validate_cmsis_nn_tensor_requirements( + input1_int8, + input2_int8, + out, + ScalarType::Char, + /*require_channels_last=*/channel_broadcast, + /*require_same_sizes=*/!channel_broadcast); // Validate quantization parameters validate_quantization_params( @@ -62,6 +69,8 @@ Tensor& quantized_add_out( int32_t out_zp = extractScalarToInt32(output_zero_point); int32_t output_mult = extractScalarToInt32(output_multiplier); int output_shift_val = extractScalarToInt(output_shift); + int8_t* input1_ptr = input1_int8.data_ptr(); + int8_t* input2_ptr = input2_int8.data_ptr(); // Left shift to maximize precision const int32_t left_shift = 20; @@ -87,33 +96,49 @@ Tensor& quantized_add_out( // addition. To preserve precision when rescaling the inputs, they are first // upscaled as much as possible, Hence the left_shift parameter required here. - // Call CMSIS-NN kernel with precomputed parameters - arm_cmsis_nn_status status = arm_elementwise_add_s8( - input1_int8.const_data_ptr(), - input2_int8.const_data_ptr(), - -static_cast(zp1), - input1_mult, - input1_shift_val, - -static_cast(zp2), - input2_mult, - input2_shift_val, - left_shift, - out.mutable_data_ptr(), - static_cast(out_zp), - output_mult, - output_shift_val, - activation_min, - activation_max, - static_cast(out.numel())); - - if (status != ARM_CMSIS_NN_SUCCESS) { - ET_LOG( - Error, - "quantized_add_out: arm_elementwise_add_s8 failed with status [%d]", - status); - - context.fail(Error::Internal); // Fail the execution context - return out; + int32_t adds_per_loop = 0; + if (channel_broadcast) { + if (input1_int8.numel() < input2_int8.numel()) { + std::swap(zp1, zp2); + std::swap(input1_mult, input2_mult); + std::swap(input1_shift_val, input2_shift_val); + std::swap(input1_ptr, input2_ptr); + } + adds_per_loop = input1_int8.size(1); + } else { + adds_per_loop = out.numel(); + } + + for (int32_t broadcast_offset = 0; broadcast_offset < out.numel(); + broadcast_offset += adds_per_loop) { + // Call CMSIS-NN kernel with precomputed parameters + arm_cmsis_nn_status status = arm_elementwise_add_s8( + input1_ptr + broadcast_offset, + input2_ptr, + -static_cast(zp1), + input1_mult, + input1_shift_val, + -static_cast(zp2), + input2_mult, + input2_shift_val, + left_shift, + out.mutable_data_ptr() + broadcast_offset, + static_cast(out_zp), + output_mult, + output_shift_val, + activation_min, + activation_max, + adds_per_loop); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_add_out: arm_elementwise_add_s8 failed with status [%d]", + status); + + context.fail(Error::Internal); // Fail the execution context + return out; + } } ET_LOG( Info, diff --git a/backends/cortex_m/ops/op_quantized_mul.cpp b/backends/cortex_m/ops/op_quantized_mul.cpp index 28af8406f87..3d2d7657e36 100644 --- a/backends/cortex_m/ops/op_quantized_mul.cpp +++ b/backends/cortex_m/ops/op_quantized_mul.cpp @@ -34,7 +34,15 @@ Tensor& quantized_mul_out( const Scalar& output_shift, Tensor& out) { // Validate tensor types and quantization parameters - validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out); + + bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8); + validate_cmsis_nn_tensor_requirements( + input1_int8, + input2_int8, + out, + ScalarType::Char, + /*require_channels_last=*/channel_broadcast, + /*require_same_sizes=*/!channel_broadcast); const Scalar kIdentityMultiplier(/*value=*/1); const Scalar kZeroShift(/*value=*/0); @@ -51,12 +59,26 @@ Tensor& quantized_mul_out( out); // Extract quantization parameters - const int32_t zp1 = extractScalarToInt32(input1_zero_point); - const int32_t zp2 = extractScalarToInt32(input2_zero_point); + int8_t* input1_ptr = input1_int8.data_ptr(); + int8_t* input2_ptr = input2_int8.data_ptr(); + int32_t zp1 = extractScalarToInt32(input1_zero_point); + int32_t zp2 = extractScalarToInt32(input2_zero_point); const int32_t out_zp = extractScalarToInt32(output_zero_point); const int32_t output_mult = extractScalarToInt32(output_multiplier); const int32_t output_shift_val = extractScalarToInt32(output_shift); + int32_t muls_per_loop = 0; + + if (channel_broadcast) { + if (input1_int8.numel() < input2_int8.numel()) { + std::swap(zp1, zp2); + std::swap(input1_ptr, input2_ptr); + } + + muls_per_loop = input1_int8.size(1); + } else { + muls_per_loop = out.numel(); + } // Note 1: The CMSIS-NN kernel implementation uses offsets which are always // added to the data, whereas zero_points are subtracted when dequantizing // (for the inputs) and added when quantizing (for the output). Hence the @@ -72,29 +94,31 @@ Tensor& quantized_mul_out( // effective_scale = (scale_in1 * scale_in2 / scale_out) // Hence no input quantization params required here. - // Call CMSIS-NN elementwise multiply kernel - arm_cmsis_nn_status status = arm_elementwise_mul_s8( - input1_int8.const_data_ptr(), - input2_int8.const_data_ptr(), - -static_cast(zp1), - -static_cast(zp2), - out.mutable_data_ptr(), - static_cast(out_zp), - output_mult, - output_shift_val, - kInt8ActivationMin, - kInt8ActivationMax, - static_cast(out.numel())); - - if (status != ARM_CMSIS_NN_SUCCESS) { - ET_LOG( - Error, - "quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]", - status); - context.fail(Error::Internal); - return out; - } + for (int32_t broadcast_offset = 0; broadcast_offset < out.numel(); + broadcast_offset += muls_per_loop) { + // Call CMSIS-NN elementwise multiply kernel + arm_cmsis_nn_status status = arm_elementwise_mul_s8( + input1_ptr + broadcast_offset, + input2_ptr, + -static_cast(zp1), + -static_cast(zp2), + out.mutable_data_ptr() + broadcast_offset, + static_cast(out_zp), + output_mult, + output_shift_val, + kInt8ActivationMin, + kInt8ActivationMax, + muls_per_loop); + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]", + status); + context.fail(Error::Internal); + return out; + } + } return out; } diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index fe175ca9783..291615f613a 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -11,6 +11,7 @@ import torch import torch.nn.functional as F from executorch.backends.cortex_m.passes.passes_utils import ( + is_channel_broadcast, requantize_cmsis, SHIFT_INT8, ) @@ -140,12 +141,15 @@ def quantized_add_meta( output_multiplier: int, output_shift: int, ) -> torch.Tensor: - assert self.shape == other.shape, ( - "Cortex-M quantized_mul: broadcasting is not yet supported — " + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — " f"got self.shape={self.shape}, other.shape={other.shape}" ) - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) + if self.numel() > other.numel(): + output_tensor = self + else: + output_tensor = other + return torch.empty_like(output_tensor) @impl(lib, "quantized_add", "CompositeExplicitAutograd") @@ -162,8 +166,8 @@ def quantized_add_impl( output_multiplier: int, output_shift: int, ) -> torch.Tensor: - assert self.shape == other.shape, ( - "Cortex-M quantized_mul: broadcasting is not yet supported — " + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — " f"got self.shape={self.shape}, other.shape={other.shape}" ) self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8 @@ -207,12 +211,15 @@ def quantized_mul_meta( output_shift: int, ) -> torch.Tensor: # Broadcast to output shape - assert self.shape == other.shape, ( - "Cortex-M quantized_mul: broadcasting is not yet supported — " + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — " f"got self.shape={self.shape}, other.shape={other.shape}" ) - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) + if self.numel() > other.numel(): + output_tensor = self + else: + output_tensor = other + return torch.empty_like(output_tensor) @impl(lib, "quantized_mul", "CompositeExplicitAutograd") @@ -228,8 +235,8 @@ def quantized_mul_impl( # CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and # only uses the output multiplier/shift for rescaling. Mirror that here to # keep the composite implementation numerically aligned with the backend. - assert self.shape == other.shape, ( - "Cortex-M quantized_mul: broadcasting is not yet supported — " + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — " f"got self.shape={self.shape}, other.shape={other.shape}" ) self_int = self.to(torch.int32) - self_zero_point diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index e3ee596d7ea..131541fcb75 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -193,3 +193,34 @@ def cleanup_nodes(nodes_to_erase, graph): print(f"Warning: {len(failed_nodes)} nodes could not be erased") return failed_nodes + + +def is_channels_last(tensor: torch.Tensor) -> bool: + """Check if a 4D tensor is in channels last format.""" + if tensor.ndim != 4: + return False + + if tensor.shape[1] == 1 or tensor.shape[2] == tensor.shape[3] == 1: + return True + + dim_order = list(tensor.dim_order()) + return dim_order[0:2] == [0, 2] + + +def is_channel_broadcast(tensor1: torch.Tensor, tensor2: torch.Tensor) -> bool: + """ + Check if tensor1 is broadcasted to tensor2 along channel dimension. + Assumes tensor2 has shape [N, C, ...] and tensor1 has shape [N, 1, ...] or [1, C, ...]. + """ + if tensor1.dim() != tensor2.dim(): + return False + if not is_channels_last(tensor1): + return False + if not is_channels_last(tensor2): + return False + + channel_match = tensor1.size(1) == tensor2.size(1) + tensor1_channels_only = tensor1.numel() == tensor1.size(1) + tensor2_channels_only = tensor2.numel() == tensor2.size(1) + + return channel_match and (tensor1_channels_only or tensor2_channels_only) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index b86372fa360..185a39b9eae 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -11,6 +11,10 @@ from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager +from executorch.backends.cortex_m.passes.passes_utils import ( + is_channel_broadcast, + is_channels_last, +) from executorch.backends.cortex_m.quantizer.operator_configs import ( BINARY_OP_PATTERNS, CONV_OP_PATTERNS, @@ -61,7 +65,9 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool: if len(node.all_input_nodes) == 2: t1 = get_first_fake_tensor(node.all_input_nodes[0]) t2 = get_first_fake_tensor(node.all_input_nodes[1]) - return t1.shape != t2.shape + return t1.shape != t2.shape and not ( + is_channel_broadcast(t1, t2) and is_channels_last(t1) + ) return False @@ -78,7 +84,7 @@ def nchw_filter(self, node: Optional[Node]) -> bool: if tensor is None: return False - return not tensor.is_contiguous(memory_format=torch.channels_last) + return not is_channels_last(tensor) def __init__(self) -> None: quantizers: List[Quantizer] = [ diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 15918394ada..ad5f276b544 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -121,6 +121,27 @@ class CortexMAlphaAdd(ModelAlpha): ramp_tensor(-5, 5, (1, 2, 1, 2)), ), ), + "broadcast_channels_1": McuTestCase( + CortexMTensorAdd(), + ( + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ramp_tensor(-5, 5, (1, 8, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_2": McuTestCase( + CortexMTensorAdd(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)).to(memory_format=torch.channels_last), + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_continous": McuTestCase( + CortexMTensorAdd(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)), + ramp_tensor(-2, 2, (1, 8, 1, 1)), + ), + ), "alpha": McuTestCase( CortexMAlphaAdd(0.5), ( @@ -143,6 +164,7 @@ class CortexMAlphaAdd(ModelAlpha): "broadcast_1": "Broadcasting is not supported in Cortex-M backend", "broadcast_2": "Broadcasting is not supported in Cortex-M backend", "broadcast_3": "Broadcasting is not supported in Cortex-M backend", + "broadcast_channels_continous": "Broadcasting channels is not supported in continous memory_format in Cortex-M backend.", } diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py index 7a1419e83d6..88dd904eb6e 100644 --- a/backends/cortex_m/test/ops/test_mul.py +++ b/backends/cortex_m/test/ops/test_mul.py @@ -103,6 +103,27 @@ class CortexMTensorMul(Model): ramp_tensor(-5, 5, (1, 2, 1, 2)), ), ), + "broadcast_channels_1": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ramp_tensor(-5, 5, (1, 8, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_2": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)).to(memory_format=torch.channels_last), + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_continous": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)), + ramp_tensor(-2, 2, (1, 8, 1, 1)), + ), + ), } @@ -112,6 +133,7 @@ class CortexMTensorMul(Model): "broadcast_1": "Broadcasting is not supported in Cortex-M backend", "broadcast_2": "Broadcasting is not supported in Cortex-M backend", "broadcast_3": "Broadcasting is not supported in Cortex-M backend", + "broadcast_channels_continous": "Broadcasting channels is not supported in continous memory_format in Cortex-M backend.", }