Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for exporting models in the ONNX format #4430

Open
Artoriuz opened this issue Dec 12, 2024 · 2 comments
Open

Support for exporting models in the ONNX format #4430

Artoriuz opened this issue Dec 12, 2024 · 2 comments

Comments

@Artoriuz
Copy link

Artoriuz commented Dec 12, 2024

Hi!

I know that this is something that have been asked before, but I just wanted to ask again: Is there any plan to eventually support exporting Flax models in the ONNX format?

ONNX has become a very popular way of distributing models as some kind of lingua franca that is supported by various inference engines like TensorRT, NCNN, OpenVINO, ONNX Runtime, et cetera...

The current workaround is to convert your Flax model to TensorFlow first using jax2tf, and then converting to ONNX from that using tf2onnx. While this works, the resulting ONNX models often contain various unnecessary steps and even simple things aren't mapped to the expected corresponding operations. I've also only been able to get this to work with enable_xla=False in the jax2tf conversion, which has been deprecated.

I understand that there's an argument to be made that perhaps it would make more sense to have this at the JAX level instead, but honestly I don't think ONNX is very popular outside of ML and doing it at the Flax level would maybe make it easier for the operations to be mapped 1:1.

FWIW, Equinox supports this by bridging through TF first as well, so that seems to be the status quo everywhere.

Thanks in advance!

@cgarciae
Copy link
Collaborator

cgarciae commented Dec 12, 2024

Hi @Artoriuz, is there a notion of a "Module" in ONNX or what you mean by this is that we should provide a helper function to easily map to ONNX? If its the latter I agree we could add that, I've been wanting to add a saved_model helper as well.

@Artoriuz
Copy link
Author

Artoriuz commented Dec 12, 2024

Should is probably a strong word, I don't think I'm qualified to tell you what you should do. I was just looking forward to adopting Flax as my main ML library going forward and found some rough edges along the way (but I really liked everything else!).

And yes, I just think it would be very convenient to have a "native" way of mapping JAX operations (and Flax modules by extension) into the corresponding ONNX operations without having to use TF as a bridge. Most things should have a 1:1 counterpart anyway.

Since this would be more oriented towards inference, "all" we need is to export the final forward step. The notion of a "Module" would be lost, but that's fine (the entire upper level nnx.Module would be a single ONNX model).

Just as reference, PyTorch has 2 distinct exporters:
https://pytorch.org/docs/stable/onnx_dynamo.html
https://pytorch.org/docs/stable/onnx.html

And this is what is generally used to convert from either TF or Keras:
https://github.com/onnx/tensorflow-onnx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants