diff --git a/pioreactor/actions/leader/experiment_profile.py b/pioreactor/actions/leader/experiment_profile.py index 725248f3..2b2540f1 100644 --- a/pioreactor/actions/leader/experiment_profile.py +++ b/pioreactor/actions/leader/experiment_profile.py @@ -892,8 +892,8 @@ def execute_experiment_profile(profile_filename: str, experiment: str, dry_run: # stop all jobs started # we can use active workers in experiment, since if a worker leaves an experiment or goes inactive, it's jobs are stopped workers = get_active_workers_in_experiment(experiment) - with ClusterJobManager(workers) as jm: - jm.kill_jobs(experiment=experiment, job_source="experiment_profile") + with ClusterJobManager() as cjm: + cjm.kill_jobs(workers, experiment=experiment, job_source="experiment_profile") else: if dry_run: diff --git a/pioreactor/cli/pios.py b/pioreactor/cli/pios.py index eb2485ba..1c663bc9 100644 --- a/pioreactor/cli/pios.py +++ b/pioreactor/cli/pios.py @@ -667,8 +667,10 @@ def kill( if confirm != "Y": raise click.Abort() - with ClusterJobManager(units) as cm: - results = cm.kill_jobs(all_jobs=all_jobs, experiment=experiment, job_source=job_source, name=name) + with ClusterJobManager() as cm: + results = cm.kill_jobs( + units, all_jobs=all_jobs, experiment=experiment, job_source=job_source, name=name + ) if json: for success, api_result in results: diff --git a/pioreactor/tests/test_utils.py b/pioreactor/tests/test_utils.py index 59612ee8..dd532007 100644 --- a/pioreactor/tests/test_utils.py +++ b/pioreactor/tests/test_utils.py @@ -269,8 +269,8 @@ def collect(msg): def test_ClusterJobManager_sends_requests(): workers = ["pio01", "pio02", "pio03"] with capture_requests() as bucket: - with ClusterJobManager(workers) as cm: - cm.kill_jobs(name="stirring") + with ClusterJobManager() as cm: + cm.kill_jobs(workers, name="stirring") assert len(bucket) == len(workers) assert bucket[0].body is None @@ -283,7 +283,7 @@ def test_ClusterJobManager_sends_requests(): def test_empty_ClusterJobManager(): workers = [] with capture_requests() as bucket: - with ClusterJobManager(workers) as cm: - cm.kill_jobs(name="stirring") + with ClusterJobManager() as cm: + cm.kill_jobs(workers, name="stirring") assert len(bucket) == len(workers) diff --git a/pioreactor/utils/__init__.py b/pioreactor/utils/__init__.py index f5275295..c6e4b787 100644 --- a/pioreactor/utils/__init__.py +++ b/pioreactor/utils/__init__.py @@ -646,20 +646,20 @@ def __exit__(self, *args) -> None: class ClusterJobManager: - def __init__(self, units: tuple[str, ...]) -> None: + # this is a context manager to mimic the API for JobManager. + def __init__(self) -> None: if not whoami.am_I_leader(): raise RoleError("Must be leader to use this. Maybe you want JobManager?") - self.units = units - + @staticmethod def kill_jobs( - self, + units: tuple[str, ...], all_jobs: bool = False, experiment: str | None = None, name: str | None = None, job_source: str | None = None, ) -> list[tuple[bool, dict]]: - if len(self.units) == 0: + if len(units) == 0: return [] if experiment: @@ -680,8 +680,8 @@ def _thread_function(unit: str) -> tuple[bool, dict]: print(f"Failed to send kill command to {unit}: {e}") return False, {"unit": unit} - with ThreadPoolExecutor(max_workers=len(self.units)) as executor: - results = executor.map(_thread_function, self.units) + with ThreadPoolExecutor(max_workers=len(units)) as executor: + results = executor.map(_thread_function, units) return list(results)