From 4a3fc9d4a094bc7ce4205fbc83b61da383308a43 Mon Sep 17 00:00:00 2001 From: johnhuichen Date: Tue, 23 Jul 2024 17:50:20 +0000 Subject: [PATCH] Implement ONNX Pad Operator (#2007) * Implement ONNX pad * ONNX pad arguments fix pad now requires 2 or more arguments if the third argument is not given, it will default to 0 * fixing bug in input len fix * change panic comment Change panic comment from needing two inputs. This comes from the fact that the ONNX spec requires two necessary inputs but could have more two more optional argument. --------- Co-authored-by: JC Co-authored-by: mepatrick73 --- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 21 +++ .../burn-import/onnx-tests/tests/pad/pad.onnx | Bin 0 -> 215 bytes .../burn-import/onnx-tests/tests/pad/pad.py | 158 ++++++++++++++++++ crates/burn-import/src/burn/node/base.rs | 5 +- crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/pad.rs | 104 ++++++++++++ .../burn-import/src/onnx/op_configuration.rs | 68 +++++++- crates/burn-import/src/onnx/to_burn.rs | 12 +- crates/burn-tensor/src/tensor/api/numeric.rs | 2 +- crates/onnx-ir/src/dim_inference.rs | 1 + 11 files changed, 369 insertions(+), 4 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/pad/pad.onnx create mode 100755 crates/burn-import/onnx-tests/tests/pad/pad.py create mode 100644 crates/burn-import/src/burn/node/pad.rs diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index d221d40a71..2ccca2e55b 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -43,6 +43,7 @@ fn main() { .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") + .input("tests/pad/pad.onnx") .input("tests/expand/expand.onnx") .input("tests/greater/greater.onnx") .input("tests/greater_or_equal/greater_or_equal.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 577f12ac85..b6cefac8c1 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -55,6 +55,7 @@ include_models!( mul, neg, not, + pad, greater, greater_or_equal, less, @@ -1407,6 +1408,26 @@ mod tests { output.assert_eq(&expected, true); } + #[test] + fn pad() { + let device = Default::default(); + let model: pad::Model = pad::Model::new(&device); + + let input = Tensor::::from_floats([[1., 2.], [3., 4.], [5., 6.]], &device); + let output = model.forward(input).to_data(); + let expected = TensorData::from([ + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 1., 2., 0., 0., 0., 0.], + [0.0_f32, 0., 3., 4., 0., 0., 0., 0.], + [0.0_f32, 0., 5., 6., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + ]); + + output.assert_eq(&expected, true); + } + #[test] fn greater() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/pad/pad.onnx b/crates/burn-import/onnx-tests/tests/pad/pad.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4c8c265c42c2fb63932bdef8125748c41b3fd2a3 GIT binary patch literal 215 zcmdlxBg_ sT2dU00*p>*_IZJ&qoh#XCd9?X!NDlR!o|SFkR-wdbrZ6v6O(`_01baHe*gdg literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/pad/pad.py b/crates/burn-import/onnx-tests/tests/pad/pad.py new file mode 100755 index 0000000000..0600b89ce7 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/pad/pad.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/pad/pad.onnx + +### Helper Functions ### +from pathlib import Path +from typing import Any +import numpy +from numpy.core.multiarray import dtype +import onnx +from onnx import ModelProto, TensorProto, ValueInfoProto +from onnx.reference import ReferenceEvaluator +from onnx.checker import check_model +from onnx.helper import ( + make_model, + make_node, + make_graph, +) + + +def build_test_save( + name: str, + inputs: list[ValueInfoProto], + outputs: list[ValueInfoProto], + initializers: list[TensorProto] = [], + attributes: dict[str, Any] = {}, +) -> None: + node_inputs = [input.name for input in inputs + initializers] + node_outputs = [output.name for output in outputs] + + node = make_node( + name.capitalize(), + inputs=node_inputs, + outputs=node_outputs, + **attributes, + ) + + graph = make_graph( + nodes=[node], + name=f"{name.capitalize()}Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + + onnx_model = make_model(graph) + check_model(onnx_model) + + run_tests(onnx_model) + + onnx.save(onnx_model, Path(__file__).with_name(f"{name}.onnx")) + + +class TestCase: + def __init__( + self, name: str, feeds: dict[str, numpy.ndarray], expected: numpy.ndarray + ): + self.name = name + self.feeds = feeds + self.expected = expected + + def test_model(self, model: ModelProto): + sess = ReferenceEvaluator(model) + + result = numpy.array(sess.run(None, self.feeds)) + + if not numpy.array_equal(result, self.expected): + print( + f"""{self.name} +Expected result: {self.expected} +Got: {result}""" + ) + raise Exception("Test failed") + + +def test_positive_pads(model: ModelProto) -> None: + input_tensor = numpy.arange(1, 7, dtype="float32").reshape(3, 2) + pads = numpy.array([1, 2, 3, 4], dtype="int") + constant_value = 0.0 + feeds = { + "input_tensor": input_tensor, + "pads": pads, + "constant_value": constant_value, + } + expected = numpy.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ) + + TestCase("test_positive_constant_pads", feeds, expected).test_model(model) + + +def test_1d_input(model: ModelProto) -> None: + input_tensor = numpy.arange(1, 5, dtype="float32") + pads = numpy.array([1, 2], dtype="int") + constant_value = 0.0 + feeds = { + "input_tensor": input_tensor, + "pads": pads, + "constant_value": constant_value, + } + expected = numpy.array([[0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0]]) + + TestCase("test_1d_input", feeds, expected).test_model(model) + + +def run_tests(model: ModelProto) -> None: + test_positive_pads(model) + test_1d_input(model) + # TODO: test_negative_pads + # TODO: support other modes: reflect, edge, wrap + + +### Helper Functions End ### + +import numpy +from onnx import TensorProto, numpy_helper +from onnx.helper import make_tensor_value_info + + +def get_initializers() -> list[TensorProto]: + pads = numpy_helper.from_array( + numpy.array([1, 2, 3, 4]).astype(numpy.int64), name="pads" + ) + constant_value = numpy_helper.from_array( + numpy.array([0.0]).astype(numpy.float32), name="constant_value" + ) + + return [pads, constant_value] + + +def main() -> None: + name = "pad" + + inputs = [make_tensor_value_info("input_tensor", TensorProto.FLOAT, [None, None])] + outputs = [make_tensor_value_info("output", TensorProto.FLOAT, [None, None])] + initializers = get_initializers() + + build_test_save( + name=name, + inputs=inputs, + outputs=outputs, + initializers=initializers, + attributes={"mode": "constant"}, + ) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index ffb4d28d47..751cbcb471 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -8,7 +8,7 @@ use super::{ conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, @@ -105,6 +105,7 @@ pub enum Node { Matmul(MatmulNode), MaxPool1d(MaxPool1dNode), MaxPool2d(MaxPool2dNode), + Pad(PadNode), Range(RangeNode), Reshape(ReshapeNode), Resize(ResizeNode), @@ -150,6 +151,7 @@ macro_rules! match_all { Node::Matmul(node) => $func(node), Node::MaxPool1d(node) => $func(node), Node::MaxPool2d(node) => $func(node), + Node::Pad(node) => $func(node), Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), Node::Resize(node) => $func(node), @@ -203,6 +205,7 @@ impl Node { Node::Matmul(_) => "matmul", Node::MaxPool1d(_) => "max_pool1d", Node::MaxPool2d(_) => "max_pool2d", + Node::Pad(_) => "pad", Node::Range(_) => "range", Node::Reshape(_) => "reshape", Node::Resize(_) => "resize", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 73c6a0201a..9d1fdce591 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -25,6 +25,7 @@ pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool1d; pub(crate) mod max_pool2d; +pub(crate) mod pad; pub(crate) mod prelu; pub(crate) mod random_normal; pub(crate) mod random_uniform; diff --git a/crates/burn-import/src/burn/node/pad.rs b/crates/burn-import/src/burn/node/pad.rs new file mode 100644 index 0000000000..eabe77d7f1 --- /dev/null +++ b/crates/burn-import/src/burn/node/pad.rs @@ -0,0 +1,104 @@ +use std::str::FromStr; + +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Config, Debug)] +pub struct PadConfig { + pub pads: Vec, + pub constant_value: f32, +} + +#[derive(Debug, Clone, new)] +pub struct PadNode { + pub input: TensorType, + pub output: TensorType, + pub config: PadConfig, +} + +impl NodeCodegen for PadNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + let pads = self.config.pads.iter().map(|p| p.to_tokens()); + let constant_value_string = format!("{}_f32.elem()", self.config.constant_value); + let constant_value = TokenStream::from_str(&constant_value_string).unwrap(); + + quote! { + let #output = #input.pad((#(#pads),*), #constant_value); + } + } + fn into_node(self) -> Node { + Node::Pad(self) + } + + fn register_imports(&self, imports: &mut crate::burn::BurnImports) { + imports.register("burn::tensor::ElementConversion"); + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{pad::PadNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_pad() { + let mut graph = BurnGraph::::default(); + let config = PadConfig::new(vec![1, 2, 3, 4], -1.0); + graph.register(PadNode::new( + TensorType::new_float("input", 2), + TensorType::new_float("output", 2), + config, + )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::tensor::ElementConversion; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.pad((1, 2, 3, 4), -1_f32.elem()); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index f69dfe4c26..8ba031e04d 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,7 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::resize::ResizeMode; +use crate::burn::node::{pad::PadConfig, resize::ResizeMode}; use onnx_ir::ir::{ArgType, AttributeValue, Data, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -745,6 +745,72 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { ) } +/// Create a PadConfig from the attributes of the node +pub fn pad_config(node: &Node) -> PadConfig { + fn get_pads(node: &Node) -> Vec { + if node.inputs.len() < 2 { + panic!("Pad: must provide at least two inputs") + } + + let input_dim = match &node.inputs.first().unwrap().ty { + ArgType::Tensor(tensor) => tensor.dim, + _ => panic!("Pad: Only tensor input is valid"), + }; + + let pads: Vec = match &node.inputs[1].value { + Some(Data::Int64s(shape)) => shape + .iter() + .map(|&x| { + if x < 0 { + // TODO: support negative pads + panic!("Pad: Negative pad is not supported"); + } + x as usize + }) + .collect(), + _ => panic!("Pad: pads data type must be int64"), + }; + + if pads.len() != input_dim * 2 { + panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); + } + // TODO: Burn's pad should support 1D tensor + if input_dim < 2 { + panic!("Pad: input tensor should be rank 2 or higher"); + } + + let left_index = input_dim - 1; + let top_index = input_dim - 2; + let right_index = pads.len() - 1; + let bottom_index = pads.len() - 2; + let index_list = [left_index, top_index, right_index, bottom_index]; + + for (index, &item) in pads.iter().enumerate() { + if !index_list.contains(&index) && item != 0 { + panic!("Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions"); + } + } + + let left = pads[left_index]; + let top = pads[top_index]; + let right = pads[right_index]; + let bottom = pads[bottom_index]; + vec![left, right, top, bottom] + } + fn get_constant_value(node: &Node) -> f32 { + // TODO: support int, boolean + match &node.inputs[2].value { + Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(), + _ => 0.0, + } + } + + let pads = get_pads(node); + let constant_value = get_constant_value(node); + + PadConfig::new(pads, constant_value) +} + /// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. /// /// # Arguments diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 77ec889ea2..9ab92e8b36 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -40,6 +40,7 @@ use crate::{ matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, + pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, @@ -63,7 +64,7 @@ use super::op_configuration::{ concat_config, conv1d_config, conv2d_config, conv3d_config, conv_transpose2d_config, conv_transpose3d_config, dropout_config, expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config, - max_pool2d_config, reduce_max_config, reduce_mean_config, reduce_min_config, + max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config, }; @@ -324,6 +325,7 @@ impl ParsedOnnxGraph { NodeType::ConvTranspose3d => { graph.register(Self::conv_transpose3d_conversion::(node)) } + NodeType::Pad => graph.register(Self::pad_conversion(node)), NodeType::Pow => graph.register(Self::pow_conversion(node)), NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)), NodeType::Where => graph.register(Self::where_conversion(node)), @@ -1108,6 +1110,14 @@ impl ParsedOnnxGraph { BinaryNode::lower_equal(lhs, rhs, output) } + fn pad_conversion(node: Node) -> PadNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let config = pad_config(&node); + + PadNode::new(input, output, config) + } + fn pow_conversion(node: Node) -> BinaryNode { let lhs = Type::from(node.inputs.first().unwrap()); let rhs = Type::from(node.inputs.get(1).unwrap()); diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index ed05b16422..0f9432cc0a 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -736,7 +736,7 @@ where ) } - /// Pad the tensor with the given value on the last two dimensions. + /// Pad the tensor of rank two or higher with the given value on the last two dimensions. /// /// # Arguments /// diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index d032569ed2..8b0ea3029f 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -48,6 +48,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), NodeType::Not => same_as_input(node), + NodeType::Pad => same_as_input(node), NodeType::Greater => greater_update_outputs(node), NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node), NodeType::Less => less_update_outputs(node),