Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion hls4ml/converters/onnx/reshape.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler
from hls4ml.converters.onnx_to_hls import get_constant_value, get_onnx_attribute, onnx_handler


@onnx_handler('Transpose')
Expand Down Expand Up @@ -58,3 +58,75 @@ def parse_resize_layer(node, input_names, input_shapes, graph):
)

return layer


@onnx_handler('Pad')
def parse_pad_layer(node, input_names, input_shapes, graph):
layer = {}
layer['name'] = node.name
layer['class_name'] = 'ZeroPadding'
layer['inputs'] = input_names
layer['outputs'] = list(node.output)
layer['data_format'] = (
'channels_last' if any(node.domain == 'qonnx.custom_op.channels_last' for node in graph.node) else 'channels_first'
)

mode = get_onnx_attribute(node, 'mode')
if mode is not None and mode != 'constant':
raise RuntimeError(f'Unsupported padding mode: {mode} in node {node.name}')

pads = get_constant_value(graph, node.input[1])
if len(input_names) > 2:
const_val = get_constant_value(graph, node.input[2])
if const_val != 0:
raise RuntimeError(f'Only constant value of 0 supported for Pad node {node.name}, got {const_val}')

if len(input_names) > 3:
raise RuntimeError(f'Parsing axes input of Pad node {node.name} is not supported.')

dim = 0
if len(input_shapes[0]) == 3:
dim = 1 # 2D input (batch, channels, width), will use ZeroPadding1D
if layer['data_format'] == 'channels_first':
_, channels, width = input_shapes[0]
pad_left, pad_right = pads[2], pads[5]
else:
_, width, channels = input_shapes[0]
pad_left, pad_right = pads[1], pads[4]
out_width = width + pad_left + pad_right

layer['n_chan'] = channels
layer['in_width'] = width
layer['out_width'] = out_width

layer['pad_left'] = pad_left
layer['pad_right'] = pad_right
elif len(input_shapes[0]) == 4:
dim = 2 # 3D input (batch, channels, height, width), will use ZeroPadding2D
if layer['data_format'] == 'channels_first':
_, channels, height, width = input_shapes[0]
pad_top, pad_bottom = pads[2], pads[6]
pad_left, pad_right = pads[3], pads[7]
else:
_, height, width, channels = input_shapes[0]
pad_top, pad_bottom = pads[1], pads[5]
pad_left, pad_right = pads[2], pads[6]
out_height = height + pad_top + pad_bottom
out_width = width + pad_left + pad_right

layer['n_chan'] = channels
layer['in_height'] = height
layer['in_width'] = width
layer['out_height'] = out_height
layer['out_width'] = out_width

layer['pad_top'] = pad_top
layer['pad_bottom'] = pad_bottom
layer['pad_left'] = pad_left
layer['pad_right'] = pad_right
else:
raise RuntimeError(f'Unsupported input shape: {input_shapes[0]} for Pad node {node.name}')

layer['class_name'] += str(dim) + 'D'

return layer
4 changes: 4 additions & 0 deletions hls4ml/converters/pytorch/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def parse_constantpad2d_layer(operation, layer_name, input_names, input_shapes,
layer['out_height'] = out_height
layer['out_width'] = out_width

layer['data_format'] = 'channels_first' # Default data format in PyTorch

return layer, output_shape


Expand Down Expand Up @@ -246,4 +248,6 @@ def parse_constantpad1d_layer(operation, layer_name, input_names, input_shapes,
layer['in_width'] = width
layer['out_width'] = out_width

layer['data_format'] = 'channels_first' # Default data format in PyTorch

return layer, output_shape
1 change: 1 addition & 0 deletions hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'parse_qonnx',
[
'reshape_constant',
'padding_constant',
'resize_remove_constants',
'quant_constant_parameters',
'bipolar_quant_constant_parameters',
Expand Down
37 changes: 37 additions & 0 deletions hls4ml/model/optimizer/passes/pad_const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from hls4ml.model.layers import Constant, ZeroPadding1D, ZeroPadding2D
from hls4ml.model.optimizer import OptimizerPass


class PaddingConstant(OptimizerPass):
"""
ONNX has the padding come as an input, not a parameter. This removes the Constant node from the input.
The constant value was already used; this is just a cleanup uptimization.
"""

def match(self, node):
is_match = (
isinstance(node, (ZeroPadding1D, ZeroPadding2D))
and len(node.inputs) > 1
and isinstance(node.get_input_node(node.inputs[1]), Constant)
)

return is_match

def transform(self, model, node):
"""
Remove Constant node(s) from the graph. Note, padding is already present in ZeroPadding node.
"""
if len(node.inputs) > 2:
const_val_node = node.get_input_node(node.inputs[2])
if not isinstance(const_val_node, Constant):
raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})')
model.remove_node(const_val_node)
node.inputs.pop(2)

pad_node = node.get_input_node(node.inputs[1])
if not isinstance(pad_node, Constant):
raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})')
model.remove_node(pad_node)
node.inputs.pop(1)

return True
44 changes: 0 additions & 44 deletions test/pytest/test_pytorch_constpadmapping.py

