Skip to content

Commit

Permalink
Merge pull request #22 from hsalehipour/main
Browse files Browse the repository at this point in the history
compute_bitmask was redundant.
  • Loading branch information
mehdiataei authored Nov 20, 2023
2 parents 30f6def + c62a420 commit 7b0555c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 91 deletions.
2 changes: 1 addition & 1 deletion examples/CFD/cavity2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def output_data(self, **kwargs):
'print_info_rate': 100,
'checkpoint_rate': checkpoint_rate,
'checkpoint_dir': checkpoint_dir,
'restore_checkpoint': True,
'restore_checkpoint': False,
}

sim = Cavity(**kwargs)
Expand Down
93 changes: 3 additions & 90 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def __init__(self, **kwargs):
self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh,
in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False))

self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh,
in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False))

# Set up the sharding and streaming for 2D and 3D simulations
elif self.dim == 3:
self.devices = mesh_utils.create_device_mesh((self.nDevices, 1, 1, 1))
Expand All @@ -139,9 +136,7 @@ def __init__(self, **kwargs):

self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh,
in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False))

self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh,
in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False))

else:
raise ValueError(f"dim = {self.dim} not supported")

Expand Down Expand Up @@ -440,7 +435,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels):
solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y)
connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True)

connectivity_bitmask = self.compute_bitmask(connectivity_bitmask)
connectivity_bitmask = self.streaming(connectivity_bitmask)
return lax.with_sharding_constraint(connectivity_bitmask, self.sharding)

elif self.dim == 3:
Expand All @@ -451,7 +446,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels):
solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y)
solid_halo_voxels = solid_halo_voxels.at[:, 2].add(hw_z)
connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True)
connectivity_bitmask = self.compute_bitmask(connectivity_bitmask)
connectivity_bitmask = self.streaming(connectivity_bitmask)
return lax.with_sharding_constraint(connectivity_bitmask, self.sharding)

def bounding_box_indices(self):
Expand Down Expand Up @@ -687,88 +682,6 @@ def streaming_i(f, c):
return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2))

return vmap(streaming_i, in_axes=(-1, 0), out_axes=-1)(f, self.c.T)

def compute_bitmask_m(self, b):
"""
This function computes a bitmask for each direction in the lattice. The bitmask is used to
determine which nodes are fluid nodes and which are boundary nodes.
To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the
distribution functions that need to be communicated to the neighboring processes.
The function then sends the left boundary slice to the right neighboring process and the right
boundary slice to the left neighboring process. The received data is then set to the
corresponding indices in the receiving domain.
Parameters
----------
b: jax.numpy.ndarray
The array holding the bitmasks for the simulation.
Returns
-------
jax.numpy.ndarray
The bitmasks after the streaming operation.
"""
b = self.compute_bitmask_p(b)
left_comm, right_comm = b[:1, ..., self.lattice.right_indices], b[-1:, ..., self.lattice.left_indices]

left_comm, right_comm = self.send_right(left_comm, 'x'), self.send_left(right_comm, 'x')
b = b.at[:1, ..., self.lattice.right_indices].set(left_comm)
b = b.at[-1:, ..., self.lattice.left_indices].set(right_comm)
return b

def compute_bitmask_p(self, b):
"""
This function computes a bitmask for each direction in the lattice. The bitmask is used to
determine which nodes are fluid nodes and which are boundary nodes.
It does this by rolling the input bitmask (b) in the opposite direction of each lattice
direction. The rolling operation shifts the values of the bitmask along the specified axes.
The function uses the vmap operation provided by the JAX library to vectorize the computation
over all lattice directions.
Parameters
----------
b: ndarray
The input bitmask.
Returns
-------
jax.numpy.ndarray
The computed bitmask for each direction in the lattice.
"""
def compute_bitmask_i(b, i):
"""
This function computes the bitmask for a specific direction in the lattice.
It does this by rolling the input bitmask (b) in the opposite direction of the specified
lattice direction. The rolling operation shifts the values of the bitmask along the
specified axes.
Parameters
----------
b: jax.numpy.ndarray
The input bitmask.
i: int
The index of the lattice direction.
Returns
-------
jax.numpy.ndarray
The computed bitmask for the specified direction in the lattice.
"""
if self.dim == 2:
rolls = (self.c.T[i, 0], self.c.T[i, 1])
axes = (0, 1)
return jnp.roll(b[..., self.lattice.opp_indices[i]], rolls, axes)
elif self.dim == 3:
rolls = (self.c.T[i, 0], self.c.T[i, 1], self.c.T[i, 2])
axes = (0, 1, 2)
return jnp.roll(b[..., self.lattice.opp_indices[i]], rolls, axes)

return vmap(compute_bitmask_i, in_axes=(None, 0), out_axes=-1)(b, self.lattice.i_s)

@partial(jit, static_argnums=(0, 3), inline=True)
def equilibrium(self, rho, u, cast_output=True):
Expand Down

0 comments on commit 7b0555c

Please sign in to comment.