Skip to content

Commit

Permalink
Merge pull request #18 from mehdiataei/main
Browse files Browse the repository at this point in the history
Refactoring in the base class and conditions for setters
  • Loading branch information
mehdiataei authored Oct 18, 2023
2 parents 3cd7fd0 + f001d39 commit 2180c56
Show file tree
Hide file tree
Showing 18 changed files with 396 additions and 214 deletions.
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,22 @@ The following examples showcase the capabilities of XLB:

To use XLB, you must first install JAX and other dependencies using the following commands:

```bash
# Please refer to https://github.com/google/jax for the latest installation documentation

pip install --upgrade pip

# For CPU run
pip install --upgrade "jax[cpu]"
Please refer to https://github.com/google/jax for the latest installation documentation. The following table is taken from [JAX's Github page](https://github.com/google/jax).

# For GPU run
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

# CUDA 12 and cuDNN 8.8 or newer.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**Note:** We encountered challenges when executing XLB on Apple GPUs due to the lack of support for certain operations in the Metal backend. We advise using the CPU backend on Mac OS. We will be testing XLB on Apple's GPUs in the future and will update this section accordingly.

# CUDA 11 and cuDNN 8.6 or newer.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Run dependencies
Install dependencies:
```bash
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp
```

Expand All @@ -118,6 +117,4 @@ export PYTHONPATH=.
python3 examples/cavity2d.py
```
## Citing XLB
Accompanying publication coming soon:

**M. Ataei, H. Salehipour**. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA
Accompanying paper will be available soon.
10 changes: 3 additions & 7 deletions examples/CFD/airfoil3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
# from IPython import display
import matplotlib.pylab as plt
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *
from src.lattice import *
import numpy as np
from src.utils import *
from jax.config import config
Expand Down Expand Up @@ -105,15 +105,13 @@ def output_data(self, **kwargs):
airfoil_thickness = 30
airfoil_angle = 20
airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T

precision = 'f32/f32'
lattice = LatticeD3Q27(precision=precision)

lattice = LatticeD3Q27(precision)

nx = airfoil.shape[0]
ny = airfoil.shape[1]

print("airfoil shape: ", airfoil.shape)

ny = 3 * ny
nx = 4 * nx
nz = 101
Expand All @@ -124,7 +122,6 @@ def output_data(self, **kwargs):

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

Expand All @@ -141,5 +138,4 @@ def output_data(self, **kwargs):
}

sim = Airfoil(**kwargs)
print('Domain size: ', sim.nx, sim.ny, sim.nz)
sim.run(20000)
13 changes: 6 additions & 7 deletions examples/CFD/cavity2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
4. Visualization: The simulation outputs data in VTK format for visualization. It also provides images of the velocity field and saves the boundary conditions at each time step. The data can be visualized using software like Paraview.
"""
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9
from src.utils import *

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

class Cavity(KBCSim):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -71,11 +71,10 @@ def output_data(self, **kwargs):
clength = nx - 1

checkpoint_rate = 1000
checkpoint_dir = "./checkpoints"
checkpoint_dir = os.path.abspath("./checkpoints")

visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
14 changes: 7 additions & 7 deletions examples/CFD/cavity3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q27

import numpy as np
from src.utils import *
from jax.config import config
from src.boundary_conditions import *

precision = 'f32/f32'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *

class Cavity(KBCSim):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -68,8 +68,6 @@ def output_data(self, **kwargs):
# live_volume_randering(timestep, u_mag)

if __name__ == '__main__':
lattice = LatticeD3Q27(precision)

nx = 101
ny = 101
nz = 101
Expand All @@ -78,9 +76,11 @@ def output_data(self, **kwargs):
prescribed_vel = 0.1
clength = nx - 1

precision = 'f32/f32'
lattice = LatticeD3Q27(precision)

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
5 changes: 2 additions & 3 deletions examples/CFD/channel3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_dns_data():
}
return dns_dic

class turbulentChannel(KBCSim):
class TurbulentChannel(KBCSim):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -68,7 +68,7 @@ def set_boundary_conditions(self):
def initialize_macroscopic_fields(self):
rho = self.precisionPolicy.cast_to_output(1.0)
u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim),
self.precisionPolicy.compute_dtype, initVal=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
u = self.precisionPolicy.cast_to_output(u)
return rho, u

Expand Down Expand Up @@ -141,7 +141,6 @@ def output_data(self, **kwargs):
zz = np.minimum(zz, zz.max() - zz)
yplus = zz * u_tau / visc

print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
Expand Down
11 changes: 6 additions & 5 deletions examples/CFD/couette2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
This script performs a 2D simulation of Couette flow using the lattice Boltzmann method (LBM).
"""

