Skip to content

Commit 49cb066

Browse files
committed
Added abstraction layer for boundary condition aux data and implementaiton, and the capability to add profiles to boundary conditions
1 parent 2b6355b commit 49cb066

14 files changed

+448
-194
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- XLB is now installable via pip
1919
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
2020
- Added NVIDIA's Warp backend for state-of-the-art performance
21+
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
22+
- Added the capability to add profiles to boundary conditions

examples/cfd/flow_past_sphere_3d.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import numpy as np
2020
import jax.numpy as jnp
2121
import time
22+
from functools import partial
23+
from jax import jit
2224

2325

2426
class FlowOverSphere:
@@ -37,13 +39,13 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
3739
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
3840
self.stepper = None
3941
self.boundary_conditions = []
42+
self.u_max = 0.04
4043

4144
# Setup the simulation BC, its initial conditions, and the stepper
4245
self._setup(omega)
4346

4447
def _setup(self, omega):
4548
self.setup_boundary_conditions()
46-
self.setup_boundary_masker()
4749
self.initialize_fields()
4850
self.setup_stepper(omega)
4951

@@ -69,7 +71,7 @@ def define_boundary_indices(self):
6971

7072
def setup_boundary_conditions(self):
7173
inlet, outlet, walls, sphere = self.define_boundary_indices()
72-
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
74+
bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet)
7375
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
7476
bc_walls = FullwayBounceBackBC(indices=walls)
7577
# bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet)
@@ -78,22 +80,63 @@ def setup_boundary_conditions(self):
7880
bc_sphere = HalfwayBounceBackBC(indices=sphere)
7981
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]
8082

81-
def setup_boundary_masker(self):
82-
# check boundary condition list for duplicate indices before creating bc mask
83-
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
84-
85-
indices_boundary_masker = IndicesBoundaryMasker(
86-
velocity_set=self.velocity_set,
87-
precision_policy=self.precision_policy,
88-
compute_backend=self.backend,
89-
)
90-
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0))
91-
9283
def initialize_fields(self):
9384
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)
9485

9586
def setup_stepper(self, omega):
96-
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")
87+
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
88+
f_0=self.f_0,
89+
f_1=self.f_1,
90+
bc_mask=self.bc_mask,
91+
missing_mask=self.missing_mask,
92+
omega=omega,
93+
boundary_conditions=self.boundary_conditions,
94+
collision_type="BGK",
95+
)
96+
97+
def bc_profile(self):
98+
u_max = self.u_max # u_max = 0.04
99+
# Get the grid dimensions for the y and z directions
100+
H_y = float(self.grid_shape[1] - 1) # Height in y direction
101+
H_z = float(self.grid_shape[2] - 1) # Height in z direction
102+
103+
@wp.func
104+
def bc_profile_warp(index: wp.vec3i):
105+
# Poiseuille flow profile: parabolic velocity distribution
106+
y = self.precision_policy.store_precision.wp_dtype(index[1])
107+
z = self.precision_policy.store_precision.wp_dtype(index[2])
108+
109+
# Calculate normalized distance from center
110+
y_center = y - (H_y / 2.0)
111+
z_center = z - (H_z / 2.0)
112+
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
113+
114+
# Parabolic profile: u = u_max * (1 - r²)
115+
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)
116+
# return u_max
117+
118+
# @partial(jit, inline=True)
119+
def bc_profile_jax():
120+
y = jnp.arange(self.grid_shape[1])
121+
z = jnp.arange(self.grid_shape[2])
122+
Y, Z = jnp.meshgrid(y, z, indexing="ij")
123+
124+
# Calculate normalized distance from center
125+
y_center = Y - (H_y / 2.0)
126+
z_center = Z - (H_z / 2.0)
127+
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
128+
129+
# Parabolic profile for x velocity, zero for y and z
130+
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
131+
u_y = jnp.zeros_like(u_x)
132+
u_z = jnp.zeros_like(u_x)
133+
134+
return jnp.stack([u_x, u_y, u_z])
135+
136+
if self.backend == ComputeBackend.JAX:
137+
return bc_profile_jax
138+
elif self.backend == ComputeBackend.WARP:
139+
return bc_profile_warp
97140

98141
def run(self, num_steps, post_process_interval=100):
99142
start_time = time.time()

examples/cfd/lid_driven_cavity_2d.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre
3636

3737
def _setup(self, omega):
3838
self.setup_boundary_conditions()
39-
self.setup_boundary_masker()
4039
self.initialize_fields()
4140
self.setup_stepper(omega)
4241

@@ -54,21 +53,19 @@ def setup_boundary_conditions(self):
5453
bc_walls = HalfwayBounceBackBC(indices=walls)
5554
self.boundary_conditions = [bc_walls, bc_top]
5655

