Skip to content

Commit

Permalink
Added abstraction layer for boundary condition application and steppe…
Browse files Browse the repository at this point in the history
…r initialization, and the capability to add profiles to boundary conditions
  • Loading branch information
mehdiataei committed Nov 30, 2024
1 parent 2b6355b commit 53f626c
Show file tree
Hide file tree
Showing 15 changed files with 542 additions and 308 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- XLB is now installable via pip
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
- Added NVIDIA's Warp backend for state-of-the-art performance
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
- Added the capability to add profiles to boundary conditions
98 changes: 63 additions & 35 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.grid import grid_factory
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
FullwayBounceBackBC,
HalfwayBounceBackBC,
ZouHeBC,
RegularizedBC,
EquilibriumBC,
DoNothingBC,
ExtrapolationOutflowBC,
)
from xlb.operator.macroscopic import Macroscopic
from xlb.operator.boundary_masker import IndicesBoundaryMasker
from xlb.utils import save_fields_vtk, save_image
from xlb.utils import save_image
import warp as wp
import numpy as np
import jax.numpy as jnp
Expand All @@ -34,18 +30,19 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
self.velocity_set = velocity_set
self.backend = backend
self.precision_policy = precision_policy
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
self.stepper = None
self.omega = omega
self.boundary_conditions = []
self.u_max = 0.04

# Setup the simulation BC, its initial conditions, and the stepper
self._setup(omega)
# Create grid using factory
self.grid = grid_factory(grid_shape, compute_backend=backend)

def _setup(self, omega):
# Setup the simulation BC and stepper
self._setup()

def _setup(self):
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.initialize_fields()
self.setup_stepper(omega)
self.setup_stepper()

def define_boundary_indices(self):
box = self.grid.bounding_box_indices()
Expand All @@ -69,31 +66,62 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
# bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet)
# bc_outlet = DoNothingBC(indices=outlet)
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
bc_sphere = HalfwayBounceBackBC(indices=sphere)
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)

indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0))

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")
self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.init_fields()

def bc_profile(self):
u_max = self.u_max # u_max = 0.04
# Get the grid dimensions for the y and z directions
H_y = float(self.grid_shape[1] - 1) # Height in y direction
H_z = float(self.grid_shape[2] - 1) # Height in z direction

@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = self.precision_policy.store_precision.wp_dtype(index[1])
z = self.precision_policy.store_precision.wp_dtype(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
z_center = z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile: u = u_max * (1 - r²)
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)

def bc_profile_jax():
y = jnp.arange(self.grid_shape[1])
z = jnp.arange(self.grid_shape[2])
Y, Z = jnp.meshgrid(y, z, indexing="ij")

# Calculate normalized distance from center
y_center = Y - (H_y / 2.0)
z_center = Z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile for x velocity, zero for y and z
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
u_y = jnp.zeros_like(u_x)
u_z = jnp.zeros_like(u_x)

return jnp.stack([u_x, u_y, u_z])

if self.backend == ComputeBackend.JAX:
return bc_profile_jax
elif self.backend == ComputeBackend.WARP:
return bc_profile_warp

def run(self, num_steps, post_process_interval=100):
start_time = time.time()
Expand Down Expand Up @@ -134,8 +162,8 @@ def post_process(self, i):

if __name__ == "__main__":
# Running the simulation
grid_shape = (512 // 2, 128 // 2, 128 // 2)
backend = ComputeBackend.WARP
grid_shape = (256 // 2, 128 // 2, 128 // 2)
backend = ComputeBackend.JAX
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
Expand Down
41 changes: 17 additions & 24 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.operator.boundary_masker import IndicesBoundaryMasker
from xlb.grid import grid_factory
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC
from xlb.operator.macroscopic import Macroscopic
Expand All @@ -26,19 +25,21 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre
self.velocity_set = velocity_set
self.backend = backend
self.precision_policy = precision_policy
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
self.stepper = None
self.omega = omega
self.boundary_conditions = []
self.prescribed_vel = prescribed_vel

# Setup the simulation BC, its initial conditions, and the stepper
self._setup(omega)
# Create grid using factory
self.grid = grid_factory(grid_shape, compute_backend=backend)

def _setup(self, omega):
# Setup the simulation BC and stepper
self._setup()

def _setup(self):
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.initialize_fields()
self.setup_stepper(omega)
self.setup_stepper()
# Initialize fields using the stepper
self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.init_fields()

def define_boundary_indices(self):
box = self.grid.bounding_box_indices()
Expand All @@ -54,21 +55,13 @@ def setup_boundary_conditions(self):
bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_walls, bc_top]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
Expand Down
19 changes: 14 additions & 5 deletions examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,24 @@ class LidDrivenCavity2D_distributed(LidDrivenCavity2D):
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)

