You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I know that this is something that have been asked before, but I just wanted to ask again: Is there any plan to eventually support exporting Flax models in the ONNX format?
ONNX has become a very popular way of distributing models as some kind of lingua franca that is supported by various inference engines like TensorRT, NCNN, OpenVINO, ONNX Runtime, et cetera...
The current workaround is to convert your Flax model to TensorFlow first using jax2tf, and then converting to ONNX from that using tf2onnx. While this works, the resulting ONNX models often contain various unnecessary steps and even simple things aren't mapped to the expected corresponding operations. I've also only been able to get this to work with enable_xla=False in the jax2tf conversion, which has been deprecated.
I understand that there's an argument to be made that perhaps it would make more sense to have this at the JAX level instead, but honestly I don't think ONNX is very popular outside of ML and doing it at the Flax level would maybe make it easier for the operations to be mapped 1:1.
Hi @Artoriuz, is there a notion of a "Module" in ONNX or what you mean by this is that we should provide a helper function to easily map to ONNX? If its the latter I agree we could add that, I've been wanting to add a saved_model helper as well.
Should is probably a strong word, I don't think I'm qualified to tell you what you should do. I was just looking forward to adopting Flax as my main ML library going forward and found some rough edges along the way (but I really liked everything else!).
And yes, I just think it would be very convenient to have a "native" way of mapping JAX operations (and Flax modules by extension) into the corresponding ONNX operations without having to use TF as a bridge. Most things should have a 1:1 counterpart anyway.
Since this would be more oriented towards inference, "all" we need is to export the final forward step. The notion of a "Module" would be lost, but that's fine (the entire upper level nnx.Module would be a single ONNX model).
Hi!
I know that this is something that have been asked before, but I just wanted to ask again: Is there any plan to eventually support exporting Flax models in the ONNX format?
ONNX has become a very popular way of distributing models as some kind of lingua franca that is supported by various inference engines like TensorRT, NCNN, OpenVINO, ONNX Runtime, et cetera...
The current workaround is to convert your Flax model to TensorFlow first using
jax2tf
, and then converting to ONNX from that usingtf2onnx
. While this works, the resulting ONNX models often contain various unnecessary steps and even simple things aren't mapped to the expected corresponding operations. I've also only been able to get this to work withenable_xla=False
in thejax2tf
conversion, which has been deprecated.I understand that there's an argument to be made that perhaps it would make more sense to have this at the JAX level instead, but honestly I don't think ONNX is very popular outside of ML and doing it at the Flax level would maybe make it easier for the operations to be mapped 1:1.
FWIW, Equinox supports this by bridging through TF first as well, so that seems to be the status quo everywhere.
Thanks in advance!
The text was updated successfully, but these errors were encountered: