From 18be7ad97135ee3c4f891c8d5910907d1925cd61 Mon Sep 17 00:00:00 2001 From: lloydbrownjr Date: Thu, 9 Oct 2025 11:36:49 -0700 Subject: [PATCH 1/5] Fix worker assignment. --- sky/client/cli/command.py | 37 ++++++++++++++++++------------- sky/client/cli/table_utils.py | 4 +++- sky/jobs/server/utils.py | 5 ++++- sky/jobs/utils.py | 39 +++++++++++++++++++++++++++++++++ tests/smoke_tests/test_pools.py | 38 ++++++++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 17 deletions(-) diff --git a/sky/client/cli/command.py b/sky/client/cli/command.py index a92ece858c3..de45048e76e 100644 --- a/sky/client/cli/command.py +++ b/sky/client/cli/command.py @@ -1347,6 +1347,8 @@ def _handle_jobs_queue_request( request_id: server_common.RequestId[Union[ List[responses.ManagedJobRecord], Tuple[List[responses.ManagedJobRecord], int, Dict[str, int], int]]], + pool_status_request_id: Optional[server_common.RequestId[List[Dict[ + str, Any]]]], show_all: bool, show_user: bool, max_num_jobs_to_show: Optional[int], @@ -1356,6 +1358,8 @@ def _handle_jobs_queue_request( """Get the in-progress managed jobs. Args: + request_id: The request ID for managed jobs. + pool_status_request_id: The request ID for pool status, or None. show_all: Show all information of each job (e.g., region, price). show_user: Show the user who submitted the job. max_num_jobs_to_show: If not None, limit the number of jobs to show to @@ -1375,6 +1379,7 @@ def _handle_jobs_queue_request( num_in_progress_jobs = None msg = '' status_counts: Optional[Dict[str, int]] = None + pool_status_result = None try: if not is_called_by_user: usage_lib.messages.usage.set_internal() @@ -1395,6 +1400,13 @@ def _handle_jobs_queue_request( managed_jobs_ = result num_in_progress_jobs = len( set(job['job_id'] for job in managed_jobs_)) + # Try to get pool status if request was made + if pool_status_request_id is not None: + try: + pool_status_result = sdk.stream_and_get(pool_status_request_id) + except Exception: # pylint: disable=broad-except + # If getting pool status fails, just continue without it + pool_status_result = None except exceptions.ClusterNotUpError as e: controller_status = e.cluster_status msg = str(e) @@ -1440,6 +1452,7 @@ def _handle_jobs_queue_request( else: msg = table_utils.format_job_table( managed_jobs_, + pool_status=pool_status_result, show_all=show_all, show_user=show_user, max_jobs=max_num_jobs_to_show, @@ -1945,6 +1958,7 @@ def submit_enabled_clouds(): try: num_in_progress_jobs, msg = _handle_jobs_queue_request( managed_jobs_queue_request_id, + pool_status_request_id, show_all=False, show_user=all_users, max_num_jobs_to_show=_NUM_MANAGED_JOBS_TO_SHOW_IN_STATUS, @@ -4621,21 +4635,6 @@ def jobs_launch( job_ids = [job_id_handle[0]] if isinstance(job_id_handle[0], int) else job_id_handle[0] - if pool: - # Display the worker assignment for the jobs. - logger.debug(f'Getting service records for pool: {pool}') - records_request_id = managed_jobs.pool_status(pool_names=pool) - service_records = _async_call_or_wait(records_request_id, async_call, - 'sky.jobs.pool_status') - logger.debug(f'Pool status: {service_records}') - replica_infos = service_records[0]['replica_info'] - for replica_info in replica_infos: - job_id = replica_info.get('used_by', None) - if job_id in job_ids: - worker_id = replica_info['replica_id'] - version = replica_info['version'] - logger.info(f'Job ID: {job_id} assigned to pool {pool} ' - f'(worker: {worker_id}, version: {version})') if not detach_run: if len(job_ids) == 1: @@ -4768,8 +4767,16 @@ def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool, all_users=all_users, limit=max_num_jobs_to_show, fields=fields) + # Try to get pool status for worker information + pool_status_request_id = None + try: + pool_status_request_id = managed_jobs.pool_status(pool_names=None) + except Exception: # pylint: disable=broad-except + # If pool_status fails, we'll just skip the worker information + pass num_jobs, msg = _handle_jobs_queue_request( managed_jobs_request_id, + pool_status_request_id, show_all=verbose, show_user=all_users, max_num_jobs_to_show=max_num_jobs_to_show, diff --git a/sky/client/cli/table_utils.py b/sky/client/cli/table_utils.py index 238ef4d3089..baf3f5600ca 100644 --- a/sky/client/cli/table_utils.py +++ b/sky/client/cli/table_utils.py @@ -1,7 +1,7 @@ """Utilities for formatting tables for CLI output.""" import abc from datetime import datetime -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import prettytable @@ -92,12 +92,14 @@ def format_job_table( jobs: List[responses.ManagedJobRecord], show_all: bool, show_user: bool, + pool_status: Optional[List[Dict[str, Any]]] = None, max_jobs: Optional[int] = None, status_counts: Optional[Dict[str, int]] = None, ): jobs = [job.model_dump() for job in jobs] return managed_jobs.format_job_table( jobs, + pool_status=pool_status, show_all=show_all, show_user=show_user, max_jobs=max_jobs, diff --git a/sky/jobs/server/utils.py b/sky/jobs/server/utils.py index 65b08a174f7..ebef0778cc4 100644 --- a/sky/jobs/server/utils.py +++ b/sky/jobs/server/utils.py @@ -103,7 +103,10 @@ def check_version_mismatch_and_non_terminal_jobs() -> None: if not version_matches and has_non_terminal_jobs: # Format job table locally using the same method as queue() formatted_job_table = managed_job_utils.format_job_table( - non_terminal_jobs, show_all=False, show_user=False) + non_terminal_jobs, + pool_status=None, + show_all=False, + show_user=False) error_msg = ( f'Controller SKYLET_VERSION ({controller_version}) does not match ' diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index ce0226defb4..bebabff45fe 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -1714,6 +1714,7 @@ def format_job_table( tasks: List[Dict[str, Any]], show_all: bool, show_user: bool, + pool_status: Optional[List[Dict[str, Any]]] = None, return_rows: Literal[False] = False, max_jobs: Optional[int] = None, job_status_counts: Optional[Dict[str, int]] = None, @@ -1727,6 +1728,7 @@ def format_job_table( show_all: bool, show_user: bool, return_rows: Literal[True], + pool_status: Optional[List[Dict[str, Any]]] = None, max_jobs: Optional[int] = None, job_status_counts: Optional[Dict[str, int]] = None, ) -> List[List[str]]: @@ -1738,6 +1740,7 @@ def format_job_table( show_all: bool, show_user: bool, return_rows: bool = False, + pool_status: Optional[List[Dict[str, Any]]] = None, max_jobs: Optional[int] = None, job_status_counts: Optional[Dict[str, int]] = None, ) -> Union[str, List[List[str]]]: @@ -1749,6 +1752,7 @@ def format_job_table( max_jobs: The maximum number of jobs to show in the table. return_rows: If True, return the rows as a list of strings instead of all rows concatenated into a single string. + pool_status: List of pool status dictionaries with replica_info. job_status_counts: The counts of each job status. Returns: A formatted string of managed jobs, if not `return_rows`; otherwise @@ -1766,6 +1770,30 @@ def get_hash(task): return (task['user'], task['job_id']) return task['job_id'] + def get_job_id_to_worker_map( + pool_status: Optional[List[Dict[str, Any]]]) -> Dict[int, int]: + """Create a mapping from job_id to worker replica_id. + + Args: + pool_status: List of pool status dictionaries with replica_info. + + Returns: + Dictionary mapping job_id to replica_id (worker ID). + """ + job_to_worker: Dict[int, int] = {} + if pool_status is None: + return job_to_worker + for pool in pool_status: + replica_info = pool.get('replica_info', []) + for replica in replica_info: + used_by = replica.get('used_by') + if used_by is not None: + job_to_worker[used_by] = replica.get('replica_id') + return job_to_worker + + # Create mapping from job_id to worker replica_id + job_to_worker = get_job_id_to_worker_map(pool_status) + for task in tasks: # The tasks within the same job_id are already sorted # by the task_id. @@ -1909,7 +1937,12 @@ def get_user_column_values(task: Dict[str, Any]) -> List[str]: if pool is None: pool = '-' + # Add worker information if job is assigned to a worker job_id = job_hash[1] if tasks_have_k8s_user else job_hash + # job_id is now always an integer, use it to look up worker + if job_id in job_to_worker and pool != '-': + pool = f'{pool} (worker={job_to_worker[job_id]})' + job_values = [ job_id, '', @@ -1952,6 +1985,12 @@ def get_user_column_values(task: Dict[str, Any]) -> List[str]: pool = task.get('pool') if pool is None: pool = '-' + + # Add worker information if task is assigned to a worker + task_job_id = task['job_id'] + if task_job_id in job_to_worker and pool != '-': + pool = f'{pool} (worker={job_to_worker[task_job_id]})' + values = [ task['job_id'] if len(job_tasks) == 1 else ' \u21B3', task['task_id'] if len(job_tasks) > 1 else '-', diff --git a/tests/smoke_tests/test_pools.py b/tests/smoke_tests/test_pools.py index a254d77410c..d26d39124b8 100644 --- a/tests/smoke_tests/test_pools.py +++ b/tests/smoke_tests/test_pools.py @@ -1010,3 +1010,41 @@ def test_pools_num_jobs(generic_cloud: str): teardown=cancel_jobs_and_teardown_pool(pool_name, timeout=5), ) smoke_tests_utils.run_one_test(test) + + +def test_pool_worker_assignment_in_queue(generic_cloud: str): + """Test that sky jobs queue shows the worker assignment for running jobs.""" + timeout = smoke_tests_utils.get_timeout(generic_cloud) + pool_config = basic_pool_conf(num_workers=1, infra=generic_cloud) + + job_name = f'{smoke_tests_utils.get_cluster_name()}-job' + job_config = basic_job_conf( + job_name=job_name, + run_cmd='echo "Hello, world!"; sleep infinity', + ) + with tempfile.NamedTemporaryFile(delete=True) as pool_yaml: + with tempfile.NamedTemporaryFile(delete=True) as job_yaml: + write_yaml(pool_yaml, pool_config) + write_yaml(job_yaml, job_config) + + name = smoke_tests_utils.get_cluster_name() + pool_name = f'{name}-pool' + + test = smoke_tests_utils.Test( + 'test_pool_worker_assignment_in_queue', + [ + _LAUNCH_POOL_AND_CHECK_SUCCESS.format( + pool_name=pool_name, pool_yaml=pool_yaml.name), + wait_until_pool_ready(pool_name, timeout=timeout), + _LAUNCH_JOB_AND_CHECK_SUCCESS.format( + pool_name=pool_name, job_yaml=job_yaml.name), + wait_until_job_status(job_name, ['RUNNING'], + timeout=timeout), + # Check that the worker assignment is shown in the queue output + f's=$(sky jobs queue); echo "$s"; echo; echo; echo "$s" | grep "{job_name}" | grep "{pool_name} (worker=1)"', + ], + timeout=timeout, + teardown=cancel_jobs_and_teardown_pool(pool_name, timeout=5), + ) + + smoke_tests_utils.run_one_test(test) \ No newline at end of file From 6ec355b9e49d9eae631ec4e2c916da252bd6ca99 Mon Sep 17 00:00:00 2001 From: lloyd-brown Date: Wed, 22 Oct 2025 15:16:57 -0700 Subject: [PATCH 2/5] Make function private. --- sky/jobs/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index bebabff45fe..21f291ea3bb 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -1770,7 +1770,7 @@ def get_hash(task): return (task['user'], task['job_id']) return task['job_id'] - def get_job_id_to_worker_map( + def _get_job_id_to_worker_map( pool_status: Optional[List[Dict[str, Any]]]) -> Dict[int, int]: """Create a mapping from job_id to worker replica_id. @@ -1792,7 +1792,7 @@ def get_job_id_to_worker_map( return job_to_worker # Create mapping from job_id to worker replica_id - job_to_worker = get_job_id_to_worker_map(pool_status) + job_to_worker = _get_job_id_to_worker_map(pool_status) for task in tasks: # The tasks within the same job_id are already sorted From d256f0f9097312d5eed948f6755c42c9ba1c2ef6 Mon Sep 17 00:00:00 2001 From: lloyd-brown Date: Thu, 23 Oct 2025 17:54:24 -0700 Subject: [PATCH 3/5] Fix calls. --- sky/client/cli/command.py | 4 ++-- tests/smoke_tests/test_pools.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sky/client/cli/command.py b/sky/client/cli/command.py index de45048e76e..65b8c9b0d08 100644 --- a/sky/client/cli/command.py +++ b/sky/client/cli/command.py @@ -1347,11 +1347,11 @@ def _handle_jobs_queue_request( request_id: server_common.RequestId[Union[ List[responses.ManagedJobRecord], Tuple[List[responses.ManagedJobRecord], int, Dict[str, int], int]]], - pool_status_request_id: Optional[server_common.RequestId[List[Dict[ - str, Any]]]], show_all: bool, show_user: bool, max_num_jobs_to_show: Optional[int], + pool_status_request_id: Optional[server_common.RequestId[List[Dict[str, + Any]]]] = None, is_called_by_user: bool = False, only_in_progress: bool = False, ) -> Tuple[Optional[int], str]: diff --git a/tests/smoke_tests/test_pools.py b/tests/smoke_tests/test_pools.py index d26d39124b8..7f758498461 100644 --- a/tests/smoke_tests/test_pools.py +++ b/tests/smoke_tests/test_pools.py @@ -1046,5 +1046,4 @@ def test_pool_worker_assignment_in_queue(generic_cloud: str): timeout=timeout, teardown=cancel_jobs_and_teardown_pool(pool_name, timeout=5), ) - - smoke_tests_utils.run_one_test(test) \ No newline at end of file + smoke_tests_utils.run_one_test(test) From e7e025a5db62138fe7e2f39ccd501213e048a603 Mon Sep 17 00:00:00 2001 From: lloyd-brown Date: Thu, 23 Oct 2025 18:03:20 -0700 Subject: [PATCH 4/5] Fix calls. --- sky/client/cli/command.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sky/client/cli/command.py b/sky/client/cli/command.py index 65b8c9b0d08..18583ac64b9 100644 --- a/sky/client/cli/command.py +++ b/sky/client/cli/command.py @@ -1354,6 +1354,8 @@ def _handle_jobs_queue_request( Any]]]] = None, is_called_by_user: bool = False, only_in_progress: bool = False, + pool_status_request_id: Optional[server_common.RequestId[List[Dict[ + str, Any]]]] = None, ) -> Tuple[Optional[int], str]: """Get the in-progress managed jobs. @@ -1958,7 +1960,7 @@ def submit_enabled_clouds(): try: num_in_progress_jobs, msg = _handle_jobs_queue_request( managed_jobs_queue_request_id, - pool_status_request_id, + pool_status_request_id=pool_status_request_id, show_all=False, show_user=all_users, max_num_jobs_to_show=_NUM_MANAGED_JOBS_TO_SHOW_IN_STATUS, @@ -4776,7 +4778,7 @@ def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool, pass num_jobs, msg = _handle_jobs_queue_request( managed_jobs_request_id, - pool_status_request_id, + pool_status_request_id=pool_status_request_id, show_all=verbose, show_user=all_users, max_num_jobs_to_show=max_num_jobs_to_show, From ae857a4c9b4f42a97265a43de2ad922b2916a1bb Mon Sep 17 00:00:00 2001 From: lloyd-brown Date: Fri, 24 Oct 2025 12:29:33 -0700 Subject: [PATCH 5/5] Make fixes add test. --- sky/client/cli/command.py | 6 +- sky/jobs/utils.py | 2 +- tests/unit_tests/test_sky/test_cli_helpers.py | 86 ++++++++++++++++++- 3 files changed, 88 insertions(+), 6 deletions(-) diff --git a/sky/client/cli/command.py b/sky/client/cli/command.py index 18583ac64b9..4f27d685d09 100644 --- a/sky/client/cli/command.py +++ b/sky/client/cli/command.py @@ -1350,12 +1350,10 @@ def _handle_jobs_queue_request( show_all: bool, show_user: bool, max_num_jobs_to_show: Optional[int], - pool_status_request_id: Optional[server_common.RequestId[List[Dict[str, - Any]]]] = None, - is_called_by_user: bool = False, - only_in_progress: bool = False, pool_status_request_id: Optional[server_common.RequestId[List[Dict[ str, Any]]]] = None, + is_called_by_user: bool = False, + only_in_progress: bool = False, ) -> Tuple[Optional[int], str]: """Get the in-progress managed jobs. diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 21f291ea3bb..aa72f0d70ed 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -1714,8 +1714,8 @@ def format_job_table( tasks: List[Dict[str, Any]], show_all: bool, show_user: bool, - pool_status: Optional[List[Dict[str, Any]]] = None, return_rows: Literal[False] = False, + pool_status: Optional[List[Dict[str, Any]]] = None, max_jobs: Optional[int] = None, job_status_counts: Optional[Dict[str, int]] = None, ) -> str: diff --git a/tests/unit_tests/test_sky/test_cli_helpers.py b/tests/unit_tests/test_sky/test_cli_helpers.py index a65bc0225ff..ac160fc550a 100644 --- a/tests/unit_tests/test_sky/test_cli_helpers.py +++ b/tests/unit_tests/test_sky/test_cli_helpers.py @@ -3,7 +3,7 @@ This module contains tests for CLI helper functions in sky.client.cli.command. """ import traceback -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from unittest import mock import colorama @@ -75,6 +75,7 @@ def test_handle_jobs_queue_request_success_tuple_response(): mock_stream.assert_called_once_with(request_id) mock_format.assert_called_once_with( managed_jobs_list, + pool_status=None, show_all=False, show_user=False, max_jobs=10, @@ -128,6 +129,89 @@ def test_handle_jobs_queue_request_success_list_response(): mock_stream.assert_called_once_with(request_id) mock_format.assert_called_once_with( mock_job_records, + pool_status=None, + show_all=True, + show_user=True, + max_jobs=None, + status_counts=None, + ) + + +def test_handle_jobs_queue_request_success_list_response_with_pool_status(): + """Test _handle_jobs_queue_request with list response (legacy API).""" + # Create mock managed job records as dicts + mock_jobs = [ + { + 'job_id': 1, + 'job_name': 'test-job-1' + }, + { + 'job_id': 2, + 'job_name': 'test-job-2' + }, + { + 'job_id': 3, + 'job_name': 'test-job-3' + }, + ] + + # Mock job records using the model + mock_job_records = [responses.ManagedJobRecord(**job) for job in mock_jobs] + + # Mock pool status records using the model + mock_pool_statuses = [ + { + 'replica_info': [ + { + 'replica_id': 1, + 'used_by': 3, + }, + { + 'replica_id': 2, + 'used_by': 2, + }, + { + 'replica_id': 3, + 'used_by': 1, + }, + ], + }, + ] + + request_id = server_common.RequestId[List[responses.ManagedJobRecord]]( + 'test-request-id') + + pool_status_request_id = server_common.RequestId[List[Dict[str, Any]]]( + 'test-pool-status-request-id') + + with mock.patch.object(client_sdk, + 'stream_and_get', + side_effect=[mock_job_records, + mock_pool_statuses]) as mock_stream: + with mock.patch.object(usage_lib.messages.usage, 'set_internal'): + with mock.patch.object( + table_utils, 'format_job_table', + return_value='formatted table') as mock_format: + num_jobs, msg = command._handle_jobs_queue_request( + request_id=request_id, + show_all=True, + show_user=True, + max_num_jobs_to_show=None, + pool_status_request_id=pool_status_request_id, + is_called_by_user=True, + only_in_progress=False, + ) + + # Verify the result - should count unique job IDs + assert num_jobs == 3 + assert msg == 'formatted table' + mock_stream.assert_has_calls([ + mock.call(request_id), + mock.call(pool_status_request_id), + ]) + mock_format.assert_called_once_with( + mock_job_records, + pool_status=mock_pool_statuses, show_all=True, show_user=True, max_jobs=None,