From c2b6318fc3232536c91910923e09ad0f9f30763a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20M=C3=BCller?= Date: Mon, 8 Jul 2024 22:11:59 +0200 Subject: [PATCH] Implement ONNX ConstantOfShape (#1815) * Feat: burn-import implement ONNX ConstantOfShape * Introduce shape type and use in ConstantOfShape and Shape * Add tests for bool and int tensors for ConstantOfShape * Fix ONNX test generation * Undo comment --------- Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 2 + .../constant_of_shape/constant_of_shape.onnx | Bin 0 -> 161 bytes .../constant_of_shape/constant_of_shape.py | 53 +++++ .../constant_of_shape_full_like.onnx | Bin 0 -> 664 bytes .../constant_of_shape_full_like.py | 59 +++++ .../onnx-tests/tests/onnx_tests.rs | 41 +++- crates/burn-import/src/burn/graph.rs | 1 + crates/burn-import/src/burn/node/base.rs | 11 +- .../src/burn/node/constant_of_shape.rs | 203 ++++++++++-------- crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/unary.rs | 45 ++-- crates/burn-import/src/burn/ty.rs | 27 +++ crates/burn-import/src/onnx/to_burn.rs | 42 +++- crates/onnx-ir/src/dim_inference.rs | 45 ++-- crates/onnx-ir/src/from_onnx.rs | 5 + crates/onnx-ir/src/ir.rs | 2 + crates/onnx-ir/src/util.rs | 38 ++++ 18 files changed, 446 insertions(+), 131 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.onnx create mode 100755 crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.py create mode 100644 crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape_full_like.onnx create mode 100755 crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape_full_like.py diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 33126a290e..571cd7b565 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -38,7 +38,7 @@ represent the corresponding Burn Op. | [Concat][30] | ✅ | ✅ | | [ConcatFromSequence][31] | ❌ | ❌ | | [Constant][32] | ✅ | ✅ | -| [ConstantOfShape][33] | ❌ | ❌ | +| [ConstantOfShape][33] | ✅ | ✅ | | [Conv1d][34] | ✅ | ✅ | | [Conv2d][34] | ✅ | ✅ | | [Conv3d][34] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 7ed38509b4..497400369f 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -85,6 +85,8 @@ fn main() { .input("tests/squeeze/squeeze_opset13.onnx") .input("tests/random_uniform/random_uniform.onnx") .input("tests/random_normal/random_normal.onnx") + .input("tests/constant_of_shape/constant_of_shape.onnx") + .input("tests/constant_of_shape/constant_of_shape_full_like.onnx") .input("tests/range/range.onnx") .out_dir("model/") .run_from_script(); diff --git a/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.onnx b/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fe0c255ccabc83630313fce5dcee10a7acab92de GIT binary patch literal 161 zcmdhzBUl1d;;LPApsu0s;URLnYt< literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.py b/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.py new file mode 100755 index 0000000000..a5aea25806 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + +# used to generate model: constant_of_shape.onnx + +# torch simplifies simple usecases where it can statically determine the shape of the constant +# to use just ONNX constants instead of ConstantOfShape +# Hence this model is exported using onnx directly + +import onnx +import onnx.helper + + +def build_model(): + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 16)], + graph=onnx.helper.make_graph(name="main_graph", nodes=[ + onnx.helper.make_node( + "ConstantOfShape", + inputs=["input1"], + outputs=["output1"], + name="/ConstantOfShape", + value=onnx.helper.make_tensor("value", data_type=onnx.TensorProto.FLOAT, dims=[1], vals=[1.125]) + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT64, shape=[3] + ), + ) + ], + outputs=[ + onnx.helper.make_value_info( + name="output1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 3, 2] + ), + ) + ]), + ) + + +def main(): + onnx_model = build_model() + file_name = "constant_of_shape.onnx" + onnx.save(onnx_model, file_name) + onnx.checker.check_model(file_name) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape_full_like.onnx b/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape_full_like.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4214ab1b3f8681f8e95707f3661713612c772d66 GIT binary patch literal 664 zcma*kK}*9h6bJA|nznlmR-fRZHwQA9u&o6ZJPp0|;=x-HN;kM}3P{Np%i7w<~07u=)4B$)WOUnY!zio7Tl4yKQs6=^I>RUPA*XZ;L@w4Or* z?G5ZVyUR>6<+;;)nQs)!)x)xm3::ones([4, 2], &device); let output = model.forward(input); - let expected = TensorData::from([4i64, 2]); - - output.to_data().assert_eq(&expected, true); + let expected = [4, 2]; + assert_eq!(output, expected); } #[test] @@ -1658,4 +1659,36 @@ mod tests { let output = model.forward(); assert_eq!(expected_shape, output.shape()); } + + #[test] + fn constant_of_shape() { + // This tests shape is being passed directly to the model + let device = Default::default(); + let model = constant_of_shape::Model::::new(&device); + let input_shape = [2, 3, 2]; + let expected = Tensor::::full(input_shape, 1.125, &device).to_data(); + + let output = model.forward(input_shape); + + output.to_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn constant_of_shape_full_like() { + // This tests shape is being passed from the input tensor + + let device = Default::default(); + let model = constant_of_shape_full_like::Model::::new(&device); + let shape = [2, 3, 2]; + let f_expected = Tensor::::full(shape, 3.0, &device); + let i_expected = Tensor::::full(shape, 5, &device); + let b_expected = Tensor::::ones(shape, &device).bool(); + + let input = Tensor::ones(shape, &device); + let (f_output, i_output, b_output) = model.forward(input); + + assert!(f_output.equal(f_expected).all().into_scalar()); + assert!(i_output.equal(i_expected).all().into_scalar()); + assert!(b_output.equal(b_expected).all().into_scalar()); + } } diff --git a/crates/burn-import/src/burn/graph.rs b/crates/burn-import/src/burn/graph.rs index e47f88d091..ee66399b3a 100644 --- a/crates/burn-import/src/burn/graph.rs +++ b/crates/burn-import/src/burn/graph.rs @@ -292,6 +292,7 @@ impl BurnGraph { Type::Tensor(tensor) => Some(tensor), Type::Scalar(_) => None, Type::Other(_) => None, + Type::Shape(_) => None, } } diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index a93bc5f2c7..ffb4d28d47 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -3,10 +3,10 @@ use std::marker::PhantomData; use super::{ argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode, - constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, conv3d::Conv3dNode, - conv_transpose_2d::ConvTranspose2dNode, conv_transpose_3d::ConvTranspose3dNode, - dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, - gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, + constant::ConstantNode, constant_of_shape::ConstantOfShapeNode, conv1d::Conv1dNode, + conv2d::Conv2dNode, conv3d::Conv3dNode, conv_transpose_2d::ConvTranspose2dNode, + conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode, + gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, @@ -116,6 +116,7 @@ pub enum Node { Where(WhereNode), RandomUniform(RandomUniformNode), RandomNormal(RandomNormalNode), + ConstantOfShape(ConstantOfShapeNode), // For now, we have to keep the precision settings in order to correctly serialize the fields // into the right data types. _Unreachable(std::convert::Infallible, PhantomData), @@ -160,6 +161,7 @@ macro_rules! match_all { Node::Where(node) => $func(node), Node::RandomNormal(node) => $func(node), Node::RandomUniform(node) => $func(node), + Node::ConstantOfShape(node) => $func(node), _ => unimplemented!(), } }}; @@ -212,6 +214,7 @@ impl Node { Node::Where(_) => "where", Node::RandomNormal(_) => "random_normal", Node::RandomUniform(_) => "random_uniform", + Node::ConstantOfShape(_) => "constant_of_shape", _ => unimplemented!(), } } diff --git a/crates/burn-import/src/burn/node/constant_of_shape.rs b/crates/burn-import/src/burn/node/constant_of_shape.rs index 2c24b8d99a..2740472d5f 100644 --- a/crates/burn-import/src/burn/node/constant_of_shape.rs +++ b/crates/burn-import/src/burn/node/constant_of_shape.rs @@ -1,18 +1,19 @@ use super::{Node, NodeCodegen}; -use crate::burn::{BurnImports, Scope, ToTokens, Type}; +use crate::burn::{Scope, ToTokens, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; -#[derive(Debug, Clone, new)] +/// Node for all unary operators. +#[derive(Debug, Clone)] pub struct ConstantOfShapeNode { - pub value: ConstantOfShapeValue, pub input: Type, pub output: Type, + pub value: ConstantValue, } -#[derive(Debug, Clone, new)] -pub enum ConstantOfShapeValue { +#[derive(Debug, Clone, PartialEq)] +pub enum ConstantValue { /// Float constant. Float32(f32), Float64(f64), @@ -25,54 +26,114 @@ pub enum ConstantOfShapeValue { Bool(bool), } -impl ToTokens for ConstantOfShapeValue { - fn to_tokens(&self) -> TokenStream { +impl ConstantOfShapeNode { + pub fn new(input: Type, output: Type, value: ConstantValue) -> Self { + assert!( + matches!(input, Type::Shape(_)), + "ConstantOfShape input needs to be a Shape!" + ); + assert!( + matches!(output, Type::Tensor(_)), + "ConstantOfShape output needs to be a Tensor!" + ); + Self { + input, + output, + value, + } + } +} + +impl ConstantValue { + pub fn val_tokens(&self) -> TokenStream { match self { - ConstantOfShapeValue::Bool(val) => val.to_tokens(), - ConstantOfShapeValue::Float32(val) => val.to_tokens(), - ConstantOfShapeValue::Float64(val) => val.to_tokens(), - ConstantOfShapeValue::Int32(val) => val.to_tokens(), - ConstantOfShapeValue::Int64(val) => val.to_tokens(), + Self::Float32(val) => quote! { #val }, + Self::Float64(val) => quote! { #val }, + Self::Int32(val) => quote! { #val }, + Self::Int64(val) => quote! { #val }, + Self::Bool(val) => quote! { #val }, } } + + pub fn from_vec + Copy>(mut source: Vec) -> Self { + assert_eq!( + source.len(), + 1, + "ConstantOfShape value from a vec needs to have exactly 1 element!" + ); + source.drain(..).next().unwrap().into() + } } -impl NodeCodegen for ConstantOfShapeNode { - fn output_types(&self) -> Vec { - vec![self.output.clone()] +impl From for ConstantValue { + fn from(value: f32) -> Self { + Self::Float32(value) + } +} +impl From for ConstantValue { + fn from(value: f64) -> Self { + Self::Float64(value) + } +} +impl From for ConstantValue { + fn from(value: i32) -> Self { + Self::Int32(value) + } +} +impl From for ConstantValue { + fn from(value: i64) -> Self { + Self::Int64(value) + } +} +impl From for ConstantValue { + fn from(value: bool) -> Self { + Self::Bool(value) } +} +impl NodeCodegen for ConstantOfShapeNode { fn input_types(&self) -> Vec { vec![self.input.clone()] } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let output = &self.output.name(); - let value = self.value.to_tokens(); + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { + let output = self.output.name(); + let input = self.input.name(); - match (&self.input, &self.output) { - (Type::Tensor(input), Type::Tensor(_)) => { - let input = scope.tensor_use_owned(&input, node_position); - quote! { - let #output = Tensor::full(#input.to_data().value, #value, &#input.device()); - } - } - (Type::Scalar(_), Type::Scalar(_)) => { - quote! { - let #output = #value; + let output_rank = match &self.output { + Type::Tensor(tensor) => tensor.dim.to_tokens(), + _ => unreachable!(), + }; + + let value = self.value.val_tokens(); + // Note: in the generated code, self.device is a &module::Ignored, + // so to get a &Device, &* is needed + + match &self.value { + ConstantValue::Bool(bool) => { + // Currently there is no full bool tensor support in the backend + // So we use 0 or 1 with bool type casting + // See: https://github.com/tracel-ai/burn/issues/1535 + if *bool { + quote! { + let #output = Tensor::::ones(#input, &*self.device).bool(); + } + } else { + quote! { + let #output = Tensor::::zeros(#input, &*self.device).bool(); + } } } - _ => panic!( - "Invalid input/output type ({:?}, {:?})", - self.input, self.output - ), + _ => quote! { + let #output = Tensor::full(#input, #value, &*self.device); + }, } } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::tensor::Int"); - } - fn into_node(self) -> Node { Node::ConstantOfShape(self) } @@ -86,62 +147,33 @@ mod tests { use crate::burn::{ graph::BurnGraph, node::{constant_of_shape::ConstantOfShapeNode, test::assert_tokens}, - ScalarType, TensorType, + ShapeType, TensorType, }; #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(ConstantOfShapeNode::new( - ConstantOfShapeValue::new_float32(1.25), - Type::Tensor(TensorType::new_int("tensor1", 1)), - Type::Tensor(TensorType::new_float("tensor2", 3)), - )); - - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } - - impl Model { - #[allow(unused_variables)] - pub fn new(device: &B::Device) -> Self { - Self { - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = Tensor::full(tensor1.to_data().value, 1.25, &tensor1.device()); - - tensor2 - } - } - }; - - assert_tokens(graph.codegen(), expected); + fn test_constant_val() { + assert_eq!(ConstantValue::from(1i32), ConstantValue::Int32(1i32)); + assert_eq!(ConstantValue::from(-1i64), ConstantValue::Int64(-1i64)); + assert_eq!(ConstantValue::from(0f32), ConstantValue::Float32(0f32)); + assert_eq!(ConstantValue::from(0f64), ConstantValue::Float64(0f64)); + assert_eq!(ConstantValue::from(true), ConstantValue::Bool(true)); + assert_eq!( + ConstantValue::from_vec(vec![2i32]), + ConstantValue::Int32(2i32) + ); } #[test] - fn test_codegen_scalar() { + fn test_codegen_nodes() { let mut graph = BurnGraph::::default(); graph.register(ConstantOfShapeNode::new( - ConstantOfShapeValue::new_float64(1.25), - Type::Scalar(ScalarType::new("scalar1", crate::burn::ScalarKind::Int64)), - Type::Scalar(ScalarType::new("scalar2", crate::burn::ScalarKind::Float64)), + Type::Shape(ShapeType::new("shape1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ConstantValue::Float32(1.25f32), )); - graph.register_input_output(vec!["scalar1".to_string()], vec!["scalar2".to_string()]); + graph.register_input_output(vec!["shape1".to_string()], vec!["tensor2".to_string()]); let expected = quote! { use burn::{ @@ -152,6 +184,7 @@ mod tests { #[derive(Module, Debug)] pub struct Model { phantom: core::marker::PhantomData, + device: burn::module::Ignored, } impl Model { @@ -159,13 +192,13 @@ mod tests { pub fn new(device: &B::Device) -> Self { Self { phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self, scalar1: i64) -> f64 { - let scalar2 = 1.25; - - scalar2 + pub fn forward(&self, shape1: [usize;4]) -> Tensor { + let tensor2 = Tensor::full(shape1, 1.25f32, &*self.device); + tensor2 } } }; diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 81919f85df..73c6a0201a 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -8,6 +8,7 @@ pub(crate) mod binary; pub(crate) mod clip; pub(crate) mod concat; pub(crate) mod constant; +pub(crate) mod constant_of_shape; pub(crate) mod conv1d; pub(crate) mod conv2d; pub(crate) mod conv3d; diff --git a/crates/burn-import/src/burn/node/unary.rs b/crates/burn-import/src/burn/node/unary.rs index 8efc5ffc11..9bb7866b90 100644 --- a/crates/burn-import/src/burn/node/unary.rs +++ b/crates/burn-import/src/burn/node/unary.rs @@ -116,12 +116,21 @@ impl NodeCodegen for UnaryNode { _ => panic!("lhs must be a tensor or scalar"), }; - // let input = scope.tensor_use_owned(&self.input, node_position); let output = &self.output.name(); let function = (self.function)(input); - quote! { - let #output = #function; + match &self.output { + Type::Shape(ref shape_type) => { + let dim = shape_type.dim.to_tokens(); + quote! { + let #output: [usize;#dim] = #function.try_into().unwrap(); + } + } + _ => { + quote! { + let #output = #function; + } + } } } @@ -135,9 +144,6 @@ impl NodeCodegen for UnaryNode { UnaryNodeKind::Neg => { imports.register("core::ops::Neg"); } - UnaryNodeKind::Shape => { - imports.register("burn::tensor::Int"); - } UnaryNodeKind::Not => { imports.register("burn::tensor::Bool"); } @@ -451,15 +457,12 @@ impl UnaryNode { } pub(crate) fn shape(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self { - // Shape as defined by the ONNX op should return a tensor because other ops - // (e.g., Gather) will be used on a tensor + let start_dim = start_dim.to_tokens(); + let end_dim = end_dim.to_tokens(); + let function = move |input| { quote! { - Tensor::::from_data( - burn::tensor::TensorData::from(&#input.dims()[#start_dim..#end_dim]) - .convert::>(), - &#input.device(), - ) + #input.dims()[#start_dim..#end_dim] } }; Self::new(input, output, UnaryNodeKind::Shape, Rc::new(function)) @@ -475,7 +478,7 @@ impl UnaryNode { mod tests { use super::*; use crate::burn::node::tests::one_node_graph; - use crate::burn::{ScalarKind, ScalarType, TensorType}; + use crate::burn::{ScalarKind, ScalarType, ShapeType, TensorType}; #[test] fn test_unary_codegen_flatten() { @@ -1094,23 +1097,19 @@ mod tests { one_node_graph( UnaryNode::shape( Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_int("tensor2", 1)), + Type::Shape(ShapeType::new("shape1", 4)), 1, 3, ), quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = Tensor::::from_data( - burn::tensor::TensorData::from(&tensor1.dims()[1usize..3usize]) - .convert::>(), - &tensor1.device(), - ); + pub fn forward(&self, tensor1: Tensor) -> [usize; 4] { + let shape1: [usize; 4] = tensor1.dims()[1..3].try_into().unwrap(); - tensor2 + shape1 } }, vec!["tensor1".to_string()], - vec!["tensor2".to_string()], + vec!["shape1".to_string()], ); } diff --git a/crates/burn-import/src/burn/ty.rs b/crates/burn-import/src/burn/ty.rs index 82ee06246b..9bf04d3d05 100644 --- a/crates/burn-import/src/burn/ty.rs +++ b/crates/burn-import/src/burn/ty.rs @@ -35,6 +35,12 @@ pub struct ScalarType { pub kind: ScalarKind, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ShapeType { + pub name: Ident, + pub dim: usize, +} + #[derive(Debug, Clone)] pub struct OtherType { pub name: Ident, @@ -49,6 +55,9 @@ pub enum Type { /// Scalar type. Scalar(ScalarType), + /// Shape type. + Shape(ShapeType), + // Other type (more flexible type). Other(OtherType), } @@ -58,6 +67,7 @@ impl Type { match self { Type::Tensor(tensor) => &tensor.name, Type::Scalar(scalar) => &scalar.name, + Type::Shape(shape) => &shape.name, Type::Other(other) => &other.name, } } @@ -65,6 +75,7 @@ impl Type { match self { Type::Tensor(tensor) => tensor.ty(), Type::Scalar(scalar) => scalar.ty(), + Type::Shape(shape) => shape.ty(), Type::Other(other) => other.ty(), } } @@ -91,6 +102,22 @@ impl ScalarType { } } +impl ShapeType { + pub fn new>(name: S, dim: usize) -> Self { + if name.as_ref().is_empty() { + panic!("Shape was passed with empty name"); + } + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + dim, + } + } + pub fn ty(&self) -> TokenStream { + let dim = self.dim.to_tokens(); + quote! { [usize; #dim] } + } +} + impl TensorType { pub fn new>( name: S, diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 95b83e15f0..3b32ba9c4f 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -23,6 +23,7 @@ use crate::{ clip::ClipNode, concat::ConcatNode, constant::{ConstantNode, ConstantValue}, + constant_of_shape::ConstantOfShapeNode, conv1d::Conv1dNode, conv2d::Conv2dNode, conv3d::Conv3dNode, @@ -51,7 +52,7 @@ use crate::{ unary::UnaryNode, unsqueeze::UnsqueezeNode, }, - ScalarKind, ScalarType, TensorKind, TensorType, Type, + ScalarKind, ScalarType, ShapeType, TensorKind, TensorType, Type, }, format_tokens, logger::init_log, @@ -330,6 +331,9 @@ impl ParsedOnnxGraph { NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), + NodeType::ConstantOfShape => { + graph.register(Self::constant_of_shape_conversion(node)) + } node_type => unsupported_ops.push(node_type), } } @@ -359,6 +363,9 @@ impl ParsedOnnxGraph { } fn constant_conversion(node: Node) -> ConstantNode { + // Additional types needed for Constant: + // use crate::burn::node::constant::{ConstantValue, TensorValue}; + let output = node.outputs.first().unwrap(); let attr = convert_constant_value(&node); @@ -461,6 +468,37 @@ impl ParsedOnnxGraph { RandomNormalNode::new(output_type, mean, scale) } + pub(crate) fn constant_of_shape_conversion(node: Node) -> ConstantOfShapeNode { + // Additional types needed for ConstantOfShape: + use crate::burn::node::constant_of_shape::ConstantValue; + + let input = node + .inputs + .first() + .expect("ConstantOfShape requires an input tensor"); + let output = node.outputs.first().unwrap(); + + // The value of the output elements.Should be a one-element tensor. + // If not specified, it defaults to a tensor of value 0 and datatype float32 + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape + let value = node + .attrs + .get("value") + .and_then(|val| val.clone().into_tensor().data) + .map(|val_data| match val_data { + // TODO: Handle Float16 + Data::Float32s(vals) => ConstantValue::from_vec(vals), + Data::Float64s(vals) => ConstantValue::from_vec(vals), + Data::Int32s(vals) => ConstantValue::from_vec(vals), + Data::Int64s(vals) => ConstantValue::from_vec(vals), + Data::Bools(vals) => ConstantValue::from_vec(vals), + ty => panic!("Unsupported value type {:?} for ConstantOfShape!", ty), + }) + .unwrap_or(ConstantValue::Float32(0.0f32)); + + ConstantOfShapeNode::new(Type::from(input), Type::from(output), value) + } + fn add_conversion(node: Node) -> BinaryNode { let lhs = Type::from(node.inputs.first().unwrap()); let rhs = Type::from(node.inputs.get(1).unwrap()); @@ -1182,7 +1220,7 @@ impl From<&OnnxArgument> for Type { ArgType::Scalar(elem_type) => { Type::Scalar(ScalarType::new(arg.name.clone(), elem_type.into())) } - ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), + ArgType::Shape(dim) => Type::Shape(ShapeType::new(arg.name.clone(), *dim)), } } } diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 385a66a585..ae11ff1ac3 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -6,7 +6,7 @@ use protobuf::Enum; use crate::{ ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, protos::tensor_proto::DataType, - util::flatten_config, + util::{flatten_config, shape_config}, }; /// Infer the dimension of each output tensor and update them. @@ -81,6 +81,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Squeeze => squeeze_update_output(node), NodeType::RandomUniform => random_update_output(node), NodeType::RandomNormal => random_update_output(node), + NodeType::ConstantOfShape => constant_of_shape_update_output(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node), } @@ -128,6 +129,34 @@ fn constant_update_outputs(node: &mut Node) { }; } +fn constant_of_shape_update_output(node: &mut Node) { + let value_type = node + .attrs + .get("value") + .map(|v| v.clone().into_tensor().elem_type) + .unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32 + + let dim = match &node.inputs[0].ty { + ArgType::Shape(dim) => *dim, + ArgType::Tensor(tensor_type) => tensor_type + .shape + .as_ref() + .and_then(|shape| shape.first()) + .copied() + .expect("ConstantOfShape node must have a Tensor with a non-empty shape"), + _ => panic!("ConstantOfShape node must have a Tensor or Shape type input"), + }; + + // Fix the input type to be a shape + node.inputs[0].ty = ArgType::Shape(dim); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: value_type, + dim, + shape: None, + }); +} + /// Infer the shape of a node's output with an explicit shape attribute /// for the Random operations with explicit shape /// @@ -568,17 +597,9 @@ fn shape_update_outputs(node: &mut Node) { panic!("Shape: multiple inputs are not supported: {:?}", node); } - let node_input = &mut node.inputs[0]; - if let ArgType::Tensor(_tensor) = node_input.clone().ty { - // Output tensor is 1D int64 - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - dim: 1, - ..Default::default() - }); - } else { - panic!("Only tensor input is valid"); - } + let (start, end) = shape_config(node); + let dim = end - start; + node.outputs[0].ty = ArgType::Shape(dim); } /// Infers the shape of a Flatten node and replaces the shape of the output tensor. diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index 0dc50dc7fb..fa30bcf83c 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -235,6 +235,11 @@ impl OnnxGraphBuilder { i += 1; keep }); + + // TODO Update graph inputs and outputs to match the processed nodes inputs and outputs + // This is necessary for the graph to be valid + // ConstantOfShape updates input to be Shape argument and output Tensor dim is updated + OnnxGraph { nodes: processed_nodes, inputs, diff --git a/crates/onnx-ir/src/ir.rs b/crates/onnx-ir/src/ir.rs index 27c22a967e..8dd94d68dc 100644 --- a/crates/onnx-ir/src/ir.rs +++ b/crates/onnx-ir/src/ir.rs @@ -5,6 +5,7 @@ use strum_macros::{Display, EnumString}; use crate::protos::TensorProto; +// TODO: Rename Dim to Rank pub type Dim = usize; pub type Shape = Vec; @@ -108,6 +109,7 @@ pub struct TensorType { pub elem_type: ElementType, /// The dimension of the tensor. + /// TODO Rename to rank pub dim: Dim, /// The shape of the tensor. diff --git a/crates/onnx-ir/src/util.rs b/crates/onnx-ir/src/util.rs index 98bf0871dd..9ad8fd3787 100644 --- a/crates/onnx-ir/src/util.rs +++ b/crates/onnx-ir/src/util.rs @@ -43,3 +43,41 @@ pub fn flatten_config(curr: &Node) -> (usize, usize) { (start_dim as usize, end_dim) } + +pub fn shape_config(curr: &Node) -> (usize, usize) { + if curr.inputs.len() != 1 { + panic!( + "Shape: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: all axes up to the last one (included) + let mut start_dim: i64 = 0; + let mut end_dim: i64 = tensor.dim as i64; + + // Extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "start" => start_dim = value.clone().into_i64(), + "end" => end_dim = value.clone().into_i64(), + _ => {} + } + } + + // If dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.dim as i64; + } + if end_dim < 0 { + end_dim += tensor.dim as i64; + } + + (start_dim as usize, end_dim as usize) +}