Get the PyTree structure of flax model without initializing weights #1421
Unanswered
patil-suraj
asked this question in
Show and tell
Replies: 1 comment
-
Thanks for this great tip @patil-suraj |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
the
jax.eval_shape
function can be used to get the PyTree structure of flax model params and optimizer state (any JAX function for that matter) without having to actually initialize them.this should give
Beta Was this translation helpful? Give feedback.
All reactions