Skip to content

Commit f0f81c3

Browse files
committed
Combined changes for profile_inlet
1 parent 00c24e3 commit f0f81c3

File tree

14 files changed

+323
-101
lines changed

14 files changed

+323
-101
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- XLB is now installable via pip
1919
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
2020
- Added NVIDIA's Warp backend for state-of-the-art performance
21+
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
22+
- Added the capability to add profiles to boundary conditions

examples/cfd/flow_past_sphere_3d.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import numpy as np
2020
import jax.numpy as jnp
2121
import time
22+
from functools import partial
23+
from jax import jit
2224

2325

2426
class FlowOverSphere:
@@ -37,6 +39,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
3739
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
3840
self.stepper = None
3941
self.boundary_conditions = []
42+
self.u_max = 0.04
4043

4144
# Setup the simulation BC, its initial conditions, and the stepper
4245
self._setup(omega)
@@ -69,7 +72,7 @@ def define_boundary_indices(self):
6972

7073
def setup_boundary_conditions(self):
7174
inlet, outlet, walls, sphere = self.define_boundary_indices()
72-
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
75+
bc_left = RegularizedBC("velocity", indices=inlet)
7376
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
7477
bc_walls = FullwayBounceBackBC(indices=walls)
7578
# bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet)
@@ -95,8 +98,58 @@ def initialize_fields(self):
9598
def setup_stepper(self, omega):
9699
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")
97100

101+
def bc_profile(self):
102+
u_max = self.u_max # u_max = 0.04
103+
# Get the grid dimensions for the y and z directions
104+
H_y = float(self.grid_shape[1] - 1) # Height in y direction
105+
H_z = float(self.grid_shape[2] - 1) # Height in z direction
106+
107+
@wp.func
108+
def bc_profile_warp(index: wp.vec3i):
109+
# Poiseuille flow profile: parabolic velocity distribution
110+
y = self.precision_policy.store_precision.wp_dtype(index[1])
111+
z = self.precision_policy.store_precision.wp_dtype(index[2])
112+
113+
# Calculate normalized distance from center
114+
y_center = y - (H_y / 2.0)
115+
z_center = z - (H_z / 2.0)
116+
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
117+
118+
# Parabolic profile: u = u_max * (1 - r²)
119+
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)
120+
# return u_max
121+
122+
# @partial(jit, inline=True)
123+
def bc_profile_jax():
124+
y = jnp.arange(self.grid_shape[1])
125+
z = jnp.arange(self.grid_shape[2])
126+
Y, Z = jnp.meshgrid(y, z, indexing="ij")
127+
128+
# Calculate normalized distance from center
129+
y_center = Y - (H_y / 2.0)
130+
z_center = Z - (H_z / 2.0)
131+
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
132+
133+
# Parabolic profile for x velocity, zero for y and z
134+
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
135+
u_y = jnp.zeros_like(u_x)
136+
u_z = jnp.zeros_like(u_x)
137+
138+
return jnp.stack([u_x, u_y, u_z])
139+
140+
if self.backend == ComputeBackend.JAX:
141+
return bc_profile_jax
142+
elif self.backend == ComputeBackend.WARP:
143+
return bc_profile_warp
144+
145+
def initialize_bc_aux_data(self):
146+
for bc in self.boundary_conditions:
147+
if bc.needs_aux_init:
148+
self.f_0, self.f_1 = bc.aux_data_init(self.bc_profile(), self.f_0, self.f_1, self.bc_mask, self.missing_mask)
149+
98150
def run(self, num_steps, post_process_interval=100):
99151
start_time = time.time()
152+
self.initialize_bc_aux_data()
100153
for i in range(num_steps):
101154
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
102155
self.f_0, self.f_1 = self.f_1, self.f_0

examples/cfd/lid_driven_cavity_2d.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
class LidDrivenCavity2D:
17-
def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
17+
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
1818
# initialize backend
1919
xlb.init(
2020
velocity_set=velocity_set,
@@ -29,6 +29,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
2929
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
3030
self.stepper = None
3131
self.boundary_conditions = []
32+
self.prescribed_vel = prescribed_vel
3233

3334
# Setup the simulation BC, its initial conditions, and the stepper
3435
self._setup(omega)
@@ -49,7 +50,7 @@ def define_boundary_indices(self):
4950

5051
def setup_boundary_conditions(self):
5152
lid, walls = self.define_boundary_indices()
52-
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
53+
bc_top = EquilibriumBC(rho=1.0, u=(self.prescribed_vel, 0.0), indices=lid)
5354
bc_walls = HalfwayBounceBackBC(indices=walls)
5455
self.boundary_conditions = [bc_walls, bc_top]
5556

@@ -112,7 +113,13 @@ def post_process(self, i):
112113
precision_policy = PrecisionPolicy.FP32FP32
113114

114115
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
115-
omega = 1.6
116116

117-
simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy)
118-
simulation.run(num_steps=5000, post_process_interval=1000)
117+
# Setting fluid viscosity and relaxation parameter.
118+
Re = 200.0
119+
prescribed_vel = 0.05
120+
clength = grid_shape[0] - 1
121+
visc = prescribed_vel * clength / Re
122+
omega = 1.0 / (3.0 * visc + 0.5)
123+
124+
simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
125+
simulation.run(num_steps=50000, post_process_interval=1000)

