Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce sum onnx ops to burn imports #1723

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ represent the corresponding Burn Op.
| [ReduceMean][136] | ✅ | ✅ |
| [ReduceMin][137] | ❌ | ✅ |
| [ReduceProd][138] | ❌ | ✅ |
| [ReduceSum][139] | | ✅ |
| [ReduceSum][139] | | ✅ |
| [ReduceSumSquare][140] | ❌ | ❌ |
| [Relu][141] | ✅ | ✅ |
| [Reshape][142] | ✅ | ✅ |
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ fn main() {
.input("tests/prelu/prelu.onnx")
.input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/shape/shape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
Expand Down
34 changes: 34 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ include_models!(
recip,
reduce_max,
reduce_mean,
reduce_sum_opset13,
reduce_sum_opset11,
relu,
reshape,
shape,
Expand Down Expand Up @@ -545,6 +547,38 @@ mod tests {
assert_eq!(output_value.to_data(), expected);
}

#[test]
fn reduce_sum_opset11() {
let device = Default::default();
let model: reduce_sum_opset11::Model<Backend> = reduce_sum_opset11::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
let expected_scalar = Data::from([39.]);
let expected = Data::from([[[[39.]]]]);

assert_eq!(output_scalar.to_data(), expected_scalar);
assert_eq!(output_tensor.to_data(), input.to_data());
assert_eq!(output_value.to_data(), expected);
}

#[test]
fn reduce_sum_opset13() {
let device = Default::default();
let model: reduce_sum_opset13::Model<Backend> = reduce_sum_opset13::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
let expected_scalar = Data::from([39.]);
let expected = Data::from([[[[39.]]]]);

assert_eq!(output_scalar.to_data(), expected_scalar);
assert_eq!(output_tensor.to_data(), input.to_data());
assert_eq!(output_value.to_data(), expected);
}

#[test]
fn reshape() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
46 changes: 46 additions & 0 deletions crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/reduce_sum/reduce_sum.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return (
# ReduceSum, keepdims=0, axes=None
torch.sum(x),
# ReduceSum, keepdims=1, axes=[1]
torch.sum(x, dim=1, keepdim=True),
# ReduceSum, keepdims=1, axes=[-1]
torch.sum(x, dim=-1, keepdim=True),
)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)

torch.onnx.export(model, test_input, "reduce_sum_opset11.onnx", verbose=False, opset_version=11)
torch.onnx.export(model, test_input, "reduce_sum_opset13.onnx", verbose=False, opset_version=13)

print("Finished exporting model")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
Binary file not shown.
Binary file not shown.
69 changes: 69 additions & 0 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub enum UnaryNodeKind {
Not,
ReduceMax,
ReduceMean,
ReduceSum,
Reciprocal,
Relu,
Shape,
Expand Down Expand Up @@ -62,6 +63,7 @@ impl UnaryNodeKind {
Self::Not => "not",
Self::ReduceMax => "reduce_max",
Self::ReduceMean => "reduce_mean",
Self::ReduceSum => "reduce_sum",
Self::Reciprocal => "reciprocal",
Self::Relu => "relu",
Self::Shape => "shape",
Expand Down Expand Up @@ -355,6 +357,36 @@ impl UnaryNode {
}
}

pub(crate) fn reduce_sum(input: Type, output: Type, dim: Option<usize>) -> Self {
if let Type::Tensor(ref tensor) = output {
if let Some(dim) = dim {
if tensor.kind == TensorKind::Bool {
// Sum is only implemented on numeric tensors
panic!("ReduceSum is not supported for boolean");
}

// ReduceSum, keepdims=1, axes=[dim]
let dim = dim.to_tokens();
Self::new(
input,
output,
UnaryNodeKind::ReduceSum,
Rc::new(move |input| quote! { #input.sum_dim(#dim) }),
)
} else {
// ReduceSum, keepdims=0, axes=None
Self::new(
input,
output,
UnaryNodeKind::ReduceSum,
Rc::new(move |input| quote! { #input.sum() }),
)
}
} else {
panic!("ReduceSum only supports tensor output");
}
}

pub(crate) fn shape(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self {
// Shape as defined by the ONNX op should return a tensor because other ops
// (e.g., Gather) will be used on a tensor
Expand Down Expand Up @@ -634,6 +666,43 @@ mod tests {
);
}

#[test]
fn test_unary_codegen_reduce_sum() {
one_node_graph(
UnaryNode::reduce_sum(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Some(1),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.sum_dim(1);

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);

one_node_graph(
UnaryNode::reduce_sum(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 1)),
None,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
let tensor2 = tensor1.sum();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_reciprocal() {
one_node_graph(
Expand Down
39 changes: 39 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMax => reduce_max_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
NodeType::ReduceSum => reduce_sum_update_outputs(node),
NodeType::Relu => same_as_input(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Shape => shape_update_outputs(node),
Expand Down Expand Up @@ -461,6 +462,44 @@ fn reduce_max_update_outputs(node: &mut Node) {
}
}

/// Infers the shape of a ReduceSum node and replaces the shape of the output tensor.
fn reduce_sum_update_outputs(node: &mut Node) {
let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

let dim_only = match node.attrs.get("axes") {
Some(value) => match &value {
AttributeValue::Int64(_) => true,
AttributeValue::Int64s(ints) => ints.len() == 1,
_ => false,
},
None => false,
};

let dim_only = match node.inputs.get(1).and_then(|arg| arg.value.as_ref()) {
Some(value) => match &value {
Data::Int64(_) => true,
Data::Int64s(ints) => ints.len() == 1,
_ => false,
},
None => dim_only,
};

if dim_only {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
// NOTE: ReduceSum w/o keepdims reduces to a scalar value, but Burn doesn't have
// 0-dim tensor so we can't track or perform other ops on that value if we call
// `.into_scalar()` on the result of `tensor.sum()`
// node.outputs[0].ty = ArgType::Scalar(tensor.elem_type);
// Instead, we return a tensor of rank 1 (the result of `tensor.sum()`)
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}

fn where_update_outputs(node: &mut Node) {
match (
node.inputs[0].ty.clone(),
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ use super::ir::{ArgType, Argument, Node, NodeType};

use protobuf::Message;

const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 8] = [
NodeType::BatchNormalization,
NodeType::Clip,
NodeType::Conv1d,
NodeType::Conv2d,
NodeType::Dropout,
NodeType::Reshape,
NodeType::Unsqueeze,
NodeType::ReduceSum,
];

#[derive(Debug)]
Expand Down
54 changes: 54 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,60 @@ pub fn reduce_mean_config(node: &Node) -> Option<usize> {
}
}

pub fn reduce_sum_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;

let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"keepdims" => keepdims = value.clone().into_i64(),
"axes" => axes = value.clone().into_i64s(),
// TODO: handle noop_with_empty_axes
_ => {}
}
}

// TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode.
if let Some(value) = node
.inputs
.get(1)
.and_then(|argument| argument.value.as_ref())
{
axes = value.clone().into_i64s();
}

if axes.len() > 1 {
panic!("ReduceMean: reducing on multiple dimensions is not supported")
}

if axes.is_empty() && keepdims == 1 {
panic!("ReduceMean: axes must be provided with keepdims")
}

if !axes.is_empty() && keepdims == 0 {
// Not supported in Burn
panic!("ReduceMean: the reduce operation must preserve the reduced dimension")
}

if axes.is_empty() {
None
} else {
let mut dim = axes[0];

if dim < 0 {
// Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim
dim += tensor.dim as i64;
}
Some(dim as usize)
}
}

pub fn shape_config(curr: &Node) -> (usize, usize) {
if curr.inputs.len() != 1 {
panic!(
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ impl OnnxGraph {
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
NodeType::Shape => graph.register(Self::shape_conversion(node)),
Expand Down Expand Up @@ -501,6 +502,14 @@ impl OnnxGraph {
UnaryNode::reduce_mean(input, output, dim)
}

fn reduce_sum_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let dim = reduce_sum_config(&node);

UnaryNode::reduce_sum(input, output, dim)
}

fn shape_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down
Loading