diff --git a/docs/frontend/keras.rst b/docs/frontend/keras.rst index d6d42cb4b8..9ede7b1d8c 100644 --- a/docs/frontend/keras.rst +++ b/docs/frontend/keras.rst @@ -1,11 +1,19 @@ ================ -Keras and QKeras +Keras and its quantized variants ================ -Keras and the quantization library QKeras are well supported in ``hls4ml``. Currently, the Keras v2 (``tf.keras``) is the preferred version, and the future versions of ``hls4ml`` will expand support for Keras v3. The frontend is based on the parsing the serialized json representation of the model. +Keras and the quantization library QKeras are well supported in ``hls4ml``. Both Keras v2 (``tf.keras``) and the new Keras v3 are supported. While the Keras v2 support is based on parsing the serialized json representation of the model, the Keras v3 support uses direct model inspection. -Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The equivalent QKeras API and quantizers are also supported. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`. +Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`. The ``data_format='channels_first'`` parameter of Keras layers is supported, but not extensively tested. All HLS implementations in ``hls4ml`` are based on ``channels_last`` data format and need to be converted to that format before the HLS code can be emitted. We encourage users of ``channels_first`` to report their experiences to developers on GitHub. + +* `QKeras `_ + The equivalent QKeras API and its quantizers are also supported by ``hls4ml``. QKeras is not compatible with Keras v3. +* `HGQ `_ + The equivalent HGQ API is also supported. HGQ is not compatible with Keras v3. See `advanced/HGQ <../advanced/hgq.html>`__ for more information. +* `HGQ2 `_ + HGQ2 is based on Keras v3. Its support in hls4ml is currently under development. + The development team of ``hls4ml`` is currently exploring options for QKeras alternative and will provide a drop-in replacement API compatible with Keras v3. diff --git a/docs/intro/setup.rst b/docs/intro/setup.rst index eba9e2f45a..3c21290981 100644 --- a/docs/intro/setup.rst +++ b/docs/intro/setup.rst @@ -37,14 +37,26 @@ version can be installed directly from ``git``: Dependencies ============ -The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed -by ``pip`` or ``conda``. +.. note:: + As of version 1.1.0+, all conversion frontend specific packages are optional. Only install the packages you need. -* `TensorFlow `_ (version 2.8 to 2.14) and `QKeras `_ are required by the Keras converter. One may want to install newer versions of QKeras from GitHub. Newer versions of TensorFlow can be used, but QKeras and hl4ml do not currently support Keras v3. +The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed by ``pip`` or ``conda``. + +The following Python packages are all optional and are only required if you intend to use the corresponding converter. Only install the packages you need. + +* `Keras `_ is required by the Keras converter. + * `TensorFlow `_ (version 2.8 to 2.14) is required by the Keras v2 converter (keras v2 is included in TensorFlow). + * `Keras ` 3.0 or above is required by the Keras v3 converter. Keras v3 supports multiple backends for training and inference, and the conversion is not tied any specific backend. Notice that Keras v3 may **not** coexist with Keras v2 in the same Python environment. * `ONNX `_ (version 1.4.0 and newer) is required by the ONNX converter. -* `PyTorch `_ package is optional. If not installed, the PyTorch converter will not be available. +* `PyTorch `_ is required by the PyTorch converter. + +* Quantization support + * `QKeras `_: based on Keras v2. See `frontend/keras <../frontend/keras.html>`_ for more details + * `HGQ `_: Based on Keras v2. See `advanced/HGQ <../advanced/hgq.html>`_ for more details. + * `Brevitas `_: Based on PyTorch. See `frontend/pytorch <../frontend/pytorch.html>`_ for more details. + * `QONNX `_: Based on ONNX. See `frontend/onnx <../frontend/onnx.html>`_ for more details. Running C simulation from Python requires a C++11-compatible compiler. On Linux, a GCC C++ compiler ``g++`` is required. Any version from a recent Linux should work. On MacOS, the *clang*-based ``g++`` is enough. For the oneAPI backend, one must have oneAPI installed, along with the FPGA compiler, diff --git a/example-models b/example-models index c6bb3c0686..3cfbcfd062 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit c6bb3c0686d52439d8c53d7407903bf78e852562 +Subproject commit 3cfbcfd062f60492507d21ff0e91559b3bdd6550 diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index bd85937d89..95d900fd62 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -914,7 +914,7 @@ def generate_conv2d_line_buffer_fn( return generated_code @staticmethod - def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): + def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): """ Generate new shape and perm_strides for a permute operation. Operates by mapping the output index to input input index by: @@ -933,12 +933,20 @@ def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]) perm (tuple[int, ...]): The permutation of the dimensions. Returns: - (new_shape, perm_strides) (tuple, tuple): the output shape and permutation strides. + dict: Dictionary containing the configuration. """ new_shape = tuple(shape[i] for i in perm) strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1] perm_strides = tuple(int(strides[i]) for i in perm) - return (new_shape, perm_strides) + return dict( + dims=len(shape), + N=math.prod(shape), + from_shape=', '.join(str(x) for x in shape), + perm=', '.join(str(x) for x in perm), + perm_strides=', '.join(str(x) for x in perm_strides), + to_shape=', '.join(str(x) for x in new_shape), + config_name=name, + ) @model_optimizer() def write_hls(self, model): diff --git a/hls4ml/backends/oneapi/passes/reshaping_templates.py b/hls4ml/backends/oneapi/passes/reshaping_templates.py index 462758c228..80b467b944 100644 --- a/hls4ml/backends/oneapi/passes/reshaping_templates.py +++ b/hls4ml/backends/oneapi/passes/reshaping_templates.py @@ -185,16 +185,8 @@ def format(self, node): perm = tuple(node.get_attr('perm')) name = f'config{node.index}' - new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm) - return transpose_config_template.format( - dims=len(shape), - N=int(np.prod(shape)), - from_shape=', '.join(str(x) for x in shape), - perm=', '.join(str(x) for x in perm), - perm_strides=', '.join(str(x) for x in perm_strides), - to_shape=', '.join(str(x) for x in new_shape), - config_name=name, - ) + conf = node.model.config.backend.transpose_config_gen(name, shape, perm) + return transpose_config_template.format(**conf) class TransposeFunctionTemplate(FunctionCallTemplate): diff --git a/hls4ml/backends/vivado/passes/einsum.py b/hls4ml/backends/vivado/passes/einsum.py new file mode 100644 index 0000000000..4f976c63af --- /dev/null +++ b/hls4ml/backends/vivado/passes/einsum.py @@ -0,0 +1,108 @@ +from math import ceil + +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import Einsum + +from .reshaping_templates import transpose_config_template + +# Shared Dense template +# Einsum template + +einsum_config_template = ''' +struct config{index} {{ + typedef config{index}_tpose_inp0 tpose_inp0_conf; + typedef config{index}_tpose_inp1 tpose_inp1_conf; + typedef config{index}_tpose_out tpose_out_conf; + + typedef {accum_t.name} accum_t; + + // Layer Sizes + static const unsigned n_free0 = {n_free0}; + static const unsigned n_free1 = {n_free1}; + static const unsigned n_contract = {n_contract}; + static const unsigned n_inplace = {n_inplace}; + + // Resource reuse info + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse_factor}; + static const unsigned multiplier_limit = {multiplier_limit}; + static const bool store_weights_in_bram = false; // NOT USED + + template + using product = nnet::product::{product_type}; +}}; +''' + +einsum_function_template = 'nnet::einsum<{input0_t}, {input1_t}, {output_t}, {config}>({input0}, {input1}, {output});' + +einsum_include_list = ['nnet_utils/nnet_einsum.h'] + + +class EinsumConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Einsum) + self.template = einsum_config_template + + def format(self, node: Einsum): + default_params = self._default_config_params(node) + + strategy = node.attributes.attributes['strategy'] + io_type = node.model.config.get_config_value('IOType') + + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' + assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now' + + # EinsumDense config + params = default_params.copy() + params['strategy'] = strategy + params['n_free0'] = node.attributes.attributes['n_free0'] + params['n_free1'] = node.attributes.attributes['n_free1'] + params['n_contract'] = node.attributes.attributes['n_contract'] + params['n_inplace'] = node.attributes.attributes['n_inplace'] + inp0_t = node.get_input_variable(node.inputs[0]).type.precision + inp1_t = node.get_input_variable(node.inputs[1]).type.precision + params['product_type'] = get_backend('vivado').product_type(inp0_t, inp1_t) + + total_mults = params['n_free0'] * params['n_free1'] * params['n_contract'] * params['n_inplace'] + params['multiplier_limit'] = ceil(total_mults / params['reuse_factor']) + + einsum_conf = self.template.format(**params) + + # inp/out transpose config + inp0_shape = node.attributes.attributes['inp0_shape'] + inp1_shape = node.attributes.attributes['inp1_shape'] + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] + inp0_tpose_idxs = node.attributes.attributes['inp0_tpose_idxs'] + inp1_tpose_idxs = node.attributes.attributes['inp1_tpose_idxs'] + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + tpose_inp0_conf_name = f'config{node.index}_tpose_inp0' + tpose_inp1_conf_name = f'config{node.index}_tpose_inp1' + tpose_out_conf_name = f'config{node.index}_tpose_out' + + conf = node.model.config.backend.transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) + inp0_tpose_conf = transpose_config_template.format(**conf) + conf = node.model.config.backend.transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) + inp1_tpose_conf = transpose_config_template.format(**conf) + conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + out_tpose_conf = transpose_config_template.format(**conf) + + return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf)) + + +class EinsumFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Einsum, include_header=einsum_include_list) + self.template = einsum_function_template + + def format(self, node: Einsum): + params = {} + params['config'] = f'config{node.index}' + params['input0_t'] = node.get_input_variable(node.inputs[0]).type.name + params['input1_t'] = node.get_input_variable(node.inputs[1]).type.name + params['output_t'] = node.get_output_variable().type.name + params['input0'] = node.get_input_variable(node.inputs[0]).name + params['input1'] = node.get_input_variable(node.inputs[1]).name + params['output'] = node.get_output_variable().name + return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py new file mode 100644 index 0000000000..1b4b183039 --- /dev/null +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -0,0 +1,147 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import EinsumDense + +from .reshaping_templates import transpose_config_template + +# Shared Dense template + +dense_config_template = """struct config{index}_dense : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned reuse_factor = {reuse}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned n_zeros = {nzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + template + using kernel = nnet::{dense_function}; + template + using product = nnet::product::{product_type}; +}};\n""" + +# EinsumDense template + +einsum_dense_config_template = ''' +struct config{index} {{ + typedef config{index}_tpose_inp tpose_inp_conf; + typedef config{index}_tpose_out tpose_out_conf; + {kernel_config}; + + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + + // Layer Sizes + static const unsigned n_free_data = {n_free_data}; + static const unsigned n_free_kernel = {n_free_kernel}; + static const unsigned n_contract = {n_contract}; + static const unsigned n_inplace = {n_inplace}; + + // Resource reuse info + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse_factor}; + static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED +}}; +''' + +einsum_dense_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +einsum_dense_da_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {b});' + +einsum_dense_include_list = ['nnet_utils/nnet_einsum_dense.h', 'nnet_utils/nnet_dense.h'] + + +class EinsumDenseConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(EinsumDense) + self.template = einsum_dense_config_template + self.dense_template = dense_config_template + + def dense_config(self, node: EinsumDense): + dense_params = self._default_config_params(node) + strategy = node.attributes['strategy'] + dense_params['strategy'] = strategy + dense_params['n_in'] = node.attributes.attributes['n_contract'] + dense_params['n_out'] = node.attributes.attributes['n_free_kernel'] + if node.attributes.attributes['n_inplace'] == 1: + dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore + else: + dense_params['nzeros'] = '-1; // Not making sense when kernels are switching' + dense_params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision # type: ignore + ) + + dense_params['dense_function'] = 'DenseLatency' # Latency only for now + + dense_config = self.dense_template.format(**dense_params) + return dense_config + + def format(self, node: EinsumDense): + default_params = self._default_config_params(node) + + strategy = node.attributes['strategy'] + io_type = node.model.config.get_config_value('IOType') + + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel and distributed_arithmetic' + + # EinsumDense config + params = default_params.copy() + params['strategy'] = strategy + params['n_free_data'] = node.attributes.attributes['n_free_data'] + params['n_free_kernel'] = node.attributes.attributes['n_free_kernel'] + params['n_contract'] = node.attributes.attributes['n_contract'] + params['n_inplace'] = node.attributes.attributes['n_inplace'] + if strategy.lower() == 'latency': + params['kernel_config'] = f'typedef config{node.index}_dense dense_conf' + else: + assert strategy.lower() == 'distributed_arithmetic', 'EinsumDense layer only supports Latency strategy for now' + inp_t = node.get_input_variable().type.name + result_t = node.get_output_variable().type.name + index = node.index + conf = f'constexpr static auto da_kernel = nnet::einsum_dense{index}_da_kernel<{inp_t}, {result_t}>' + params['kernel_config'] = conf + pf = node.attributes.attributes['parallelization_factor'] + if pf < 0: + pf = params['n_inplace'] + params['parallelization_factor'] = pf + + einsum_conf = self.template.format(**params) + + # inp/out transpose config + inp_shape = node.attributes.attributes['inp_shape'] + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] + inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs'] + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + tpose_inp_conf_name = f'config{node.index}_tpose_inp' + tpose_out_conf_name = f'config{node.index}_tpose_out' + + conf = node.model.config.backend.transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) + inp_tpose_conf = transpose_config_template.format(**conf) + conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + out_tpose_conf = transpose_config_template.format(**conf) + + if strategy.lower() == 'distributed_arithmetic': + return '\n\n'.join((inp_tpose_conf, out_tpose_conf, einsum_conf)) + + dense_config = self.dense_config(node) + return '\n\n'.join((inp_tpose_conf, out_tpose_conf, dense_config, einsum_conf)) + + +class EinsumDenseFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(EinsumDense, include_header=einsum_dense_include_list) + self.template = einsum_dense_function_template + + def format(self, node): + params = self._default_function_params(node) + params['b'] = node.get_weights('bias').name + + strategy = node.attributes['strategy'] + if strategy == 'distributed_arithmetic': + return einsum_dense_da_function_template.format(**params) + + params['w'] = node.get_weights('weight').name + return einsum_dense_function_template.format(**params) diff --git a/hls4ml/backends/vivado/passes/reshaping_templates.py b/hls4ml/backends/vivado/passes/reshaping_templates.py index ff16d15c9d..69944e4497 100644 --- a/hls4ml/backends/vivado/passes/reshaping_templates.py +++ b/hls4ml/backends/vivado/passes/reshaping_templates.py @@ -1,5 +1,3 @@ -import numpy as np - from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D @@ -128,22 +126,13 @@ def format(self, node): class TransposeConfigTemplate(LayerConfigTemplate): def __init__(self): super().__init__(Transpose) - self.template = transpose_config_template def format(self, node): shape = tuple(node.get_input_variable().shape) perm = tuple(node.get_attr('perm')) name = f'config{node.index}' - new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm) - return transpose_config_template.format( - dims=len(shape), - N=np.prod(shape), - from_shape=', '.join(str(x) for x in shape), - perm=', '.join(str(x) for x in perm), - perm_strides=', '.join(str(x) for x in perm_strides), - to_shape=', '.join(str(x) for x in new_shape), - config_name=name, - ) + conf = node.model.config.backend.transpose_config_gen(name, shape, perm) + return transpose_config_template.format(**conf) class TransposeFunctionTemplate(FunctionCallTemplate): diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 0a18d4503d..fa564d2b0c 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -17,6 +17,8 @@ Dense, DepthwiseConv1D, DepthwiseConv2D, + Einsum, + EinsumDense, Embedding, GarNet, GarNetStack, @@ -685,3 +687,30 @@ def init_garnet(self, layer): @layer_optimizer(GarNetStack) def init_garnet_stack(self, layer): self.init_garnet(layer) + + @layer_optimizer(EinsumDense) + def init_einsum_dense(self, layer: EinsumDense) -> None: + strategy: str | None = layer.model.config.get_strategy(layer) + if not strategy: + layer.set_attr('strategy', 'latency') + return + if strategy in ('latency', 'resource', 'distributed_arithmetic'): + layer.set_attr('strategy', strategy) + return + warn(f'Invalid strategy "{strategy}" for EinsumDense layer "{layer.name}". Using "latency" strategy instead.') + layer.set_attr('strategy', 'latency') + + @layer_optimizer(Einsum) + def init_einsum(self, layer: Einsum) -> None: + strategy: str | None = layer.model.config.get_strategy(layer) + if not strategy: + layer.set_attr('strategy', 'latency') + return + if strategy.lower() == 'resource': + layer.set_attr('strategy', 'resource') + return + if strategy.lower() in ('latency', 'distributed_arithmetic'): + layer.set_attr('strategy', 'latency') + return + warn(f'Invalid strategy "{strategy}" for Einsum layer "{layer.name}". Using "latency" strategy instead.') + layer.set_attr('strategy', 'latency') diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 693a76f666..47569b1ad9 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -9,6 +9,7 @@ from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401 from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401 from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler +from hls4ml.converters.keras_v3_to_hls import parse_keras_v3_model # noqa: F401 from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401 from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401 from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler @@ -17,6 +18,8 @@ pytorch_to_hls, register_pytorch_layer_handler, ) + +# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401 from hls4ml.model import ModelGraph from hls4ml.utils.config import create_config from hls4ml.utils.dependency import requires diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_to_hls.py index a7260625e9..ea3f96c236 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_to_hls.py @@ -4,6 +4,8 @@ from hls4ml.model import ModelGraph +from .keras_v3_to_hls import parse_keras_v3_model + MAXMULT = 4096 @@ -357,6 +359,13 @@ def parse_keras_model(model_arch, reader): def keras_to_hls(config): + if 'KerasModel' in config: + import keras + + if keras.__version__ >= '3.0': + layer_list, input_layers, output_layers, _ = parse_keras_v3_model(config['KerasModel']) + return ModelGraph(config, layer_list, input_layers, output_layers) + model_arch, reader = get_model_arch(config) layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader) print('Creating HLS model') diff --git a/hls4ml/converters/keras_v3/__init__.py b/hls4ml/converters/keras_v3/__init__.py new file mode 100644 index 0000000000..6dffcb71d5 --- /dev/null +++ b/hls4ml/converters/keras_v3/__init__.py @@ -0,0 +1,6 @@ +from . import conv # noqa: F401 +from . import core # noqa: F401 +from . import einsum_dense # noqa: F401 +from ._base import registry as layer_handlers + +__all__ = ['layer_handlers'] diff --git a/hls4ml/converters/keras_v3/_base.py b/hls4ml/converters/keras_v3/_base.py new file mode 100644 index 0000000000..6f50ed6523 --- /dev/null +++ b/hls4ml/converters/keras_v3/_base.py @@ -0,0 +1,216 @@ +import typing +from types import FunctionType +from typing import Any, Callable, Sequence, TypedDict, overload + + +class DefaultConfig(TypedDict, total=False): + name: str + class_name: str + module: str + input_keras_tensor_names: list[str] + input_shape: list[list[int]] + output_keras_tensor_names: list[str] + epsilon: float + use_bias: bool + data_format: str + + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + +T_kv3_handler = Callable[ + ['keras.Layer', Sequence['keras.KerasTensor'], Sequence['keras.KerasTensor']], tuple[dict[str, Any], ...] +] + +registry: dict[str, T_kv3_handler] = {} + + +@overload +def register(cls: type) -> type: ... + + +@overload +def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ... + + +def register(cls: str | type): + """Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. + + Parameters + ---------- + cls : str|type + If str, the key to register the handler under. If type, the class to register the handler for. + + Examples + -------- + ```python + @keras_dispatcher.register + class MyLayerHandler(KerasV3LayerHandler): + handles = ('my_package.src.submodule.MyLayer', 'MyLayer2') + + def handle(self, layer, inp_tensors, out_tensors): + # handler code + + + @keras_dispatcher.register('MyLayer3') + def my_layer_handler(layer, inp_tensors, out_tensors): + # handler code + ``` + """ + + def deco(func): + if isinstance(cls, str): + registry[cls] = func + for k in getattr(func, 'handles', ()): + registry[k] = func + if isinstance(cls, type): + return cls + return func + + if isinstance(cls, type): + return deco(cls()) + return deco + + +def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: str): + for attr in attrs: + if attr not in config and hasattr(obj, attr): + config[attr] = getattr(obj, attr) + + +class KerasV3LayerHandler: + """Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type.""" + + handles = () + default_config: DefaultConfig + + def __call__( + self, + layer: 'keras.Layer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ) -> tuple[dict[str, Any], ...]: + """Handle a keras layer. Return a tuple of dictionaries, each + dictionary representing a layer (module) in the HLS model. One + layer may correspond one or more dictionaries (e.g., layers with + activation functions will be split into two layers). + + Some common attributes are automatically added to the dictionary + if the handler returns a single dictionary. If the handler + returns multiple dictionaries, the attributes must be added + manually. Anything returned by the handler will override the + automatic attributes. + + Automatic attributes: - name - class_name - module - + input_keras_tensor_names - input_shape - + output_keras_tensor_names + + If the layer has an activation function, an additional + dictionary will be added to the return value representing the + activation function. + + + Parameters + ---------- + layer : keras.Layer + The layer to be converted to HLS configuration(s). + in_tensors : Sequence[KerasTensor] + The list of input tensors to the layer. + out_tensors : Sequence[KerasTensor] + The list of output tensors from the layer. + + Returns + ------- + dict[str, Any] | tuple[dict[str, Any], ...] + layer configuration(s) for the HLS model to be consumed by + the ModelGraph constructor + """ + + name = layer.name + class_name = layer.__class__.__name__ + module = layer.__module__ + + default_config: DefaultConfig = { + 'name': name, + 'class_name': class_name, + 'module': module, + 'input_keras_tensor_names': [t.name for t in in_tensors], + 'input_shape': [list(t.shape[1:]) for t in in_tensors], # type: ignore + 'output_keras_tensor_names': [t.name for t in out_tensors], + } + + maybe_add_attrs(default_config, layer, 'epsilon', 'use_bias', 'data_format') + + mandatory_keys = ['name', 'class_name', 'output_keras_tensor_names', 'input_keras_tensor_names'] + + self.default_config = default_config + config0 = self.handle(layer, in_tensors, out_tensors) + del self.default_config + + if isinstance(config0, tuple): + for conf in config0: + for key in mandatory_keys: + assert key in conf, f"Key {key} missing from layer {name} handled by {self.__class__.__name__}" + return config0 + + config = {} + config.update(default_config) + config.update(config0) + ret = (config,) + + # If activation exists, append it + + act_config, intermediate_tensor_name = self.maybe_get_activation_config(layer, out_tensors) + if act_config is not None: + ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name] + ret = *ret, act_config + + return ret + + def maybe_get_activation_config(self, layer, out_tensors): + import keras + + activation = getattr(layer, 'activation', None) + name = layer.name + if activation not in (keras.activations.linear, None): + assert len(out_tensors) == 1, f"Layer {name} has more than one output, but has an activation function" + assert isinstance(activation, FunctionType), f"Activation function for layer {name} is not a function" + intermediate_tensor_name = f'{out_tensors[0].name}_activation' + act_cls_name = activation.__name__ + act_config = { + 'class_name': 'Activation', + 'activation': act_cls_name, + 'name': f'{name}_{act_cls_name}', + 'input_keras_tensor_names': [intermediate_tensor_name], + 'output_keras_tensor_names': [out_tensors[0].name], + } + return act_config, intermediate_tensor_name + return None, None + + def handle( + self, + layer: 'keras.Layer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ) -> dict[str, Any] | tuple[dict[str, Any], ...]: + return {} + + def load_weight(self, layer: 'keras.Layer', key: str): + """Load a weight from a layer. + + Parameters + ---------- + layer : keras.Layer + The layer to load the weight from. + key : str + The key of the weight to load. + + Returns + ------- + np.ndarray + The weight. + """ + import keras + + return keras.ops.convert_to_numpy(getattr(layer, key)) diff --git a/hls4ml/converters/keras_v3/conv.py b/hls4ml/converters/keras_v3/conv.py new file mode 100644 index 0000000000..adf6221822 --- /dev/null +++ b/hls4ml/converters/keras_v3/conv.py @@ -0,0 +1,119 @@ +import typing +from math import ceil +from typing import Sequence + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + + +@register +class KV3ConvHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.convolutional.conv1d.Conv1D', + 'keras.src.layers.convolutional.conv2d.Conv2D', + 'keras.src.layers.convolutional.depthwise_conv1d.DepthwiseConv1D', + 'keras.src.layers.convolutional.depthwise_conv2d.DepthwiseConv2D', + 'keras.src.layers.convolutional.separable_conv1d.SeparableConv1D', + 'keras.src.layers.convolutional.separable_conv2d.SeparableConv2D', + ) + + def handle( + self, + layer: 'keras.layers.Conv1D|keras.layers.Conv2D|keras.layers.DepthwiseConv1D|keras.layers.DepthwiseConv2D', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras.src.layers.convolutional.base_conv import BaseConv + from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv + from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv + + assert len(in_tensors) == 1, f"Layer {layer.name} has more than one input" + assert len(out_tensors) == 1, f"Layer {layer.name} has more than one output" + + in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + assert all(isinstance(x, int) for x in in_shape), f"Layer {layer.name} has non-fixed size input: {in_shape}" + assert all(isinstance(x, int) for x in out_shape), f"Layer {layer.name} has non-fixed size output: {out_shape}" + + kernel = self.load_weight(layer, 'kernel') + if layer.use_bias: + bias = self.load_weight(layer, 'bias') + else: + bias = None + + ker_px_shape: tuple[int, ...] = layer.kernel_size + data_format = layer.data_format + + if data_format == 'channels_last': + *px_in_shape, ch_in = in_shape + *px_out_shape, ch_out = out_shape + else: + ch_in, *px_in_shape = in_shape + ch_out, *px_out_shape = out_shape + + if layer.padding == 'same': + n_padding = [ceil(N / n) * n - N for N, n in zip(px_in_shape, ker_px_shape)] + n_padding0 = [p // 2 for p in n_padding] + n_padding1 = [p - p0 for p, p0 in zip(n_padding, n_padding0)] + elif layer.padding == 'valid': + n_padding0 = [0] * len(px_in_shape) + n_padding1 = [0] * len(px_in_shape) + elif layer.padding == 'causal': + n_padding0 = [ker_px_shape[0] - 1] + [0] * (len(px_in_shape) - 1) + n_padding1 = [0] * len(px_in_shape) + else: + raise ValueError(f"Invalid padding mode {layer.padding} for layer {layer.name}") + + config = { + 'bias_data': bias, + 'data_format': data_format, + 'weight_data': kernel, + 'n_filt': ch_out, + 'n_chan': ch_in, + } + + if layer.rank == 1: + config.update( + { + 'filt_width': ker_px_shape[0], + 'stride_width': layer.strides[0], + 'pad_left': n_padding0[0], + 'pad_right': n_padding1[0], + 'in_width': px_in_shape[0], + 'out_width': px_out_shape[0], + } + ) + elif layer.rank == 2: + config.update( + { + 'filt_height': ker_px_shape[0], + 'filt_width': ker_px_shape[1], + 'stride_height': layer.strides[0], + 'stride_width': layer.strides[1], + 'pad_top': n_padding0[0], + 'pad_bottom': n_padding1[0], + 'pad_left': n_padding0[1], + 'pad_right': n_padding1[1], + 'in_height': px_in_shape[0], + 'in_width': px_in_shape[1], + 'out_height': px_out_shape[0], + 'out_width': px_out_shape[1], + } + ) + else: + _cls = f"{layer.__class__.__module__}.{layer.__class__.__name__}" + raise ValueError(f"Only 1D and 2D conv layers are supported, got {_cls} (rank={layer.rank})") + if isinstance(layer, BaseDepthwiseConv): + config['depthwise_data'] = kernel + config['depth_multiplier'] = layer.depth_multiplier + elif isinstance(layer, BaseSeparableConv): + config['depthwise_data'] = kernel + config['pointwise_data'] = self.load_weight(layer, 'pointwise_kernel') + config['depth_multiplier'] = layer.depth_multiplier + elif isinstance(layer, BaseConv): + config['weight_data'] = kernel + + return config diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py new file mode 100644 index 0000000000..f3ac9a0d75 --- /dev/null +++ b/hls4ml/converters/keras_v3/core.py @@ -0,0 +1,222 @@ +import inspect +import typing +from math import prod +from typing import Any, Sequence + +import numpy as np + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + from keras.src.layers.merging.base_merge import Merge + + +@register +class KV3DenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.dense.Dense',) + + def handle( + self, + layer: 'keras.layers.Dense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + + kernel = self.load_weight(layer, 'kernel') + bias = self.load_weight(layer, 'bias') if layer.use_bias else None + n_in, n_out = kernel.shape + + config = { + 'data_format': 'channels_last', + 'weight_data': kernel, + 'bias_data': bias, + 'n_out': n_out, + 'n_in': n_in, + } + return config + + +@register +class KV3InputHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.input_layer.InputLayer',) + + def handle( + self, + layer: 'keras.layers.InputLayer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {'input_shape': list(layer._batch_shape[1:])} + return config + + +@register +class KV3MergeHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.merging.add.Add', + 'keras.src.layers.merging.multiply.Multiply', + 'keras.src.layers.merging.average.Average', + 'keras.src.layers.merging.maximum.Maximum', + 'keras.src.layers.merging.minimum.Minimum', + 'keras.src.layers.merging.concatenate.Concatenate', + 'keras.src.layers.merging.subtract.Subtract', + 'keras.src.layers.merging.dot.Dot', + ) + + def handle( + self, + layer: 'Merge', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + cls_name: str | None = None, + ): + assert len(out_tensors) == 1, f"Merge layer {layer.name} has more than one output" + output_shape = list(out_tensors[0].shape[1:]) + + cls_name = cls_name or layer.__class__.__name__ + config: dict[str, Any] = { + 'output_shape': output_shape, + 'op': cls_name.lower(), + } + + match cls_name.lower(): + case 'Concatenate': + rank = len(output_shape) + class_name = f'Concatenate{rank}d' + config['axis'] = layer.axis + case 'Dot': + class_name = f'Dot{len(output_shape)}d' + rank = len(output_shape) + assert rank == 1, f"Dot product only supported for 1D tensors, got {rank}D on layer {layer.name}" + case _: + class_name = 'Merge' + + config['class_name'] = class_name + return config + + +@register +class KV3ActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.activation.Activation',) + + def handle( + self, + layer: 'keras.layers.Activation', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + import keras + + config = {} + config.update(self.default_config) + + activation = getattr(layer, 'activation', keras.activations.linear) + match activation: + case keras.activations.softmax: + class_name = 'Softmax' + config['axis'] = -1 + case keras.activations.hard_sigmoid: + class_name = 'HardActivation' + case keras.activations.leaky_relu: + class_name = 'LeakyReLU' + signature = inspect.signature(keras.activations.leaky_relu) + config['activ_param'] = signature.parameters['negative_slope'].default + case keras.activations.elu: + class_name = 'ELU' + signature = inspect.signature(keras.activations.elu) + config['activ_param'] = signature.parameters['alpha'].default + case _: + class_name = 'Activation' + + config['activation'] = activation.__name__ + config['class_name'] = class_name + return (config,) + + +@register +class KV3ReLUHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.activations.leaky_relu.LeakyReLU', + 'keras.src.layers.activations.prelu.PReLU', + 'keras.src.layers.activations.relu.ReLU', + ) + + def handle( + self, + layer: 'keras.layers.ReLU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + if layer.__class__.__name__ == 'ReLU': + config['class_name'] = 'Activation' + config['activation'] = 'relu' + return config + + if layer.__class__.__name__ == 'PReLU': + config['class_name'] = 'PReLU' + config['param_data'] = np.array(layer.alpha) + config['activation'] = 'prelu' + else: + config['class_name'] = 'LeakyReLU' + config['activ_param'] = float(layer.negative_slope) + config['activation'] = 'leaky_relu' + + return (config,) + + +@register +class KV3SoftmaxHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.softmax.Softmax',) + + def handle( + self, + layer: 'keras.layers.Softmax', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + ax = layer.axis + ax = ax if ax >= 0 else len(in_tensors[0].shape) + ax + # io_stream asserts axis=-1, convert to -1 when it is + n_outer: int = prod(in_tensors[0].shape[1:ax]) # type: ignore + n_inner: int = prod(in_tensors[0].shape[ax + 1 :]) # type: ignore + ax = -1 if ax == len(in_tensors[0].shape) - 1 else ax + config = {} + config.update(self.default_config) + if len(in_tensors) == 2: + raise NotImplementedError("Masked softmax not supported yet") + config['class_name'] = 'MaskedSoftmax' + elif len(in_tensors) == 1: + config['class_name'] = 'Softmax' + else: + raise ValueError(f"Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}") + config['axis'] = layer.axis + config['activation'] = 'softmax' + config['n_outer'] = (n_outer,) + config['n_inner'] = n_inner + + return (config,) + + +@register +class KV3HardActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.elu.ELU',) + + def handle( + self, + layer: 'keras.layers.ELU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + config['class_name'] = 'ELU' + config['activ_param'] = float(layer.alpha) + config['activation'] = 'elu' + + return (config,) diff --git a/hls4ml/converters/keras_v3/einsum_dense.py b/hls4ml/converters/keras_v3/einsum_dense.py new file mode 100644 index 0000000000..8eb000fcf7 --- /dev/null +++ b/hls4ml/converters/keras_v3/einsum_dense.py @@ -0,0 +1,75 @@ +import typing +from typing import Sequence + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + + +def strip_batch_dim(equation: str, einsum_dense: bool = True): + """Remove the batch dimension from the equation. + + Args: + equation (str): The einsum equation. + einsum_dense (bool): Whether the equation is for EinsumDense layer. + + Returns: + str: The einsum equation without the batch dimension. + """ + + _inps, out = equation.split('->') + inp0, inp1 = _inps.split(',') + if einsum_dense: + if inp0.startswith('...'): + assert out.startswith('...'), f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + else: + assert inp0[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + assert inp0[0] not in inp1, f'Error in eq: {equation}: Batch dim is used in the kernel.' + inp0, out = inp0[1:], out[1:] + else: + assert inp0[0] == inp1[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the inputs and output.' + inp0, inp1, out = inp0[1:], inp1[1:], out[1:] + return f'{inp0},{inp1}->{out}' + + +@register +class KV3EinsumDenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.einsum_dense.EinsumDense',) + + def handle( + self, + layer: 'keras.layers.EinsumDense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor' + assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor' + + inp_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + + # fmt: off + assert all(d is not None for d in inp_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully inp shapes' + assert all(d is not None for d in out_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully out shapes' + # fmt: on + + equation = strip_batch_dim(layer.equation, True) + + kernel = self.load_weight(layer, 'kernel') + + bias = None + if layer.bias_axes: + bias = self.load_weight(layer, 'bias') + + return { + 'class_name': 'EinsumDense', + 'equation': equation, + 'weight_data': kernel, + 'bias_data': bias, + 'inp_shape': inp_shape, + 'out_shape': out_shape, + } diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py new file mode 100644 index 0000000000..5c0168cc1e --- /dev/null +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -0,0 +1,284 @@ +import typing +from itertools import chain +from types import FunctionType +from typing import Any, Callable, Sequence + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + +import numpy as np + +from .keras_v3 import layer_handlers as v3_layer_handlers + +T_kv3_handler = Callable[ + ['keras.Layer', Sequence['keras.KerasTensor'], Sequence['keras.KerasTensor']], tuple[dict[str, Any], ...] +] + + +def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None): + """Given a keras layer, return a list of tuples of input and output + tensors. If the layer is called only once (i.e., no shared layers), + the list will contain only one tuple. + + The layer must have been built before calling this function. + + Parameters + ---------- + layer : keras.Layer + The layer to get input and output tensors from. + node_whitelist : set[int]|None, optional + If not None, only return tensors from nodes with ids in this + set, used to filter out nodes that are not part of the model, by + default None + + + Returns + ------- + list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] + A list of tuples of input and output tensors. + """ + in_nodes = layer._inbound_nodes + if node_whitelist is not None: + in_nodes = [node for node in in_nodes if id(node) in node_whitelist] + + ret: list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] = [] + for node in in_nodes: + in_tensors = tuple(node.arguments.keras_tensors) + out_tensors = tuple(node.outputs) + ret.append((in_tensors, out_tensors)) + return ret + + +def resolve_dependency_relation(model: 'keras.Model'): + """Given a keras model, return the following information: + - A list of input tensor names + - A list of output tensor names + - A list of (layer_name, input_tensor_names, output_tensor_names) tuples + - A dictionary of tensor_name -> KerasTensor + + Parameters + ---------- + model : keras.Model + The keras model to analyze. + + Returns + ------- + tuple[tuple[str, ...], tuple[str, ...], list[tuple[str, tuple[str, ...], tuple[str, ...]]], dict[str, KerasTensor]] + inp_tensor_names, out_tensor_names, layer_io, tensors + """ + tensors: dict[str, 'KerasTensor'] = {} + "tensor_name -> KerasTensor" + depends_on: dict[str, tuple[str, ...]] = {} + "tensor_name -> {tensor_name}" + layer_io: list[tuple[str, tuple[str, ...], tuple[str, ...]]] = [] + "layer_name -> ((input_tensor_names), (output_tensor_names))" + + inputs = tuple(t.name for t in model.inputs) + outputs = tuple(t.name for t in model.outputs) + node_whitelist = {id(node) for v in model._nodes_by_depth.values() for node in v} + + for layer in model.layers: + for in_tensors, out_tensors in get_io_tensors(layer, node_whitelist): + in_tensor_names = tuple(t.name for t in in_tensors) + out_tensor_names = tuple(t.name for t in out_tensors) + for t in chain(in_tensors, out_tensors): + tensors[t.name] = t + for o_name in out_tensor_names: + depends_on[o_name] = in_tensor_names + layer_io.append((layer.name, in_tensor_names, out_tensor_names)) + + return inputs, outputs, layer_io, tensors + + +class UniqueName: + """Helper class to generate unique names for layers, if one being used multiple times.""" + + def __init__(self): + self.used_names: set[str] = set() + + def next_name(self, name: str): + i = 0 + if name in self.used_names: + while f'{name}_{i}' in self.used_names: + i += 1 + name = f'{name}_{i}' + self.used_names.add(name) + return name + + def __call__(self, name: str): + return self.next_name(name) + + def reset(self): + self.used_names.clear() + + +class KerasV3HandlerDispatcher: + """Dispatcher class to handle different types of keras v3 layers.""" + + def __init__(self, layer_handlers: dict[str, T_kv3_handler], v2_layer_handlers=None): + self.registry = layer_handlers + self.v2_layer_handlers = v2_layer_handlers or {} + + def __call__( + self, layer: 'keras.Layer', in_tensors: Sequence['keras.KerasTensor'], out_tensors: Sequence['keras.KerasTensor'] + ) -> tuple[dict[str, Any], ...]: + assert layer.built, f"Layer {layer.name} is not built" + + ret = self.v3_call(layer, in_tensors, out_tensors) + if ret is not None: + return ret + ret = self.v2_call(layer, in_tensors, out_tensors) + if ret is not None: + return ret + + raise ValueError( + f"Layer {layer.__class__.__module__}.{layer.__class__.__name__} not found in either v3 or v2 handlers" + ) + + def v3_call( + self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'] + ): + cls_name = layer.__class__.__name__ + module = layer.__module__ + key = f"{module}.{cls_name}" + + # keras v3 handlers + handler = self.registry.get(key, None) + handler = handler or self.registry.get(cls_name, None) + + if handler is None: + return None + return handler(layer, inp_tensors, out_tensors) + + def v2_call( + self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'] + ): + # keras v2 handlers fallback + print(f"v2 handler used for layer {layer.name}") + + import keras + + config = layer.get_config() + layer_dict = {'config': config, 'class_name': layer.__class__.__name__} + + class DummyReader: + def get_weights_data(self, layer_name, var_name): + assert layer_name == layer.name, f"Processing {layer.name}, but handler tried to read {layer_name}" + for w in layer.weights: + if var_name in w.name: + return np.array(w) + return None + + reader = DummyReader() + input_shapes = [list(t.shape) for t in inp_tensors] + input_names = [t.name for t in inp_tensors] + output_names = [t.name for t in out_tensors] + key = layer.__class__.__name__ + handler = self.v2_layer_handlers.get(key, None) + if handler is None: + return None + + ret, _ = handler(layer_dict, input_names, input_shapes, reader) + ret['output_keras_tensor_names'] = output_names + ret['input_keras_tensor_names'] = input_names + ret = (ret,) + + activation = getattr(layer, 'activation', None) + if activation not in (keras.activations.linear, None): + assert isinstance(activation, FunctionType), f"Activation function for layer {layer.name} is not a function" + intermediate_tensor_name = f'{output_names[0]}_activation' + ret[0]['output_keras_tensor_names'] = (intermediate_tensor_name,) + act_cls_name = activation.__name__ + act_config = { + 'class_name': 'Activation', + 'activation': act_cls_name, + 'name': f'{layer.name}_{act_cls_name}', + 'input_keras_tensor_names': (intermediate_tensor_name,), + 'output_keras_tensor_names': output_names, + } + ret = *ret, act_config + return ret + + +def parse_keras_v3_model(model: 'keras.Model'): + """Parse a keras model into a list of dictionaries, each + representing a layer in the HLS model, and a list of input and + output layer names. + + Parameters + ---------- + model : keras.Model + + Returns + ------- + tuple[list[dict[str, Any]], list[str], list[str], list[list[int]]] + layer_list, input_layer_names, output_layer_names, + batch_output_shapes + + Raises + ------ + ValueError + If a circular dependency is detected. + """ + + assert model.built, "Model must be built before parsing" + + import keras + + if isinstance(model, keras.Sequential): + model = model._functional # everything is functional under the hood lol + + from .keras_to_hls import layer_handlers as v2_layer_handlers # Delayed import to avoid circular import + + keras_v3_dispatcher = KerasV3HandlerDispatcher(v3_layer_handlers, v2_layer_handlers) + + model_inputs, model_outputs, dependency, tensors = resolve_dependency_relation(model) + + satisfied = set() + + unique_name = UniqueName() + + layer_list: list[dict[str, Any]] = [] + + while any(t not in satisfied for t in model_outputs): + # Until all tensors in the model are satisfied + for i, (layer_name, in_tensor_names, out_tensor_names) in enumerate(dependency): + if not all(t in satisfied for t in in_tensor_names): + continue # Skip layer if some inputs are not ready + if all(t in satisfied for t in out_tensor_names): + continue # Skip layer if the outputs are already satisfied + + layer: 'keras.Layer' = model.get_layer(layer_name) + inp_tensors = [tensors[t] for t in in_tensor_names] + out_tensors = [tensors[t] for t in out_tensor_names] + + _configs = keras_v3_dispatcher(layer, inp_tensors, out_tensors) + # Dispatch to v3 handler if available, else fallback to v2 handler + + # Prevent name conflicts. If a layer is used multiple times, add a suffix to the name. + # At this stage connections between modules are recorded by i/o keras tensor names + for _conf in _configs: + _conf['name'] = unique_name(_conf['name']) + + layer_list.extend(_configs) # Add the layer to the list + satisfied.update(out_tensor_names) # Mark the outputs as satisfied + dependency.pop(i) + break # Restart the loop to add another layer + else: + # If no layer was added in the loop, then there is a circular dependency + raise ValueError("Circular dependency detected") + + # Mark inputs[inp layer name] for ModelGraph to parse from i/o keras tensor names + provides: dict[str, str] = {} # tensor_name -> src_layer_name + for conf in layer_list: + for out_name in conf['output_keras_tensor_names']: + provides[out_name] = conf['name'] + inputs = [provides[tname] for tname in conf['input_keras_tensor_names']] + conf['inputs'] = inputs + + input_layer_names = [provides[tname] for tname in model_inputs] + output_layer_names = [provides[tname] for tname in model_outputs] + batch_output_shapes = [list(tensors[tname].shape) for tname in model_outputs] + + return layer_list, input_layer_names, output_layer_names, batch_output_shapes diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 57c42f401f..eb61337c6f 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -157,3 +157,35 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node layer['n_filt'] = input_shapes[0][1] # Always channel first for Pytorch return layer, [shape for shape in input_shapes[0]] + + +@pytorch_handler('einsum') +def parse_einsum_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert 'einsum' in operation + + layer = {} + + if len(input_names) != 2: + raise Exception('Only einsum operations with two inputs are supported') + layer['class_name'] = 'Einsum' + layer['equation'] = node.args[0] + layer['name'] = layer_name + layer['inputs'] = input_names + + # Need to set batch size to a real value instead of 'None'. Using '1' as dummy value + import copy + + input_shapes_tmp = copy.deepcopy(input_shapes) + input_shapes_tmp[0][0] = 1 + input_shapes_tmp[1][0] = 1 + layer['inp0_shape'] = tuple(input_shapes_tmp[0]) + layer['inp1_shape'] = tuple(input_shapes_tmp[1]) + + # Run einsum to infer output shape + import torch + + a = torch.randn(input_shapes_tmp[0]) + b = torch.randn(input_shapes_tmp[1]) + layer['out_shape'] = tuple(torch.einsum(layer['equation'], a, b).shape) + + return layer, [shape for shape in input_shapes[0]] diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 0efeaafa3d..041ef8ab8d 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -27,10 +27,12 @@ find_minimum_width, ) from hls4ml.utils import attribute_descriptions as descriptions +from hls4ml.utils.einsum_utils import parse_einsum from hls4ml.utils.string_utils import convert_to_snake_case - # TODO move this to some utility module + + class classproperty: def __init__(self, func): self.func = func @@ -1016,16 +1018,21 @@ def initialize(self): dims = inp.dim_names self.add_output_variable(shape, dims) - gamma = self.get_attr('gamma_data') - beta = self.get_attr('beta_data') - mean = self.get_attr('mean_data') - var = self.get_attr('variance_data') - - scale = gamma / np.sqrt(var + self.get_attr('epsilon')) - bias = beta - scale * mean + if self.get_attr('scale_data') is None: + gamma = self.get_attr('gamma_data') + var = self.get_attr('variance_data') + scale = gamma / np.sqrt(var + self.get_attr('epsilon')) + self.add_weights_variable(name='scale', var_name='s{index}', data=scale) + else: + self.add_weights_variable(name='scale', var_name='s{index}') - self.add_weights_variable(name='scale', var_name='s{index}', data=scale) - self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + if self.get_attr('bias_data') is None: + beta = self.get_attr('beta_data') + mean = self.get_attr('mean_data') + bias = beta - scale * mean + self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + else: + self.add_weights_variable(name='bias', var_name='b{index}') # TODO: discuss whether this should be renamed to soemthing more descriptive, and whether the class hierarchy makes sense @@ -1643,6 +1650,131 @@ def initialize(self): self.add_output_variable([len(self.get_attr('expression'))], [f'N_OUTPUTS_{self.index}'], var_name='y') +class EinsumDense(Layer): + _expected_attributes = [ + WeightAttribute('weight'), + WeightAttribute('bias'), + TypeAttribute('weight'), + TypeAttribute('bias'), + TypeAttribute('accum'), + Attribute('equation', value_type=str), + Attribute('inp_shape', value_type=tuple), + Attribute('out_shape', value_type=tuple), + ] + + def initialize(self): + out_shape = self.attributes['out_shape'] + if len(out_shape) > 1: + dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + self.add_output_variable(list(out_shape), dims) + + kernel: np.ndarray = self.attributes.attributes['weight_data'] + bias: np.ndarray | None = self.attributes.attributes['bias_data'] + equation = self.attributes['equation'] + inp_shape = self.attributes['inp_shape'] + out_shape = self.attributes['out_shape'] + + kernel_shape = kernel.shape + recipe = parse_einsum(equation, inp_shape, kernel_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp_tpose_idxs, ker_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + # Pre-transpose kernel (and bias) to save a transpose in cpp. Shouldn't matter for latency strategy though. + # hls4ml dense acts like i,ij->j + # parser assumes ij,j->i, so we need to transpose the kernel to match + kernel = kernel.transpose(ker_tpose_idxs) + kernel = kernel.reshape(recipe['I'], recipe['L1'], recipe['C']).transpose(0, 2, 1) + + def to_original_kernel(tkernel: np.ndarray) -> np.ndarray: + _kernel = tkernel.transpose(0, 2, 1) + _kernel = _kernel.reshape(tuple(kernel_shape[i] for i in ker_tpose_idxs)) + return _kernel.transpose(np.argsort(ker_tpose_idxs)) + + # TODO: for weight in bram mode (resource), broadcasting bias here shall be avoided. + if bias is not None: + bias = np.broadcast_to(bias, out_shape).transpose(np.argsort(out_tpose_idxs)) + else: + # The automatically created bias is just the last dimension of the output shape + # Which is too small in general for einsum dense. + # The transpose is just to match the shape in case of have real bias, no real effect. + bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs)) + + self.attributes.attributes['weight_data'] = kernel + self.attributes.attributes['to_original_kernel'] = to_original_kernel + self.attributes.attributes['bias_data'] = bias + self.attributes['inp_tpose_idxs'] = inp_tpose_idxs + self.attributes['out_tpose_idxs'] = out_tpose_idxs + self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + self.attributes['n_free_data'] = recipe['L0'] + self.attributes['n_free_kernel'] = recipe['L1'] + self.attributes['n_inplace'] = recipe['I'] + self.attributes['n_contract'] = recipe['C'] + pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + self.attributes['parallelization_factor'] = pf + + self.add_weights(compression=self.model.config.get_compression(self)) + self.add_bias() + + +class Matmul(Layer): + _expected_attributes = [ + TypeAttribute('accum'), + Attribute('inup1_shape', value_type=tuple), + Attribute('inp2_shape', value_type=tuple), + ] + + +class Einsum(Layer): + _expected_attributes = [ + TypeAttribute('accum'), + Attribute('equation', value_type=str), + Attribute('inp0_shape', value_type=tuple), + Attribute('inp1_shape', value_type=tuple), + Attribute('out_shape', value_type=tuple), + ] + + def initialize(self): + out_shape = self.attributes['out_shape'] + if len(out_shape) > 1: + dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + self.add_output_variable(list(out_shape), dims) + + equation = self.attributes['equation'] + inp0_shape = self.attributes['inp0_shape'] + inp1_shape = self.attributes['inp1_shape'] + out_shape = self.attributes['out_shape'] + + recipe = parse_einsum(equation, inp0_shape, inp1_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp0_tpose_idxs, inp1_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + self.attributes.attributes.update(recipe) + self.attributes['n_free0'] = recipe['L0'] + self.attributes['n_free1'] = recipe['L1'] + self.attributes['n_inplace'] = recipe['I'] + self.attributes['n_contract'] = recipe['C'] + self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + + self.attributes['inp0_tpose_idxs'] = inp0_tpose_idxs + self.attributes['inp1_tpose_idxs'] = inp1_tpose_idxs + self.attributes['out_tpose_idxs'] = out_tpose_idxs + + pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + self.attributes['parallelization_factor'] = pf + + layer_map = { 'Input': Input, 'InputLayer': Input, @@ -1710,6 +1842,8 @@ def initialize(self): 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, 'SymbolicExpression': SymbolicExpression, + 'EinsumDense': EinsumDense, + 'Einsum': Einsum, # TensorFlow-specific layers: 'BiasAdd': BiasAdd, } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_common.h b/hls4ml/templates/vivado/nnet_utils/nnet_common.h index 6db3f62f6e..308892ba49 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_common.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_common.h @@ -24,7 +24,7 @@ namespace nnet { // Common type definitions enum io_type { io_parallel = 0, io_stream }; -enum strategy { latency, resource, resource_unrolled }; +enum strategy { latency, resource, resource_unrolled, distributed_arithmetic }; /* --- * Balanced tree reduce implementation. diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h new file mode 100644 index 0000000000..cc2917783c --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h @@ -0,0 +1,83 @@ +#ifndef NNET_EINSUM_H_ +#define NNET_EINSUM_H_ + +#include "nnet_common.h" +#include "nnet_mult.h" +#include "nnet_transpose.h" + +namespace nnet { + +struct config_einsum { + typedef void tpose_inp0_conf; + typedef void tpose_inp1_conf; + typedef void tpose_out_conf; + + // Layer Sizes + static const unsigned n_free0; + static const unsigned n_free1; + static const unsigned n_contract; + static const unsigned n_inplace; + + // Resource reuse info + static const unsigned io_type; + static const unsigned strategy; + static const unsigned reuse_factor; + static const unsigned multiplier_limit; + static const bool store_weights_in_bram = false; // NOT USED + + template using product = nnet::product::mult; +}; + +template +void einsum(const data0_T data0[CONFIG_T::tpose_inp0_conf::N], const data1_T data1[CONFIG_T::tpose_inp1_conf::N], + res_T res[CONFIG_T::tpose_out_conf::N]) { + + #pragma HLS PIPELINE II = CONFIG_T::reuse_factor + #pragma HLS ALLOCATION operation instances = mul limit = CONFIG_T::multiplier_limit + + data0_T tpose_i0[CONFIG_T::tpose_inp0_conf::N]; + data1_T tpose_i1[CONFIG_T::tpose_inp1_conf::N]; + res_T tpose_o[CONFIG_T::tpose_out_conf::N]; + + #pragma HLS ARRAY_PARTITION variable = tpose_i0 complete + #pragma HLS ARRAY_PARTITION variable = tpose_i1 complete + #pragma HLS ARRAY_PARTITION variable = tpose_o complete + + nnet::transpose(data0, tpose_i0); + nnet::transpose(data1, tpose_i1); + + // for l0 in range(L0): + // for i in range(I): + // output[(i*L0+l0)*L1:(i*L0+l0+1)*L1] = input1[i*L1*C:(i+1)*L1*C].reshape((L1,C)) @ + // input0[(i*L0+l0)*C:(i*L0+l0+1)*C] + + constexpr unsigned L0 = CONFIG_T::n_free0; + constexpr unsigned L1 = CONFIG_T::n_free1; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + typename CONFIG_T::accum_t accum_buf; + for (unsigned i = 0; i < I; i++) { + #pragma HLS UNROLL + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL + for (unsigned l1 = 0; l1 < L1; l1++) { + #pragma HLS UNROLL + accum_buf = 0; + for (unsigned c = 0; c < C; c++) { + #pragma HLS UNROLL + data0_T a = tpose_i0[(i * L0 + l0) * C + c]; + data1_T b = tpose_i1[i * L1 * C + l1 * C + c]; + accum_buf += CONFIG_T::template product::product(a, b); + } + tpose_o[(i * L0 + l0) * L1 + l1] = accum_buf; + } + } + } + + nnet::transpose(tpose_o, res); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h new file mode 100644 index 0000000000..9f26ff0bd7 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h @@ -0,0 +1,114 @@ +#ifndef NNET_EINSUM_DENSE_H_ +#define NNET_EINSUM_DENSE_H_ + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_dense_latency.h" +#include "nnet_dense_resource.h" +#include "nnet_function_stubs.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" +#include "nnet_transpose.h" + +namespace nnet { + +struct einsum_dense_config { + // Internal data type definitions + + typedef void tpose_inp_conf; + typedef void tpose_out_conf; + typedef void dense_conf; + + // Layer Sizes + static const unsigned n_free_data = 1; + static const unsigned n_free_kernel = 1; + static const unsigned n_contract = 1; + static const unsigned n_inplace = 1; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned strategy = latency; + static const unsigned reuse_factor = 1; + static const unsigned parallelization_factor = 1000; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED + + // Product function to use + template using product = nnet::product::mult; +}; + +template +void einsum_dense( + data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace], + res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::weight_t weights[CONFIG_T::n_free_kernel * CONFIG_T::n_contract * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) { + data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace]; + res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]; + res_T out_buffer[CONFIG_T::n_free_kernel]; + #pragma HLS ARRAY_PARTITION variable = inp_tpose complete + #pragma HLS ARRAY_PARTITION variable = out_tpose complete + + nnet::transpose(data, inp_tpose); + + constexpr unsigned L0 = CONFIG_T::n_free_data; + constexpr unsigned L1 = CONFIG_T::n_free_kernel; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL factor = CONFIG_T::parallelization_factor + for (unsigned i = 0; i < I; i++) { + #pragma HLS UNROLL + // even w/o explicit distributed arithmetic optimization, latency kernels are partially implemented as such + // so reusing the same multiplier for different weights doesn't really help... only full unrolling for now + dense(&inp_tpose[(i * L0 + l0) * C], out_buffer, + &weights[(i * L1 * C)], &biases[((i * L0 + l0) * L1)]); + for (unsigned j = 0; j < L1; j++) { + #pragma HLS UNROLL + out_tpose[(i * L0 + l0) * L1 + j] = out_buffer[j]; + } + } + } + + nnet::transpose(out_tpose, res); +} + +template +typename std::enable_if::type +einsum_dense(data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace], + res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace], + typename CONFIG_T::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) { + + data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace]; + res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]; + res_T out_buffer[CONFIG_T::n_free_kernel]; + + #pragma HLS ARRAY_PARTITION variable = inp_tpose complete + #pragma HLS ARRAY_PARTITION variable = out_tpose complete + + nnet::transpose(data, inp_tpose); + + constexpr unsigned L0 = CONFIG_T::n_free_data; + constexpr unsigned L1 = CONFIG_T::n_free_kernel; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL factor = CONFIG_T::parallelization_factor + // for (unsigned i = 0; i < I; i++) { + // #pragma HLS UNROLL + // inp_tpose[(i * L0 + l0) * C]->out_tpose[(i * L0 + l0) * L1]; + // } + CONFIG_T::da_kernel(inp_tpose, out_tpose, l0); + } + for (unsigned ii = 0; ii < (L0 * L1 * I); ii++) { + #pragma HLS UNROLL + out_tpose[ii] = out_tpose[ii] + biases[ii]; + } + + nnet::transpose(out_tpose, res); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index b14d1ce99d..8e2d9b7011 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -156,13 +156,21 @@ def config_from_keras_model( layer_list = [] if isinstance(model, dict): + # keras v2 only model_arch = model + reader = hls4ml.converters.KerasModelReader(model) + layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader) else: - model_arch = json.loads(model.to_json()) + import keras - reader = hls4ml.converters.KerasModelReader(model) + # model is keras.Model here - layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader) + if keras.__version__ > '3.0': + layer_list, *_ = hls4ml.converters.parse_keras_v3_model(model) + else: + model_arch = json.loads(model.to_json()) + reader = hls4ml.converters.KerasModelReader(model) + layer_list, *_ = hls4ml.converters.parse_keras_model(model_arch, reader) def make_layer_config(layer): cls_name = layer['class_name'] diff --git a/hls4ml/utils/einsum_utils.py b/hls4ml/utils/einsum_utils.py new file mode 100644 index 0000000000..43ceb2ba96 --- /dev/null +++ b/hls4ml/utils/einsum_utils.py @@ -0,0 +1,256 @@ +from math import prod +from typing import TypedDict + +import numpy as np + + +class EinsumRecipe(TypedDict): + direct_sum_axis: tuple[tuple[int, ...], tuple[int, ...]] + in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]] + L0: int + L1: int + I: int + C: int + out_interpert_shape: tuple[int, ...] + out_transpose_idxs: tuple[int, ...] + + +def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]): + """Validate, resolve broadcasting, and compute output shape for einsum string + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + shape0 : tuple[int,...] + shape of input0 + shape1 : tuple[int,...] + shape of input1 + + Returns + ------- + tuple[str, tuple[int,...]] + einsum string w/o broadcasting, and output shape + + Raises + ------ + ValueError + If the einsum string is invalid, or if it is incompatible with the input shapes + """ + inp, out = map(str.strip, fn.split('->')) + in0, in1 = map(str.strip, inp.split(',')) + alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' + s_alphabets = set(alphabets) + + # Invalid characters + if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))): + raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only") + + in0 = in0.replace('...', '0') + in1 = in1.replace('...', '0') + out = out.replace('...', '0') + ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out) + sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out) + free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out)) + + # Repeated indices + if len(sax_in0) != len(ax_in0): + for a in in0: + if in0.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times") + if len(sax_in1) != len(ax_in1): + for a in in1: + if in1.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times") + if len(sax_out) != len(ax_out): + for a in out: + if out.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times") + + # Invalid broadcasting + if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out: + if '0' in sax_in0 and '0' in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: both input0 and input1 allows broadcasting") + if '0' not in sax_out: + raise ValueError(f"einsum string {fn} is invalid: output does not allow broadcasting, but inputs do") + if '0' not in sax_in0 and '0' not in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output allows broadcasting, but inputs do not") + + # Output index out of nowhere + if remaining := sax_out - sax_in0 - sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output subscripts {remaining} not found in inputs") + + _common_in = sax_in0 & sax_in1 + + # Invalid input dimensions + if '0' in sax_in0: + if len(sax_in0) - 1 > len(shape0): + raise ValueError(f"Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape0) - len(sax_in0) + 1 + in0 = in0.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in0 = list(in0) + ax_out = list(out) + else: + if len(sax_in0) != len(shape0): + raise ValueError(f"Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given") + if '0' in sax_in1: + if len(sax_in1) - 1 > len(shape1): + raise ValueError(f"Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape1) - len(sax_in1) + 1 + in1 = in1.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in1 = list(in1) + ax_out = list(out) + else: + if len(sax_in1) != len(shape1): + raise ValueError(f"Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given") + + # Input dimension mismatch + for a in _common_in: + ax_0 = ax_in0.index(a) + ax_1 = ax_in1.index(a) + if shape0[ax_0] != shape1[ax_1]: + raise ValueError( + f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}" + ) + + out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out) + return f'{in0},{in1}->{out}', out_shape + + +def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe: + """Parse einsum operation on two input arrays, return a recipe for execution + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + EinsumRecipe + einsum recipe; executed by _exec_einsum + """ + + fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1) + + _in, _out = fn.split('->') + _in0, _in1 = _in.split(',') + + in0, in1, out = list(_in0), list(_in1), list(_out) + s_in0, s_in1, s_out = set(in0), set(in1), set(out) + _common = s_in0 & s_in1 + _contract = _common - s_out + _inplace = _common & s_out + contract = sorted(_contract, key=lambda x: in1.index(x)) + inplace = sorted(_inplace, key=lambda x: in1.index(x)) + invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x)) + invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x)) + direct_sum0 = s_in0 - s_out - _common + direct_sum1 = s_in1 - s_out - _common + direct_sum_axis = ( + tuple(sorted(in0.index(x) for x in direct_sum0)), + tuple(sorted(in1.index(x) for x in direct_sum1)), + ) + + contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract)) + inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace)) + invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1)) + + inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0]) + inplace_size = prod(inplace_shape) + contract_size = prod(input_shape0[i] for i in contract_idxs[0]) + invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0]) + invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1]) + invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1) + + transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0] + transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1] + + out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1 + _out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1))) + out_transpose_idx = tuple(int(i) for i in _out_transpose_idx) + + return EinsumRecipe( + direct_sum_axis=direct_sum_axis, + in_transpose_idxs=(transpose_idx0, transpose_idx1), + out_interpert_shape=out_shape_pretranspose, + out_transpose_idxs=out_transpose_idx, + L0=invariant_size0, + L1=invariant_size1, + I=inplace_size, + C=contract_size, + ) + + +def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays + + Parameters + ---------- + recipe : EinsumRecipe + einsum recipe + input0 : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + sum_axis0, sum_axis1 = recipe['direct_sum_axis'] + if sum_axis0: + input0 = np.sum(input0, axis=sum_axis0) + if sum_axis1: + input1 = np.sum(input1, axis=sum_axis1) + input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel() + input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel() + output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=input0.dtype) + + L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C'] + + for l0 in range(L0): + for i in range(I): + A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C)) + B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C] + output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = A @ B + + return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs']) + + +def einsum(fn: str, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays. + + WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + recipe = parse_einsum(fn, input0.shape, input1.shape) + return _exec_einsum(recipe, input0, input1) diff --git a/test/pytest/test_einsum_dense.py b/test/pytest/test_einsum_dense.py new file mode 100644 index 0000000000..dbddf545ff --- /dev/null +++ b/test/pytest/test_einsum_dense.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import keras +import numpy as np +import pytest + +from hls4ml.converters import convert_from_keras_model + +if keras.__version__ < '3.0.0': + pytest.skip('Only keras v3 is supported for now', allow_module_level=True) + +from keras.api.layers import EinsumDense, Input + +test_root_path = Path(__file__).parent + + +@pytest.mark.parametrize('strategy', ['latency', 'distributed_arithmetic']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize( + 'operation', + [ + # eq, inp, out + ('bi,j->bij', (8,), (8, 7), None), + ('bi,j->bij', (8,), (8, 7), 'i'), + ('bi,j->bij', (8,), (8, 7), 'j'), + ('bi,io->bo', (8,), 7, None), + ('...i,oi->...o', (4, 3), (5,), None), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), None), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'aeb'), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'ab'), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'a'), + ], +) +def test_einsum_dense(backend, io_type, strategy, operation): + eq, inp_shape, out_shape, bias_axes = operation + model = keras.Sequential( + [Input(inp_shape), EinsumDense(eq, output_shape=out_shape, bias_axes=bias_axes, name='einsum_dense')] + ) + + if bias_axes is not None: + layer = model.get_layer('einsum_dense') + layer.bias.assign(keras.ops.convert_to_tensor(np.random.rand(*layer.bias.shape))) + + data = np.random.rand(1000, *inp_shape) + eq_name = eq.replace(',', '_').replace('->', '_') + ('' if bias_axes is None else f'_{bias_axes}') + output_dir = str(test_root_path / f'hls4mlprj_einsum_dense_{eq_name}_{backend}_{io_type}_{strategy}') + hls_config = {'Model': {'Precision': 'ap_fixed<32,8>', 'ReuseFactor': 1}, 'Strategy': strategy} + model_hls = convert_from_keras_model( + model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type + ) + + model_hls.compile() + r_keras = model.predict(data, verbose=0, batch_size=1000) # type: ignore + r_hls = model_hls.predict(data).reshape(r_keras.shape) # type: ignore + + np.testing.assert_allclose(r_hls, r_keras, atol=2e-6, rtol=0) diff --git a/test/pytest/test_keras_v3_api.py b/test/pytest/test_keras_v3_api.py new file mode 100644 index 0000000000..81ac5c240c --- /dev/null +++ b/test/pytest/test_keras_v3_api.py @@ -0,0 +1,516 @@ +import math +from pathlib import Path + +import keras +import numpy as np +import pytest + +if keras.__version__ < '3.0': + pytest.skip('Keras API tests are only for Keras 3.0 and above', allow_module_level=True) + +from keras.api.layers import ( + ELU, + Activation, + AveragePooling1D, + AveragePooling2D, + Conv1D, + Conv2D, + Dense, + DepthwiseConv1D, + DepthwiseConv2D, + LeakyReLU, + MaxPooling1D, + MaxPooling2D, + PReLU, +) + +import hls4ml + +test_root_path = Path('/tmp/tests') + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_dense(backend, io_type): + model = keras.Sequential( + [ + Dense( + 2, + input_shape=(1,), + name='Dense', + use_bias=True, + kernel_initializer=keras.initializers.RandomUniform(minval=1, maxval=10), # type: ignore + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + ), + Activation(activation='elu', name='Activation'), + ] + ) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(1000, 1) + + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}_{io_type}') + + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + + hls_model.compile() + + hls_prediction = hls_model.predict(X_input) + + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.02) + + assert len(model.layers) + 1 == len(hls_model.get_layers()) + assert list(hls_model.get_layers())[0].attributes['class_name'] == "InputLayer" + assert list(hls_model.get_layers())[1].attributes["class_name"] == model.layers[0].name + assert list(hls_model.get_layers())[2].attributes['class_name'] == 'ELU' + + +# TODO: add ThresholdedReLU test when it can be made to pass +# https://github.com/fastmachinelearning/hls4ml/issues/376 + + +@pytest.mark.parametrize( + "activation_function", + [ + Activation(activation='relu', name='relu'), + LeakyReLU(negative_slope=0.5), + ELU(alpha=1.0), + PReLU( + alpha_initializer="zeros", + ), + Activation(activation='sigmoid', name='sigmoid'), + ], +) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_activations(activation_function, backend, io_type): + model = keras.models.Sequential() + model.add(Dense(64, input_shape=(1,), name='Dense', kernel_initializer='lecun_uniform', kernel_regularizer=None)) + model.add(activation_function) + + model.compile(optimizer='adam', loss='mse') + + model.summary() + + X_input = np.random.rand(1000, 1) + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_activations_{activation_function.name}_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input) + + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.02) + + for layer in hls_model.get_layers(): + print(layer.attributes.attributes['class_name']) + assert len(model.layers) + 1 == len(hls_model.get_layers()) + + assert list(hls_model.get_layers())[2].attributes['class_name'] == activation_function.__class__.__name__ + + +padds_options = ['same', 'valid'] + + +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv1d(padds, backend, io_type): + model = keras.models.Sequential() + input_shape = (10, 128, 4) + model.add( + Conv1D( + filters=32, + kernel_size=3, + strides=2, + padding=padds, + activation='relu', + input_shape=input_shape[1:], + kernel_initializer='normal', + use_bias=False, + data_format='channels_last', + name='conv', + ) + ) + model.add(Activation(activation='relu')) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(10, 128, 4) + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_conv1d_{padds}_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + + # 5e-2 might be too high + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=5e-2) + + if backend in ('Vivado', 'Vitis', 'Catapult') and io_type == 'io_stream' and padds == 'same': + # Vivado/Vitis inserts and additional layer for 'same' padding in io_stream + return + + conv: keras.layers.Conv1D = model.layers[0] + ker_w, ch_in, ch_out = conv.kernel.shape + inp_shape = model.inputs[0].shape[1:] + out_shape = model.outputs[0].shape[1:] + hls_attr = hls_model.graph['conv'].attributes + _stride = conv.strides[0] + + assert len(model.layers) + 2 == len(hls_model.get_layers()) + + assert hls_attr['name'] == model.layers[0].name + assert hls_attr['class_name'] == 'Conv1D' + assert hls_attr["in_width"] == inp_shape[0] + assert hls_attr['filt_width'] == ker_w + assert hls_attr['n_chan'] == ch_in + assert hls_attr['n_filt'] == ch_out + assert hls_attr['stride_width'] == _stride + assert hls_attr['data_format'] == conv.data_format + assert hls_attr["out_width"] == out_shape[0] + + w_pad = math.ceil(inp_shape[0] / ker_w) * ker_w - inp_shape[0] + + pad_left = w_pad // 2 + pad_right = w_pad - pad_left + + if model.layers[0].padding == 'same': + assert hls_attr['pad_left'] == pad_left + assert hls_attr['pad_right'] == pad_right + elif model.layers[0].padding == 'valid': + assert hls_attr['pad_left'] == 0 + assert hls_attr['pad_right'] == 0 + + +chans_options = ['channels_last'] +padds_options = ['same', 'valid'] + + +@pytest.mark.parametrize('chans', chans_options) +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv2d(chans, padds, backend, io_type): + input_shape = (32, 32, 3) + model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape), + Conv2D( + filters=32, + kernel_size=(2, 3), + strides=(4, 5), + padding=padds, + kernel_initializer='normal', + use_bias=False, + data_format=chans, + name='conv', + ), + ] + ) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(1000, *input_shape) + keras_prediction = model.predict(X_input) + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4ml_project_keras_api_conv2d_{backend}_{chans}_{padds}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + + # A high tolerance, simply to verify correct functionality + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=5e-2) + + hls_conv_attr = hls_model.graph['conv'].attributes + + conv: keras.layers.Conv2D = model.get_layer('conv') + + kh, kw, ch_in, ch_out = conv.kernel.shape # type: ignore + _stride = conv.strides + inp_shape = model.inputs[0].shape[1:] + out_shape = model.outputs[0].shape[1:] + + if io_type == 'io_stream' and padds == 'same' and backend in ('Vivado', 'Vitis', 'Catapult'): + return + + assert len(model.layers) + 1 == len(hls_model.get_layers()) + assert hls_conv_attr['name'] == conv.name + assert hls_conv_attr['class_name'] == 'Conv2D' + assert hls_conv_attr['filt_width'] == kw + assert hls_conv_attr['filt_height'] == kh + assert hls_conv_attr['n_filt'] == ch_out + assert hls_conv_attr['stride_width'] == _stride[1] + assert hls_conv_attr['stride_height'] == _stride[0] + assert hls_conv_attr['data_format'] == conv.data_format + + if conv.data_format == 'channels_first': + assert hls_conv_attr['n_chan'] == inp_shape[0] + assert hls_conv_attr['in_height'] == inp_shape[1] + assert hls_conv_attr['in_width'] == inp_shape[2] + assert hls_conv_attr['out_height'] == out_shape[1] + assert hls_conv_attr['out_width'] == out_shape[2] + elif model.layers[0].data_format == 'channels_last': + assert hls_conv_attr['n_chan'] == inp_shape[2] + assert hls_conv_attr['in_height'] == inp_shape[0] + assert hls_conv_attr['in_width'] == inp_shape[1] + assert hls_conv_attr['out_height'] == out_shape[0] + assert hls_conv_attr['out_width'] == out_shape[1] + + if conv.padding == 'same': + if conv.data_format == 'channels_first': + h_pad = math.ceil(inp_shape[1] / kh) * kh - inp_shape[1] + w_pad = math.ceil(inp_shape[2] / kw) * kw - inp_shape[2] + elif model.layers[0].data_format == 'channels_last': + h_pad = math.ceil(inp_shape[0] / kh) * kh - inp_shape[0] + w_pad = math.ceil(inp_shape[1] / kw) * kw - inp_shape[1] + else: + raise ValueError('Invalid data_format') + pad_top = h_pad // 2 + pad_bottom = h_pad - pad_top + pad_left = w_pad // 2 + pad_right = w_pad - pad_left + assert hls_conv_attr['pad_top'] == pad_top + assert hls_conv_attr['pad_bottom'] == pad_bottom + assert hls_conv_attr['pad_left'] == pad_left + assert hls_conv_attr['pad_right'] == pad_right + elif model.layers[0].padding == 'valid': + assert hls_conv_attr['pad_top'] == 0 + assert hls_conv_attr['pad_bottom'] == 0 + assert hls_conv_attr['pad_left'] == 0 + assert hls_conv_attr['pad_right'] == 0 + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +def test_depthwise2d(backend, io_type): + ''' + Test proper handling of DepthwiseConv2D + ''' + X = np.random.rand(10, 32, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = keras.models.Sequential([keras.layers.Input((32, 32, 3)), DepthwiseConv2D(kernel_size=(3, 3))]) + model.compile() + + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision='fixed<32,12>', backend=backend + ) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv2d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) # type: ignore + + +# Currently only Vivado and Vitis is supported for io_stream. +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_stream']) +def test_depthwise1d(backend, io_type): + ''' + Test proper handling of DepthwiseConv1D. + ''' + X = np.random.rand(10, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = keras.Sequential([DepthwiseConv1D(kernel_size=3, input_shape=(32, 3))]) + model.compile() + + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv1d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) # type: ignore + + +pooling_layers = [MaxPooling1D, MaxPooling2D, AveragePooling1D, AveragePooling2D] + + +@pytest.mark.parametrize('pooling', pooling_layers) +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('chans', chans_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +def test_pooling(pooling, padds, chans, backend): + assert '1D' in pooling.__name__ or '2D' in pooling.__name__ + + input_shape = (18, 15, 3) if '2D' in pooling.__name__ else (121, 3) + pool_size = (4, 2) if '2D' in pooling.__name__ else 2 + + X_input = np.random.rand(100, *input_shape) + + keras_model = keras.Sequential([pooling(pool_size, padding=padds, input_shape=input_shape)]) + keras_model.compile() + + hls_cfg = hls4ml.utils.config_from_keras_model(keras_model) + output_dir = str( + test_root_path / f'hls4mlprj_keras_api_pooling_{pooling.__name__}_channels_{chans}_padds_{padds}_backend_{backend}' + ) + hls_model = hls4ml.converters.convert_from_keras_model( + keras_model, hls_config=hls_cfg, output_dir=output_dir, backend=backend + ) + hls_model.compile() + + # Verify accuracy + keras_prediction = keras_model.predict(X_input) + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=3e-2) + + # # Verify correct parsing of layer + # hls_pool = list(hls_model.get_layers())[-1] + # ker_pool = keras_model.layers[-1] + # if '2D' in pooling.__name__: + # assert hls_pool.attributes['name'] == ker_pool._name + # assert hls_pool.attributes['class_name'][-2] == str(2) + # assert hls_pool.attributes['stride_height'] == ker_pool.strides[0] + # assert hls_pool.attributes['stride_width'] == ker_pool.strides[1] + # assert hls_pool.attributes['pool_height'] == ker_pool.pool_size[1] + # assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0] + + # if hls_pool.attributes['data_format'] == 'channels_last': + # assert hls_pool.attributes['in_height'] == ker_pool.input_shape[1] + # assert hls_pool.attributes['in_width'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[3] + # elif hls_pool.attributes['data_format'] == 'channels_first': + # assert hls_pool.attributes['in_height'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['in_width'] == ker_pool.input_shape[3] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[1] + + # if ker_pool.padding == 'same': + # # Height + # in_height = ker_pool.input_shape[1] + # if ker_pool.data_format == 'channels_first': + # in_height = ker_pool.input_shape[2] + # out_height = int(math.ceil(float(in_height) / float(ker_pool.strides[0]))) + # assert out_height == hls_pool.attributes['out_height'] + # if in_height % ker_pool.strides[0] == 0: + # pad_along_height = max(ker_pool.pool_size[1] - ker_pool.strides[0], 0) + # else: + # pad_along_height = max(ker_pool.pool_size[1] - (in_height % ker_pool.strides[0]), 0) + # pad_top = pad_along_height // 2 + # pad_bottom = pad_along_height - pad_top + # assert pad_bottom == hls_pool.attributes['pad_bottom'] + # assert pad_top == hls_pool.attributes['pad_top'] + + # # Width + # in_width = ker_pool.input_shape[2] + # if ker_pool.data_format == 'channels_first': + # in_height = keras_model.layers[1].input_shape[-1] + # out_width = int(math.ceil(float(in_width) / float(ker_pool.strides[1]))) + # assert out_width == hls_pool.attributes['out_width'] + # if in_width % ker_pool.strides[1] == 0: + # pad_along_width = max(ker_pool.pool_size[0] - ker_pool.strides[1], 0) + # else: + # pad_along_width = max(ker_pool.pool_size[0] - (in_width % ker_pool.strides[1]), 0) + # pad_left = pad_along_width // 2 + # pad_right = pad_along_width - pad_left + # assert pad_left == hls_pool.attributes['pad_left'] + # assert pad_right == hls_pool.attributes['pad_right'] + + # elif ker_pool.padding == 'valid': + # if hls_pool.attributes['data_format'] == 'channels_first': + # in_height = ker_pool.input_shape[2] + # in_width = ker_pool.input_shape[3] + # elif hls_pool.attributes['data_format'] == 'channels_last': + # in_height = ker_pool.input_shape[1] + # in_width = ker_pool.input_shape[2] + # else: + # raise ValueError('Invalid data_format') + + # out_width = int(math.ceil(float(in_width - ker_pool.pool_size[0] + 1) / float(ker_pool.strides[1]))) + # out_height = int(math.ceil(float(in_height - ker_pool.pool_size[1] + 1) / float(ker_pool.strides[0]))) + + # assert hls_pool.attributes['out_height'] == out_height + # assert hls_pool.attributes['out_width'] == out_width + # assert hls_pool.attributes['pad_top'] == 0 + # assert hls_pool.attributes['pad_bottom'] == 0 + # assert hls_pool.attributes['pad_left'] == 0 + # assert hls_pool.attributes['pad_right'] == 0 + + # elif '1D' in pooling.__name__: + # assert hls_pool.attributes['name'] == ker_pool._name + # assert hls_pool.attributes['class_name'][-2] == str(1) + # assert hls_pool.attributes['n_in'] == ker_pool.input_shape[1] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0] + # assert hls_pool.attributes['stride_width'] == ker_pool.strides[0] + + # out_same = math.ceil(float(ker_pool.input_shape[1]) / float(ker_pool.strides[0])) + # out_valid = math.ceil(float(ker_pool.input_shape[1] - ker_pool.pool_size[0] + 1) / ker_pool.strides[0]) + + # if ker_pool.padding == 'same': + # assert hls_pool.attributes['n_out'] == out_same + # if ker_pool.input_shape[1] % ker_pool.strides[0] == 0: + # pad_along_width = max(ker_pool.pool_size[0] - ker_pool.strides[0], 0) + # else: + # pad_along_width = max(ker_pool.pool_size[0] - (ker_pool.input_shape[1] % ker_pool.strides[0]), 0) + # assert hls_pool.attributes['pad_left'] == pad_along_width // 2 + # assert hls_pool.attributes['pad_right'] == pad_along_width - pad_along_width // 2 + + # elif ker_pool.padding == 'valid': + # assert hls_pool.attributes['n_out'] == out_valid + # assert hls_pool.attributes['pad_left'] == 0 + # assert hls_pool.attributes['pad_right'] == 0 + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_reused_layer(backend, io_type): + + inp1 = keras.layers.Input(shape=(10, 10)) + inp2 = keras.layers.Input(shape=(10, 10)) + + conv = keras.layers.Conv1D(2, 3, activation='relu') + + o1 = conv(inp1) + o2 = conv(inp2) + o3 = keras.layers.Add()([o1, o2]) + o4 = keras.layers.Dense(5)(o3) + + _ = keras.layers.Dense(5)(o3) + + model = keras.models.Model(inputs=[inp1, inp2], outputs=[o1, o2, o3, o4]) + + _ = model([inp1, inp1]) + + hls_config = {'Model': {'Precision': 'ap_fixed<32,8>', 'ReuseFactor': 1}} + output_dir = str(test_root_path / f'hls4mlprj_keras_api_conv1d_{backend}_{io_type}') + + model_hls = hls4ml.converters.convert_from_keras_model( + model, backend=backend, io_type=io_type, hls_config=hls_config, output_dir=output_dir + ) + + model_hls.compile() + + data = [np.random.rand(1000, 10, 10).astype(np.float32), np.random.rand(1000, 10, 10).astype(np.float32)] + keras_pred = model.predict(data) + hls_pred = model_hls.predict(data) + + np.testing.assert_allclose(keras_pred[0].reshape(hls_pred[0].shape), hls_pred[0], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[1].reshape(hls_pred[1].shape), hls_pred[1], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[2].reshape(hls_pred[2].shape), hls_pred[2], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[3].reshape(hls_pred[3].shape), hls_pred[3], rtol=0, atol=1e-2) diff --git a/test/pytest/test_pytorch_api.py b/test/pytest/test_pytorch_api.py index d182d9ae16..e1cef48401 100644 --- a/test/pytest/test_pytorch_api.py +++ b/test/pytest/test_pytorch_api.py @@ -877,3 +877,79 @@ def forward(self, x): rtol = 0 atol = 5.0e-2 np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=rtol, atol=atol) + + +class EinsumOuterProduct(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum('bi,bj->bij', x, y) + + +class EinsumBatchMatMul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum('bij,bjk->bik', x, y) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +def test_einsum_outer_product(backend, io_type): + + model = EinsumOuterProduct() + model.eval() + + X_input = np.random.rand(3, 4) + Y_input = np.random.rand(3, 5) + + pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(Y_input)).detach().numpy() + + config = config_from_pytorch_model( + model, + [(None, 4), (None, 5)], + default_precision='ap_fixed<16,6>', + channels_last_conversion="internal", + transpose_outputs=False, + ) + output_dir = str(test_root_path / f'hls4mlprj_pytorch_einsum_outer_product_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) + + hls_model.compile() + print(X_input, Y_input) + hls_prediction = np.reshape(hls_model.predict([X_input, Y_input]), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +def test_einsum_batch_matmul(backend, io_type): + + model = EinsumBatchMatMul() + model.eval() + + X_input = np.random.rand(3, 2, 5) + Y_input = np.random.rand(3, 5, 4) + + pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(Y_input)).detach().numpy() + + config = config_from_pytorch_model( + model, + [(None, 2, 5), (None, 5, 4)], + default_precision='ap_fixed<16,6>', + channels_last_conversion="internal", + transpose_outputs=False, + ) + output_dir = str(test_root_path / f'hls4mlprj_pytorch_einsum_batch_matmul_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) + + hls_model.compile() + print(X_input, Y_input) + hls_prediction = np.reshape(hls_model.predict([X_input, Y_input]), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01)