From 4da8d3dea6f1b0151bd0372f445ec60afe6d6c03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Blomstr=C3=B6m?= Date: Sun, 5 May 2024 14:21:17 +0200 Subject: [PATCH] Add reduce sum onnx ops to burn imports --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 2 + .../onnx-tests/tests/onnx_tests.rs | 34 +++++++++ .../onnx-tests/tests/reduce_sum/reduce_sum.py | 46 ++++++++++++ .../tests/reduce_sum/reduce_sum_opset11.onnx | Bin 0 -> 384 bytes .../tests/reduce_sum/reduce_sum_opset13.onnx | Bin 0 -> 529 bytes crates/burn-import/src/burn/node/unary.rs | 69 ++++++++++++++++++ crates/burn-import/src/onnx/dim_inference.rs | 39 ++++++++++ crates/burn-import/src/onnx/from_onnx.rs | 3 +- .../burn-import/src/onnx/op_configuration.rs | 54 ++++++++++++++ crates/burn-import/src/onnx/to_burn.rs | 9 +++ 11 files changed, 256 insertions(+), 2 deletions(-) create mode 100755 crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py create mode 100644 crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset11.onnx create mode 100644 crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset13.onnx diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index c884bee82e..0b39bd6ce1 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -143,7 +143,7 @@ represent the corresponding Burn Op. | [ReduceMean][136] | ✅ | ✅ | | [ReduceMin][137] | ❌ | ✅ | | [ReduceProd][138] | ❌ | ✅ | -| [ReduceSum][139] | ❌ | ✅ | +| [ReduceSum][139] | ✅ | ✅ | | [ReduceSumSquare][140] | ❌ | ❌ | | [Relu][141] | ✅ | ✅ | | [Reshape][142] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 3a2c9b685d..89f99ffb2c 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -42,6 +42,8 @@ fn main() { .input("tests/prelu/prelu.onnx") .input("tests/reduce_max/reduce_max.onnx") .input("tests/reduce_mean/reduce_mean.onnx") + .input("tests/reduce_sum/reduce_sum_opset13.onnx") + .input("tests/reduce_sum/reduce_sum_opset11.onnx") .input("tests/reshape/reshape.onnx") .input("tests/shape/shape.onnx") .input("tests/sigmoid/sigmoid.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index c237542d41..de8eee7921 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -51,6 +51,8 @@ include_models!( recip, reduce_max, reduce_mean, + reduce_sum_opset13, + reduce_sum_opset11, relu, reshape, shape, @@ -545,6 +547,38 @@ mod tests { assert_eq!(output_value.to_data(), expected); } + #[test] + fn reduce_sum_opset11() { + let device = Default::default(); + let model: reduce_sum_opset11::Model = reduce_sum_opset11::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device); + let (output_scalar, output_tensor, output_value) = model.forward(input.clone()); + let expected_scalar = Data::from([39.]); + let expected = Data::from([[[[39.]]]]); + + assert_eq!(output_scalar.to_data(), expected_scalar); + assert_eq!(output_tensor.to_data(), input.to_data()); + assert_eq!(output_value.to_data(), expected); + } + + #[test] + fn reduce_sum_opset13() { + let device = Default::default(); + let model: reduce_sum_opset13::Model = reduce_sum_opset13::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device); + let (output_scalar, output_tensor, output_value) = model.forward(input.clone()); + let expected_scalar = Data::from([39.]); + let expected = Data::from([[[[39.]]]]); + + assert_eq!(output_scalar.to_data(), expected_scalar); + assert_eq!(output_tensor.to_data(), input.to_data()); + assert_eq!(output_value.to_data(), expected); + } + #[test] fn reshape() { // Initialize the model without weights (because the exported file does not contain them) diff --git a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py new file mode 100755 index 0000000000..7bb83e6b90 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/reduce_sum/reduce_sum.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + return ( + # ReduceSum, keepdims=0, axes=None + torch.sum(x), + # ReduceSum, keepdims=1, axes=[1] + torch.sum(x, dim=1, keepdim=True), + # ReduceSum, keepdims=1, axes=[-1] + torch.sum(x, dim=-1, keepdim=True), + ) + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device) + + torch.onnx.export(model, test_input, "reduce_sum_opset11.onnx", verbose=False, opset_version=11) + torch.onnx.export(model, test_input, "reduce_sum_opset13.onnx", verbose=False, opset_version=13) + + print("Finished exporting model") + + # Output some test data for use in the test + print(f"Test input data: {test_input}") + output = model.forward(*test_input) + print(f"Test output data: {output}") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset11.onnx b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset11.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cb9f5773b8dc2f6678fe005fe87a97e9f791a706 GIT binary patch literal 384 zcmd;J6Jjr@EXglQ&X8g?(lgXEw0h3OWyd9$pO;r*Wfhc~Qkt9^T$&qiAjD`W#ib7y zP~wC$wfMO>vQtwFQZjRkB^VYkGI9B0)o&!lgU}yuh-?Bk7fWJAYOw?30!DTevQtwFQZjRkB^VYkGI6;IVdyuK;&RTlD0$fZSj7XRziJObjP>79-1t`go VB+bQWj8i30(iDdzqZ5+=F91rzh?f8W literal 0 HcmV?d00001 diff --git a/crates/burn-import/src/burn/node/unary.rs b/crates/burn-import/src/burn/node/unary.rs index c46da67861..d4433d09b1 100644 --- a/crates/burn-import/src/burn/node/unary.rs +++ b/crates/burn-import/src/burn/node/unary.rs @@ -34,6 +34,7 @@ pub enum UnaryNodeKind { Not, ReduceMax, ReduceMean, + ReduceSum, Reciprocal, Relu, Shape, @@ -62,6 +63,7 @@ impl UnaryNodeKind { Self::Not => "not", Self::ReduceMax => "reduce_max", Self::ReduceMean => "reduce_mean", + Self::ReduceSum => "reduce_sum", Self::Reciprocal => "reciprocal", Self::Relu => "relu", Self::Shape => "shape", @@ -355,6 +357,36 @@ impl UnaryNode { } } + pub(crate) fn reduce_sum(input: Type, output: Type, dim: Option) -> Self { + if let Type::Tensor(ref tensor) = output { + if let Some(dim) = dim { + if tensor.kind == TensorKind::Bool { + // Sum is only implemented on numeric tensors + panic!("ReduceSum is not supported for boolean"); + } + + // ReduceSum, keepdims=1, axes=[dim] + let dim = dim.to_tokens(); + Self::new( + input, + output, + UnaryNodeKind::ReduceSum, + Rc::new(move |input| quote! { #input.sum_dim(#dim) }), + ) + } else { + // ReduceSum, keepdims=0, axes=None + Self::new( + input, + output, + UnaryNodeKind::ReduceSum, + Rc::new(move |input| quote! { #input.sum() }), + ) + } + } else { + panic!("ReduceSum only supports tensor output"); + } + } + pub(crate) fn shape(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self { // Shape as defined by the ONNX op should return a tensor because other ops // (e.g., Gather) will be used on a tensor @@ -634,6 +666,43 @@ mod tests { ); } + #[test] + fn test_unary_codegen_reduce_sum() { + one_node_graph( + UnaryNode::reduce_sum( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Some(1), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.sum_dim(1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + + one_node_graph( + UnaryNode::reduce_sum( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 1)), + None, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.sum(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + #[test] fn test_unary_codegen_reciprocal() { one_node_graph( diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index bc5f75a662..a93e6b6619 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -45,6 +45,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Reciprocal => same_as_input(node), NodeType::ReduceMax => reduce_max_update_outputs(node), NodeType::ReduceMean => reduce_mean_update_outputs(node), + NodeType::ReduceSum => reduce_sum_update_outputs(node), NodeType::Relu => same_as_input(node), NodeType::Reshape => reshape_update_outputs(node), NodeType::Shape => shape_update_outputs(node), @@ -461,6 +462,44 @@ fn reduce_max_update_outputs(node: &mut Node) { } } +/// Infers the shape of a ReduceSum node and replaces the shape of the output tensor. +fn reduce_sum_update_outputs(node: &mut Node) { + let node_input = &mut node.inputs[0]; + let tensor = match node_input.clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let dim_only = match node.inputs.get(1).and_then(|arg| arg.value.as_ref()) { + Some(value) => match &value { + Data::Int64(_) => true, + Data::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => dim_only, + }; + + if dim_only { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + // NOTE: ReduceSum w/o keepdims reduces to a scalar value, but Burn doesn't have + // 0-dim tensor so we can't track or perform other ops on that value if we call + // `.into_scalar()` on the result of `tensor.sum()` + // node.outputs[0].ty = ArgType::Scalar(tensor.elem_type); + // Instead, we return a tensor of rank 1 (the result of `tensor.sum()`) + node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor }); + } +} + fn where_update_outputs(node: &mut Node) { match ( node.inputs[0].ty.clone(), diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index 0a38722aec..32b4e2b634 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -17,7 +17,7 @@ use super::ir::{ArgType, Argument, Node, NodeType}; use protobuf::Message; -const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ +const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 8] = [ NodeType::BatchNormalization, NodeType::Clip, NodeType::Conv1d, @@ -25,6 +25,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ NodeType::Dropout, NodeType::Reshape, NodeType::Unsqueeze, + NodeType::ReduceSum, ]; #[derive(Debug)] diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 1416e55230..ebfa2cfd00 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -798,6 +798,60 @@ pub fn reduce_mean_config(node: &Node) -> Option { } } +pub fn reduce_sum_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "keepdims" => keepdims = value.clone().into_i64(), + "axes" => axes = value.clone().into_i64s(), + // TODO: handle noop_with_empty_axes + _ => {} + } + } + + // TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode. + if let Some(value) = node + .inputs + .get(1) + .and_then(|argument| argument.value.as_ref()) + { + axes = value.clone().into_i64s(); + } + + if axes.len() > 1 { + panic!("ReduceMean: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMean: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMean: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.dim as i64; + } + Some(dim as usize) + } +} + pub fn shape_config(curr: &Node) -> (usize, usize) { if curr.inputs.len() != 1 { panic!( diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 44b328492a..dfec530448 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -265,6 +265,7 @@ impl OnnxGraph { NodeType::Constant => graph.register(Self::constant_conversion::(node)), NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)), NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)), + NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)), NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), NodeType::Shape => graph.register(Self::shape_conversion(node)), @@ -501,6 +502,14 @@ impl OnnxGraph { UnaryNode::reduce_mean(input, output, dim) } + fn reduce_sum_conversion(node: Node) -> UnaryNode { + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + let dim = reduce_sum_config(&node); + + UnaryNode::reduce_sum(input, output, dim) + } + fn shape_conversion(node: Node) -> UnaryNode { let input = node.inputs.first().unwrap().to_type(); let output = node.outputs.first().unwrap().to_type();