Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions backends/cortex_m/ops/cortex_m_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,43 @@ inline void validate_quantization_params(
"Single quant Output");
}

inline bool is_channels_last_tensor(const Tensor& tensor) {
if (tensor.dim() != 4) {
return false;
}

// When channels or spatial dims are 1 the layout information is ambiguous.
if (tensor.size(1) == 1 || (tensor.size(2) == 1 && tensor.size(3) == 1)) {
return true;
}

constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = {
0, 2, 3, 1};
executorch::aten::ArrayRef<executorch::aten::DimOrderType>
channels_last_order(kChannelsLastDimOrder, 4);

return tensor.dim_order() == channels_last_order;
}

inline bool is_channel_broadcast(const Tensor& tensor1, const Tensor& tensor2) {
if (tensor1.dim() != tensor2.dim()) {
return false;
}

if (tensor1.dim() != 4) {
return false;
}

if (tensor1.size(1) != tensor2.size(1)) {
return false;
}

const bool tensor1_channels_only = tensor1.numel() == tensor1.size(1);
const bool tensor2_channels_only = tensor2.numel() == tensor2.size(1);

return tensor1_channels_only || tensor2_channels_only;
}

// Refer to CMSIS-NN 'arm_nn_requantize' implementation for details:
// https://github.com/ARM-software/CMSIS-NN/blob/main/Include/arm_nnsupportfunctions.h#L1625
// multiplier: Range {ARM_NN_Q31_MIN + 1, Q32_MAX}
Expand Down
81 changes: 53 additions & 28 deletions backends/cortex_m/ops/op_quantized_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ Tensor& quantized_add_out(
const Scalar& output_shift,
Tensor& out) {
// Validate tensor types and dim order
validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out);
bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8);
validate_cmsis_nn_tensor_requirements(
input1_int8,
input2_int8,
out,
ScalarType::Char,
/*require_channels_last=*/channel_broadcast,
/*require_same_sizes=*/!channel_broadcast);

// Validate quantization parameters
validate_quantization_params(
Expand Down Expand Up @@ -62,6 +69,8 @@ Tensor& quantized_add_out(
int32_t out_zp = extractScalarToInt32(output_zero_point);
int32_t output_mult = extractScalarToInt32(output_multiplier);
int output_shift_val = extractScalarToInt(output_shift);
int8_t* input1_ptr = input1_int8.data_ptr<int8_t>();
int8_t* input2_ptr = input2_int8.data_ptr<int8_t>();

// Left shift to maximize precision
const int32_t left_shift = 20;
Expand All @@ -87,33 +96,49 @@ Tensor& quantized_add_out(
// addition. To preserve precision when rescaling the inputs, they are first
// upscaled as much as possible, Hence the left_shift parameter required here.

// Call CMSIS-NN kernel with precomputed parameters
arm_cmsis_nn_status status = arm_elementwise_add_s8(
input1_int8.const_data_ptr<int8_t>(),
input2_int8.const_data_ptr<int8_t>(),
-static_cast<int32_t>(zp1),
input1_mult,
input1_shift_val,
-static_cast<int32_t>(zp2),
input2_mult,
input2_shift_val,
left_shift,
out.mutable_data_ptr<int8_t>(),
static_cast<int32_t>(out_zp),
output_mult,
output_shift_val,
activation_min,
activation_max,
static_cast<int32_t>(out.numel()));

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_add_out: arm_elementwise_add_s8 failed with status [%d]",
status);

context.fail(Error::Internal); // Fail the execution context
return out;
int32_t adds_per_loop = 0;
if (channel_broadcast) {
if (input1_int8.numel() < input2_int8.numel()) {
std::swap<int32_t>(zp1, zp2);
std::swap<int32_t>(input1_mult, input2_mult);
std::swap<int>(input1_shift_val, input2_shift_val);
std::swap<int8_t*>(input1_ptr, input2_ptr);
}
adds_per_loop = input1_int8.size(1);
} else {
adds_per_loop = out.numel();
}

for (int32_t broadcast_offset = 0; broadcast_offset < out.numel();
broadcast_offset += adds_per_loop) {
// Call CMSIS-NN kernel with precomputed parameters
arm_cmsis_nn_status status = arm_elementwise_add_s8(
input1_ptr + broadcast_offset,
input2_ptr,
-static_cast<int32_t>(zp1),
input1_mult,
input1_shift_val,
-static_cast<int32_t>(zp2),
input2_mult,
input2_shift_val,
left_shift,
out.mutable_data_ptr<int8_t>() + broadcast_offset,
static_cast<int32_t>(out_zp),
output_mult,
output_shift_val,
activation_min,
activation_max,
adds_per_loop);

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_add_out: arm_elementwise_add_s8 failed with status [%d]",
status);

context.fail(Error::Internal); // Fail the execution context
return out;
}
}
ET_LOG(
Info,
Expand Down
74 changes: 49 additions & 25 deletions backends/cortex_m/ops/op_quantized_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ Tensor& quantized_mul_out(
const Scalar& output_shift,
Tensor& out) {
// Validate tensor types and quantization parameters
validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out);

bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8);
validate_cmsis_nn_tensor_requirements(
input1_int8,
input2_int8,
out,
ScalarType::Char,
/*require_channels_last=*/channel_broadcast,
/*require_same_sizes=*/!channel_broadcast);