examples/cfd/lid_driven_cavity_2d_distributed.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88

99
class LidDrivenCavity2D_distributed(LidDrivenCavity2D):
10-
def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
11-
super().__init__(omega, grid_shape, velocity_set, backend, precision_policy)
10+
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
11+
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
1212

1313
def setup_stepper(self, omega):
1414
stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
@@ -29,7 +29,13 @@ def setup_stepper(self, omega):
2929
precision_policy = PrecisionPolicy.FP32FP32
3030

3131
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
32-
omega = 1.6
3332

34-
simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy)
35-
simulation.run(num_steps=5000, post_process_interval=1000)
33+
# Setting fluid viscosity and relaxation parameter.
34+
Re = 200.0
35+
prescribed_vel = 0.05
36+
clength = grid_shape[0] - 1
37+
visc = prescribed_vel * clength / Re
38+
omega = 1.0 / (3.0 * visc + 0.5)
39+
40+
simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
41+
simulation.run(num_steps=50000, post_process_interval=1000)

examples/cfd/windtunnel_3d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def define_boundary_indices(self):
9595
def setup_boundary_conditions(self):
9696
inlet, outlet, walls, car = self.define_boundary_indices()
9797
bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet)
98-
# bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet)
9998
bc_walls = FullwayBounceBackBC(indices=walls)
10099
bc_do_nothing = ExtrapolationOutflowBC(indices=outlet)
101100
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)

tests/boundary_conditions/mask/test_bc_indices_masker_warp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
6060
[test_bc],
6161
bc_mask,
6262
missing_mask,
63-
start_index=(0, 0, 0) if dim == 3 else (0, 0),
6463
)
6564
assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype
6665

@@ -69,9 +68,12 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
6968
bc_mask = bc_mask.numpy()
7069
missing_mask = missing_mask.numpy()
7170

72-
assert bc_mask.shape == (1,) + grid_shape
73-
74-
assert missing_mask.shape == (velocity_set.q,) + grid_shape
71+
if len(grid_shape) == 2:
72+
assert bc_mask.shape == (1,) + grid_shape + (1,), "bc_mask shape is incorrect got {}".format(bc_mask.shape)
73+
assert missing_mask.shape == (velocity_set.q,) + grid_shape + (1,), "missing_mask shape is incorrect got {}".format(missing_mask.shape)
74+
else:
75+
assert bc_mask.shape == (1,) + grid_shape, "bc_mask shape is incorrect got {}".format(bc_mask.shape)
76+
assert missing_mask.shape == (velocity_set.q,) + grid_shape, "missing_mask shape is incorrect got {}".format(missing_mask.shape)
7577

7678
if dim == 2:
7779
assert np.all(bc_mask[0, indices[0], indices[1]] == test_bc.id)

xlb/operator/boundary_condition/bc_extrapolation_outflow.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,15 @@ def __init__(
5656
mesh_vertices,
5757
)
5858

59+
# Set the flag for auxilary data recovery
60+
self.needs_aux_recovery = True
61+
5962
# find and store the normal vector using indices
6063
self._get_normal_vec(indices)
6164

6265
# Unpack the two warp functionals needed for this BC!
6366
if self.compute_backend == ComputeBackend.WARP:
64-
self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional
67+
self.warp_functional, self.update_bc_auxilary_data = self.warp_functional
6568

6669
def _get_normal_vec(self, indices):
6770
# Get the frequency count and most common element directly
@@ -92,9 +95,9 @@ def _roll(self, fld, vec):
9295
return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3))
9396

9497
@partial(jit, static_argnums=(0,), inline=True)
95-
def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
98+
def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
9699
"""
97-
Prepare the auxilary distribution functions for the boundary condition.
100+
Update the auxilary distribution functions for the boundary condition.
98101
Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision
99102
"""
100103
sound_speed = 1.0 / jnp.sqrt(3.0)
@@ -134,7 +137,6 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
134137
def _construct_warp(self):
135138
# Set local constants
136139
sound_speed = self.compute_dtype(1.0 / wp.sqrt(3.0))
137-
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
138140
_c = self.velocity_set.c
139141
_q = self.velocity_set.q
140142
_opp_indices = self.velocity_set.opp_indices
@@ -143,9 +145,14 @@ def _construct_warp(self):
143145
def get_normal_vectors(
144146
missing_mask: Any,
145147
):
146-
for l in range(_q):
147-
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
148-
return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l])
148+
if wp.static(self.velocity_set.d == 3):
149+
for l in range(_q):
150+
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
151+
return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l])
152+
else:
153+
for l in range(_q):
154+
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
155+
return -wp.vec2i(_c[0, l], _c[1, l])
149156