This file was deleted.

86 changes: 86 additions & 0 deletions test/pytest/test_zeropadding_pytorch_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pathlib import Path

import numpy as np
import qonnx.util.cleanup
import torch
import torch.nn as nn
from qonnx.core.modelwrapper import ModelWrapper

from hls4ml.converters import convert_from_onnx_model, convert_from_pytorch_model
from hls4ml.utils.config import config_from_onnx_model, config_from_pytorch_model

test_root_path = Path(__file__).parent


def test_constantpad_1d():
class Pad1DModel(nn.Module):
def __init__(self):
super().__init__()
self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right

def forward(self, x):
return self.pad(x)

model = Pad1DModel()
model.eval()
config_pytorch = config_from_pytorch_model(model, (2, 4), channels_last_conversion='off')
hls_model_pytorch = convert_from_pytorch_model(
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/pytorch'), hls_config=config_pytorch
)

hls_model_pytorch.compile()

onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx')
torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, dynamo=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch.onnx.export needs the onnxscript module that we don't have in the test environment. So that needs to be added for these tests to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding onnxscript to the toml file should be fine, right? There's no need for a new container for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that should work. We would also need to require torch>=2.5 to make sure that the dynamo argument is available for the export. Trying that out just royally blew up my testing environment so I think I need to rebuilt that from scratch and I'll try to come with a setup that works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently there is an issue right now with ONNX and ml-dtypes, so adding these 3 lines to the pyproject.toml gets the test to work:

  "torch>=2.5",
  "onnxscript",
  "ml-dtype>=0.5.3"

Now that the test actually runs it reveals an actual issue:

    @onnx_handler('Pad')
    def parse_pad_layer(node, input_names, input_shapes, graph):
        layer = {}
        layer['name'] = node.name
        layer['class_name'] = 'ZeroPadding'
        layer['inputs'] = input_names
        layer['outputs'] = list(node.output)
        layer['data_format'] = (
            'channels_last' if any(node.domain == 'qonnx.custom_op.channels_last' for node in graph.node) else 'channels_first'
        )
    
        mode = get_onnx_attribute(node, 'mode')
        if mode is not None and mode != 'constant':
            raise RuntimeError(f'Unsupported padding mode: {mode} in node {node.name}')
    
        pads = get_onnx_attribute(node, 'pads')
    
        dim = 0
        if len(input_shapes[0]) == 3:
            dim = 1  # 2D input (batch, channels, width), will use ZeroPadding1D
            if layer['data_format'] == 'channels_first':
                _, channels, width = input_shapes[0]
>               pad_left, pad_right = pads[2], pads[5]
E               TypeError: 'NoneType' object is not subscriptable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In their infinite wisdom, ONNX team changed the pads to be an input and not an attribute, then depending on the PyTorch setup you have it will export one or the other.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that they are at opset 23 already, I think requiring >= 11 seems reasonable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here. Tools move far slower than opset versions. We can support both, supporting it as input is far more annoying though as that's a separate node

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do both, of course. But it looks that if we want to support the preferred opset version, the one where it's an input is required either way :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently FINN already requires 13 or later, so there is a qonnx pr to update the preferred opset to 13. (Probably qonnx should remove the warning on the GEMM to matmul converter that suggests an old opset.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now support the opset >= 11. But the problem with the dependencies remains. If I put the dependency on ml-dtypes>=.0.5.3 in the testing optional the TF breaks so most of the tests fails. Should we split the env into two in a separate PR?

qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path)
pad1d_onnx = ModelWrapper(onnx_path)

config_onnx = config_from_onnx_model(pad1d_onnx)
hls_model_onnx = convert_from_onnx_model(
pad1d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/onnx'), hls_config=config_onnx
)

hls_model_onnx.compile()

input_data = np.random.randn(10, 2, 4)
pred_pytorch = hls_model_pytorch.predict(input_data)
pred_onnx = hls_model_onnx.predict(input_data)

np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5)


def test_constantpad_2d():
class Pad2DModel(nn.Module):
def __init__(self):
super().__init__()
self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom

def forward(self, x):
return self.pad(x)

model = Pad2DModel()
model.eval()
config_pytorch = config_from_pytorch_model(model, (2, 3, 4), channels_last_conversion='off')
hls_model_pytorch = convert_from_pytorch_model(
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/pytorch'), hls_config=config_pytorch
)

hls_model_pytorch.compile()

onnx_path = str(test_root_path / 'hls4mlprj_constpad_2d/pad2d.onnx')
torch.onnx.export(model, torch.randn(1, 2, 3, 4), onnx_path, dynamo=True)
qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path)
pad2d_onnx = ModelWrapper(onnx_path)

config_onnx = config_from_onnx_model(pad2d_onnx)
hls_model_onnx = convert_from_onnx_model(
pad2d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/onnx'), hls_config=config_onnx
)

hls_model_onnx.compile()

input_data = np.random.randn(10, 2, 3, 4)
pred_pytorch = hls_model_pytorch.predict(input_data)
pred_onnx = hls_model_onnx.predict(input_data)

np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5)
Loading