From 285997620fbe559007f02c774566faef620b73e7 Mon Sep 17 00:00:00 2001 From: JC Date: Mon, 8 Jul 2024 18:44:08 -0400 Subject: [PATCH 1/4] Implement ONNX pad --- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 21 +++ .../burn-import/onnx-tests/tests/pad/pad.onnx | Bin 0 -> 215 bytes .../burn-import/onnx-tests/tests/pad/pad.py | 158 ++++++++++++++++++ crates/burn-import/src/burn/node/base.rs | 5 +- crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/pad.rs | 104 ++++++++++++ .../burn-import/src/onnx/op_configuration.rs | 68 +++++++- crates/burn-import/src/onnx/to_burn.rs | 12 +- crates/burn-tensor/src/tensor/api/numeric.rs | 2 +- crates/onnx-ir/src/dim_inference.rs | 1 + 11 files changed, 369 insertions(+), 4 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/pad/pad.onnx create mode 100755 crates/burn-import/onnx-tests/tests/pad/pad.py create mode 100644 crates/burn-import/src/burn/node/pad.rs diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 497400369f..f10ee38619 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -43,6 +43,7 @@ fn main() { .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") + .input("tests/pad/pad.onnx") .input("tests/expand/expand.onnx") .input("tests/greater/greater.onnx") .input("tests/greater_or_equal/greater_or_equal.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index d5cc470f06..5cafd3e4cb 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -55,6 +55,7 @@ include_models!( mul, neg, not, + pad, greater, greater_or_equal, less, @@ -1406,6 +1407,26 @@ mod tests { output.assert_eq(&expected, true); } + #[test] + fn pad() { + let device = Default::default(); + let model: pad::Model = pad::Model::new(&device); + + let input = Tensor::::from_floats([[1., 2.], [3., 4.], [5., 6.]], &device); + let output = model.forward(input).to_data(); + let expected = TensorData::from([ + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 1., 2., 0., 0., 0., 0.], + [0.0_f32, 0., 3., 4., 0., 0., 0., 0.], + [0.0_f32, 0., 5., 6., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + ]); + + output.assert_eq(&expected, true); + } + #[test] fn greater() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/pad/pad.onnx b/crates/burn-import/onnx-tests/tests/pad/pad.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4c8c265c42c2fb63932bdef8125748c41b3fd2a3 GIT binary patch literal 215 zcmdlxBg_ sT2dU00*p>*_IZJ&qoh#XCd9?X!NDlR!o|SFkR-wdbrZ6v6O(`_01baHe*gdg literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/pad/pad.py b/crates/burn-import/onnx-tests/tests/pad/pad.py new file mode 100755 index 0000000000..0600b89ce7 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/pad/pad.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/pad/pad.onnx + +### Helper Functions ### +from pathlib import Path +from typing import Any +import numpy +from numpy.core.multiarray import dtype +import onnx +from onnx import ModelProto, TensorProto, ValueInfoProto +from onnx.reference import ReferenceEvaluator +from onnx.checker import check_model +from onnx.helper import ( + make_model, + make_node, + make_graph, +) + + +def build_test_save( + name: str, + inputs: list[ValueInfoProto], + outputs: list[ValueInfoProto], + initializers: list[TensorProto] = [], + attributes: dict[str, Any] = {}, +) -> None: + node_inputs = [input.name for input in inputs + initializers] + node_outputs = [output.name for output in outputs] + + node = make_node( + name.capitalize(), + inputs=node_inputs, + outputs=node_outputs, + **attributes, + ) + + graph = make_graph( + nodes=[node], + name=f"{name.capitalize()}Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + + onnx_model = make_model(graph) + check_model(onnx_model) + + run_tests(onnx_model) + + onnx.save(onnx_model, Path(__file__).with_name(f"{name}.onnx")) + + +class TestCase: + def __init__( + self, name: str, feeds: dict[str, numpy.ndarray], expected: numpy.ndarray + ): + self.name = name + self.feeds = feeds + self.expected = expected + + def test_model(self, model: ModelProto): + sess = ReferenceEvaluator(model) + + result = numpy.array(sess.run(None, self.feeds)) + + if not numpy.array_equal(result, self.expected): + print( + f"""{self.name} +Expected result: {self.expected} +Got: {result}""" + ) + raise Exception("Test failed") + + +def test_positive_pads(model: ModelProto) -> None: + input_tensor = numpy.arange(1, 7, dtype="float32").reshape(3, 2) + pads = numpy.array([1, 2, 3, 4], dtype="int") + constant_value = 0.0 + feeds = { + "input_tensor": input_tensor, + "pads": pads, + "constant_value": constant_value, + } + expected = numpy.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ) + + TestCase("test_positive_constant_pads", feeds, expected).test_model(model) + + +def test_1d_input(model: ModelProto) -> None: + input_tensor = numpy.arange(1, 5, dtype="float32") + pads = numpy.array([1, 2], dtype="int") + constant_value = 0.0 + feeds = { + "input_tensor": input_tensor, + "pads": pads, + "constant_value": constant_value, + } + expected = numpy.array([[0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0]]) + + TestCase("test_1d_input", feeds, expected).test_model(model) + + +def run_tests(model: ModelProto) -> None: + test_positive_pads(model) + test_1d_input(model) + # TODO: test_negative_pads + # TODO: support other modes: reflect, edge, wrap + + +### Helper Functions End ### + +import numpy +from onnx import TensorProto, numpy_helper +from onnx.helper import make_tensor_value_info + + +def get_initializers() -> list[TensorProto]: + pads = numpy_helper.from_array( + numpy.array([1, 2, 3, 4]).astype(numpy.int64), name="pads" + ) + constant_value = numpy_helper.from_array( + numpy.array([0.0]).astype(numpy.float32), name="constant_value" + ) + + return [pads, constant_value] + + +def main() -> None: + name = "pad" + + inputs = [make_tensor_value_info("input_tensor", TensorProto.FLOAT, [None, None])] + outputs = [make_tensor_value_info("output", TensorProto.FLOAT, [None, None])] + initializers = get_initializers() + + build_test_save( + name=name, + inputs=inputs, + outputs=outputs, + initializers=initializers, + attributes={"mode": "constant"}, + ) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index ffb4d28d47..751cbcb471 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -8,7 +8,7 @@ use super::{ conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, @@ -105,6 +105,7 @@ pub enum Node { Matmul(MatmulNode), MaxPool1d(MaxPool1dNode), MaxPool2d(MaxPool2dNode), + Pad(PadNode), Range(RangeNode), Reshape(ReshapeNode), Resize(ResizeNode), @@ -150,6 +151,7 @@ macro_rules! match_all { Node::Matmul(node) => $func(node), Node::MaxPool1d(node) => $func(node), Node::MaxPool2d(node) => $func(node), + Node::Pad(node) => $func(node), Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), Node::Resize(node) => $func(node), @@ -203,6 +205,7 @@ impl Node { Node::Matmul(_) => "matmul", Node::MaxPool1d(_) => "max_pool1d", Node::MaxPool2d(_) => "max_pool2d", + Node::Pad(_) => "pad", Node::Range(_) => "range", Node::Reshape(_) => "reshape", Node::Resize(_) => "resize", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 73c6a0201a..9d1fdce591 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -25,6 +25,7 @@ pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool1d; pub(crate) mod max_pool2d; +pub(crate) mod pad; pub(crate) mod prelu; pub(crate) mod random_normal; pub(crate) mod random_uniform; diff --git a/crates/burn-import/src/burn/node/pad.rs b/crates/burn-import/src/burn/node/pad.rs new file mode 100644 index 0000000000..eabe77d7f1 --- /dev/null +++ b/crates/burn-import/src/burn/node/pad.rs @@ -0,0 +1,104 @@ +use std::str::FromStr; + +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Config, Debug)] +pub struct PadConfig { + pub pads: Vec, + pub constant_value: f32, +} + +#[derive(Debug, Clone, new)] +pub struct PadNode { + pub input: TensorType, + pub output: TensorType, + pub config: PadConfig, +} + +impl NodeCodegen for PadNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + let pads = self.config.pads.iter().map(|p| p.to_tokens()); + let constant_value_string = format!("{}_f32.elem()", self.config.constant_value); + let constant_value = TokenStream::from_str(&constant_value_string).unwrap(); + + quote! { + let #output = #input.pad((#(#pads),*), #constant_value); + } + } + fn into_node(self) -> Node { + Node::Pad(self) + } + + fn register_imports(&self, imports: &mut crate::burn::BurnImports) { + imports.register("burn::tensor::ElementConversion"); + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{pad::PadNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_pad() { + let mut graph = BurnGraph::::default(); + let config = PadConfig::new(vec![1, 2, 3, 4], -1.0); + graph.register(PadNode::new( + TensorType::new_float("input", 2), + TensorType::new_float("output", 2), + config, + )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::tensor::ElementConversion; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.pad((1, 2, 3, 4), -1_f32.elem()); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 1c85ab6e6a..d28ce72576 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,7 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::resize::ResizeMode; +use crate::burn::node::{pad::PadConfig, resize::ResizeMode}; use onnx_ir::ir::{ArgType, AttributeValue, Data, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -745,6 +745,72 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { ) } +/// Create a PadConfig from the attributes of the node +pub fn pad_config(node: &Node) -> PadConfig { + fn get_pads(node: &Node) -> Vec { + if node.inputs.len() != 3 { + panic!("Pad: must provide three inputs") + } + + let input_dim = match &node.inputs.first().unwrap().ty { + ArgType::Tensor(tensor) => tensor.dim, + _ => panic!("Pad: Only tensor input is valid"), + }; + + let pads: Vec = match &node.inputs[1].value { + Some(Data::Int64s(shape)) => shape + .iter() + .map(|&x| { + if x < 0 { + // TODO: support negative pads + panic!("Pad: Negative pad is not supported"); + } + x as usize + }) + .collect(), + _ => panic!("Pad: pads data type must be int64"), + }; + + if pads.len() != input_dim * 2 { + panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); + } + // TODO: Burn's pad should support 1D tensor + if input_dim < 2 { + panic!("Pad: input tensor should be rank 2 or higher"); + } + + let left_index = input_dim - 1; + let top_index = input_dim - 2; + let right_index = pads.len() - 1; + let bottom_index = pads.len() - 2; + let index_list = [left_index, top_index, right_index, bottom_index]; + + for (index, &item) in pads.iter().enumerate() { + if !index_list.contains(&index) && item != 0 { + panic!("Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions"); + } + } + + let left = pads[left_index]; + let top = pads[top_index]; + let right = pads[right_index]; + let bottom = pads[bottom_index]; + vec![left, right, top, bottom] + } + fn get_constant_value(node: &Node) -> f32 { + // TODO: support int, boolean + match &node.inputs[2].value { + Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(), + _ => panic!("Pad: should provide a constant value input to pad with, for example 0.0"), + } + } + + let pads = get_pads(node); + let constant_value = get_constant_value(node); + + PadConfig::new(pads, constant_value) +} + /// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. /// /// # Arguments diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 3b32ba9c4f..292d515e81 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -40,6 +40,7 @@ use crate::{ matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, + pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, @@ -63,7 +64,7 @@ use super::op_configuration::{ concat_config, conv1d_config, conv2d_config, conv3d_config, conv_transpose2d_config, conv_transpose3d_config, dropout_config, expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config, - max_pool2d_config, reduce_max_config, reduce_mean_config, reduce_min_config, + max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config, }; @@ -324,6 +325,7 @@ impl ParsedOnnxGraph { NodeType::ConvTranspose3d => { graph.register(Self::conv_transpose3d_conversion::(node)) } + NodeType::Pad => graph.register(Self::pad_conversion(node)), NodeType::Pow => graph.register(Self::pow_conversion(node)), NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)), NodeType::Where => graph.register(Self::where_conversion(node)), @@ -1098,6 +1100,14 @@ impl ParsedOnnxGraph { BinaryNode::lower_equal(lhs, rhs, output) } + fn pad_conversion(node: Node) -> PadNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let config = pad_config(&node); + + PadNode::new(input, output, config) + } + fn pow_conversion(node: Node) -> BinaryNode { let lhs = Type::from(node.inputs.first().unwrap()); let rhs = Type::from(node.inputs.get(1).unwrap()); diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index ed05b16422..0f9432cc0a 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -736,7 +736,7 @@ where ) } - /// Pad the tensor with the given value on the last two dimensions. + /// Pad the tensor of rank two or higher with the given value on the last two dimensions. /// /// # Arguments /// diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index ae11ff1ac3..0fb45fd3c4 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -48,6 +48,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), NodeType::Not => same_as_input(node), + NodeType::Pad => same_as_input(node), NodeType::Greater => greater_update_outputs(node), NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node), NodeType::Less => less_update_outputs(node), From c7b00291089578b165272292a80c695a41ee7f38 Mon Sep 17 00:00:00 2001 From: mepatrick73 Date: Tue, 23 Jul 2024 11:58:45 -0400 Subject: [PATCH 2/4] ONNX pad arguments fix pad now requires 2 or more arguments if the third argument is not given, it will default to 0 --- crates/burn-import/src/onnx/op_configuration.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index d28ce72576..ca422a621d 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -748,7 +748,7 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { fn get_pads(node: &Node) -> Vec { - if node.inputs.len() != 3 { + if node.inputs.len() > 1 { panic!("Pad: must provide three inputs") } @@ -801,7 +801,7 @@ pub fn pad_config(node: &Node) -> PadConfig { // TODO: support int, boolean match &node.inputs[2].value { Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(), - _ => panic!("Pad: should provide a constant value input to pad with, for example 0.0"), + _ => 0.0, } } From 706b48b2c2df5a4a2a4a079aaa7f546a82a9f2e4 Mon Sep 17 00:00:00 2001 From: mepatrick73 Date: Tue, 23 Jul 2024 12:19:37 -0400 Subject: [PATCH 3/4] fixing bug in input len fix --- crates/burn-import/src/onnx/op_configuration.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index ca422a621d..c74a349efe 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -748,8 +748,8 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { fn get_pads(node: &Node) -> Vec { - if node.inputs.len() > 1 { - panic!("Pad: must provide three inputs") + if node.inputs.len() < 2 { + panic!("Pad: must provide two inputs") } let input_dim = match &node.inputs.first().unwrap().ty { From f230da1b5ceea3a0742eb68accd5db501a4f333e Mon Sep 17 00:00:00 2001 From: mepatrick73 Date: Tue, 23 Jul 2024 13:49:47 -0400 Subject: [PATCH 4/4] change panic comment Change panic comment from needing two inputs. This comes from the fact that the ONNX spec requires two necessary inputs but could have more two more optional argument. --- crates/burn-import/src/onnx/op_configuration.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index c74a349efe..def018f4ae 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -749,7 +749,7 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { pub fn pad_config(node: &Node) -> PadConfig { fn get_pads(node: &Node) -> Vec { if node.inputs.len() < 2 { - panic!("Pad: must provide two inputs") + panic!("Pad: must provide at least two inputs") } let input_dim = match &node.inputs.first().unwrap().ty {