Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch: Unable to Compile Softmax or Sigmoid Layers Correctly #1083

Open
3 tasks done
sei-rquartiano opened this issue Oct 16, 2024 · 3 comments
Open
3 tasks done

PyTorch: Unable to Compile Softmax or Sigmoid Layers Correctly #1083

sei-rquartiano opened this issue Oct 16, 2024 · 3 comments
Labels

Comments

@sei-rquartiano
Copy link
Contributor

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.

Quick summary

Unable to compile a PyTorch softmax or sigmoid layer. Attempted both module and functional calls; got different error messages described in next section

Details

For the nn.Softmax() call I get a KeyError: 'axis' error which stops config generation. For nn.Sigmoid() and functional softmax or sigmoid calls it gets through compilation, however an AssertionError shows that the values are out of tolerance. More importantly than the tolerance, they appear to output the same incorrect values every time.

Steps to Reproduce

This is the code I've used to investigate errors with pytorch softmax/sigmoid calls. To switch between them comment/uncomment the appropriate lines in the model definition towards the top of the file. I also added a linear layer as an example of something that would normally precede a softmax. Doesn't seem to affect the error.

  1. Clone hls4ml repository at current master branch
  2. Run this test code
  3. Change model definition to a different softmax/sigmoid call, repeat step 2
from pathlib import Path

import numpy as np
import pytest
import tensorflow as tf
from sklearn.metrics import accuracy_score
import torch.nn as nn
import torch
#from torchsummary import summary

import hls4ml
import os

from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model

test_root_path = Path(__file__).parent

class SoftmaxModel(nn.Module):
    def __init__(self):
        super().__init__()
        #self.linear = nn.Linear(10,4)
        self.softmax = nn.Softmax(dim=-1)
        #self.softmax = nn.Sigmoid()
    def forward(self, x):
        #x = self.linear(x)
        return self.softmax(x)
        #return nn.functional.softmax(x,dim=-1)
        #return nn.functional.sigmoid(x)

if __name__ == "__main__":

    n_in = 2
    size_in = 10
    n_batch = 3

    model = SoftmaxModel()
    model.eval()
    print(model)

    X_input = np.random.rand(n_batch, n_in, size_in)
    with torch.no_grad():
        pytorch_prediction = model(torch.Tensor(X_input))

    config = config_from_pytorch_model(model,
                                       (None, n_in, size_in),
                                       channels_last_conversion='internal',
                                       transpose_outputs=False)
    config['Model']['Strategy'] = 'Resource'
    config['Model']['Precision'] = 'ap_fixed<32,12>'
    print(config)

    backend='Vitis'
    output_dir = str(test_root_path / f'hls4mlprj_softmax_{backend}_io_stream')

    hls_model = convert_from_pytorch_model(
    model,
    output_dir=output_dir,
    backend=backend,
    hls_config=config,
    io_type='io_stream',
    )
    print(list(hls_model.get_layers()))
    hls_model.compile()

    # X_input_hls is channels last
    X_input_hls = np.ascontiguousarray(X_input.transpose(0, 2, 1))
    # write tb data
    ipf = output_dir + "/tb_data/tb_input_features.dat"
    if os.path.isfile(ipf):
        os.remove(ipf)
    np.savetxt(ipf, X_input_hls.flatten(), newline=" ")

    hls_prediction = hls_model.predict(X_input_hls)

    print("X_input")
    print(X_input)

    print("pytorch_prediction")
    print(pytorch_prediction)
    # write tb data
    opf = output_dir + "/tb_data/tb_output_predictions.dat"
    if os.path.isfile(opf):
        os.remove(opf)
    with open(opf, "ab") as f:
        for p in pytorch_prediction:
            np.savetxt(f, p.flatten(), newline=" ")


    out_height = pytorch_prediction.shape[-1]
    n_out = n_in

    hls_prediction = np.transpose(
            np.reshape(hls_model.predict(X_input_hls),
                       (n_batch, out_height, n_out)),
            (0, 2, 1)
    )

    print("hls_prediction")
    print(hls_prediction)

    rtol = 1.0e-5
    atol = 5.0e-2
    for p, h in zip(pytorch_prediction, hls_prediction):
        np.testing.assert_allclose(p,
                                   h,
                                   rtol=rtol, atol=atol)

Expected behavior

Successful model compilation

Actual behavior

Output and error message for softmax module call:

