-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add mixed precision support to deepxde #1650
base: master
Are you sure you want to change the base?
Conversation
d3590e4
to
1352213
Compare
70d6c15
to
b30dc78
Compare
If I use the API here to set mixed precision, then all the demo code can run in mixed precision? |
Not all. The L-BFGS optimizer doesn't work in mixed precision. But if you add the line |
@@ -74,7 +74,7 @@ def set_default_float(value): | |||
The default floating point type is 'float32'. | |||
|
|||
Args: | |||
value (String): 'float16', 'float32', or 'float64'. | |||
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision in https://arxiv.org/abs/2401.16645). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision).
@@ -74,7 +74,7 @@ def set_default_float(value): | |||
The default floating point type is 'float32'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default floating point type is 'float32'. Mixed precision uses the method in the paper: `J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision. Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.
|
||
self.opt.step(closure) | ||
def closure_mixed(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why delete line 360
@@ -85,6 +85,20 @@ def set_default_float(value): | |||
elif value == "float64": | |||
print("Set the default float type to float64") | |||
real.set_float64() | |||
elif value == "mixed": | |||
print("Set the float type to mixed precision of float16 and float32") | |||
real.set_mixed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is confusing. Here you do real.set_mixed()
, but later you do either real.set_float16()
or real.set_float32()
. It seems you only need a flag mixed
. You can do this flag after line 42.
No description provided.