diff --git a/pioreactor/cli/pios.py b/pioreactor/cli/pios.py index f54f60f6..5f6c14e1 100644 --- a/pioreactor/cli/pios.py +++ b/pioreactor/cli/pios.py @@ -28,6 +28,7 @@ from pioreactor.whoami import am_I_leader from pioreactor.whoami import get_assigned_experiment_name from pioreactor.whoami import get_unit_name +from pioreactor.whoami import is_testing_env from pioreactor.whoami import UNIVERSAL_EXPERIMENT from pioreactor.whoami import UNIVERSAL_IDENTIFIER @@ -52,7 +53,7 @@ def pios(ctx) -> None: raise click.Abort() -if am_I_leader(): +if am_I_leader() or is_testing_env(): which_units = click.option( "--units", multiple=True, @@ -94,14 +95,14 @@ def universal_identifier_to_all_active_workers(workers: tuple[str, ...]) -> tupl if workers == (UNIVERSAL_IDENTIFIER,): return active_workers else: - return tuple(u for u in workers if u in active_workers) + return tuple(u for u in set(workers) if u in active_workers) def universal_identifier_to_all_workers(workers: tuple[str, ...]) -> tuple[str, ...]: all_workers = get_workers_in_inventory() if workers == (UNIVERSAL_IDENTIFIER,): return all_workers else: - return tuple(u for u in workers if u in all_workers) + return tuple(u for u in set(workers) if u in all_workers) def add_leader(units: tuple[str, ...]) -> tuple[str, ...]: leader = get_leader_hostname() diff --git a/pioreactor/tests/test_cli.py b/pioreactor/tests/test_cli.py index f38fc406..4197f4db 100644 --- a/pioreactor/tests/test_cli.py +++ b/pioreactor/tests/test_cli.py @@ -4,15 +4,20 @@ import time +import click import pytest from click.testing import CliRunner from pioreactor import whoami from pioreactor.background_jobs.dosing_automation import start_dosing_automation from pioreactor.cli.pio import pio +from pioreactor.cli.pios import kill from pioreactor.cli.pios import pios +from pioreactor.cli.pios import reboot +from pioreactor.cli.pios import run from pioreactor.pubsub import collect_all_logs_of_level from pioreactor.pubsub import subscribe_and_callback +from pioreactor.tests.conftest import capture_requests from pioreactor.utils import is_pio_job_running from pioreactor.utils import local_intermittent_storage @@ -112,3 +117,42 @@ def test_pio_kill_cleans_up_automations_correctly() -> None: pause() assert not is_pio_job_running("dosing_automation") + + +def test_pios_run_requests(): + with capture_requests() as bucket: + ctx = click.Context(run, allow_extra_args=True) + ctx.forward(run, job="stirring", y=True) + + assert len(bucket) == 2 + assert bucket[0].url == "http://unit1.local:4999/unit_api/jobs/run/job_name/stirring" + + +def test_pios_run_requests_dedup_and_filter(): + units = ("unit1", "unit1", "notaunit") + + with capture_requests() as bucket: + ctx = click.Context(run, allow_extra_args=True) + ctx.forward(run, job="stirring", y=True, units=units) + + assert len(bucket) == 1 + assert bucket[0].url == "http://unit1.local:4999/unit_api/jobs/run/job_name/stirring" + + +def test_pios_kill_requests(): + with capture_requests() as bucket: + ctx = click.Context(kill, allow_extra_args=True) + ctx.forward(kill, experiment="demo", y=True) + + assert len(bucket) == 2 + assert bucket[0].url == "http://unit1.local:4999/unit_api/jobs/stop/experiment/demo" + assert bucket[1].url == "http://unit2.local:4999/unit_api/jobs/stop/experiment/demo" + + +def test_pios_reboot_requests(): + with capture_requests() as bucket: + ctx = click.Context(reboot, allow_extra_args=True) + ctx.forward(reboot, y=True, units=("unit1",)) + + assert len(bucket) == 1 + assert bucket[0].url == "http://unit1.local:4999/unit_api/system/reboot" diff --git a/pioreactor/tests/test_utils.py b/pioreactor/tests/test_utils.py index fb906ff0..59612ee8 100644 --- a/pioreactor/tests/test_utils.py +++ b/pioreactor/tests/test_utils.py @@ -267,13 +267,23 @@ def collect(msg): def test_ClusterJobManager_sends_requests(): + workers = ["pio01", "pio02", "pio03"] with capture_requests() as bucket: - with ClusterJobManager(["pio01", "pio02", "pio03"]) as cm: + with ClusterJobManager(workers) as cm: cm.kill_jobs(name="stirring") - assert len(bucket) == 3 + assert len(bucket) == len(workers) assert bucket[0].body is None assert bucket[0].method == "PATCH" - assert bucket[0].url == "http://pio01.local:4999/unit_api/jobs/stop/job_name/stirring" - assert bucket[1].url == "http://pio02.local:4999/unit_api/jobs/stop/job_name/stirring" - assert bucket[2].url == "http://pio03.local:4999/unit_api/jobs/stop/job_name/stirring" + + for request, worker in zip(bucket, workers): + assert request.url == f"http://{worker}.local:4999/unit_api/jobs/stop/job_name/stirring" + + +def test_empty_ClusterJobManager(): + workers = [] + with capture_requests() as bucket: + with ClusterJobManager(workers) as cm: + cm.kill_jobs(name="stirring") + + assert len(bucket) == len(workers)