const Scalar kIdentityMultiplier(/*value=*/1);
const Scalar kZeroShift(/*value=*/0);
Expand All @@ -51,12 +59,26 @@ Tensor& quantized_mul_out(
out);

// Extract quantization parameters
const int32_t zp1 = extractScalarToInt32(input1_zero_point);
const int32_t zp2 = extractScalarToInt32(input2_zero_point);
int8_t* input1_ptr = input1_int8.data_ptr<int8_t>();
int8_t* input2_ptr = input2_int8.data_ptr<int8_t>();
int32_t zp1 = extractScalarToInt32(input1_zero_point);
int32_t zp2 = extractScalarToInt32(input2_zero_point);
const int32_t out_zp = extractScalarToInt32(output_zero_point);
const int32_t output_mult = extractScalarToInt32(output_multiplier);
const int32_t output_shift_val = extractScalarToInt32(output_shift);

int32_t muls_per_loop = 0;

if (channel_broadcast) {
if (input1_int8.numel() < input2_int8.numel()) {
std::swap<int32_t>(zp1, zp2);
std::swap<int8_t*>(input1_ptr, input2_ptr);
}

muls_per_loop = input1_int8.size(1);
} else {
muls_per_loop = out.numel();
}
// Note 1: The CMSIS-NN kernel implementation uses offsets which are always
// added to the data, whereas zero_points are subtracted when dequantizing
// (for the inputs) and added when quantizing (for the output). Hence the
Expand All @@ -72,29 +94,31 @@ Tensor& quantized_mul_out(
// effective_scale = (scale_in1 * scale_in2 / scale_out)
// Hence no input quantization params required here.

// Call CMSIS-NN elementwise multiply kernel
arm_cmsis_nn_status status = arm_elementwise_mul_s8(
input1_int8.const_data_ptr<int8_t>(),
input2_int8.const_data_ptr<int8_t>(),
-static_cast<int32_t>(zp1),
-static_cast<int32_t>(zp2),
out.mutable_data_ptr<int8_t>(),
static_cast<int32_t>(out_zp),
output_mult,
output_shift_val,
kInt8ActivationMin,
kInt8ActivationMax,
static_cast<int32_t>(out.numel()));

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]",
status);
context.fail(Error::Internal);
return out;
}
for (int32_t broadcast_offset = 0; broadcast_offset < out.numel();
broadcast_offset += muls_per_loop) {
// Call CMSIS-NN elementwise multiply kernel
arm_cmsis_nn_status status = arm_elementwise_mul_s8(
input1_ptr + broadcast_offset,
input2_ptr,
-static_cast<int32_t>(zp1),
-static_cast<int32_t>(zp2),
out.mutable_data_ptr<int8_t>() + broadcast_offset,
static_cast<int32_t>(out_zp),
output_mult,
output_shift_val,
kInt8ActivationMin,
kInt8ActivationMax,
muls_per_loop);

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]",
status);
context.fail(Error::Internal);
return out;
}
}
return out;
}

