Skip to content

Commit 22a38b0

Browse files
update develop equation code
1 parent d319e80 commit 22a38b0

File tree

11 files changed

+92
-16
lines changed

11 files changed

+92
-16
lines changed

ppsci/equation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_equation(cfg):
5454
"""Build equation(s)
5555
5656
Args:
57-
cfg (List[DictConfig]): Equation(s) config list.
57+
cfg (List[AttrDict]): Equation(s) config list.
5858
5959
Returns:
6060
Dict[str, Equation]: Equation(s) in dict.

ppsci/equation/pde/base.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Union
2323

2424
import paddle
25-
import sympy
25+
import sympy as sp
2626
from paddle import nn
2727

2828
DETACH_FUNC_NAME = "detach"
@@ -33,7 +33,7 @@ class PDE:
3333

3434
def __init__(self):
3535
super().__init__()
36-
self.equations = {}
36+
self.equations: Dict[str, Union[Callable, sp.Basic]] = {}
3737
# for PDE which has learnable parameter(s)
3838
self.learnable_parameters = nn.ParameterList()
3939

@@ -42,7 +42,7 @@ def __init__(self):
4242
@staticmethod
4343
def create_symbols(
4444
symbol_str: str,
45-
) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
45+
) -> Union[sp.Symbol, Tuple[sp.Symbol, ...]]:
4646
"""create symbolic variables.
4747
4848
Args:
@@ -61,11 +61,9 @@ def create_symbols(
6161
>>> print(symbols_xyz)
6262
(x, y, z)
6363
"""
64-
return sympy.symbols(symbol_str)
64+
return sp.symbols(symbol_str)
6565

66-
def create_function(
67-
self, name: str, invars: Tuple[sympy.Symbol, ...]
68-
) -> sympy.Function:
66+
def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Function:
6967
"""Create named function depending on given invars.
7068
7169
Args:
@@ -86,14 +84,73 @@ def create_function(
8684
>>> print(f)
8785
f(x, y, z)
8886
"""
89-
expr = sympy.Function(name)(*invars)
87+
expr = sp.Function(name)(*invars)
9088

91-
# wrap `expression(...)` to `detach(expression(...))`
92-
# if name of expression is in given detach_keys
93-
if self.detach_keys and name in self.detach_keys:
94-
expr = sympy.Function(DETACH_FUNC_NAME)(expr)
9589
return expr
9690

91+
def _apply_detach(self):
92+
"""
93+
Wrap detached sub_expr into detach(sub_expr) to prevent gradient
94+
back-propagation, only for those items speicified in self.detach_keys.
95+
96+
NOTE: This function is expected to be called after self.equations is ready in PDE.__init__.
97+
98+
Examples:
99+
>>> import ppsci
100+
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False)
101+
>>> print(ns)
102+
NavierStokes
103+
continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y)
104+
momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
105+
momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
106+
>>> detach_keys = ("u", "v__y")
107+
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys)
108+
>>> print(ns)
109+
NavierStokes
110+
continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x)
111+
momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
112+
momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
113+
"""
114+
if self.detach_keys is None:
115+
return
116+
117+
from copy import deepcopy
118+
119+
from sympy.core.traversal import postorder_traversal
120+
121+
from ppsci.utils.symbolic import _cvt_to_key
122+
123+
for name, expr in self.equations.items():
124+
if not isinstance(expr, sp.Basic):
125+
continue
126+
# only process sympy expression
127+
expr_ = deepcopy(expr)
128+
for item in postorder_traversal(expr):
129+
if _cvt_to_key(item) in self.detach_keys:
130+
# inplace all related sub_expr into detach(sub_expr)
131+
expr_ = expr_.replace(item, sp.Function(DETACH_FUNC_NAME)(item))
132+
133+
# remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping
134+
expr_ = expr_.replace(
135+
sp.Function(DETACH_FUNC_NAME)(
136+
sp.Function(DETACH_FUNC_NAME)(item)
137+
),
138+
sp.Function(DETACH_FUNC_NAME)(item),
139+
)
140+
141+
# remove unccessary detach wrapping for the first arg of Derivative
142+
for item_ in list(postorder_traversal(expr_)):
143+
if isinstance(item_, sp.Derivative):
144+
if item_.args[0].name == DETACH_FUNC_NAME:
145+
expr_ = expr_.replace(
146+
item_,
147+
sp.Derivative(
148+
item_.args[0].args[0], *item_.args[1:]
149+
),
150+
)
151+
152+
self.equations[name] = expr_
153+
97154
def add_equation(self, name: str, equation: Callable):
98155
"""Add an equation.
99156
@@ -110,7 +167,8 @@ def add_equation(self, name: str, equation: Callable):
110167
>>> equation = sympy.diff(u, x) + sympy.diff(u, y)
111168
>>> pde.add_equation('linear_pde', equation)
112169
>>> print(pde)
113-
PDE, linear_pde: 2*x + 2*y
170+
PDE
171+
linear_pde: 2*x + 2*y
114172
"""
115173
self.equations.update({name: equation})
116174

@@ -181,7 +239,7 @@ def set_state_dict(
181239
return self.learnable_parameters.set_state_dict(state_dict)
182240

183241
def __str__(self):
184-
return ", ".join(
242+
return "\n".join(
185243
[self.__class__.__name__]
186-
+ [f"{name}: {eq}" for name, eq in self.equations.items()]
244+
+ [f" {name}: {eq}" for name, eq in self.equations.items()]
187245
)

ppsci/equation/pde/biharmonic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,5 @@ def __init__(
7070
biharmonic += u.diff(invar_i, 2).diff(invar_j, 2)
7171

7272
self.add_equation("biharmonic", biharmonic)
73+
74+
self._apply_detach()

ppsci/equation/pde/heat_exchanger.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,5 @@ def __init__(
9090
self.add_equation("heat_boundary", heat_boundary)
9191
self.add_equation("cold_boundary", cold_boundary)
9292
self.add_equation("wall", wall)
93+
94+
self._apply_detach()

ppsci/equation/pde/laplace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
5151
laplace += u.diff(invar, 2)
5252

5353
self.add_equation("laplace", laplace)
54+
55+
self._apply_detach()

ppsci/equation/pde/linear_elasticity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,5 @@ def __init__(
179179
self.add_equation("traction_y", traction_y)
180180
if self.dim == 3:
181181
self.add_equation("traction_z", traction_z)
182+
183+
self._apply_detach()

ppsci/equation/pde/navier_stokes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,5 @@ def __init__(
147147
self.add_equation("momentum_y", momentum_y)
148148
if self.dim == 3:
149149
self.add_equation("momentum_z", momentum_z)
150+
151+
self._apply_detach()

ppsci/equation/pde/nls_m_b.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,5 @@ def __init__(
9797
self.add_equation("Maxwell_1", Maxwell_1)
9898
self.add_equation("Maxwell_2", Maxwell_2)
9999
self.add_equation("Bloch", Bloch)
100+
101+
self._apply_detach()

ppsci/equation/pde/normal_dot_vec.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,5 @@ def __init__(
5555
normal_dot_vec += normal * vec
5656

5757
self.add_equation("normal_dot_vec", normal_dot_vec)
58+
59+
self._apply_detach()

ppsci/equation/pde/poisson.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
4949
poisson += p.diff(invar, 2)
5050

5151
self.add_equation("poisson", poisson)
52+
53+
self._apply_detach()

0 commit comments

Comments
 (0)