150157
# Construct the functionals for this BC
151158
@wp.func
@@ -167,7 +174,7 @@ def functional(
167174
return _f
168175

169176
@wp.func
170-
def prepare_bc_auxilary_data(
177+
def update_bc_auxilary_data(
171178
index: Any,
172179
timestep: Any,
173180
missing_mask: Any,
@@ -176,7 +183,7 @@ def prepare_bc_auxilary_data(
176183
f_pre: Any,
177184
f_post: Any,
178185
):
179-
# Preparing the formulation for this BC using the neighbour's populations stored in f_aux and
186+
# Update the auxilary data for this BC using the neighbour's populations stored in f_aux and
180187
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
181188
# for storing this prepared data.
182189
_f = f_post
@@ -195,7 +202,7 @@ def prepare_bc_auxilary_data(
195202

196203
kernel = self._construct_kernel(functional)
197204

198-
return (functional, prepare_bc_auxilary_data), kernel
205+
return (functional, update_bc_auxilary_data), kernel
199206

200207
@Operator.register_backend(ComputeBackend.WARP)
201208
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):

xlb/operator/boundary_condition/bc_regularized.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class RegularizedBC(ZouHeBC):
4444
def __init__(
4545
self,
4646
bc_type,
47-
prescribed_value,
4847
velocity_set: VelocitySet = None,
4948
precision_policy: PrecisionPolicy = None,
5049
compute_backend: ComputeBackend = None,
@@ -54,7 +53,6 @@ def __init__(
5453
# Call the parent constructor
5554
super().__init__(
5655
bc_type,
57-
prescribed_value,
5856
velocity_set,
5957
precision_policy,
6058
compute_backend,
@@ -127,16 +125,11 @@ def _construct_warp(self):
127125
# assign placeholders for both u and rho based on prescribed_value
128126
_d = self.velocity_set.d
129127
_q = self.velocity_set.q
130-
u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d
131-
rho = self.prescribed_value if self.bc_type == "pressure" else 0.0
132128

133129
# Set local constants TODO: This is a hack and should be fixed with warp update
134130
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
135131
# compute Qi tensor and store it in self
136-
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
137132
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
138-
_rho = self.compute_dtype(rho)
139-
_u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1])
140133
_opp_indices = self.velocity_set.opp_indices
141134
_w = self.velocity_set.w
142135
_c = self.velocity_set.c
@@ -162,9 +155,14 @@ def _get_fsum(
162155
def get_normal_vectors(
163156
missing_mask: Any,
164157
):
165-
for l in range(_q):
166-
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
167-
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
158+
if wp.static(_d == 3):
159+
for l in range(_q):
160+
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
161+
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
162+
else:
163+
for l in range(_q):
164+
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
165+
return -_u_vec(_c_float[0, l], _c_float[1, l])
168166

169167
@wp.func
170168
def bounceback_nonequilibrium(
@@ -218,6 +216,15 @@ def functional_velocity(
218216
# Find normal vector
219217
normals = get_normal_vectors(missing_mask)
220218

219+
# Find the value of u from the missing directions
220+
for l in range(wp.static(_q)):
221+
# Since we are only considering normal velocity, we only need to find one value
222+
if missing_mask[l] == wp.uint8(1):
223+
# Create velocity vector by multiplying the prescribed value with the normal vector
224+
prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]]
225+
_u = -prescribed_value * normals
226+
break
227+
221228
# calculate rho
222229
fsum = _get_fsum(_f, missing_mask)
223230
unormal = self.compute_dtype(0.0)
@@ -249,6 +256,13 @@ def functional_pressure(
249256
# Find normal vector
250257
normals = get_normal_vectors(missing_mask)
251258

259+
# Find the value of rho from the missing directions
260+
for q in range(wp.static(_q)):
261+
# Since we need only one scalar value, we only need to find one value
262+
if missing_mask[q] == wp.uint8(1):
263+
_rho = f_0[_opp_indices[q], index[0], index[1], index[2]]
264+
break
265+
252266
# calculate velocity
253267
fsum = _get_fsum(_f, missing_mask)
254268
unormal = -self.compute_dtype(1.0) + fsum / _rho

0 commit comments

Comments
 (0)