Skip to content

Commit

Permalink
refactor CJM
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Sep 18, 2024
1 parent 17fa844 commit 11503ef
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
4 changes: 2 additions & 2 deletions pioreactor/actions/leader/experiment_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,8 +892,8 @@ def execute_experiment_profile(profile_filename: str, experiment: str, dry_run:
# stop all jobs started
# we can use active workers in experiment, since if a worker leaves an experiment or goes inactive, it's jobs are stopped
workers = get_active_workers_in_experiment(experiment)
with ClusterJobManager(workers) as jm:
jm.kill_jobs(experiment=experiment, job_source="experiment_profile")
with ClusterJobManager() as cjm:
cjm.kill_jobs(workers, experiment=experiment, job_source="experiment_profile")

else:
if dry_run:
Expand Down
6 changes: 4 additions & 2 deletions pioreactor/cli/pios.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,10 @@ def kill(
if confirm != "Y":
raise click.Abort()

with ClusterJobManager(units) as cm:
results = cm.kill_jobs(all_jobs=all_jobs, experiment=experiment, job_source=job_source, name=name)
with ClusterJobManager() as cm:
results = cm.kill_jobs(
units, all_jobs=all_jobs, experiment=experiment, job_source=job_source, name=name
)

if json:
for success, api_result in results:
Expand Down
8 changes: 4 additions & 4 deletions pioreactor/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def collect(msg):
def test_ClusterJobManager_sends_requests():
workers = ["pio01", "pio02", "pio03"]
with capture_requests() as bucket:
with ClusterJobManager(workers) as cm:
cm.kill_jobs(name="stirring")
with ClusterJobManager() as cm:
cm.kill_jobs(workers, name="stirring")

assert len(bucket) == len(workers)
assert bucket[0].body is None
Expand All @@ -283,7 +283,7 @@ def test_ClusterJobManager_sends_requests():
def test_empty_ClusterJobManager():
workers = []
with capture_requests() as bucket:
with ClusterJobManager(workers) as cm:
cm.kill_jobs(name="stirring")
with ClusterJobManager() as cm:
cm.kill_jobs(workers, name="stirring")

assert len(bucket) == len(workers)
14 changes: 7 additions & 7 deletions pioreactor/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,20 +646,20 @@ def __exit__(self, *args) -> None:


class ClusterJobManager:
def __init__(self, units: tuple[str, ...]) -> None:
# this is a context manager to mimic the API for JobManager.
def __init__(self) -> None:
if not whoami.am_I_leader():
raise RoleError("Must be leader to use this. Maybe you want JobManager?")

self.units = units

@staticmethod
def kill_jobs(
self,
units: tuple[str, ...],
all_jobs: bool = False,
experiment: str | None = None,
name: str | None = None,
job_source: str | None = None,
) -> list[tuple[bool, dict]]:
if len(self.units) == 0:
if len(units) == 0:
return []

if experiment:
Expand All @@ -680,8 +680,8 @@ def _thread_function(unit: str) -> tuple[bool, dict]:
print(f"Failed to send kill command to {unit}: {e}")
return False, {"unit": unit}

with ThreadPoolExecutor(max_workers=len(self.units)) as executor:
results = executor.map(_thread_function, self.units)
with ThreadPoolExecutor(max_workers=len(units)) as executor:
results = executor.map(_thread_function, units)

return list(results)

Expand Down

0 comments on commit 11503ef

Please sign in to comment.