57-
def setup_boundary_masker(self):
58-
# check boundary condition list for duplicate indices before creating bc mask
59-
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
60-
indices_boundary_masker = IndicesBoundaryMasker(
61-
velocity_set=self.velocity_set,
62-
precision_policy=self.precision_policy,
63-
compute_backend=self.backend,
64-
)
65-
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)
66-
6756
def initialize_fields(self):
6857
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)
6958

7059
def setup_stepper(self, omega):
71-
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
60+
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
61+
f_0=self.f_0,
62+
f_1=self.f_1,
63+
bc_mask=self.bc_mask,
64+
missing_mask=self.missing_mask,
65+
omega=omega,
66+
boundary_conditions=self.boundary_conditions,
67+
collision_type="BGK",
68+
)
7269

7370
def run(self, num_steps, post_process_interval=100):
7471
for i in range(num_steps):
@@ -109,7 +106,7 @@ def post_process(self, i):
109106
# Running the simulation
110107
grid_size = 500
111108
grid_shape = (grid_size, grid_size)
112-
backend = ComputeBackend.WARP
109+
backend = ComputeBackend.JAX
113110
precision_policy = PrecisionPolicy.FP32FP32
114111

115112
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)

examples/cfd/lid_driven_cavity_2d_distributed.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,20 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre
1111
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
1212

1313
def setup_stepper(self, omega):
14-
stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
15-
distributed_stepper = distribute(
14+
stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
15+
f_0=self.f_0,
16+
f_1=self.f_1,
17+
bc_mask=self.bc_mask,
18+
missing_mask=self.missing_mask,
19+
omega=omega,
20+
boundary_conditions=self.boundary_conditions,
21+
collision_type="BGK",
22+
)
23+
self.stepper = distribute(
1624
stepper,
1725
self.grid,
1826
self.velocity_set,
1927
)
20-
self.stepper = distributed_stepper
2128
return
2229

2330

examples/cfd/turbulent_channel_3d.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def get_force(self):
7171

7272
def _setup(self):
7373
self.setup_boundary_conditions()
74-
self.setup_boundary_masker()
7574
self.initialize_fields()
7675
self.setup_stepper()
7776

@@ -86,14 +85,6 @@ def setup_boundary_conditions(self):
8685
bc_walls = RegularizedBC("velocity", (0.0, 0.0, 0.0), indices=walls)
8786
self.boundary_conditions = [bc_walls]
8887

89-
def setup_boundary_masker(self):
90-
indices_boundary_masker = IndicesBoundaryMasker(
91-
velocity_set=self.velocity_set,
92-
precision_policy=self.precision_policy,
93-
compute_backend=self.backend,
94-
)
95-
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)
96-
9788
def initialize_fields(self):
9889
shape = (self.velocity_set.d,) + (self.grid_shape)
9990
np.random.seed(0)
@@ -104,10 +95,16 @@ def initialize_fields(self):
10495
u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype)
10596
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init)
10697

107-
def setup_stepper(self):
108-
force = self.get_force()
109-
self.stepper = IncompressibleNavierStokesStepper(
110-
self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC", forcing_scheme="exact_difference", force_vector=force
98+
def setup_stepper(self, omega):
99+
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
100+
f_0=self.f_0,
101+
f_1=self.f_1,
102+
bc_mask=self.bc_mask,
103+
missing_mask=self.missing_mask,
104+
omega=omega,
105+
boundary_conditions=self.boundary_conditions,
106+
collision_type="BGK",
107+
force=self.get_force(),
111108
)
112109

113110
def run(self, num_steps, print_interval, post_process_interval=100):

examples/cfd/windtunnel_3d.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import numpy as np
2323
import jax.numpy as jnp
2424
import matplotlib.pyplot as plt
25+
from functools import partial
26+
from jax import jit
2527

2628

2729
class WindTunnel3D:
@@ -55,8 +57,7 @@ def _setup(self):
5557
# NOTE: it is important to initialize fields before setup_boundary_masker is called because f_0 or f_1 might be used to store BC information
5658
self.initialize_fields()
5759
self.setup_boundary_conditions()
58-
self.setup_boundary_masker()
59-
self.setup_stepper()
60+
self.setup_stepper(self.omega)
6061

6162
def voxelize_stl(self, stl_filename, length_lbm_unit):
6263
mesh = trimesh.load_mesh(stl_filename, process=False)
@@ -85,57 +86,77 @@ def define_boundary_indices(self):
8586
length_phys_unit = mesh_extents.max()
8687
length_lbm_unit = self.grid_shape[0] / 4
8788
dx = length_phys_unit / length_lbm_unit
88-
shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0.0])
89+
mesh_vertices = mesh_vertices / dx
90+
shift = np.array([self.grid_shape[0] / 4, (self.grid_shape[1] - mesh_extents[1] / dx) / 2, 0.0])
8991
car = mesh_vertices + shift
90-
self.grid_spacing = dx
9192
self.car_cross_section = np.prod(mesh_extents[1:]) / dx**2
9293