from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9
import os
import jax.numpy as jnp
import numpy as np
from src.utils import *
from jax.config import config
import os


from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9

# config.update('jax_disable_jit', True)
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
Expand Down Expand Up @@ -60,7 +62,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * clength / Re

omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
assert omega < 1.98, "omega must be less than 2.0"
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
17 changes: 9 additions & 8 deletions examples/CFD/cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView.
"""

import os
import jax
from time import time
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update('jax_enable_x64', True)

class Cylinder(KBCSim):
Expand Down Expand Up @@ -93,9 +93,10 @@ def output_data(self, **kwargs):

if __name__ == '__main__':
precision = 'f64/f64'
lattice = LatticeD2Q9(precision)

prescribed_vel = 0.005
diam = 80
lattice = LatticeD2Q9(precision)

nx = int(22*diam)
ny = int(4.1*diam)
Expand Down
18 changes: 7 additions & 11 deletions examples/CFD/oscilating_cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@
"""


import os
import jax
from time import time
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update('jax_enable_x64', True)

class Cylinder(KBCSim):
Expand Down Expand Up @@ -119,7 +120,6 @@ def output_data(self, **kwargs):
if __name__ == '__main__':
precision = 'f64/f64'
lattice = LatticeD2Q9(precision)

prescribed_vel = 0.005
diam = 20
nx = int(22*diam)
Expand All @@ -129,10 +129,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * diam / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny)
print("Number of voxels: ", nx * ny)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')
kwargs = {
'lattice': lattice,
Expand Down
21 changes: 11 additions & 10 deletions examples/CFD/taylor_green_vortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
"""


from src.boundary_conditions import *
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK
import os
import matplotlib.pyplot as plt
import json
import jax
import numpy as np
import matplotlib.pyplot as plt

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK
from src.lattice import LatticeD2Q9


# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
# disable JIt compilation

jax.config.update('jax_enable_x64', True)
Expand All @@ -37,9 +39,9 @@ def set_boundary_conditions(self):

def initialize_macroscopic_fields(self):
ux, uy, rho = taylor_green_initial_fields(xx, yy, vel_ref, 1, 0., 0.)
rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, initVal=1.0, sharding=self.sharding)
rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, init_val=1.0, sharding=self.sharding)
u = np.stack([ux, uy], axis=-1)
u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, initVal=u, sharding=self.sharding)
u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, init_val=u, sharding=self.sharding)
return rho, u

def initialize_populations(self, rho, u):
Expand Down Expand Up @@ -95,7 +97,6 @@ def output_data(self, **kwargs):

visc = vel_ref * nx / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")
kwargs = {
'lattice': lattice,
Expand Down
18 changes: 8 additions & 10 deletions examples/CFD/windtunnel3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@
"""


from time import time
import os
import jax
import trimesh
from src.boundary_conditions import *
from time import time
import numpy as np
import jax.numpy as jnp
from jax.config import config

from src.utils import *
import numpy as np
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

# disable JIt compilation

Expand Down Expand Up @@ -122,9 +123,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny, nz)
print("Number of voxels: ", nx * ny * nz)
os.system('rm -rf ./*.vtk && rm -rf ./*.png')

kwargs = {
Expand Down
Loading

0 comments on commit 2180c56

Please sign in to comment.