def setup_stepper(self, omega):
stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
distributed_stepper = distribute(
def setup_stepper(self):
# Create the base stepper
stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)

# Distribute the stepper
self.stepper = distribute(
stepper,
self.grid,
self.velocity_set,
)
self.stepper = distributed_stepper
return

# Initialize fields using the distributed stepper
self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.init_fields()


if __name__ == "__main__":
Expand Down
57 changes: 32 additions & 25 deletions examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import time
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq
from xlb.grid import grid_factory
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import RegularizedBC
from xlb.operator.macroscopic import Macroscopic
from xlb.operator.boundary_masker import IndicesBoundaryMasker
from xlb.utils import save_fields_vtk, save_image
from xlb.helper import initialize_eq
import warp as wp
import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -48,18 +48,16 @@ def __init__(self, channel_half_width, Re_tau, u_tau, grid_shape, velocity_set,
self.u_tau = u_tau
self.visc = u_tau * channel_half_width / Re_tau
self.omega = 1.0 / (3.0 * self.visc + 0.5)
# DeltaPlus = Re_tau / channel_half_width
# DeltaPlus = u_tau / nu * Delta where u_tau / nu = Re_tau / channel_half_width

self.grid_shape = grid_shape
self.velocity_set = velocity_set
self.backend = backend
self.precision_policy = precision_policy
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
self.stepper = None
self.boundary_conditions = []

# Setup the simulation BC, its initial conditions, and the stepper
# Create grid using factory
self.grid = grid_factory(grid_shape, compute_backend=backend)

# Setup the simulation BC and stepper
self._setup()

def get_force(self):
Expand All @@ -71,31 +69,38 @@ def get_force(self):

def _setup(self):
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.initialize_fields()
self.setup_stepper()
# Initialize fields using the stepper
self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.init_fields()
self.initialize_fields()

def define_boundary_indices(self):
# top and bottom sides of the channel are no-slip and the other directions are periodic
box = self.grid.bounding_box_indices(remove_edges=True)
walls = [box["bottom"][i] + box["top"][i] for i in range(self.velocity_set.d)]
return walls

def bc_profile(self):
@wp.func
def bc_profile_warp(index: wp.vec3i):
return wp.vec(0.0, length=1)

def bc_profile_jax():
return jnp.zeros(1)

if self.backend == ComputeBackend.JAX:
return bc_profile_jax
elif self.backend == ComputeBackend.WARP:
return bc_profile_warp

def setup_boundary_conditions(self):
walls = self.define_boundary_indices()
bc_walls = RegularizedBC("velocity", (0.0, 0.0, 0.0), indices=walls)
bc_walls = RegularizedBC("velocity", profile=self.bc_profile(), indices=walls)
self.boundary_conditions = [bc_walls]

def setup_boundary_masker(self):
indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
)
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)

def initialize_fields(self):
shape = (self.velocity_set.d,) + (self.grid_shape)
# Initialize with random velocity field
shape = (self.velocity_set.d,) + self.grid_shape
np.random.seed(0)
u_init = np.random.random(shape)
if self.backend == ComputeBackend.JAX:
Expand All @@ -105,9 +110,12 @@ def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init)

def setup_stepper(self):
force = self.get_force()
self.stepper = IncompressibleNavierStokesStepper(
self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC", forcing_scheme="exact_difference", force_vector=force
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
force_vector=self.get_force(),
)

def run(self, num_steps, print_interval, post_process_interval=100):
Expand Down Expand Up @@ -142,14 +150,12 @@ def post_process(self, i):
u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5
fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_z": u[2], "u_magnitude": u_magnitude}
save_fields_vtk(fields, timestep=i)
save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i)
save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i)

# Save monitor plot
self.plot_uplus(u, i)
return

def plot_uplus(self, u, timestep):
# Compute moving average of drag coefficient, 100, 1000, 10000
# mean streamwise velocity in wall units u^+(z)
# Wall distance in wall units to be used inside output_data
zz = np.arange(self.grid_shape[-1])
Expand All @@ -165,6 +171,7 @@ def plot_uplus(self, u, timestep):
ax.set_ylim([0, 20])
fname = "uplus_" + str(timestep // 10000).zfill(5) + ".png"
plt.savefig(fname, format="png")
plt.close()


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 53f626c

Please sign in to comment.