9394
return inlet, outlet, walls, car
9495

9596
def setup_boundary_conditions(self):
9697
inlet, outlet, walls, car = self.define_boundary_indices()
97-
bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet)
98-
# bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet)
98+
bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet)
9999
bc_walls = FullwayBounceBackBC(indices=walls)
100100
bc_do_nothing = ExtrapolationOutflowBC(indices=outlet)
101-
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
102-
bc_car = GradsApproximationBC(mesh_vertices=car)
103-
# bc_car = FullwayBounceBackBC(mesh_vertices=car)
101+
bc_car = FullwayBounceBackBC(mesh_vertices=car)
104102
self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car]
105103

106-
def setup_boundary_masker(self):
107-
# check boundary condition list for duplicate indices before creating bc mask
108-
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
109-
110-
indices_boundary_masker = IndicesBoundaryMasker(
111-
velocity_set=self.velocity_set,
112-
precision_policy=self.precision_policy,
113-
compute_backend=self.backend,
114-
)
115-
# mesh_boundary_masker = MeshBoundaryMasker(
116-
# velocity_set=self.velocity_set,
117-
# precision_policy=self.precision_policy,
118-
# compute_backend=self.backend,
119-
# )
120-
mesh_distance_boundary_masker = MeshDistanceBoundaryMasker(
121-
velocity_set=self.velocity_set,
122-
precision_policy=self.precision_policy,
123-
compute_backend=self.backend,
124-
)
125-
bclist_other = self.boundary_conditions[:-1]
126-
bc_mesh = self.boundary_conditions[-1]
127-
dx = self.grid_spacing
128-
origin, spacing = (0, 0, 0), (dx, dx, dx)
129-
self.bc_mask, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_mask, self.missing_mask)
130-
# self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask)
131-
self.bc_mask, self.missing_mask, self.f_1 = mesh_distance_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask, self.f_1)
132-
133104
def initialize_fields(self):
134105
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)
135106
self.f_1 = initialize_eq(self.f_1, self.grid, self.velocity_set, self.precision_policy, self.backend)
136107

137-
def setup_stepper(self):
138-
self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC")
108+
def bc_profile(self):
109+
u_max = self.wind_speed
110+
# Get the grid dimensions for the y and z directions
111+
H_y = float(self.grid_shape[1] - 1) # Height in y direction
112+
H_z = float(self.grid_shape[2] - 1) # Height in z direction
113+
114+
@wp.func
115+
def bc_profile_warp(index: wp.vec3i):
116+
# Poiseuille flow profile: parabolic velocity distribution
117+
y = self.precision_policy.store_precision.wp_dtype(index[1])
118+
z = self.precision_policy.store_precision.wp_dtype(index[2])
119+
120+
# Calculate normalized distance from center
121+
y_center = y - (H_y / 2.0)
122+
z_center = z - (H_z / 2.0)
123+
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
124+
125+
# Parabolic profile: u = u_max * (1 - r²)
126+
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)
127+
128+
def bc_profile_jax():
129+
y = jnp.arange(self.grid_shape[1])
130+
z = jnp.arange(self.grid_shape[2])
131+
Y, Z = jnp.meshgrid(y, z, indexing="ij")
132+
133+
# Calculate normalized distance from center
134+
y_center = Y - (H_y / 2.0)
135+
z_center = Z - (H_z / 2.0)
136+
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
137+
138+
# Parabolic profile for x velocity, zero for y and z
139+
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
140+
u_y = jnp.zeros_like(u_x)
141+
u_z = jnp.zeros_like(u_x)
142+
143+
return jnp.stack([u_x, u_y, u_z])
144+
145+
if self.backend == ComputeBackend.JAX:
146+
return bc_profile_jax
147+
elif self.backend == ComputeBackend.WARP:
148+
return bc_profile_warp
149+
150+
def setup_stepper(self, omega):
151+
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
152+
f_0=self.f_0,
153+
f_1=self.f_1,
154+
bc_mask=self.bc_mask,
155+
missing_mask=self.missing_mask,
156+
omega=omega,
157+
boundary_conditions=self.boundary_conditions,
158+
collision_type="BGK",
159+
)
139160

140161
def run(self, num_steps, print_interval, post_process_interval=100):
141162
# Setup the operator for computing surface forces at the interface of the specified BC
@@ -236,8 +257,7 @@ def plot_drag_coefficient(self):
236257
print_interval = 1000
237258

238259
# Set up Reynolds number and deduce relaxation time (omega)
239-
# Re = 50000.0
240-
Re = 500000000000.0
260+
Re = 5000.0
241261
clength = grid_size_x - 1
242262
visc = wind_speed * clength / Re
243263
omega = 1.0 / (3.0 * visc + 0.5)

0 commit comments

Comments
 (0)