-
Notifications
You must be signed in to change notification settings - Fork 493
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement ONNX ConstantOfShape (#1815)
* Feat: burn-import implement ONNX ConstantOfShape * Introduce shape type and use in ConstantOfShape and Shape * Add tests for bool and int tensors for ConstantOfShape * Fix ONNX test generation * Undo comment --------- Co-authored-by: Dilshod Tadjibaev <[email protected]>
- Loading branch information
Showing
18 changed files
with
446 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+161 Bytes
crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.onnx
Binary file not shown.
53 changes: 53 additions & 0 deletions
53
crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# used to generate model: constant_of_shape.onnx | ||
|
||
# torch simplifies simple usecases where it can statically determine the shape of the constant | ||
# to use just ONNX constants instead of ConstantOfShape | ||
# Hence this model is exported using onnx directly | ||
|
||
import onnx | ||
import onnx.helper | ||
|
||
|
||
def build_model(): | ||
return onnx.helper.make_model( | ||
ir_version=8, | ||
opset_imports=[onnx.helper.make_operatorsetid("", 16)], | ||
graph=onnx.helper.make_graph(name="main_graph", nodes=[ | ||
onnx.helper.make_node( | ||
"ConstantOfShape", | ||
inputs=["input1"], | ||
outputs=["output1"], | ||
name="/ConstantOfShape", | ||
value=onnx.helper.make_tensor("value", data_type=onnx.TensorProto.FLOAT, dims=[1], vals=[1.125]) | ||
), | ||
], | ||
inputs=[ | ||
onnx.helper.make_value_info( | ||
name="input1", | ||
type_proto=onnx.helper.make_tensor_type_proto( | ||
elem_type=onnx.TensorProto.INT64, shape=[3] | ||
), | ||
) | ||
], | ||
outputs=[ | ||
onnx.helper.make_value_info( | ||
name="output1", | ||
type_proto=onnx.helper.make_tensor_type_proto( | ||
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3, 2] | ||
), | ||
) | ||
]), | ||
) | ||
|
||
|
||
def main(): | ||
onnx_model = build_model() | ||
file_name = "constant_of_shape.onnx" | ||
onnx.save(onnx_model, file_name) | ||
onnx.checker.check_model(file_name) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Binary file added
BIN
+664 Bytes
crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape_full_like.onnx
Binary file not shown.
59 changes: 59 additions & 0 deletions
59
crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape_full_like.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env python3 | ||
import torch | ||
import torch.nn as nn | ||
|
||
class Model(nn.Module): | ||
def __init__(self, fill_value_float, fill_value_int, fill_value_bool): | ||
super(Model, self).__init__() | ||
self.fill_value_float = fill_value_float | ||
self.fill_value_int = fill_value_int | ||
self.fill_value_bool = fill_value_bool | ||
|
||
def forward(self, x): | ||
# Use full_like, which will be exported as ConstantOfShape | ||
f = torch.full_like(x, self.fill_value_float, dtype=torch.float) | ||
i = torch.full_like(x, self.fill_value_int, dtype=torch.int) | ||
# Convert bool to int (1 or 0) for compatibility | ||
b = torch.full_like(x, int(self.fill_value_bool), dtype=torch.bool) | ||
return f, i, b | ||
|
||
def main(): | ||
# Set random seed for reproducibility | ||
torch.manual_seed(0) | ||
|
||
# Create an instance of the model | ||
model = Model(3.0, 5, True) | ||
|
||
# Create a dummy input | ||
test_input = torch.randn(2, 3, 4) | ||
|
||
file_name = "constant_of_shape_full_like.onnx" | ||
|
||
# Export the model to ONNX | ||
torch.onnx.export(model, test_input, file_name, | ||
verbose=False, opset_version=16, | ||
input_names=['input'], | ||
output_names=['output_float', 'output_int', 'output_bool'], | ||
dynamic_axes={'input': {0: 'batch_size', 1: 'height', 2: 'width'}, | ||
'output_float': {0: 'batch_size', 1: 'height', 2: 'width'}, | ||
'output_int': {0: 'batch_size', 1: 'height', 2: 'width'}, | ||
'output_bool': {0: 'batch_size', 1: 'height', 2: 'width'}}) | ||
|
||
print(f"Finished exporting model to {file_name}") | ||
|
||
# Output some test data for use in the test | ||
print(f"Test input data shape: {test_input.shape}") | ||
f, i, b = model.forward(test_input) | ||
print(f"Test output data shape of float: {f.shape}") | ||
print(f"Test output data shape of int: {i.shape}") | ||
print(f"Test output data shape of bool: {b.shape}") | ||
|
||
sum_f = f.sum().item() | ||
sum_i = i.sum().item() | ||
all_b = b.all().item() | ||
print(f"Test output sum of float: {sum_f}") | ||
print(f"Test output sum of int: {sum_i}") | ||
print(f"Test output all of bool: {all_b}") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.