diff --git a/.ci_support/environment-old.yml b/.ci_support/environment-old.yml index 60a88a1..8c07315 100644 --- a/.ci_support/environment-old.yml +++ b/.ci_support/environment-old.yml @@ -4,7 +4,7 @@ dependencies: - lammps =2022.06.23 - openmpi - numpy =1.23.5 -- mpi4py =3.1.4 -- pympipool =0.7.0 +- mpi4py =3.1.5 +- pympipool =0.7.2 - ase =3.20.1 - scipy =1.9.3 diff --git a/pylammpsmpi/wrapper/ase.py b/pylammpsmpi/wrapper/ase.py index 4bfa68b..cee4e16 100644 --- a/pylammpsmpi/wrapper/ase.py +++ b/pylammpsmpi/wrapper/ase.py @@ -22,6 +22,7 @@ def __init__( log_file=None, library=None, diable_log_file=True, + use_srun=False, ): self._logger = logger self._prism = None @@ -45,7 +46,9 @@ def __init__( ) else: self._interactive_library = LammpsBase( - cores=self._cores, working_directory=working_directory + cores=self._cores, + working_directory=working_directory, + use_srun=use_srun, ) def interactive_lib_command(self, command): diff --git a/pylammpsmpi/wrapper/concurrent.py b/pylammpsmpi/wrapper/concurrent.py index bc5f70e..2b5421e 100644 --- a/pylammpsmpi/wrapper/concurrent.py +++ b/pylammpsmpi/wrapper/concurrent.py @@ -11,6 +11,7 @@ interface_bootup, cancel_items_in_queue, MpiExecInterface, + SrunInterface, ) @@ -32,6 +33,7 @@ def execute_async( cores=1, oversubscribe=False, cwd=None, + use_srun=False, ): executable = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "mpi", "lmpmpi.py" @@ -39,13 +41,21 @@ def execute_async( cmds = [sys.executable, executable] if cmdargs is not None: cmds.extend(cmdargs) - interface = interface_bootup( - command_lst=cmds, - connections=MpiExecInterface( + if use_srun: + connection_interface = SrunInterface( cwd=cwd, cores=cores, oversubscribe=oversubscribe, - ), + ) + else: + connection_interface = MpiExecInterface( + cwd=cwd, + cores=cores, + oversubscribe=oversubscribe, + ) + interface = interface_bootup( + command_lst=cmds, + connections=connection_interface, ) while True: task_dict = future_queue.get() @@ -65,6 +75,7 @@ def __init__( oversubscribe=False, working_directory=".", cmdargs=None, + use_srun=False, ): self.cores = cores self.working_directory = working_directory @@ -72,6 +83,7 @@ def __init__( self._process = None self._oversubscribe = oversubscribe self._cmdargs = cmdargs + self._use_srun = use_srun self._start_process() def _start_process(self): @@ -83,6 +95,7 @@ def _start_process(self): "cores": self.cores, "oversubscribe": self._oversubscribe, "cwd": self.working_directory, + "use_srun": self._use_srun, }, ) self._process.start() diff --git a/pylammpsmpi/wrapper/extended.py b/pylammpsmpi/wrapper/extended.py index d9d5723..29a776d 100644 --- a/pylammpsmpi/wrapper/extended.py +++ b/pylammpsmpi/wrapper/extended.py @@ -248,6 +248,7 @@ def __init__( client=None, mode="local", cmdargs=None, + use_srun=False, ): self.cores = cores self.working_directory = working_directory @@ -259,6 +260,7 @@ def __init__( oversubscribe=self.oversubscribe, working_directory=self.working_directory, cmdargs=cmdargs, + use_srun=use_srun, ) def __getattr__(self, name):