Confused about dtype and precision #3987
Unanswered
davidshen84
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am a bit confused about the
dtype
,param_dtype
and theprecision
parameters in some of theflax.linen
modules.According to the document of
Conv
, it has these parameters to control the precision:dtype
: can infer from the inputparam_dtype
: default tofloat32
precision
: default toNone
; I guess it is resolved todefault
? https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.PrecisionIf I want to use
bfloat16
for my model, which parameter should I use?Also, I found the return value of
nn.Conv.apply
is not controlled bydtype
norprecision
but byparam_dtype
.For example, if I want to create a simple 2-layer conv net and do not set any of these parameters, then the 1st
conv
layer's precision can be controlled by the input type, but the 2ndconv
layer's precision is controlled by the output type of the first layer, which is alwaysfloat32
.Should I explicitly set all the
param_dtype
parameters of all the layers?Is there a way to control the precision globally? I guess it would cause trouble for some layers, like
BatchNorm,
which always prefers higher precision.Do we have official guidelines on controlling the model precision and utilising hardware features?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions