Skip to content

Commit

Permalink
Implement ONNX ConstantOfShape (#1815)
Browse files Browse the repository at this point in the history
* Feat: burn-import implement ONNX ConstantOfShape

* Introduce shape type and use in ConstantOfShape and Shape

* Add tests for bool and int tensors for ConstantOfShape

* Fix ONNX test generation

* Undo comment

---------

Co-authored-by: Dilshod Tadjibaev <[email protected]>
  • Loading branch information
hexd0t and antimora authored Jul 8, 2024
1 parent 924e357 commit c2b6318
Show file tree
Hide file tree
Showing 18 changed files with 446 additions and 131 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ represent the corresponding Burn Op.
| [Concat][30] |||
| [ConcatFromSequence][31] |||
| [Constant][32] |||
| [ConstantOfShape][33] | | |
| [ConstantOfShape][33] | | |
| [Conv1d][34] |||
| [Conv2d][34] |||
| [Conv3d][34] |||
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ fn main() {
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/constant_of_shape/constant_of_shape.onnx")
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
.input("tests/range/range.onnx")
.out_dir("model/")
.run_from_script();
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3

# used to generate model: constant_of_shape.onnx

# torch simplifies simple usecases where it can statically determine the shape of the constant
# to use just ONNX constants instead of ConstantOfShape
# Hence this model is exported using onnx directly

import onnx
import onnx.helper


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"ConstantOfShape",
inputs=["input1"],
outputs=["output1"],
name="/ConstantOfShape",
value=onnx.helper.make_tensor("value", data_type=onnx.TensorProto.FLOAT, dims=[1], vals=[1.125])
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[3]
),
)
],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3, 2]
),
)
]),
)


def main():
onnx_model = build_model()
file_name = "constant_of_shape.onnx"
onnx.save(onnx_model, file_name)
onnx.checker.check_model(file_name)


if __name__ == "__main__":
main()
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self, fill_value_float, fill_value_int, fill_value_bool):
super(Model, self).__init__()
self.fill_value_float = fill_value_float
self.fill_value_int = fill_value_int
self.fill_value_bool = fill_value_bool

def forward(self, x):
# Use full_like, which will be exported as ConstantOfShape
f = torch.full_like(x, self.fill_value_float, dtype=torch.float)
i = torch.full_like(x, self.fill_value_int, dtype=torch.int)
# Convert bool to int (1 or 0) for compatibility
b = torch.full_like(x, int(self.fill_value_bool), dtype=torch.bool)
return f, i, b

def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Create an instance of the model
model = Model(3.0, 5, True)

# Create a dummy input
test_input = torch.randn(2, 3, 4)

file_name = "constant_of_shape_full_like.onnx"

# Export the model to ONNX
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16,
input_names=['input'],
output_names=['output_float', 'output_int', 'output_bool'],
dynamic_axes={'input': {0: 'batch_size', 1: 'height', 2: 'width'},
'output_float': {0: 'batch_size', 1: 'height', 2: 'width'},
'output_int': {0: 'batch_size', 1: 'height', 2: 'width'},
'output_bool': {0: 'batch_size', 1: 'height', 2: 'width'}})

print(f"Finished exporting model to {file_name}")

# Output some test data for use in the test
print(f"Test input data shape: {test_input.shape}")
f, i, b = model.forward(test_input)
print(f"Test output data shape of float: {f.shape}")
print(f"Test output data shape of int: {i.shape}")
print(f"Test output data shape of bool: {b.shape}")

sum_f = f.sum().item()
sum_i = i.sum().item()
all_b = b.all().item()
print(f"Test output sum of float: {sum_f}")
print(f"Test output sum of int: {sum_i}")
print(f"Test output all of bool: {all_b}")

if __name__ == "__main__":
main()
41 changes: 37 additions & 4 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ include_models!(
squeeze_opset16,
squeeze_opset13,
random_uniform,
random_normal
random_normal,
constant_of_shape,
constant_of_shape_full_like
);

#[cfg(test)]
Expand Down Expand Up @@ -892,9 +894,8 @@ mod tests {
// Run the model
let input = Tensor::<Backend, 2>::ones([4, 2], &device);
let output = model.forward(input);
let expected = TensorData::from([4i64, 2]);

output.to_data().assert_eq(&expected, true);
let expected = [4, 2];
assert_eq!(output, expected);
}

#[test]
Expand Down Expand Up @@ -1658,4 +1659,36 @@ mod tests {
let output = model.forward();
assert_eq!(expected_shape, output.shape());
}

#[test]
fn constant_of_shape() {
// This tests shape is being passed directly to the model
let device = Default::default();
let model = constant_of_shape::Model::<Backend>::new(&device);
let input_shape = [2, 3, 2];
let expected = Tensor::<Backend, 3>::full(input_shape, 1.125, &device).to_data();

let output = model.forward(input_shape);

output.to_data().assert_approx_eq(&expected, 3);
}

#[test]
fn constant_of_shape_full_like() {
// This tests shape is being passed from the input tensor

let device = Default::default();
let model = constant_of_shape_full_like::Model::<Backend>::new(&device);
let shape = [2, 3, 2];
let f_expected = Tensor::<Backend, 3>::full(shape, 3.0, &device);
let i_expected = Tensor::<Backend, 3, Int>::full(shape, 5, &device);
let b_expected = Tensor::<Backend, 3, Int>::ones(shape, &device).bool();

let input = Tensor::ones(shape, &device);
let (f_output, i_output, b_output) = model.forward(input);

assert!(f_output.equal(f_expected).all().into_scalar());
assert!(i_output.equal(i_expected).all().into_scalar());
assert!(b_output.equal(b_expected).all().into_scalar());
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
Type::Tensor(tensor) => Some(tensor),
Type::Scalar(_) => None,
Type::Other(_) => None,
Type::Shape(_) => None,
}
}

Expand Down
11 changes: 7 additions & 4 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use std::marker::PhantomData;
use super::{
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, conv3d::Conv3dNode,
conv_transpose_2d::ConvTranspose2dNode, conv_transpose_3d::ConvTranspose3dNode,
dropout::DropoutNode, expand::ExpandNode, gather::GatherNode,
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
constant::ConstantNode, constant_of_shape::ConstantOfShapeNode, conv1d::Conv1dNode,
conv2d::Conv2dNode, conv3d::Conv3dNode, conv_transpose_2d::ConvTranspose2dNode,
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
Expand Down Expand Up @@ -116,6 +116,7 @@ pub enum Node<PS: PrecisionSettings> {
Where(WhereNode),
RandomUniform(RandomUniformNode),
RandomNormal(RandomNormalNode),
ConstantOfShape(ConstantOfShapeNode),
// For now, we have to keep the precision settings in order to correctly serialize the fields
// into the right data types.
_Unreachable(std::convert::Infallible, PhantomData<PS>),
Expand Down Expand Up @@ -160,6 +161,7 @@ macro_rules! match_all {
Node::Where(node) => $func(node),
Node::RandomNormal(node) => $func(node),
Node::RandomUniform(node) => $func(node),
Node::ConstantOfShape(node) => $func(node),
_ => unimplemented!(),
}
}};
Expand Down Expand Up @@ -212,6 +214,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Where(_) => "where",
Node::RandomNormal(_) => "random_normal",
Node::RandomUniform(_) => "random_uniform",
Node::ConstantOfShape(_) => "constant_of_shape",
_ => unimplemented!(),
}
}
Expand Down
Loading

0 comments on commit c2b6318

Please sign in to comment.