python softmax_test.py 
2024-10-16 20:26:40.713174: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-16 20:26:40.748742: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-16 20:26:40.749201: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-10-16 20:26:41.325078: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
SoftmaxModel(
  (softmax): Softmax(dim=-1)
)
{'Model': {'Precision': 'ap_fixed<32,12>', 'ReuseFactor': 1, 'ChannelsLastConversion': 'internal', 'TransposeOutputs': False, 'Strategy': 'Resource'}, 'PytorchModel': SoftmaxModel(
  (softmax): Softmax(dim=-1)
), 'InputShape': (None, 2, 10)}
Interpreting Model ...
Topology:
Layer name: softmax, layer type: Softmax, input shape: [[None, 2, 10]]
Creating HLS model
WARNING: Changing pipeline style to "dataflow".
Traceback (most recent call last):
  File "/home/hls4ml-user/work/softmax_test.py", line 54, in <module>
    hls_model = convert_from_pytorch_model(
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/converters/__init__.py", line 308, in convert_from_pytorch_model
    return pytorch_to_hls(config)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/converters/pytorch_to_hls.py", line 374, in pytorch_to_hls
    hls_model = ModelGraph(config, layer_list, inputs=input_layers)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 390, in __init__
    self.apply_flow(flow)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 452, in apply_flow
    self._apply_sub_flow(flow, applied_flows)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 461, in _apply_sub_flow
    self._apply_sub_flow(sub_flow, applied_flows)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 464, in _apply_sub_flow
    applied_passes = optimize_model(self, flow.optimizers)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/optimizer/optimizer.py", line 318, in optimize_model
    res = opt.transform(model, node)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/backends/template.py", line 19, in transform
    formatted_template = self.format(node)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/backends/vivado/passes/core_templates.py", line 166, in format
    return self.template.format(**params)
KeyError: 'axis'

Error message for sigmoid module or any functional calls:

python softmax_test.py 
2024-10-16 20:29:05.267343: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-16 20:29:05.302741: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-16 20:29:05.303172: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-10-16 20:29:05.888829: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
SoftmaxModel()
{'Model': {'Precision': 'ap_fixed<32,12>', 'ReuseFactor': 1, 'ChannelsLastConversion': 'internal', 'TransposeOutputs': False, 'Strategy': 'Resource'}, 'PytorchModel': SoftmaxModel(), 'InputShape': (None, 2, 10)}
Interpreting Model ...
Topology:
Layer name: softmax, layer type: Softmax, input shape: [[None, 2, 10]]
Creating HLS model
WARNING: Changing pipeline style to "dataflow".
[<hls4ml.backends.fpga.fpga_backend.VitisInput object at 0x7f99d0067ee0>, <hls4ml.backends.fpga.fpga_backend.VitisSoftmax object at 0x7f99d0067e50>]
Writing HLS project
Done
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
X_input
[[[0.05363817 0.31153485 0.75046616 0.58199667 0.20053846 0.72763441
   0.06146303 0.84705805 0.98926526 0.01218201]
  [0.21202604 0.46483977 0.80683202 0.04741267 0.55840604 0.81320857
   0.22513971 0.72937851 0.26433831 0.67033987]]

 [[0.156527   0.20964269 0.61042486 0.49824273 0.97225156 0.0401948
   0.45337303 0.02984117 0.93368881 0.39622952]
  [0.61855931 0.91478427 0.14109976 0.792937   0.89176405 0.99400871
   0.80888843 0.72983924 0.41402195 0.83616711]]

 [[0.12493313 0.96566315 0.94087154 0.54928708 0.17520313 0.03779079
   0.86541598 0.07419386 0.65669758 0.18357249]
  [0.65256606 0.98119094 0.50356697 0.76214902 0.12594131 0.78394119
   0.24394984 0.63036021 0.16529908 0.28997437]]]
pytorch_prediction
tensor([[[0.0631, 0.0817, 0.1267, 0.1070, 0.0731, 0.1238, 0.0636, 0.1395,
          0.1609, 0.0605],
         [0.0740, 0.0953, 0.1341, 0.0628, 0.1046, 0.1350, 0.0750, 0.1241,
          0.0780, 0.1170]],

        [[0.0721, 0.0761, 0.1136, 0.1015, 0.1631, 0.0642, 0.0971, 0.0636,
          0.1569, 0.0917],
         [0.0884, 0.1189, 0.0548, 0.1053, 0.1162, 0.1287, 0.1069, 0.0988,
          0.0721, 0.1099]],

        [[0.0672, 0.1557, 0.1519, 0.1027, 0.0706, 0.0616, 0.1409, 0.0639,
          0.1143, 0.0712],
         [0.1105, 0.1535, 0.0952, 0.1233, 0.0653, 0.1260, 0.0734, 0.1081,
          0.0679, 0.0769]]])
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
hls_prediction
[[[ 1.85546875e-02  1.90000000e+01 -1.02400000e+03  0.00000000e+00
    0.00000000e+00  1.85546875e-02  1.90000000e+01 -1.02400000e+03
    0.00000000e+00  0.00000000e+00]
  [ 1.00000000e+00  1.02400000e+03  0.00000000e+00  0.00000000e+00
    0.00000000e+00  1.00000000e+00  1.02400000e+03  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]

 [[ 1.85546875e-02  1.90000000e+01 -1.02400000e+03  0.00000000e+00
    0.00000000e+00  1.85546875e-02  1.90000000e+01 -1.02400000e+03
    0.00000000e+00  0.00000000e+00]
  [ 1.00000000e+00  1.02400000e+03  0.00000000e+00  0.00000000e+00
    0.00000000e+00  1.00000000e+00  1.02400000e+03  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]

 [[ 1.85546875e-02  1.90000000e+01 -1.02400000e+03  0.00000000e+00
    0.00000000e+00  1.85546875e-02  1.90000000e+01 -1.02400000e+03
    0.00000000e+00  0.00000000e+00]
  [ 1.00000000e+00  1.02400000e+03  0.00000000e+00  0.00000000e+00
    0.00000000e+00  1.00000000e+00  1.02400000e+03  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]]
Traceback (most recent call last):
  File "/home/hls4ml-user/softmax_test.py", line 103, in <module>
    np.testing.assert_allclose(p,
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/numpy/testing/_private/utils.py", line 1592, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/numpy/testing/_private/utils.py", line 862, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-05, atol=0.05

Mismatched elements: 19 / 20 (95%)
Max absolute difference: 1024.13953553
Max relative difference: 5.673692
 x: array([[0.063111, 0.081679, 0.126688, 0.107046, 0.073098, 0.123828,
        0.063607, 0.139536, 0.160859, 0.060548],
       [0.074005, 0.095292, 0.134147, 0.062772, 0.104638, 0.135005,
        0.074982, 0.124149, 0.077979, 0.117031]], dtype=float32)
 y: array([[ 1.855469e-02,  1.900000e+01, -1.024000e+03,  0.000000e+00,
         0.000000e+00,  1.855469e-02,  1.900000e+01, -1.024000e+03,
         0.000000e+00,  0.000000e+00],...

Note: the output values between nn.Sigmoid() and nn.functional.softmax/sigmoid are different, but they're the same for each function and are always outside of tolerance.

Optional

Possible Fix

Isn't a fix as much as it's something I noticed. When I originally encountered the module call error I wondered whether it had to do with the specific keyword 'axis' since it's that in Keras but 'dim' in PyTorch. However when I looked in converters/core.py: parse_activation_layer() I notice two things

  1. The axis vs dim keyword discrepancy seems to be handled in both module and functional calls, however since I'm getting a KeyError maybe its not functioning properly?
  2. Softmax is missing from the ['layer_name'] if statements inside the function (it is in the global activation_layers list outside it), however Sigmoid is there and adding softmax to the list doesn't seem to do anything

Thanks for looking into this

@returnwellbeing
Copy link

Hi, I encountered this issue before. I added two lines below https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/converters/pytorch/core.py#L66-L67

if hasattr(node, 'dim'):
            layer['axis'] = class_object.dim
if hasattr(class_object, 'dim'): # temperal solution
            layer['axis'] = class_object.dim

when parsing the pytorch's softmax, there is no dim in node. I found dim is in class_object. I hope this helps you:)

@JanFSchulte
Copy link
Contributor

Hi! Thanks for reporting these issues. The softmax issues are indeed a bug that we overlooked because softmax activation was missing from our tests. This is fixed in #1086.

For the issues with numerical correctness in sigmoid and softmax, this seems to be related to the io_type setting. If I run your setup with io_parallel, everything seems to work. The issues with io_stream are not reproduced in our tests, so I'm not quite sure yet what's going wrong there.

@sei-rquartiano
Copy link
Contributor Author

Thank you both for the input! I will switch over to that PR branch until it's merged and change my 'io_type' to 'io_parallel.' If I manage to figure out what's going on with 'io_stream' on my end I'll be sure to let you know. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants