From d27f9679d752d9c98a6374ae8d46f3435673add6 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Wed, 31 May 2023 10:33:48 -0400 Subject: [PATCH 1/6] Removed padding to avoid OOM errors on large runs --- src/base.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/base.py b/src/base.py index 7875a32..bd692c7 100644 --- a/src/base.py +++ b/src/base.py @@ -98,7 +98,9 @@ def __init__(self, lattice, omega, nx, ny, nz=0, precision="f32/f32", optimize=F self.positionalSharding = PositionalSharding(self.devices) self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh, in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False)) - + + self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh, + in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False)) # Set up the sharding and streaming for 2D and 3D simulations elif self.dim == 3: @@ -109,6 +111,10 @@ def __init__(self, lattice, omega, nx, ny, nz=0, precision="f32/f32", optimize=F self.positionalSharding = PositionalSharding(self.devices) self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh, in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False)) + + self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh, + in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False)) + else: raise ValueError(f"dim = {self.dim} not supported") @@ -130,9 +136,8 @@ def _create_boundary_data(self): solid_halo_voxels = np.unique(np.vstack(solid_halo_list), axis=0) if solid_halo_list else None # Create the grid connectivity bitmask on each process - create_grid_connectivity_bitmask = jit((self.create_grid_connectivity_bitmask)) start = time.time() - connectivity_bitmask = create_grid_connectivity_bitmask(solid_halo_voxels) + connectivity_bitmask = self.create_grid_connectivity_bitmask(solid_halo_voxels) print("Time to create the grid connectivity bitmask:", time.time() - start) start = time.time() @@ -197,24 +202,27 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels): hw_x = self.n_devices hw_y = hw_z = 1 if self.dim == 2: - connectivity_bitmask = jnp.zeros((self.nx, self.ny, self.lattice.q), dtype=bool) + connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, initVal=True) + connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None))].set(False) if solid_halo_voxels is not None: - connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True) - connectivity_bitmask = jnp.pad(connectivity_bitmask, ((hw_x, hw_x), (hw_y, hw_y), (0, 0)), - 'constant', constant_values=True) - connectivity_bitmask = self.compute_bitmask(connectivity_bitmask) + solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) + solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) + connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True) - return connectivity_bitmask[hw_x:-hw_x, hw_y:-hw_y] + connectivity_bitmask = self.compute_bitmask(connectivity_bitmask) + return lax.with_sharding_constraint(connectivity_bitmask[hw_x:-hw_x, hw_y:-hw_y], self.namedSharding) elif self.dim == 3: - connectivity_bitmask = jnp.zeros((self.nx, self.ny, self.nz, self.lattice.q), dtype=bool) + connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, initVal=True) + connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(hw_z, -hw_z), slice(None))].set(False) if solid_halo_voxels is not None: + solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) + solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) + solid_halo_voxels = solid_halo_voxels.at[:, 2].add(hw_z) connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True) - connectivity_bitmask = jnp.pad(connectivity_bitmask, ((hw_x, hw_x), (hw_y, hw_y), (hw_z, hw_z), (0, 0)), - 'constant', constant_values=True) - connectivity_bitmask = self.compute_bitmask(connectivity_bitmask) - return connectivity_bitmask[hw_x:-hw_x, hw_y:-hw_y, hw_z:-hw_z] + connectivity_bitmask = self.compute_bitmask(connectivity_bitmask) + return lax.with_sharding_constraint(connectivity_bitmask[hw_x:-hw_x, hw_y:-hw_y, hw_z:-hw_z], self.namedSharding) def bounding_box_indices(self): """ @@ -450,7 +458,7 @@ def streaming_i(f, c): return vmap(streaming_i, in_axes=(-1, 0), out_axes=-1)(f, self.c.T) - def compute_bitmask(self, b): + def compute_bitmask_m(self, b): """ This function computes a bitmask for each direction in the lattice. The bitmask is used to determine which nodes are fluid nodes and which are boundary nodes. From 18c408ad941623d3d490ee35cf8bdb66c82caef1 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Wed, 11 Oct 2023 23:05:12 -0400 Subject: [PATCH 2/6] Some formatting fixes based on Python conventions --- examples/CFD/channel3d.py | 7 +- examples/CFD/taylor_green_vortex.py | 4 +- examples/performance/MLUPS3d_distributed.py | 23 +++-- src/base.py | 101 ++++++++++---------- src/lattice.py | 12 +-- src/models.py | 6 +- src/utils.py | 2 +- 7 files changed, 81 insertions(+), 74 deletions(-) diff --git a/examples/CFD/channel3d.py b/examples/CFD/channel3d.py index 521baa0..366764c 100644 --- a/examples/CFD/channel3d.py +++ b/examples/CFD/channel3d.py @@ -55,7 +55,7 @@ def get_dns_data(): } return dns_dic -class turbulentChannel(KBCSim): +class TurbulentChannel(KBCSim): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -63,12 +63,11 @@ def set_boundary_conditions(self): # top and bottom sides of the channel are no-slip and the other directions are periodic wall = np.concatenate((self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'])) self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) - return def initialize_macroscopic_fields(self): rho = self.precisionPolicy.cast_to_output(1.0) u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim), - self.precisionPolicy.compute_dtype, initVal=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim))) + self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim))) u = self.precisionPolicy.cast_to_output(u) return rho, u @@ -149,5 +148,5 @@ def output_data(self, **kwargs): 'io_rate': 20000, 'print_info_rate': 20000 } - sim = turbulentChannel(**kwargs) + sim = TurbulentChannel(**kwargs) sim.run(4000000) diff --git a/examples/CFD/taylor_green_vortex.py b/examples/CFD/taylor_green_vortex.py index 9434d68..f6142bf 100644 --- a/examples/CFD/taylor_green_vortex.py +++ b/examples/CFD/taylor_green_vortex.py @@ -40,9 +40,9 @@ def set_boundary_conditions(self): def initialize_macroscopic_fields(self): ux, uy, rho = taylor_green_initial_fields(xx, yy, vel_ref, 1, 0., 0.) - rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, initVal=1.0, sharding=self.sharding) + rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, init_val=1.0, sharding=self.sharding) u = np.stack([ux, uy], axis=-1) - u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, initVal=u, sharding=self.sharding) + u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, init_val=u, sharding=self.sharding) return rho, u def initialize_populations(self, rho, u): diff --git a/examples/performance/MLUPS3d_distributed.py b/examples/performance/MLUPS3d_distributed.py index 2aa5814..3cf4731 100644 --- a/examples/performance/MLUPS3d_distributed.py +++ b/examples/performance/MLUPS3d_distributed.py @@ -5,22 +5,27 @@ """ -from src.models import BGKSim -from src.lattice import LatticeD3Q19 -import jax.numpy as jnp -import numpy as np -from src.utils import * -from jax.config import config +# Standard Libraries +import argparse import os from time import time -import argparse -import jax import portpicker + +import jax +import jax.numpy as jnp +import numpy as np + +from jax.config import config + +from src.boundary_conditions import * +from src.lattice import LatticeD3Q19 +from src.models import BGKSim +from src.utils import * + #config.update('jax_disable_jit', True) # Use 8 CPU devices #os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' #config.update("jax_enable_x64", True) -from src.boundary_conditions import * precision = 'f32/f32' diff --git a/src/base.py b/src/base.py index c4dbf36..115168f 100644 --- a/src/base.py +++ b/src/base.py @@ -1,23 +1,29 @@ -from src.boundary_conditions import * -from jax.config import config -from src.utils import * -from functools import partial -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec -from jax.sharding import PositionalSharding -from jax.sharding import Mesh -from jax.experimental import mesh_utils -from jax.experimental.shard_map import shard_map -from jax.experimental.multihost_utils import process_allgather -from jax import jit, lax, vmap -from termcolor import colored -from orbax.checkpoint import * +# Standard Libraries +import os import time + +# Third-Party Libraries +import jax import jax.numpy as jnp -import numpy as np import jmp -import os -import jax +import numpy as np +from termcolor import colored + +# JAX-related imports +from jax import jit, lax, vmap +from jax.config import config +from jax.experimental import mesh_utils +from jax.experimental.multihost_utils import process_allgather +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding, PartitionSpec, PositionalSharding, Mesh +import orbax.checkpoint as orb +# functools imports +from functools import partial + +# Local/Custom Libraries +# from src.boundary_conditions import * +from src.utils import downsample_field + jax.config.update("jax_spmd_mode", 'allow_all') @@ -88,8 +94,8 @@ def __init__(self, **kwargs): # Set the checkpoint manager if self.checkpointRate > 0: - mngr_options = CheckpointManagerOptions(save_interval_steps=self.checkpointRate, max_to_keep=1) - self.mngr = CheckpointManager(self.checkpointDir, PyTreeCheckpointer(), options=mngr_options) + mngr_options = orb.CheckpointManagerOptions(save_interval_steps=self.checkpointRate, max_to_keep=1) + self.mngr = orb.CheckpointManager(self.checkpointDir, orb.PyTreeCheckpointer(), options=mngr_options) else: print("WARNING: Checkpointing is disabled for this simulation.") self.mngr = None @@ -152,7 +158,7 @@ def __init__(self, **kwargs): raise ValueError(f"dim = {self.dim} not supported") # Compute the bounding box indices for boundary conditions - self.boundingBoxIndices = self.bounding_box_indices() + self.boundingBoxIndices= self.bounding_box_indices() # Create boundary data for the simulation self._create_boundary_data() self.force = self.get_force() @@ -182,7 +188,7 @@ def _create_boundary_data(self): print("Time to create the local bitmasks and normal arrays:", time.time() - start) # This is another non-JITed way of creating the distributed arrays. It is not used at the moment. - # def distributed_array_init(self, shape, type, initVal=None): + # def distributed_array_init(self, shape, type, init_val=None): # sharding_dim = shape[0] // self.nDevices # sharded_shape = (self.nDevices, sharding_dim, *shape[1:]) # device_shape = sharded_shape[1:] @@ -190,16 +196,16 @@ def _create_boundary_data(self): # for d, index in self.sharding.addressable_devices_indices_map(sharded_shape).items(): # jax.default_device = d - # if initVal is None: + # if init_val is None: # x = jnp.zeros(shape=device_shape, dtype=type) # else: - # x = jnp.full(shape=device_shape, fill_value=initVal, dtype=type) + # x = jnp.full(shape=device_shape, fill_value=init_val, dtype=type) # arrays += [jax.device_put(x, d)] # jax.default_device = jax.devices()[0] # return jax.make_array_from_single_device_arrays(shape, self.sharding, arrays) @partial(jit, static_argnums=(0, 1, 2, 4)) - def distributed_array_init(self, shape, type, initVal=0, sharding=None): + def distributed_array_init(self, shape, type, init_val=0, sharding=None): """ Initialize a distributed array using JAX, with a specified shape, data type, and initial value. Optionally, provide a custom sharding strategy. @@ -208,7 +214,7 @@ def distributed_array_init(self, shape, type, initVal=0, sharding=None): ---------- shape (tuple): The shape of the array to be created. type (dtype): The data type of the array to be created. - initVal (scalar, optional): The initial value to fill the array with. Defaults to 0. + init_val (scalar, optional): The initial value to fill the array with. Defaults to 0. sharding (Sharding, optional): The sharding strategy to use. Defaults to `self.sharding`. Returns @@ -217,7 +223,7 @@ def distributed_array_init(self, shape, type, initVal=0, sharding=None): """ if sharding is None: sharding = self.sharding - x = jnp.full(shape=shape, fill_value=initVal, dtype=type) + x = jnp.full(shape=shape, fill_value=init_val, dtype=type) return jax.lax.with_sharding_constraint(x, sharding) @partial(jit, static_argnums=(0,)) @@ -237,7 +243,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels): hw_x = self.nDevices hw_y = hw_z = 1 if self.dim == 2: - connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, initVal=True) + connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, init_val=True) connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None))].set(False) if solid_halo_voxels is not None: solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) @@ -248,7 +254,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels): return lax.with_sharding_constraint(connectivity_bitmask, self.sharding) elif self.dim == 3: - connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, initVal=True) + connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, init_val=True) connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(hw_z, -hw_z), slice(None))].set(False) if solid_halo_voxels is not None: solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) @@ -273,18 +279,18 @@ def bounding_box_indices(self): # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. # Each edge is represented as an array of indices. For example, the bottom edge includes # all points where the y-coordinate is 0, so its indices are [[i, 0] for i in range(self.nx)]. - boundingBox = {"bottom": np.array([[i, 0] for i in range(self.nx)], dtype=int), + bounding_box = {"bottom": np.array([[i, 0] for i in range(self.nx)], dtype=int), "top": np.array([[i, self.ny - 1] for i in range(self.nx)], dtype=int), "left": np.array([[0, i] for i in range(self.ny)], dtype=int), "right": np.array([[self.nx - 1, i] for i in range(self.ny)], dtype=int)} - return boundingBox + return bounding_box elif self.dim == 3: # For a 3D grid, the bounding box consists of six faces: bottom, top, left, right, front, and back. # Each face is represented as an array of indices. For example, the bottom face includes all points # where the z-coordinate is 0, so its indices are [[i, j, 0] for i in range(self.nx) for j in range(self.ny)]. - boundingBox = { + bounding_box = { "bottom": np.array([[i, j, 0] for i in range(self.nx) for j in range(self.ny)], dtype=int), "top": np.array([[i, j, self.nz - 1] for i in range(self.nx) for j in range(self.ny)],dtype=int), "left": np.array([[0, j, k] for j in range(self.ny) for k in range(self.nz)], dtype=int), @@ -292,7 +298,7 @@ def bounding_box_indices(self): "front": np.array([[i, 0, k] for i in range(self.nx) for k in range(self.nz)], dtype=int), "back": np.array([[i, self.ny - 1, k] for i in range(self.nx) for k in range(self.nz)], dtype=int)} - return boundingBox + return bounding_box def set_precisions(self, precision): """ @@ -335,7 +341,7 @@ def initialize_macroscopic_fields(self): print(" To set explicit initial density and velocity, use self.initialize_macroscopic_fields.") return None, None - def assign_fields_sharded(self, checkpoint=None): + def assign_fields_sharded(self): """ This function is used to initialize the simulation by assigning the macroscopic fields and populations. @@ -362,7 +368,7 @@ def assign_fields_sharded(self, checkpoint=None): shape = (self.nx, self.ny, self.nz, self.lattice.q) if rho0 is None or u0 is None: - f = self.distributed_array_init(shape, self.precisionPolicy.output_dtype, initVal=self.w) + f = self.distributed_array_init(shape, self.precisionPolicy.output_dtype, init_val=self.w) else: f = self.initialize_populations(rho0, u0) @@ -575,13 +581,13 @@ def compute_bitmask_i(b, i): return vmap(compute_bitmask_i, in_axes=(None, 0), out_axes=-1)(b, self.lattice.i_s) @partial(jit, static_argnums=(0, 3), inline=True) - def equilibrium(self, rho, u, castOutput=True): + def equilibrium(self, rho, u, cast_output=True): """ This function computes the equilibrium distribution function in the Lattice Boltzmann Method. The equilibrium distribution function is a function of the macroscopic density and velocity. - The function first casts the density and velocity to the compute precision if the castOutput flag is True. - The function finally casts the equilibrium distribution function to the output precision if the castOutput + The function first casts the density and velocity to the compute precision if the cast_output flag is True. + The function finally casts the equilibrium distribution function to the output precision if the cast_output flag is True. Parameters @@ -590,7 +596,7 @@ def equilibrium(self, rho, u, castOutput=True): The macroscopic density. u: jax.numpy.ndarray The macroscopic velocity. - castOutput: bool, optional + cast_output: bool, optional A flag indicating whether to cast the density, velocity, and equilibrium distribution function to the compute and output precisions. Default is True. @@ -599,8 +605,8 @@ def equilibrium(self, rho, u, castOutput=True): feq: ja.numpy.ndarray The equilibrium distribution function. """ - # Cast the density and velocity to the compute precision if the castOutput flag is True - if castOutput: + # Cast the density and velocity to the compute precision if the cast_output flag is True + if cast_output: rho, u = self.precisionPolicy.cast_to_compute((rho, u)) # Cast c to compute precision so that XLA call FXX matmul, @@ -610,7 +616,7 @@ def equilibrium(self, rho, u, castOutput=True): usqr = 1.5 * jnp.sum(jnp.square(u), axis=-1, keepdims=True) feq = rho * self.w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) - if castOutput: + if cast_output: return self.precisionPolicy.cast_to_output(feq) else: return feq @@ -736,9 +742,6 @@ def step(self, f_poststreaming, timestep, return_fpost=False): return f_poststreaming, f_postcollision else: return f_poststreaming, None - - def checkpoint_manager(self): - pass def run(self, t_max): """ @@ -768,11 +771,11 @@ def run(self, t_max): assert self.mngr is not None, "Checkpoint manager does not exist." state = {'f': f} shardings = jax.tree_map(lambda x: x.sharding, state) - restore_args = checkpoint_utils.construct_restore_args(state, shardings) + restore_args = orb.checkpoint_utils.construct_restore_args(state, shardings) try: f = self.mngr.restore(latest_step, restore_kwargs={'restore_args': restore_args})['f'] print(f"Restored checkpoint at step {latest_step}.") - except: + except ValueError: raise ValueError(f"Failed to restore checkpoint at step {latest_step}.") start_step = latest_step + 1 @@ -940,7 +943,7 @@ def get_force(self): force: jax.numpy.ndarray The force to be applied to the fluid. """ - return + pass @partial(jit, static_argnums=(0,), inline=True) def apply_force(self, f_postcollision, feq, rho, u): @@ -972,8 +975,8 @@ def apply_force(self, f_postcollision, feq, rho, u): Boundary conditions. Physica A, 392, 1925-1930. Krüger, T., et al. (2017). The lattice Boltzmann method. Springer International Publishing, 10.978-3, 4-15. """ - deltaU = self.get_force() - feq_force = self.equilibrium(rho, u + deltaU, castOutput=False) + delta_u = self.get_force() + feq_force = self.equilibrium(rho, u + delta_u, cast_output=False) f_postcollision = f_postcollision + feq_force - feq return f_postcollision diff --git a/src/lattice.py b/src/lattice.py index 3052795..f3c4107 100644 --- a/src/lattice.py +++ b/src/lattice.py @@ -1,6 +1,6 @@ +import re import numpy as np import jax.numpy as jnp -import re class Lattice(object): @@ -75,7 +75,7 @@ def construct_right_indices(self): The indices of the right velocities. """ c = self.c.T - return np.where(c[:, 0] == 1)[0] + return np.nonzero(c[:, 0] == 1)[0] def construct_left_indices(self): """ @@ -88,7 +88,7 @@ def construct_left_indices(self): The indices of the left velocities. """ c = self.c.T - return np.where(c[:, 0] == -1)[0] + return np.nonzero(c[:, 0] == -1)[0] def construct_main_indices(self): """ @@ -103,10 +103,10 @@ def construct_main_indices(self): """ c = self.c.T if self.d == 2: - return np.where((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] elif self.d == 3: - return np.where((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] def construct_lattice_velocity(self): """ @@ -165,7 +165,7 @@ def construct_lattice_weight(self): w[0] = 1.0 / 3.0 elif self.name == "D3Q27": cl = np.linalg.norm(c, axis=1) - w[cl == 1.0] = 2.0 / 27.0 + w[np.isclose(cl, 1.0, atol=1e-8)] = 2.0 / 27.0 w[(cl > 1) & (cl <= np.sqrt(2))] = 1.0 / 54.0 w[(cl > np.sqrt(2)) & (cl <= np.sqrt(3))] = 1.0 / 216.0 w[0] = 8.0 / 27.0 diff --git a/src/models.py b/src/models.py index 7e5e825..e9a5d81 100644 --- a/src/models.py +++ b/src/models.py @@ -26,7 +26,7 @@ def collision(self, f): """ f = self.precisionPolicy.cast_to_compute(f) rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, castOutput=False) + feq = self.equilibrium(rho, u, cast_output=False) fneq = f - feq fout = f - self.omega * fneq if self.force is not None: @@ -52,7 +52,7 @@ def collision(self, f): tiny = 1e-32 beta = self.omega * 0.5 rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, castOutput=False) + feq = self.equilibrium(rho, u, cast_output=False) fneq = f - feq if self.dim == 2: deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 @@ -212,7 +212,7 @@ def collision(self, f): """ f = self.precisionPolicy.cast_to_compute(f) rho =jnp.sum(f, axis=-1, keepdims=True) - feq = self.equilibrium(rho, self.vel, castOutput=False) + feq = self.equilibrium(rho, self.vel, cast_output=False) fneq = f - feq fout = f - self.omega * fneq return self.precisionPolicy.cast_to_output(fout) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index d7b2eb1..d937042 100644 --- a/src/utils.py +++ b/src/utils.py @@ -73,7 +73,7 @@ def save_image(timestep, fld, prefix=None): fname = fname + "_" + str(timestep).zfill(4) if len(fld.shape) > 3: - raise ValueError(f"The input field should be 2D!") + raise ValueError("The input field should be 2D!") elif len(fld.shape) == 3: fld = np.sqrt(fld[..., 0] ** 2 + fld[..., 1] ** 2) From 0bbeb5d43080871ddf64f403f9c7aaf93d299b7b Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Sun, 15 Oct 2023 21:29:06 -0400 Subject: [PATCH 3/6] Updated requirements and installation guide --- README.md | 27 ++++++++++++--------------- requirements.txt | 16 ++++++++-------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 6b66008..4a7bac7 100644 --- a/README.md +++ b/README.md @@ -90,23 +90,22 @@ The following examples showcase the capabilities of XLB: To use XLB, you must first install JAX and other dependencies using the following commands: -```bash -# Please refer to https://github.com/google/jax for the latest installation documentation - -pip install --upgrade pip -# For CPU run -pip install --upgrade "jax[cpu]" +Please refer to https://github.com/google/jax for the latest installation documentation. The following table is taken from [JAX's Github page](https://github.com/google/jax). -# For GPU run +| Hardware | Instructions | +|------------|-----------------------------------------------------------------------------------------------------------------| +| CPU | `pip install -U "jax[cpu]"` | +| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` | +| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | +| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | +| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | -# CUDA 12 and cuDNN 8.8 or newer. -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +**Note:** We encountered challenges when executing XLB on Apple GPUs due to the lack of support for certain operations in the Metal backend. We advise using the CPU backend on Mac OS. We will be testing XLB on Apple's GPUs in the future and will update this section accordingly. -# CUDA 11 and cuDNN 8.6 or newer. -pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -# Run dependencies +Install dependencies: +```bash pip install jmp pyvista numpy matplotlib Rtree trimesh jmp ``` @@ -118,6 +117,4 @@ export PYTHONPATH=. python3 examples/cavity2d.py ``` ## Citing XLB -Accompanying publication coming soon: - -**M. Ataei, H. Salehipour**. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA +Accompanying paper will be available soon. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2c912ed..0794ff7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -jax==0.4.11 -jaxlib==0.4.11 +jax==0.4.19 +jaxlib==0.4.19 jmp==0.0.4 -matplotlib==3.7.1 -numpy==1.24.2 -pyvista==0.38.5 +matplotlib==3.8.0 +numpy==1.26.1 +pyvista==0.42.3 Rtree==1.0.1 -trimesh==3.20.2 -orbax-checkpoint==0.2.3 -portpicker===1.5.2 +trimesh==4.0.0 +orbax-checkpoint==0.4.1 +portpicker===1.6.0 termcolor==2.3.0 \ No newline at end of file From dcf54ff25b0787c0c322c83e1ee0d267d61172cf Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Mon, 16 Oct 2023 12:38:30 -0400 Subject: [PATCH 4/6] Refactoring the base class to better check the attributes and print them in each run --- examples/CFD/airfoil3d.py | 8 - examples/CFD/cavity2d.py | 7 +- examples/CFD/cavity3d.py | 5 - examples/CFD/channel3d.py | 3 - examples/CFD/couette2d.py | 4 - examples/CFD/cylinder2d.py | 3 - examples/CFD/oscilating_cylinder2d.py | 7 - examples/CFD/taylor_green_vortex.py | 1 - examples/CFD/windtunnel3d.py | 6 - src/base.py | 251 +++++++++++++++++++++++--- src/lattice.py | 4 +- 11 files changed, 227 insertions(+), 72 deletions(-) diff --git a/examples/CFD/airfoil3d.py b/examples/CFD/airfoil3d.py index e601551..db538f6 100644 --- a/examples/CFD/airfoil3d.py +++ b/examples/CFD/airfoil3d.py @@ -32,7 +32,6 @@ import matplotlib.pylab as plt from src.models import BGKSim, KBCSim from src.boundary_conditions import * -from src.lattice import * import numpy as np from src.utils import * from jax.config import config @@ -105,15 +104,11 @@ def output_data(self, **kwargs): airfoil_thickness = 30 airfoil_angle = 20 airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T - precision = 'f32/f32' - lattice = LatticeD3Q27(precision=precision) nx = airfoil.shape[0] ny = airfoil.shape[1] - print("airfoil shape: ", airfoil.shape) - ny = 3 * ny nx = 4 * nx nz = 101 @@ -124,13 +119,11 @@ def output_data(self, **kwargs): visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) os.system('rm -rf ./*.vtk && rm -rf ./*.png') # Set the parameters for the simulation kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, @@ -141,5 +134,4 @@ def output_data(self, **kwargs): } sim = Airfoil(**kwargs) - print('Domain size: ', sim.nx, sim.ny, sim.nz) sim.run(20000) \ No newline at end of file diff --git a/examples/CFD/cavity2d.py b/examples/CFD/cavity2d.py index a7fae66..4aaa384 100644 --- a/examples/CFD/cavity2d.py +++ b/examples/CFD/cavity2d.py @@ -20,14 +20,12 @@ from jax.config import config from src.utils import * import numpy as np -from src.lattice import LatticeD2Q9 from src.models import BGKSim, KBCSim import jax.numpy as jnp import os # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax class Cavity(KBCSim): def __init__(self, **kwargs): @@ -61,7 +59,6 @@ def output_data(self, **kwargs): if __name__ == "__main__": precision = "f32/f32" - lattice = LatticeD2Q9(precision) nx = 200 ny = 200 @@ -71,16 +68,14 @@ def output_data(self, **kwargs): clength = nx - 1 checkpoint_rate = 1000 - checkpoint_dir = "./checkpoints" + checkpoint_dir = os.path.abspath("./checkpoints") visc = prescribed_vel * clength / Re omega = 1.0 / (3.0 * visc + 0.5) - print("omega = ", omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/cavity3d.py b/examples/CFD/cavity3d.py index d912786..4137290 100644 --- a/examples/CFD/cavity3d.py +++ b/examples/CFD/cavity3d.py @@ -20,7 +20,6 @@ # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' from src.models import BGKSim, KBCSim -from src.lattice import LatticeD3Q27 import numpy as np from src.utils import * from jax.config import config @@ -68,8 +67,6 @@ def output_data(self, **kwargs): # live_volume_randering(timestep, u_mag) if __name__ == '__main__': - lattice = LatticeD3Q27(precision) - nx = 101 ny = 101 nz = 101 @@ -80,12 +77,10 @@ def output_data(self, **kwargs): visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/channel3d.py b/examples/CFD/channel3d.py index 366764c..78cb049 100644 --- a/examples/CFD/channel3d.py +++ b/examples/CFD/channel3d.py @@ -115,7 +115,6 @@ def output_data(self, **kwargs): if __name__ == "__main__": precision = "f64/f64" - lattice = LatticeD3Q27(precision) # h: channel half-width h = 10 # Define channel geometry based on h @@ -135,11 +134,9 @@ def output_data(self, **kwargs): zz = np.minimum(zz, zz.max() - zz) yplus = zz * u_tau / visc - print("omega = ", omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/couette2d.py b/examples/CFD/couette2d.py index e14765c..88316d2 100644 --- a/examples/CFD/couette2d.py +++ b/examples/CFD/couette2d.py @@ -4,7 +4,6 @@ from src.models import BGKSim from src.boundary_conditions import * -from src.lattice import LatticeD2Q9 import jax.numpy as jnp import numpy as np from src.utils import * @@ -49,7 +48,6 @@ def output_data(self, **kwargs): if __name__ == "__main__": precision = "f32/f32" - lattice = LatticeD2Q9(precision) nx = 501 ny = 101 @@ -60,12 +58,10 @@ def output_data(self, **kwargs): visc = prescribed_vel * clength / Re omega = 1.0 / (3.0 * visc + 0.5) - print("omega = ", omega) assert omega < 1.98, "omega must be less than 2.0" os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/cylinder2d.py b/examples/CFD/cylinder2d.py index 8196378..7f2e0da 100644 --- a/examples/CFD/cylinder2d.py +++ b/examples/CFD/cylinder2d.py @@ -23,7 +23,6 @@ from jax.config import config from src.utils import * import numpy as np -from src.lattice import LatticeD2Q9 from src.models import BGKSim, KBCSim import jax.numpy as jnp import os @@ -95,7 +94,6 @@ def output_data(self, **kwargs): precision = 'f64/f64' prescribed_vel = 0.005 diam = 80 - lattice = LatticeD2Q9(precision) nx = int(22*diam) ny = int(4.1*diam) @@ -111,7 +109,6 @@ def output_data(self, **kwargs): os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/oscilating_cylinder2d.py b/examples/CFD/oscilating_cylinder2d.py index 8c7fcc3..200f91a 100644 --- a/examples/CFD/oscilating_cylinder2d.py +++ b/examples/CFD/oscilating_cylinder2d.py @@ -24,7 +24,6 @@ from jax.config import config from src.utils import * import numpy as np -from src.lattice import LatticeD2Q9 from src.models import BGKSim, KBCSim import jax.numpy as jnp import os @@ -118,7 +117,6 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f64/f64' - lattice = LatticeD2Q9(precision) prescribed_vel = 0.005 diam = 20 @@ -129,13 +127,8 @@ def output_data(self, **kwargs): visc = prescribed_vel * diam / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) - print("Mesh size: ", nx, ny) - print("Number of voxels: ", nx * ny) - os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/taylor_green_vortex.py b/examples/CFD/taylor_green_vortex.py index f6142bf..9f6b870 100644 --- a/examples/CFD/taylor_green_vortex.py +++ b/examples/CFD/taylor_green_vortex.py @@ -92,7 +92,6 @@ def output_data(self, **kwargs): visc = vel_ref * nx / Re omega = 1.0 / (3.0 * visc + 0.5) - print("omega = ", omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { 'lattice': lattice, diff --git a/examples/CFD/windtunnel3d.py b/examples/CFD/windtunnel3d.py index 780d04a..79bfebe 100644 --- a/examples/CFD/windtunnel3d.py +++ b/examples/CFD/windtunnel3d.py @@ -18,7 +18,6 @@ from jax.config import config from src.utils import * import numpy as np -from src.lattice import LatticeD3Q19, LatticeD3Q27 from src.models import BGKSim, KBCSim import jax.numpy as jnp import os @@ -109,7 +108,6 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f32/f32' - lattice = LatticeD3Q27(precision) nx = 601 ny = 351 @@ -122,13 +120,9 @@ def output_data(self, **kwargs): visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) - print("Mesh size: ", nx, ny, nz) - print("Number of voxels: ", nx * ny * nz) os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { - 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/src/base.py b/src/base.py index 115168f..298071a 100644 --- a/src/base.py +++ b/src/base.py @@ -21,10 +21,9 @@ from functools import partial # Local/Custom Libraries -# from src.boundary_conditions import * +import src.models from src.utils import downsample_field - - +from src.lattice import LatticeD2Q9, LatticeD3Q19, LatticeD3Q27 jax.config.update("jax_spmd_mode", 'allow_all') # Disables annoying TF warnings @@ -42,26 +41,30 @@ class LBMBase(object): ny (int): Number of grid points in the y-direction. nz (int, optional): Number of grid points in the z-direction. Defaults to 0. precision (str, optional): A string specifying the precision used for the simulation. Defaults to "f32/f32". - optimize (bool, optional): Whether or not to run adjoint optimization (not functional yet). Defaults to False. """ def __init__(self, **kwargs): - # Set the precision for computation and storage - precision = kwargs.get("precision", "f32/f32") - computedType, storedType = self.set_precisions(precision) + self.omega = kwargs.get("omega") + self.nx = kwargs.get("nx") + self.ny = kwargs.get("ny") + self.nz = kwargs.get("nz") + + self.precision = kwargs.get("precision") + computedType, storedType = self.set_precisions(self.precision) self.precisionPolicy = jmp.Policy(compute_dtype=computedType, param_dtype=computedType, output_dtype=storedType) - self.optimize = kwargs.get("optimize", False) + self.lattice = kwargs.get("lattice") self.checkpointRate = kwargs.get("checkpoint_rate", 0) self.checkpointDir = kwargs.get("checkpoint_dir", './checkpoints') self.downsamplingFactor = kwargs.get("downsampling_factor", 1) - self.printInfoRate= kwargs.get("print_info_rate", 100) + self.printInfoRate = kwargs.get("print_info_rate", 100) self.ioRate = kwargs.get("io_rate", 0) self.returnFpost = kwargs.get("return_fpost", False) self.computeMLUPS = kwargs.get("compute_MLUPS", False) self.restore_checkpoint = kwargs.get("restore_checkpoint", False) self.nDevices = jax.device_count() + self.backend = jax.default_backend() if self.computeMLUPS: self.restore_checkpoint = False @@ -72,21 +75,7 @@ def __init__(self, **kwargs): # Check for distributed mode if self.nDevices > jax.local_device_count(): print("WARNING: Running in distributed mode. Make sure that jax.distributed.initialize is called before performing any JAX computations.") - print("XLA backend:", jax.default_backend()) - print("Number of XLA devices available: " + colored(f'{self.nDevices}', 'green')) - self.p_i = np.arange(self.nDevices) - - # Set the lattice and relaxation parameter - lattice = kwargs.get("lattice", None) - if lattice is None: - raise ValueError("lattice must be provided") - - omega = kwargs.get("omega", None) - if omega is None: - raise ValueError("omega must be provided") - - self.lattice = lattice - self.omega = omega + self.c = self.lattice.c self.q = self.lattice.q self.w = self.lattice.w @@ -97,7 +86,6 @@ def __init__(self, **kwargs): mngr_options = orb.CheckpointManagerOptions(save_interval_steps=self.checkpointRate, max_to_keep=1) self.mngr = orb.CheckpointManager(self.checkpointDir, orb.PyTreeCheckpointer(), options=mngr_options) else: - print("WARNING: Checkpointing is disabled for this simulation.") self.mngr = None # Adjust the number of grid points in the x direction, if necessary. @@ -113,14 +101,16 @@ def __init__(self, **kwargs): print("WARNING: nx increased from {} to {} in order to accommodate domain sharding per XLA device.".format(nx, self.nx)) self.ny = ny self.nz = nz + + self.show_simulation_parameters() # Store grid information self.gridInfo = { "nx": self.nx, "ny": self.ny, "nz": self.nz, - "dim": lattice.d, - "lattice": lattice + "dim": self.lattice.d, + "lattice": self.lattice } P = PartitionSpec @@ -130,7 +120,6 @@ def __init__(self, **kwargs): # Define the left permutation self.leftPerm = [((i + 1) % self.nDevices, i) for i in range(self.nDevices)] - # Set up the sharding and streaming for 2D and 3D simulations if self.dim == 2: self.devices = mesh_utils.create_device_mesh((self.nDevices, 1, 1)) @@ -163,6 +152,212 @@ def __init__(self, **kwargs): self._create_boundary_data() self.force = self.get_force() + @property + def lattice(self): + return self._lattice + + @lattice.setter + def lattice(self, value): + if value is None: + if isinstance(self, src.models.BGKSim): + lattice_class = LatticeD2Q9 if self.nz == 0 else LatticeD3Q19 + elif isinstance(self, src.models.KBCSim): + lattice_class = LatticeD2Q9 if self.nz == 0 else LatticeD3Q27 + else: + # Default values for other base classes (e.g., advection-diffusion) + lattice_class = LatticeD2Q9 if self.nz == 0 else LatticeD3Q19 + + value = lattice_class(self._precision) + + self._lattice = value + + @property + def omega(self): + return self._omega + + @omega.setter + def omega(self, value): + if value is None: + raise ValueError("omega must be provided") + if not isinstance(value, float): + raise TypeError("omega must be a float") + self._omega = value + + @property + def nx(self): + return self._nx + + @nx.setter + def nx(self, value): + if value is None: + raise ValueError("nx must be provided") + if not isinstance(value, int): + raise TypeError("nx must be an integer") + self._nx = value + + @property + def ny(self): + return self._ny + + @ny.setter + def ny(self, value): + if value is None: + raise ValueError("ny must be provided") + if not isinstance(value, int): + raise TypeError("ny must be an integer") + self._ny = value + + @property + def nz(self): + return self._nz + + @nz.setter + def nz(self, value): + if value is None: + raise ValueError("nz must be provided") + if not isinstance(value, int): + raise TypeError("nz must be an integer") + self._nz = value + + @property + def precision(self): + return self._precision + + @precision.setter + def precision(self, value): + if not isinstance(value, str): + raise TypeError("precision must be a string") + self._precision = value + + @property + def checkpointRate(self): + return self._checkpointRate + + @checkpointRate.setter + def checkpointRate(self, value): + if not isinstance(value, int): + raise TypeError("checkpointRate must be an integer") + self._checkpointRate = value + + @property + def checkpointDir(self): + return self._checkpointDir + + @checkpointDir.setter + def checkpointDir(self, value): + if not isinstance(value, str): + raise TypeError("checkpointDir must be a string") + self._checkpointDir = value + + @property + def downsamplingFactor(self): + return self._downsamplingFactor + + @downsamplingFactor.setter + def downsamplingFactor(self, value): + if not isinstance(value, int): + raise TypeError("downsamplingFactor must be an integer") + self._downsamplingFactor = value + + @property + def printInfoRate(self): + return self._printInfoRate + + @printInfoRate.setter + def printInfoRate(self, value): + if not isinstance(value, int): + raise TypeError("printInfoRate must be an integer") + self._printInfoRate = value + + @property + def ioRate(self): + return self._ioRate + + @ioRate.setter + def ioRate(self, value): + if not isinstance(value, int): + raise TypeError("ioRate must be an integer") + self._ioRate = value + + @property + def returnFpost(self): + return self._returnFpost + + @returnFpost.setter + def returnFpost(self, value): + if not isinstance(value, bool): + raise TypeError("returnFpost must be a boolean") + self._returnFpost = value + + @property + def computeMLUPS(self): + return self._computeMLUPS + + @computeMLUPS.setter + def computeMLUPS(self, value): + if not isinstance(value, bool): + raise TypeError("computeMLUPS must be a boolean") + self._computeMLUPS = value + + @property + def restore_checkpoint(self): + return self._restore_checkpoint + + @restore_checkpoint.setter + def restore_checkpoint(self, value): + if not isinstance(value, bool): + raise TypeError("restore_checkpoint must be a boolean") + self._restore_checkpoint = value + + @property + def nDevices(self): + return self._nDevices + + @nDevices.setter + def nDevices(self, value): + if not isinstance(value, int): + raise TypeError("nDevices must be an integer") + self._nDevices = value + + def show_simulation_parameters(self): + attributes_to_show = [ + 'omega', 'nx', 'ny', 'nz', 'dim', 'precision', 'lattice', + 'checkpointRate', 'checkpointDir', 'downsamplingFactor', + 'printInfoRate', 'ioRate', 'computeMLUPS', + 'restore_checkpoint', 'backend', 'nDevices' + ] + + descriptive_names = { + 'omega': 'Omega', + 'nx': 'Grid Points in X', + 'ny': 'Grid Points in Y', + 'nz': 'Grid Points in Z', + 'dim': 'Dimensionality', + 'precision': 'Precision Policy', + 'lattice': 'Lattice Type', + 'checkpointRate': 'Checkpoint Rate', + 'checkpointDir': 'Checkpoint Directory', + 'downsamplingFactor': 'Downsampling Factor', + 'printInfoRate': 'Print Info Rate', + 'ioRate': 'I/O Rate', + 'computeMLUPS': 'Compute MLUPS', + 'restore_checkpoint': 'Restore Checkpoint', + 'backend': 'Backend', + 'nDevices': 'Number of Devices' + } + simulation_name = self.__class__.__name__ + + print(colored(f'**** Simulation Parameters for {simulation_name} ****', 'green')) + + header = f"{colored('Parameter', 'blue'):>30} | {colored('Value', 'yellow')}" + print(header) + print('-' * 50) + + for attr in attributes_to_show: + value = getattr(self, attr, 'Attribute not set') + descriptive_name = descriptive_names.get(attr, attr) # Use the attribute name as a fallback + row = f"{colored(descriptive_name, 'blue'):>30} | {colored(value, 'yellow')}" + print(row) def _create_boundary_data(self): """ diff --git a/src/lattice.py b/src/lattice.py index f3c4107..788796b 100644 --- a/src/lattice.py +++ b/src/lattice.py @@ -202,7 +202,9 @@ def construct_lattice_moment(self): cntr += 1 return cc - + + def __str__(self): + return self.name class LatticeD2Q9(Lattice): """ From 9caa97b963b48fa05e297cf6105480455be5c905 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Mon, 16 Oct 2023 15:05:54 -0400 Subject: [PATCH 5/6] Added Lattice import to all the examples and changed the property setter errors, as well as fixed a bug in the distributed processing. Removed portpicker as a dependency as it can't be used for multi-process computations. --- examples/CFD/airfoil3d.py | 4 +++ examples/CFD/cavity2d.py | 10 +++++--- examples/CFD/cavity3d.py | 11 +++++--- examples/CFD/channel3d.py | 1 + examples/CFD/couette2d.py | 11 +++++--- examples/CFD/cylinder2d.py | 16 +++++++----- examples/CFD/oscilating_cylinder2d.py | 15 ++++++----- examples/CFD/taylor_green_vortex.py | 16 ++++++------ examples/CFD/windtunnel3d.py | 16 +++++++----- examples/performance/MLUPS2d.py | 7 ++++-- examples/performance/MLUPS3d.py | 16 ++++++------ examples/performance/MLUPS3d_distributed.py | 28 ++++++++++++--------- requirements.txt | 1 - src/base.py | 16 +++++------- 14 files changed, 102 insertions(+), 66 deletions(-) diff --git a/examples/CFD/airfoil3d.py b/examples/CFD/airfoil3d.py index db538f6..c33879a 100644 --- a/examples/CFD/airfoil3d.py +++ b/examples/CFD/airfoil3d.py @@ -31,6 +31,7 @@ # from IPython import display import matplotlib.pylab as plt from src.models import BGKSim, KBCSim +from src.lattice import LatticeD3Q19, LatticeD3Q27 from src.boundary_conditions import * import numpy as np from src.utils import * @@ -106,6 +107,8 @@ def output_data(self, **kwargs): airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T precision = 'f32/f32' + lattice = LatticeD3Q27(precision) + nx = airfoil.shape[0] ny = airfoil.shape[1] @@ -124,6 +127,7 @@ def output_data(self, **kwargs): # Set the parameters for the simulation kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/cavity2d.py b/examples/CFD/cavity2d.py index 4aaa384..5692027 100644 --- a/examples/CFD/cavity2d.py +++ b/examples/CFD/cavity2d.py @@ -16,14 +16,16 @@ 4. Visualization: The simulation outputs data in VTK format for visualization. It also provides images of the velocity field and saves the boundary conditions at each time step. The data can be visualized using software like Paraview. """ -from src.boundary_conditions import * from jax.config import config -from src.utils import * import numpy as np -from src.models import BGKSim, KBCSim import jax.numpy as jnp import os +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD2Q9 +from src.utils import * + # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' @@ -59,6 +61,7 @@ def output_data(self, **kwargs): if __name__ == "__main__": precision = "f32/f32" + lattice = LatticeD2Q9(precision) nx = 200 ny = 200 @@ -76,6 +79,7 @@ def output_data(self, **kwargs): os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/cavity3d.py b/examples/CFD/cavity3d.py index 4137290..58db262 100644 --- a/examples/CFD/cavity3d.py +++ b/examples/CFD/cavity3d.py @@ -19,13 +19,14 @@ # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -from src.models import BGKSim, KBCSim + import numpy as np from src.utils import * from jax.config import config -from src.boundary_conditions import * -precision = 'f32/f32' +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD3Q19, LatticeD3Q27 +from src.boundary_conditions import * class Cavity(KBCSim): def __init__(self, **kwargs): @@ -75,12 +76,16 @@ def output_data(self, **kwargs): prescribed_vel = 0.1 clength = nx - 1 + precision = 'f32/f32' + lattice = LatticeD3Q27(precision) + visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/channel3d.py b/examples/CFD/channel3d.py index c9a4249..e1a0cec 100644 --- a/examples/CFD/channel3d.py +++ b/examples/CFD/channel3d.py @@ -144,6 +144,7 @@ def output_data(self, **kwargs): os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/couette2d.py b/examples/CFD/couette2d.py index 88316d2..05b60c5 100644 --- a/examples/CFD/couette2d.py +++ b/examples/CFD/couette2d.py @@ -2,13 +2,16 @@ This script performs a 2D simulation of Couette flow using the lattice Boltzmann method (LBM). """ -from src.models import BGKSim -from src.boundary_conditions import * +import os import jax.numpy as jnp import numpy as np from src.utils import * from jax.config import config -import os + + +from src.models import BGKSim +from src.boundary_conditions import * +from src.lattice import LatticeD2Q9 # config.update('jax_disable_jit', True) # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' @@ -48,6 +51,7 @@ def output_data(self, **kwargs): if __name__ == "__main__": precision = "f32/f32" + lattice = LatticeD2Q9(precision) nx = 501 ny = 101 @@ -62,6 +66,7 @@ def output_data(self, **kwargs): os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/cylinder2d.py b/examples/CFD/cylinder2d.py index 7f2e0da..9fa9779 100644 --- a/examples/CFD/cylinder2d.py +++ b/examples/CFD/cylinder2d.py @@ -17,19 +17,20 @@ 5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView. """ - +import os +import jax from time import time -from src.boundary_conditions import * from jax.config import config -from src.utils import * import numpy as np -from src.models import BGKSim, KBCSim import jax.numpy as jnp -import os + +from src.utils import * +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD2Q9 # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax jax.config.update('jax_enable_x64', True) class Cylinder(KBCSim): @@ -92,6 +93,8 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f64/f64' + lattice = LatticeD2Q9(precision) + prescribed_vel = 0.005 diam = 80 @@ -109,6 +112,7 @@ def output_data(self, **kwargs): os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/oscilating_cylinder2d.py b/examples/CFD/oscilating_cylinder2d.py index 200f91a..f6db4d4 100644 --- a/examples/CFD/oscilating_cylinder2d.py +++ b/examples/CFD/oscilating_cylinder2d.py @@ -19,18 +19,20 @@ """ +import os +import jax from time import time -from src.boundary_conditions import * from jax.config import config -from src.utils import * import numpy as np -from src.models import BGKSim, KBCSim import jax.numpy as jnp -import os + +from src.utils import * +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD2Q9 # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax jax.config.update('jax_enable_x64', True) class Cylinder(KBCSim): @@ -117,7 +119,7 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f64/f64' - + lattice = LatticeD2Q9(precision) prescribed_vel = 0.005 diam = 20 nx = int(22*diam) @@ -129,6 +131,7 @@ def output_data(self, **kwargs): os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/CFD/taylor_green_vortex.py b/examples/CFD/taylor_green_vortex.py index e593e6b..e52b7d3 100644 --- a/examples/CFD/taylor_green_vortex.py +++ b/examples/CFD/taylor_green_vortex.py @@ -5,18 +5,20 @@ """ -from src.boundary_conditions import * -from src.utils import * -import numpy as np -from src.lattice import LatticeD2Q9 -from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK import os -import matplotlib.pyplot as plt import json +import jax +import numpy as np +import matplotlib.pyplot as plt + +from src.utils import * +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK +from src.lattice import LatticeD2Q9 + # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax # disable JIt compilation jax.config.update('jax_enable_x64', True) diff --git a/examples/CFD/windtunnel3d.py b/examples/CFD/windtunnel3d.py index 79bfebe..1c78951 100644 --- a/examples/CFD/windtunnel3d.py +++ b/examples/CFD/windtunnel3d.py @@ -12,19 +12,21 @@ """ -from time import time +import os +import jax import trimesh -from src.boundary_conditions import * +from time import time +import numpy as np +import jax.numpy as jnp from jax.config import config + from src.utils import * -import numpy as np from src.models import BGKSim, KBCSim -import jax.numpy as jnp -import os +from src.lattice import LatticeD3Q19, LatticeD3Q27 +from src.boundary_conditions import * # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax # disable JIt compilation @@ -108,6 +110,7 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f32/f32' + lattice = LatticeD3Q19(precision) nx = 601 ny = 351 @@ -123,6 +126,7 @@ def output_data(self, **kwargs): os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': nx, 'ny': ny, diff --git a/examples/performance/MLUPS2d.py b/examples/performance/MLUPS2d.py index ec7dd48..5eb7e86 100644 --- a/examples/performance/MLUPS2d.py +++ b/examples/performance/MLUPS2d.py @@ -3,14 +3,15 @@ """ import os +import argparse import jax.numpy as jnp import numpy as np -from src.utils import * from jax.config import config from time import time -import argparse +from src.utils import * from src.boundary_conditions import * +from src.lattice import LatticeD2Q9 from src.models import BGKSim class Cavity(BGKSim): @@ -34,6 +35,7 @@ def set_boundary_conditions(self): if __name__ == '__main__': precision = 'f32/f32' + lattice = LatticeD2Q9(precision) parser = argparse.ArgumentParser("simple_example") parser.add_argument("N", help="The total number of voxels will be NxN", type=int) @@ -51,6 +53,7 @@ def set_boundary_conditions(self): print('omega = ', omega) kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': n, 'ny': n, diff --git a/examples/performance/MLUPS3d.py b/examples/performance/MLUPS3d.py index 6492928..164afe6 100644 --- a/examples/performance/MLUPS3d.py +++ b/examples/performance/MLUPS3d.py @@ -2,21 +2,22 @@ This script computes the MLUPS (Million Lattice Updates per Second) in 3D by simulating fluid flow inside a 2D cavity. """ -from src.models import BGKSim +import os +import argparse + +import jax import jax.numpy as jnp import numpy as np -from src.utils import * from jax.config import config -import os from time import time -import argparse -import jax #config.update('jax_disable_jit', True) # Use 8 CPU devices #os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' #config.update("jax_enable_x64", True) +from src.utils import * from src.boundary_conditions import * - +from src.models import BGKSim +from src.lattice import LatticeD3Q19 class Cavity(BGKSim): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -37,7 +38,7 @@ def set_boundary_conditions(self): if __name__ == '__main__': precision = 'f32/f32' - + lattice = LatticeD3Q19(precision) # Create a parser that will read the command line arguments parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation") parser.add_argument("N", help="The total number of voxels all directions. The final dimension will be N*NxN", default=100, type=int) @@ -60,6 +61,7 @@ def set_boundary_conditions(self): omega = 1.0 / (3. * visc + 0.5) kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': n, 'ny': n, diff --git a/examples/performance/MLUPS3d_distributed.py b/examples/performance/MLUPS3d_distributed.py index 36d30bf..4a418db 100644 --- a/examples/performance/MLUPS3d_distributed.py +++ b/examples/performance/MLUPS3d_distributed.py @@ -8,10 +8,19 @@ # Standard Libraries import argparse import os -from time import time -import portpicker - import jax +# Initialize JAX distributed. The IP, number of processes and process id must be updated. +# Currently set on local host for testing purposes. +# Can be tested on a two GPU system as follows: +# (export PYTHONPATH=.; CUDA_VISIBLE_DEVICES=0 python3 examples/performance/MLUPS3d_distributed.py 100 100 & CUDA_VISIBLE_DEVICES=1 python3 examples/performance/MLUPS3d_distributed.py 100 100 &) +#IMPORTANT: jax distributed must be initialized before any jax computation is performed +jax.distributed.initialize(f'127.0.0.1:1234', 2, process_id=int(os.environ['CUDA_VISIBLE_DEVICES'])) + +print('Process id: ', jax.process_index()) +print('Number of total devices (over all processes): ', jax.device_count()) +print('Number of local devices:', jax.local_device_count()) + + import jax.numpy as jnp import numpy as np @@ -19,6 +28,7 @@ from src.boundary_conditions import * from src.models import BGKSim +from src.lattice import LatticeD3Q19 from src.utils import * #config.update('jax_disable_jit', True) @@ -26,8 +36,6 @@ #os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' #config.update("jax_enable_x64", True) -precision = 'f32/f32' - class Cavity(BGKSim): def __init__(self, **kwargs): @@ -48,13 +56,8 @@ def set_boundary_conditions(self): self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) if __name__ == '__main__': - - # Initialize JAX distributed. The IP, number of processes and process id must be updated. - # Currently set on local host for testing purposes. - # Can be tested with - # (export PYTHONPATH=.; CUDA_VISIBLE_DEVICES=0 python3 examples/performance/MLUPS3d_distributed.py 100 100 & CUDA_VISIBLE_DEVICES=1 python3 examples/performance/MLUPS3d_distributed.py 100 100 &) - port = portpicker.pick_unused_port() - jax.distributed.initialize(f'127.0.0.1:1234', 2, int(os.environ['CUDA_VISIBLE_DEVICES'])) + precision = 'f32/f32' + lattice = LatticeD3Q19(precision) # Create a parser that will read the command line arguments parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation") @@ -79,6 +82,7 @@ def set_boundary_conditions(self): # Create a new instance of the Cavity class kwargs = { + 'lattice': lattice, 'omega': omega, 'nx': n, 'ny': n, diff --git a/requirements.txt b/requirements.txt index 0794ff7..bc453d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,4 @@ pyvista==0.42.3 Rtree==1.0.1 trimesh==4.0.0 orbax-checkpoint==0.4.1 -portpicker===1.6.0 termcolor==2.3.0 \ No newline at end of file diff --git a/src/base.py b/src/base.py index 298071a..9cabda7 100644 --- a/src/base.py +++ b/src/base.py @@ -159,16 +159,12 @@ def lattice(self): @lattice.setter def lattice(self, value): if value is None: - if isinstance(self, src.models.BGKSim): - lattice_class = LatticeD2Q9 if self.nz == 0 else LatticeD3Q19 - elif isinstance(self, src.models.KBCSim): - lattice_class = LatticeD2Q9 if self.nz == 0 else LatticeD3Q27 - else: - # Default values for other base classes (e.g., advection-diffusion) - lattice_class = LatticeD2Q9 if self.nz == 0 else LatticeD3Q19 - - value = lattice_class(self._precision) - + raise ValueError("Lattice type must be provided.") + if self.nz == 0 and not isinstance(value, LatticeD2Q9): + raise ValueError("For 2D simulations, lattice type must be LatticeD2Q9.") + if self.nz != 0 and isinstance(self, src.models.KBCSim) and not isinstance(value, LatticeD3Q27): + raise ValueError("For 3D KBC simulations, lattice type must be LatticeD3Q19,") + self._lattice = value @property From f001d394f3a492caf4c39364bbb3adfce01df889 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Wed, 18 Oct 2023 02:19:51 -0400 Subject: [PATCH 6/6] Fixed the issues and better handle the lattice cases --- examples/CFD/windtunnel3d.py | 2 +- src/base.py | 7 +++---- src/models.py | 3 ++- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/CFD/windtunnel3d.py b/examples/CFD/windtunnel3d.py index 1c78951..2f94b60 100644 --- a/examples/CFD/windtunnel3d.py +++ b/examples/CFD/windtunnel3d.py @@ -110,7 +110,7 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f32/f32' - lattice = LatticeD3Q19(precision) + lattice = LatticeD3Q27(precision) nx = 601 ny = 351 diff --git a/src/base.py b/src/base.py index 9cabda7..b359b6a 100644 --- a/src/base.py +++ b/src/base.py @@ -23,7 +23,6 @@ # Local/Custom Libraries import src.models from src.utils import downsample_field -from src.lattice import LatticeD2Q9, LatticeD3Q19, LatticeD3Q27 jax.config.update("jax_spmd_mode", 'allow_all') # Disables annoying TF warnings @@ -160,10 +159,10 @@ def lattice(self): def lattice(self, value): if value is None: raise ValueError("Lattice type must be provided.") - if self.nz == 0 and not isinstance(value, LatticeD2Q9): + if self.nz == 0 and value.name not in ['D2Q9']: raise ValueError("For 2D simulations, lattice type must be LatticeD2Q9.") - if self.nz != 0 and isinstance(self, src.models.KBCSim) and not isinstance(value, LatticeD3Q27): - raise ValueError("For 3D KBC simulations, lattice type must be LatticeD3Q19,") + if self.nz != 0 and value.name not in ['D3Q19', 'D3Q27']: + raise ValueError("For 3D simulations, lattice type must be LatticeD3Q19, or LatticeD3Q27.") self._lattice = value diff --git a/src/models.py b/src/models.py index e9a5d81..a7af7b6 100644 --- a/src/models.py +++ b/src/models.py @@ -39,8 +39,9 @@ class KBCSim(LBMBase): This class implements the Karlin-Bösch-Chikatamarla (KBC) model for the collision step in the Lattice Boltzmann Method. """ - def __init__(self, **kwargs): + if kwargs.get('lattice').name != 'D3Q27' and kwargs.get('nz') > 0: + raise ValueError("KBC collision operator in 3D must only be used with D3Q27 lattice.") super().__init__(**kwargs) @partial(jit, static_argnums=(0,), donate_argnums=(1,))