-
Notifications
You must be signed in to change notification settings - Fork 482
Support for parsing ONNX Pad node #1352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vloncar
wants to merge
5
commits into
fastmachinelearning:main
Choose a base branch
from
vloncar:onnx_pad
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
a7e1e4e
Support for parsing ONNX Pad node
vloncar 673ac3e
Merge remote-tracking branch 'upstream/main' into onnx_pad
vloncar 562fd49
Use dynamo onnx export
vloncar 25f80d9
Merge branch 'main' into onnx_pad
JanFSchulte 8c4b669
Parse Pad node with ONNX opset >= 11
vloncar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from hls4ml.model.layers import Constant, ZeroPadding1D, ZeroPadding2D | ||
from hls4ml.model.optimizer import OptimizerPass | ||
|
||
|
||
class PaddingConstant(OptimizerPass): | ||
""" | ||
ONNX has the padding come as an input, not a parameter. This removes the Constant node from the input. | ||
The constant value was already used; this is just a cleanup uptimization. | ||
""" | ||
|
||
def match(self, node): | ||
is_match = ( | ||
isinstance(node, (ZeroPadding1D, ZeroPadding2D)) | ||
and len(node.inputs) > 1 | ||
and isinstance(node.get_input_node(node.inputs[1]), Constant) | ||
) | ||
|
||
return is_match | ||
|
||
def transform(self, model, node): | ||
""" | ||
Remove Constant node(s) from the graph. Note, padding is already present in ZeroPadding node. | ||
""" | ||
if len(node.inputs) > 2: | ||
const_val_node = node.get_input_node(node.inputs[2]) | ||
if not isinstance(const_val_node, Constant): | ||
raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})') | ||
model.remove_node(const_val_node) | ||
node.inputs.pop(2) | ||
|
||
pad_node = node.get_input_node(node.inputs[1]) | ||
if not isinstance(pad_node, Constant): | ||
raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})') | ||
model.remove_node(pad_node) | ||
node.inputs.pop(1) | ||
|
||
return True |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import qonnx.util.cleanup | ||
import torch | ||
import torch.nn as nn | ||
from qonnx.core.modelwrapper import ModelWrapper | ||
|
||
from hls4ml.converters import convert_from_onnx_model, convert_from_pytorch_model | ||
from hls4ml.utils.config import config_from_onnx_model, config_from_pytorch_model | ||
|
||
test_root_path = Path(__file__).parent | ||
|
||
|
||
def test_constantpad_1d(): | ||
class Pad1DModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right | ||
|
||
def forward(self, x): | ||
return self.pad(x) | ||
|
||
model = Pad1DModel() | ||
model.eval() | ||
config_pytorch = config_from_pytorch_model(model, (2, 4), channels_last_conversion='off') | ||
hls_model_pytorch = convert_from_pytorch_model( | ||
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/pytorch'), hls_config=config_pytorch | ||
) | ||
|
||
hls_model_pytorch.compile() | ||
|
||
onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx') | ||
torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, dynamo=True) | ||
qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path) | ||
pad1d_onnx = ModelWrapper(onnx_path) | ||
|
||
config_onnx = config_from_onnx_model(pad1d_onnx) | ||
hls_model_onnx = convert_from_onnx_model( | ||
pad1d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/onnx'), hls_config=config_onnx | ||
) | ||
|
||
hls_model_onnx.compile() | ||
|
||
input_data = np.random.randn(10, 2, 4) | ||
pred_pytorch = hls_model_pytorch.predict(input_data) | ||
pred_onnx = hls_model_onnx.predict(input_data) | ||
|
||
np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5) | ||
|
||
|
||
def test_constantpad_2d(): | ||
class Pad2DModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom | ||
|
||
def forward(self, x): | ||
return self.pad(x) | ||
|
||
model = Pad2DModel() | ||
model.eval() | ||
config_pytorch = config_from_pytorch_model(model, (2, 3, 4), channels_last_conversion='off') | ||
hls_model_pytorch = convert_from_pytorch_model( | ||
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/pytorch'), hls_config=config_pytorch | ||
) | ||
|
||
hls_model_pytorch.compile() | ||
|
||
onnx_path = str(test_root_path / 'hls4mlprj_constpad_2d/pad2d.onnx') | ||
torch.onnx.export(model, torch.randn(1, 2, 3, 4), onnx_path, dynamo=True) | ||
qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path) | ||
pad2d_onnx = ModelWrapper(onnx_path) | ||
|
||
config_onnx = config_from_onnx_model(pad2d_onnx) | ||
hls_model_onnx = convert_from_onnx_model( | ||
pad2d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/onnx'), hls_config=config_onnx | ||
) | ||
|
||
hls_model_onnx.compile() | ||
|
||
input_data = np.random.randn(10, 2, 3, 4) | ||
pred_pytorch = hls_model_pytorch.predict(input_data) | ||
pred_onnx = hls_model_onnx.predict(input_data) | ||
|
||
np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
torch.onnx.export
needs theonnxscript
module that we don't have in the test environment. So that needs to be added for these tests to work.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding onnxscript to the toml file should be fine, right? There's no need for a new container for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that should work. We would also need to require torch>=2.5 to make sure that the
dynamo
argument is available for the export. Trying that out just royally blew up my testing environment so I think I need to rebuilt that from scratch and I'll try to come with a setup that works.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently there is an issue right now with ONNX and ml-dtypes, so adding these 3 lines to the pyproject.toml gets the test to work:
Now that the test actually runs it reveals an actual issue:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In their infinite wisdom, ONNX team changed the
pads
to be an input and not an attribute, then depending on the PyTorch setup you have it will export one or the other.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that they are at opset 23 already, I think requiring >= 11 seems reasonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See here. Tools move far slower than opset versions. We can support both, supporting it as input is far more annoying though as that's a separate node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do both, of course. But it looks that if we want to support the preferred opset version, the one where it's an input is required either way :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently FINN already requires 13 or later, so there is a qonnx pr to update the preferred opset to 13. (Probably qonnx should remove the warning on the GEMM to matmul converter that suggests an old opset.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now support the opset >= 11. But the problem with the dependencies remains. If I put the dependency on
ml-dtypes>=.0.5.3
in thetesting
optional the TF breaks so most of the tests fails. Should we split the env into two in a separate PR?