Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple way to initialize from the command line. #20

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions examples/performance/MLUPS3d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,6 @@
import argparse
import os
import jax
# Initialize JAX distributed. The IP, number of processes and process id must be updated.
# Currently set on local host for testing purposes.
# Can be tested on a two GPU system as follows:
# (export PYTHONPATH=.; CUDA_VISIBLE_DEVICES=0 python3 examples/performance/MLUPS3d_distributed.py 100 100 & CUDA_VISIBLE_DEVICES=1 python3 examples/performance/MLUPS3d_distributed.py 100 100 &)
#IMPORTANT: jax distributed must be initialized before any jax computation is performed
jax.distributed.initialize(f'127.0.0.1:1234', 2, process_id=int(os.environ['CUDA_VISIBLE_DEVICES']))

print('Process id: ', jax.process_index())
print('Number of total devices (over all processes): ', jax.device_count())
print('Number of local devices:', jax.local_device_count())


import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -56,17 +45,41 @@ def set_boundary_conditions(self):
self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall))

if __name__ == '__main__':
precision = 'f32/f32'
lattice = LatticeD3Q19(precision)

# Create a parser that will read the command line arguments
parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation")
parser.add_argument("N", help="The total number of voxels in one direction. The final dimension will be N*NxN", default=100, type=int)
parser.add_argument("N_ITERS", help="Number of timesteps", default=10000, type=int)
parser.add_argument("N", help="The total number of voxels in one direction. The final dimension will be N*NxN",
default=100, type=int)
parser.add_argument("N_ITERS", help="Number of iterations", default=10000, type=int)
parser.add_argument("N_PROCESSES", help="Number of processes. If >1, call jax.distributed.initialize with that number of process. If -1 will call jax.distributed.initialize without any arsgument. So it should pick up the values from SLURM env variable.",
default=1, type=int)
parser.add_argument("IP", help="IP of the master node for multi-node. Useless if using SLURM.",
default='127.0.0.1', type=str, nargs='?')
parser.add_argument("PROCESS_ID_INCREMENT", help="For multi-node only. Useless if using SLURM.",
default=0, type=int, nargs='?')

args = parser.parse_args()
n = args.N
n_iters = args.N_ITERS
n_processes = args.N_PROCESSES
# Initialize JAX distributed. The IP, number of processes and process id must be set correctly.
print("N processes, ", n_processes)
print("N iter, ", n_iters)
if n_processes > 1:
process_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0)) + args.PROCESS_ID_INCREMENT
print("ip, num_processes, process_id, ", args.IP, n_processes, process_id)
jax.distributed.initialize(args.IP, num_processes=n_processes,
process_id=process_id)
elif n_processes == -1:
print("Will call jax.distributed.initialize()")
jax.distributed.initialize()
print("jax.distributed.initialize() ended")
else:
print("No call to jax.distributed.initialize")
print("JAX local devices: ", jax.local_devices())

precision = 'f32/f32'
# Create a 3D lattice with the D3Q19 scheme
lattice = LatticeD3Q19(precision)

# Store the Reynolds number in the variable Re
Re = 100.0
Expand All @@ -92,4 +105,4 @@ def set_boundary_conditions(self):
}

sim = Cavity(**kwargs) # Run the simulation
sim.run(n_iters)
sim.run(n_iters)