Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added maxpool1d onnx operator #1725

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-core/src/nn/pool/max_pool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::tensor::Tensor;
use burn_tensor::module::max_pool1d;

/// Configuration to create a [1D max pooling](MaxPool1d) layer.
#[derive(Config)]
#[derive(Config, Debug)]
pub struct MaxPool1dConfig {
/// The size of the kernel.
pub kernel_size: usize,
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ represent the corresponding Burn Op.
| [MatMul][94] | ✅ | ✅ |
| [MatMulInteger][95] | ❌ | ✅ |
| [Max][96] | ❌ | ✅ |
| [MaxPool1d][97] | | ✅ |
| [MaxPool1d][97] | | ✅ |
| [MaxPool2d][98] | ✅ | ✅ |
| [MaxRoiPool][99] | ❌ | ❌ |
| [MaxUnpool][100] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn main() {
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
.input("tests/matmul/matmul.onnx")
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
Expand Down
Binary file not shown.
49 changes: 49 additions & 0 deletions crates/burn-import/onnx-tests/tests/maxpool1d/maxpool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: maxpool2d1.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.maxpool = nn.MaxPool1d(5, stride=2, padding=2, dilation=1)

def forward(self, x):
x = self.maxpool(x)
return x


def main():
# Set seed for reproducibility
torch.manual_seed(42)

# Print options
torch.set_printoptions(precision=3)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "maxpool1d.onnx"
test_input = torch.randn(1, 5, 5, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
print("Test input data of ones: {}".format(test_input))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))


if __name__ == '__main__':
main()

26 changes: 26 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include_models!(
log,
mask_where,
matmul,
maxpool1d,
maxpool2d,
mul,
neg,
Expand Down Expand Up @@ -442,6 +443,31 @@ mod tests {
assert_eq!(output1.to_data(), expected1);
assert_eq!(output2, expected2);
}
#[test]
fn maxpool1d() {
let device = Default::default();

let model: maxpool1d::Model<Backend> = maxpool1d::Model::new(&device);
let input = Tensor::<Backend, 3>::from_floats(
[[
[1.927, 1.487, 0.901, -2.106, 0.678],
[-1.235, -0.043, -1.605, -0.752, -0.687],
[-0.493, 0.241, -1.111, 0.092, -2.317],
[-0.217, -1.385, -0.396, 0.803, -0.622],
[-0.592, -0.063, -0.829, 0.331, -1.558],
]],
&device,
);
let output = model.forward(input);
let expected = Data::from([[
[1.927, 1.927, 0.901],
[-0.043, -0.043, -0.687],
[0.241, 0.241, 0.092],
[-0.217, 0.803, 0.803],
[-0.063, 0.331, 0.331],
]]);
assert_eq!(output.to_data(), expected);
}

#[test]
fn maxpool2d() {
Expand Down
13 changes: 7 additions & 6 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::prelu::PReluNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, linear::LinearNode, matmul::MatmulNode,
max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -93,6 +91,7 @@ pub enum Node<PS: PrecisionSettings> {
LayerNorm(LayerNormNode<PS>),
Linear(LinearNode<PS>),
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Reshape(ReshapeNode),
Unary(UnaryNode),
Expand Down Expand Up @@ -120,6 +119,7 @@ macro_rules! match_all {
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Unary(node) => $func(node),
Expand Down Expand Up @@ -157,6 +157,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::LayerNorm(_) => "layer_norm",
Node::Linear(_) => "linear",
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Reshape(_) => "reshape",
Node::Unary(unary) => unary.kind.as_str(),
Expand Down
158 changes: 158 additions & 0 deletions crates/burn-import/src/burn/node/max_pool1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use proc_macro2::TokenStream;
use quote::quote;

use burn::{nn::pool::MaxPool1dConfig, record::PrecisionSettings};

use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};

#[derive(Debug, Clone)]
pub struct MaxPool1dNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub config: MaxPool1dConfig,
}

impl MaxPool1dNode {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
config: MaxPool1dConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
MaxPool1d
},
),
input,
output,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool1dNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.stride.to_tokens();
let padding = self.config.padding.to_tokens();
let dilation = self.config.dilation.to_tokens();
let tokens = quote! {
let #name = MaxPool1dConfig::new(#kernel_size)
.with_stride(#strides)
.with_padding(#padding)
.with_dilation(#dilation)
.init();
};

Some(tokens)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}

fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PaddingConfig1d");
imports.register("burn::nn::pool::MaxPool1d");
imports.register("burn::nn::pool::MaxPool1dConfig");
}

fn into_node(self) -> Node<PS> {
Node::MaxPool1d(self)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{
nn::{pool::MaxPool1dConfig, PaddingConfig1d},
record::FullPrecisionSettings,
};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(MaxPool1dNode::new(
"max_pool1d",
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
MaxPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid)
.with_dilation(1),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::PaddingConfig1d;
use burn::nn::pool::MaxPool1d;
use burn::nn::pool::MaxPool1dConfig;

#[derive(Module, Debug)]
pub struct Model <B: Backend> {
max_pool1d: MaxPool1d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let max_pool1d = MaxPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid)
.with_dilation(1)
.init();

Self {
max_pool1d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = self.max_pool1d.forward(input);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub(crate) mod layer_norm;
pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod reshape;
Expand Down
29 changes: 28 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn::nn::{
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
pool::{AvgPool2dConfig, MaxPool2dConfig},
pool::{AvgPool2dConfig, MaxPool1dConfig, MaxPool2dConfig},
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
PaddingConfig2d,
};
Expand Down Expand Up @@ -96,6 +96,33 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
.with_padding(padding)
}

/// Create a MaxPool2dConfig from the attributes of the node
pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig {
let mut kernel_shape = Vec::new();
let mut stride = vec![1];
let mut pads = vec![0, 0];
let mut dilation = vec![1];

for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => stride = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilation = value.clone().into_i64s(),
_ => {}
}
}
assert_eq!(kernel_shape.len(), 1);
assert_eq!(dilation.len(), 1);
assert_eq!(stride.len(), 1);
let padding = padding_config_1d(&pads);

MaxPool1dConfig::new(kernel_shape[0] as usize)
.with_stride(stride[0] as usize)
.with_padding(padding)
.with_dilation(dilation[0] as usize)
}

/// Create a MaxPool2dConfig from the attributes of the node
pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
let mut kernel_shape = Vec::new();
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::{
linear::LinearNode,
mask_where::WhereNode,
matmul::MatmulNode,
max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode,
prelu::PReluNode,
reshape::ReshapeNode,
Expand Down Expand Up @@ -239,6 +240,7 @@ impl OnnxGraph {
NodeType::Cos => graph.register(Self::cos_conversion(node)),
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
NodeType::Conv2d => graph.register(Self::conv2d_conversion::<PS>(node)),
NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
Expand Down Expand Up @@ -694,6 +696,14 @@ impl OnnxGraph {
let name = &node.name;
Conv2dNode::<PS>::new(name, input, output, weight, bias, config)
}
fn max_pool1d_conversion(node: Node) -> MaxPool1dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let config = max_pool1d_config(&node);

let name = &node.name;
MaxPool1dNode::new(name, input, output, config)
}

fn max_pool2d_conversion(node: Node) -> MaxPool2dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
Expand Down
Loading