Expand Down
31 changes: 19 additions & 12 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn.functional as F
from executorch.backends.cortex_m.passes.passes_utils import (
is_channel_broadcast,
requantize_cmsis,
SHIFT_INT8,
)
Expand Down Expand Up @@ -140,12 +141,15 @@ def quantized_add_meta(
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
assert self.shape == other.shape or is_channel_broadcast(self, other), (
"Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
if self.numel() > other.numel():
output_tensor = self
else:
output_tensor = other
return torch.empty_like(output_tensor)


@impl(lib, "quantized_add", "CompositeExplicitAutograd")
Expand All @@ -162,8 +166,8 @@ def quantized_add_impl(
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
assert self.shape == other.shape or is_channel_broadcast(self, other), (
"Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8
Expand Down Expand Up @@ -207,12 +211,15 @@ def quantized_mul_meta(
output_shift: int,
) -> torch.Tensor:
# Broadcast to output shape
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
assert self.shape == other.shape or is_channel_broadcast(self, other), (
"Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
if self.numel() > other.numel():
output_tensor = self
else:
output_tensor = other
return torch.empty_like(output_tensor)


@impl(lib, "quantized_mul", "CompositeExplicitAutograd")
Expand All @@ -228,8 +235,8 @@ def quantized_mul_impl(
# CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and
# only uses the output multiplier/shift for rescaling. Mirror that here to
# keep the composite implementation numerically aligned with the backend.
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
assert self.shape == other.shape or is_channel_broadcast(self, other), (
"Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
self_int = self.to(torch.int32) - self_zero_point
Expand Down
31 changes: 31 additions & 0 deletions backends/cortex_m/passes/passes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,34 @@ def cleanup_nodes(nodes_to_erase, graph):
print(f"Warning: {len(failed_nodes)} nodes could not be erased")

return failed_nodes


def is_channels_last(tensor: torch.Tensor) -> bool:
"""Check if a 4D tensor is in channels last format."""
if tensor.ndim != 4:
return False

if tensor.shape[1] == 1 or tensor.shape[2] == tensor.shape[3] == 1:
return True

dim_order = list(tensor.dim_order())
return dim_order[0:2] == [0, 2]


def is_channel_broadcast(tensor1: torch.Tensor, tensor2: torch.Tensor) -> bool:
"""
Check if tensor1 is broadcasted to tensor2 along channel dimension.
Assumes tensor2 has shape [N, C, ...] and tensor1 has shape [N, 1, ...] or [1, C, ...].
"""
if tensor1.dim() != tensor2.dim():
return False
if not is_channels_last(tensor1):
return False
if not is_channels_last(tensor2):
return False

channel_match = tensor1.size(1) == tensor2.size(1)
tensor1_channels_only = tensor1.numel() == tensor1.size(1)
tensor2_channels_only = tensor2.numel() == tensor2.size(1)

return channel_match and (tensor1_channels_only or tensor2_channels_only)
10 changes: 8 additions & 2 deletions backends/cortex_m/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
from executorch.backends.cortex_m.passes.passes_utils import (
is_channel_broadcast,
is_channels_last,
)
from executorch.backends.cortex_m.quantizer.operator_configs import (
BINARY_OP_PATTERNS,
CONV_OP_PATTERNS,
Expand Down Expand Up @@ -61,7 +65,9 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool:
if len(node.all_input_nodes) == 2:
t1 = get_first_fake_tensor(node.all_input_nodes[0])
t2 = get_first_fake_tensor(node.all_input_nodes[1])
return t1.shape != t2.shape
return t1.shape != t2.shape and not (
is_channel_broadcast(t1, t2) and is_channels_last(t1)
)

return False

Expand All @@ -78,7 +84,7 @@ def nchw_filter(self, node: Optional[Node]) -> bool:
if tensor is None:
return False

return not tensor.is_contiguous(memory_format=torch.channels_last)
return not is_channels_last(tensor)

def __init__(self) -> None:
quantizers: List[Quantizer] = [
Expand Down
Loading
Loading