diff --git a/examples/CFD/cavity2d.py b/examples/CFD/cavity2d.py index 5692027..6673390 100644 --- a/examples/CFD/cavity2d.py +++ b/examples/CFD/cavity2d.py @@ -89,7 +89,7 @@ def output_data(self, **kwargs): 'print_info_rate': 100, 'checkpoint_rate': checkpoint_rate, 'checkpoint_dir': checkpoint_dir, - 'restore_checkpoint': True, + 'restore_checkpoint': False, } sim = Cavity(**kwargs) diff --git a/src/base.py b/src/base.py index b359b6a..fd7d32e 100644 --- a/src/base.py +++ b/src/base.py @@ -128,9 +128,6 @@ def __init__(self, **kwargs): self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh, in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False)) - self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh, - in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False)) - # Set up the sharding and streaming for 2D and 3D simulations elif self.dim == 3: self.devices = mesh_utils.create_device_mesh((self.nDevices, 1, 1, 1)) @@ -139,9 +136,7 @@ def __init__(self, **kwargs): self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh, in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False)) - - self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh, - in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False)) + else: raise ValueError(f"dim = {self.dim} not supported") @@ -440,7 +435,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels): solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True) - connectivity_bitmask = self.compute_bitmask(connectivity_bitmask) + connectivity_bitmask = self.streaming(connectivity_bitmask) return lax.with_sharding_constraint(connectivity_bitmask, self.sharding) elif self.dim == 3: @@ -451,7 +446,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels): solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) solid_halo_voxels = solid_halo_voxels.at[:, 2].add(hw_z) connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True) - connectivity_bitmask = self.compute_bitmask(connectivity_bitmask) + connectivity_bitmask = self.streaming(connectivity_bitmask) return lax.with_sharding_constraint(connectivity_bitmask, self.sharding) def bounding_box_indices(self): @@ -687,88 +682,6 @@ def streaming_i(f, c): return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) return vmap(streaming_i, in_axes=(-1, 0), out_axes=-1)(f, self.c.T) - - def compute_bitmask_m(self, b): - """ - This function computes a bitmask for each direction in the lattice. The bitmask is used to - determine which nodes are fluid nodes and which are boundary nodes. - - To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the - distribution functions that need to be communicated to the neighboring processes. - - The function then sends the left boundary slice to the right neighboring process and the right - boundary slice to the left neighboring process. The received data is then set to the - corresponding indices in the receiving domain. - - Parameters - ---------- - b: jax.numpy.ndarray - The array holding the bitmasks for the simulation. - - Returns - ------- - jax.numpy.ndarray - The bitmasks after the streaming operation. - """ - b = self.compute_bitmask_p(b) - left_comm, right_comm = b[:1, ..., self.lattice.right_indices], b[-1:, ..., self.lattice.left_indices] - - left_comm, right_comm = self.send_right(left_comm, 'x'), self.send_left(right_comm, 'x') - b = b.at[:1, ..., self.lattice.right_indices].set(left_comm) - b = b.at[-1:, ..., self.lattice.left_indices].set(right_comm) - return b - - def compute_bitmask_p(self, b): - """ - This function computes a bitmask for each direction in the lattice. The bitmask is used to - determine which nodes are fluid nodes and which are boundary nodes. - - It does this by rolling the input bitmask (b) in the opposite direction of each lattice - direction. The rolling operation shifts the values of the bitmask along the specified axes. - - The function uses the vmap operation provided by the JAX library to vectorize the computation - over all lattice directions. - - Parameters - ---------- - b: ndarray - The input bitmask. - - Returns - ------- - jax.numpy.ndarray - The computed bitmask for each direction in the lattice. - """ - def compute_bitmask_i(b, i): - """ - This function computes the bitmask for a specific direction in the lattice. - - It does this by rolling the input bitmask (b) in the opposite direction of the specified - lattice direction. The rolling operation shifts the values of the bitmask along the - specified axes. - - Parameters - ---------- - b: jax.numpy.ndarray - The input bitmask. - i: int - The index of the lattice direction. - - Returns - ------- - jax.numpy.ndarray - The computed bitmask for the specified direction in the lattice. - """ - if self.dim == 2: - rolls = (self.c.T[i, 0], self.c.T[i, 1]) - axes = (0, 1) - return jnp.roll(b[..., self.lattice.opp_indices[i]], rolls, axes) - elif self.dim == 3: - rolls = (self.c.T[i, 0], self.c.T[i, 1], self.c.T[i, 2]) - axes = (0, 1, 2) - return jnp.roll(b[..., self.lattice.opp_indices[i]], rolls, axes) - - return vmap(compute_bitmask_i, in_axes=(None, 0), out_axes=-1)(b, self.lattice.i_s) @partial(jit, static_argnums=(0, 3), inline=True) def equilibrium(self, rho, u, cast_output=True):