6
6
7
7
8
8
from src .boundary_conditions import *
9
- from jax .config import config
10
9
from src .utils import *
11
10
import numpy as np
12
11
from src .lattice import LatticeD2Q9
13
12
from src .models import BGKSim , KBCSim , AdvectionDiffusionBGK
14
- import jax .numpy as jnp
15
- from jax .experimental import mesh_utils
16
- from jax .sharding import PositionalSharding
17
13
import os
18
14
import matplotlib .pyplot as plt
15
+ import json
19
16
20
17
# Use 8 CPU devices
21
18
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
25
22
jax .config .update ('jax_enable_x64' , True )
26
23
27
24
def taylor_green_initial_fields (xx , yy , u0 , rho0 , nu , time ):
28
- ux = - u0 * np .cos (xx ) * np .sin (yy ) * np .exp (- 2 * nu * time )
29
- uy = u0 * np .sin (xx ) * np .cos (yy ) * np .exp (- 2 * nu * time )
25
+ ux = u0 * np .sin (xx ) * np .cos (yy ) * np .exp (- 2 * nu * time )
26
+ uy = - u0 * np .cos (xx ) * np .sin (yy ) * np .exp (- 2 * nu * time )
30
27
rho = 1.0 - rho0 * u0 ** 2 / 12. * (np .cos (2. * xx ) + np .cos (2. * yy )) * np .exp (- 4 * nu * time )
31
28
return ux , uy , np .expand_dims (rho , axis = - 1 )
32
29
@@ -51,7 +48,7 @@ def initialize_populations(self, rho, u):
51
48
ADE = AdvectionDiffusionBGK (** kwargs )
52
49
ADE .initialize_macroscopic_fields = self .initialize_macroscopic_fields
53
50
print ("Initializing the distribution functions using the specified macroscopic fields...." )
54
- f = ADE .run (20000 )
51
+ f = ADE .run (int ( 20000 * 32 / nx ) )
55
52
return f
56
53
57
54
def output_data (self , ** kwargs ):
@@ -64,49 +61,67 @@ def output_data(self, **kwargs):
64
61
time = timestep * (kx ** 2 + ky ** 2 )/ 2.
65
62
ux_th , uy_th , rho_th = taylor_green_initial_fields (xx , yy , vel_ref , 1 , visc , time )
66
63
vel_err_L2 = np .sqrt (np .sum ((u [..., 0 ]- ux_th )** 2 + (u [..., 1 ]- uy_th )** 2 ) / np .sum (ux_th ** 2 + uy_th ** 2 ))
67
- print ("error= {:07.6f}" .format (vel_err_L2 ))
64
+ rho_err_L2 = np .sqrt (np .sum ((rho - rho_th )** 2 ) / np .sum (rho_th ** 2 ))
65
+ print ("Vel error= {:07.6f}, Pressure error= {:07.6f}" .format (vel_err_L2 , rho_err_L2 ))
68
66
if timestep == endTime :
69
67
ErrL2ResList .append (vel_err_L2 )
68
+ ErrL2ResListRho .append (rho_err_L2 )
70
69
# save_image(timestep, u)
71
70
72
71
73
72
if __name__ == "__main__" :
74
- precision = "f64/f64"
75
- lattice = LatticeD2Q9 (precision )
76
-
77
- resList = [32 , 64 , 128 , 256 , 512 ]
78
- ErrL2ResList = []
79
-
80
- for nx in resList :
81
- print ("Running at nx = ny = {:07.6f}" .format (nx ))
82
- ny = nx
83
- twopi = 2.0 * np .pi
84
- coord = np .array ([(i , j ) for i in range (nx ) for j in range (ny )])
85
- xx , yy = coord [:, 0 ], coord [:, 1 ]
86
- kx , ky = twopi / nx , twopi / ny
87
- xx = xx .reshape ((nx , ny )) * kx
88
- yy = yy .reshape ((nx , ny )) * ky
89
-
90
- Re = 1000.0
91
- vel_ref = 0.04 * 32 / nx
92
-
93
- visc = vel_ref * nx / Re
94
- omega = 1.0 / (3.0 * visc + 0.5 )
95
- print ("omega = " , omega )
96
- os .system ("rm -rf ./*.vtk && rm -rf ./*.png" )
97
- kwargs = {
98
- 'lattice' : lattice ,
99
- 'omega' : omega ,
100
- 'nx' : nx ,
101
- 'ny' : ny ,
102
- 'nz' : 0 ,
103
- 'precision' : precision ,
104
- 'io_rate' : 500 ,
105
- 'print_info_rate' : 500
106
- }
107
- sim = TaylorGreenVortex (** kwargs )
108
- endTime = int (20000 * nx / 32.0 )
109
- sim .run (endTime )
110
- plt .loglog (resList , ErrL2ResList , '-o' )
111
- plt .loglog (resList , 1e-3 * (np .array (resList )/ 128 )** (- 2 ), '--' )
112
- plt .savefig ('Error.png' ); plt .savefig ('Error.pdf' , format = 'pdf' )
73
+ precision_list = ["f32/f32" , "f64/f32" , "f64/f64" ]
74
+ resList = [32 , 64 , 128 , 256 , 512 , 1024 ]
75
+ result_dict = dict .fromkeys (precision_list )
76
+ result_dict ['resolution_list' ] = resList
77
+
78
+ for precision in precision_list :
79
+ lattice = LatticeD2Q9 (precision )
80
+ ErrL2ResList = []
81
+ ErrL2ResListRho = []
82
+ result_dict [precision ] = dict .fromkeys (['vel_error' , 'rho_error' ])
83
+ for nx in resList :
84
+ print ("Running at nx = ny = {:07.6f}" .format (nx ))
85
+ ny = nx
86
+ twopi = 2.0 * np .pi
87
+ coord = np .array ([(i , j ) for i in range (nx ) for j in range (ny )])
88
+ xx , yy = coord [:, 0 ], coord [:, 1 ]
89
+ kx , ky = twopi / nx , twopi / ny
90
+ xx = xx .reshape ((nx , ny )) * kx
91
+ yy = yy .reshape ((nx , ny )) * ky
92
+
93
+ Re = 1600.0
94
+ vel_ref = 0.04 * 32 / nx
95
+
96
+ visc = vel_ref * nx / Re
97
+ omega = 1.0 / (3.0 * visc + 0.5 )
98
+ print ("omega = " , omega )
99
+ os .system ("rm -rf ./*.vtk && rm -rf ./*.png" )
100
+ kwargs = {
101
+ 'lattice' : lattice ,
102
+ 'omega' : omega ,
103
+ 'nx' : nx ,
104
+ 'ny' : ny ,
105
+ 'nz' : 0 ,
106
+ 'precision' : precision ,
107
+ 'io_rate' : 5000 ,
108
+ 'print_info_rate' : 1000
109
+ }
110
+ sim = TaylorGreenVortex (** kwargs )
111
+ tc = 2.0 / (2. * visc * (kx ** 2 + ky ** 2 ))
112
+ endTime = int (0.05 * tc )
113
+ sim .run (endTime )
114
+ result_dict [precision ]['vel_error' ] = ErrL2ResList
115
+ result_dict [precision ]['rho_error' ] = ErrL2ResListRho
116
+
117
+ with open ('data.json' , 'w' ) as fp :
118
+ json .dump (result_dict , fp )
119
+
120
+ # plt.loglog(resList, ErrL2ResList, '-o')
121
+ # plt.loglog(resList, 1e-3*(np.array(resList)/128)**(-2), '--')
122
+ # plt.savefig('ErrorVel.png'); plt.savefig('ErrorVel.pdf', format='pdf')
123
+
124
+ # plt.figure()
125
+ # plt.loglog(resList, ErrL2ResListRho, '-o')
126
+ # plt.loglog(resList, 1e-3*(np.array(resList)/128)**(-2), '--')
127
+ # plt.savefig('ErrorRho.png'); plt.savefig('ErrorRho.pdf', format='pdf')
0 commit comments