Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export for Pytorch Model #93

Open
H-G-11 opened this issue Sep 20, 2023 · 8 comments
Open

Add ONNX export for Pytorch Model #93

H-G-11 opened this issue Sep 20, 2023 · 8 comments

Comments

@H-G-11
Copy link

H-G-11 commented Sep 20, 2023

Hello,

I have been using the Pytorch model on a raspberry Pi. I am running it on a 2 seconds audio to detect a Wakeword every 200ms (see this issue regarding why I am not running it on the independents 200ms chunks). It takes between 10ms and 40ms to run it.

The performance are still good, but could be improved with ONNX. I have therefore tried to export the model to ONNX (code below), and got several errors (below too).

This is not an issue with this repository. Simply, Pytorch does not support exporting neither istft nor stft to ONNX. See this issue that tracks it down.

Nonetheless, on our end, we could maybe use directly ONNX STFT. For the istft, it seems that they very recently are thinking about adding it (see this issue), but that is is still not here.

What are your thought on this?

Note: once a model is exported to ONNX, the parameters cannot be changed as far as I know. So, probably, a great thing to do here would be to allow the export with a to_onnx method on an instantiated TorchGate object. If we find a solution for this istft and stft issue, I'd be willing to make a PR for it :)

Note 2: From here and here it seems that we might just have to wait until torch.onnx supports opset 19, which should contain the other operators... Not sure though

Annex:

Code to export to ONNX (to be put here):

if __name__ == "__main__":
    import torch
    import torchaudio

    data, _ = torchaudio.load("path/to/test/file.wav")

    model = TorchGate(
        sr=16000,
        nonstationary=False,
        n_fft=1024,
        prop_decrease=0.8,
        n_std_thresh_stationary=2,
        freq_mask_smooth_hz=None,
        time_mask_smooth_ms=None,
    )

    torch.onnx.export(
        model,
        data,
        "noise_supression.onnx",
        verbose=True,
        input_names=["x"],
        output_names=["y"],
        opset_version=17
    )

Error:

Exporting the operator 'aten::stft' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
@nuniz
Copy link
Collaborator

nuniz commented Sep 23, 2023

Hi HuguesGallier,
Thanks for the feedback, I'm glad to see that you are using the spectral gate as a nn.module :-).

The STFT or iSTFT operations can be performed externally (you need to remove the STFT and iSTFT computations inside the spectral gating code), or you can implement the STFT operation as a nn.module using conv1d and precomputed Fourier basis and integrate it with the spectral gate (see this issue: pytorch/pytorch#31317). Since it is scheduled to be supported in the next op set, we think it is unnecessary to add it to noisereduce.

By the way, if you are only running the spectral gating on 2 seconds of audio, it may not be enough, as it expects both noise and speech to be in the same recording. I suggest that you capture the noise profile externally and pass it to the y_noise argument. We may add the ability to continuously stream noise statistics to the streamer function in the future.

I hope this is helpful!

@DakeQQ
Copy link

DakeQQ commented Dec 20, 2024

Feel free to use this repo to export your custom STFT or ISTFT process to ONNX format. There’s no need to separate the STFT and ISTFT from the model anymore.

@kabyanil
Copy link

kabyanil commented Dec 20, 2024

Feel free to use this repo to export your custom STFT or ISTFT process to ONNX format. There’s no need to separate the STFT and ISTFT from the model anymore.

I am trying to export the quartznet model from NeMo. During the preprocessing of the audio signal, torch.stft() is used [line 437 in source code], which causes errors in exporting the preprocessor class to onnx. I am looking for a replacement for the torch.stft() function.

When I pass a dummy audio signal as input to the preprocessor, torch.stft() outputs a tensor of shape torch.Size([1, 257, 61]), and your package outputs a real part and an imaginary part of shape torch.Size([257, 61]) torch.Size([257, 61]).

The NeMo code then passes the torch.stft() output to torch.view_as_real(). I'm not sure how to pass the real and imaginary parts output by your package to this function.

Another question, since your package used conv1d layers, I assume during the model training, these layers get updated as well. Since these layers are being initialized randomly, wouldn't it affect the outputs? I'm using your package only for inference.

@DakeQQ
Copy link

DakeQQ commented Dec 20, 2024

You can use stft_A, which outputs only the real part, if your task requires just the real component.

custom_stft = STFT_Process(
    model_type='stft_A', 
    n_fft=NFFT, 
    n_mels=N_MELS, 
    hop_len=HOP_LENGTH, 
    max_frames=0, 
    window_type=WINDOW_TYPE
).eval()

real_part = custom_stft(audio, pad_mode='constant')

In random tests, this custom STFT/ISTFT shows almost no difference compared to torch.stft(). In real-world tasks such as ASR, denoising, VAD, speaker identification, TTS, and other audio applications, this custom STFT/ISTFT performs exceptionally well with no noticeable differences from the PyTorch version. Feel free to use it with confidence.

@kabyanil
Copy link

Thanks. Do I need to train it? I'm using a pretrained model for inference only.

@DakeQQ
Copy link

DakeQQ commented Dec 20, 2024

No, you don't need to train it.

@kabyanil
Copy link

Hi @DakeQQ , I successfully exported the preprocessor to onnx by replacing the torch.stft() with your custom implementation. However when I load it in onnxruntime web and try to run inference, I get the following error -

Uncaught (in promise) 20088320 vad.html:133:53
    inference http://127.0.0.1:5500/templates/vad.html:133
    AsyncFunctionThrow self-hosted:804
    (Async: async)
    onSpeechEnd http://127.0.0.1:5500/templates/vad.html:82
    handleFrameProcessorEvent https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js:1
    processFrame https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js:1
    AsyncFunctionNext self-hosted:800
    (Async: async)
    onmessage https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js:1
    (Async: EventHandlerNonNull)
    new https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js:1
    InterpretGeneratorResume self-hosted:1413
    AsyncFunctionNext self-hosted:800
    (Async: async)
    new https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js:1
    InterpretGeneratorResume self-hosted:1413
    AsyncFunctionNext self-hosted:800
    (Async: async)
    <anonymous> http://127.0.0.1:5500/templates/vad.html:45
    InterpretGeneratorResume self-hosted:1413
    AsyncFunctionNext self-hosted:800

I tried loading the same onnx model in python like so -

import onnxruntime as ort

session = ort.InferenceSession('./preprocessor.onnx')

test_1 = np.random.rand(1, 19820).astype(np.float32)
test_2 = np.random.randint(low=0, high=100, size=(19820,), dtype=np.int64)

res = session.run(None, {"input_signal": test_1, "length": test_2})

(res_1, res_2) = res

print(res_1.shape, res_2.shape)

Here I get the output as expected -

(19820, 80, 123) (19820,)

Why is the inference working in python, but not in web?

@DakeQQ
Copy link

DakeQQ commented Dec 21, 2024

@kabyanil
I’m not sure about the issue you encountered, as I’m not familiar with ONNX Runtime Web.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants