Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions sky/client/cli/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,12 +1350,16 @@ def _handle_jobs_queue_request(
show_all: bool,
show_user: bool,
max_num_jobs_to_show: Optional[int],
pool_status_request_id: Optional[server_common.RequestId[List[Dict[
str, Any]]]] = None,
is_called_by_user: bool = False,
only_in_progress: bool = False,
) -> Tuple[Optional[int], str]:
"""Get the in-progress managed jobs.

Args:
request_id: The request ID for managed jobs.
pool_status_request_id: The request ID for pool status, or None.
show_all: Show all information of each job (e.g., region, price).
show_user: Show the user who submitted the job.
max_num_jobs_to_show: If not None, limit the number of jobs to show to
Expand All @@ -1375,6 +1379,7 @@ def _handle_jobs_queue_request(
num_in_progress_jobs = None
msg = ''
status_counts: Optional[Dict[str, int]] = None
pool_status_result = None
try:
if not is_called_by_user:
usage_lib.messages.usage.set_internal()
Expand All @@ -1395,6 +1400,13 @@ def _handle_jobs_queue_request(
managed_jobs_ = result
num_in_progress_jobs = len(
set(job['job_id'] for job in managed_jobs_))
# Try to get pool status if request was made
if pool_status_request_id is not None:
try:
pool_status_result = sdk.stream_and_get(pool_status_request_id)
except Exception: # pylint: disable=broad-except
# If getting pool status fails, just continue without it
pool_status_result = None
except exceptions.ClusterNotUpError as e:
controller_status = e.cluster_status
msg = str(e)
Expand Down Expand Up @@ -1440,6 +1452,7 @@ def _handle_jobs_queue_request(
else:
msg = table_utils.format_job_table(
managed_jobs_,
pool_status=pool_status_result,
show_all=show_all,
show_user=show_user,
max_jobs=max_num_jobs_to_show,
Expand Down Expand Up @@ -1945,6 +1958,7 @@ def submit_enabled_clouds():
try:
num_in_progress_jobs, msg = _handle_jobs_queue_request(
managed_jobs_queue_request_id,
pool_status_request_id=pool_status_request_id,
show_all=False,
show_user=all_users,
max_num_jobs_to_show=_NUM_MANAGED_JOBS_TO_SHOW_IN_STATUS,
Expand Down Expand Up @@ -4621,21 +4635,6 @@ def jobs_launch(

job_ids = [job_id_handle[0]] if isinstance(job_id_handle[0],
int) else job_id_handle[0]
if pool:
# Display the worker assignment for the jobs.
logger.debug(f'Getting service records for pool: {pool}')
records_request_id = managed_jobs.pool_status(pool_names=pool)
service_records = _async_call_or_wait(records_request_id, async_call,
'sky.jobs.pool_status')
logger.debug(f'Pool status: {service_records}')
replica_infos = service_records[0]['replica_info']
for replica_info in replica_infos:
job_id = replica_info.get('used_by', None)
if job_id in job_ids:
worker_id = replica_info['replica_id']
version = replica_info['version']
logger.info(f'Job ID: {job_id} assigned to pool {pool} '
f'(worker: {worker_id}, version: {version})')

if not detach_run:
if len(job_ids) == 1:
Expand Down Expand Up @@ -4768,8 +4767,16 @@ def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool,
all_users=all_users,
limit=max_num_jobs_to_show,
fields=fields)
# Try to get pool status for worker information
pool_status_request_id = None
try:
pool_status_request_id = managed_jobs.pool_status(pool_names=None)
except Exception: # pylint: disable=broad-except
# If pool_status fails, we'll just skip the worker information
pass
num_jobs, msg = _handle_jobs_queue_request(
managed_jobs_request_id,
pool_status_request_id=pool_status_request_id,
show_all=verbose,
show_user=all_users,
max_num_jobs_to_show=max_num_jobs_to_show,
Expand Down
4 changes: 3 additions & 1 deletion sky/client/cli/table_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utilities for formatting tables for CLI output."""
import abc
from datetime import datetime
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import prettytable

Expand Down Expand Up @@ -92,12 +92,14 @@ def format_job_table(
jobs: List[responses.ManagedJobRecord],
show_all: bool,
show_user: bool,
pool_status: Optional[List[Dict[str, Any]]] = None,
max_jobs: Optional[int] = None,
status_counts: Optional[Dict[str, int]] = None,
):
jobs = [job.model_dump() for job in jobs]
return managed_jobs.format_job_table(
jobs,
pool_status=pool_status,
show_all=show_all,
show_user=show_user,
max_jobs=max_jobs,
Expand Down
5 changes: 4 additions & 1 deletion sky/jobs/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def check_version_mismatch_and_non_terminal_jobs() -> None:
if not version_matches and has_non_terminal_jobs:
# Format job table locally using the same method as queue()
formatted_job_table = managed_job_utils.format_job_table(
non_terminal_jobs, show_all=False, show_user=False)
non_terminal_jobs,
pool_status=None,
show_all=False,
show_user=False)

error_msg = (
f'Controller SKYLET_VERSION ({controller_version}) does not match '
Expand Down
39 changes: 39 additions & 0 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,7 @@ def format_job_table(
show_all: bool,
show_user: bool,
return_rows: Literal[False] = False,
pool_status: Optional[List[Dict[str, Any]]] = None,
max_jobs: Optional[int] = None,
job_status_counts: Optional[Dict[str, int]] = None,
) -> str:
Expand All @@ -1727,6 +1728,7 @@ def format_job_table(
show_all: bool,
show_user: bool,
return_rows: Literal[True],
pool_status: Optional[List[Dict[str, Any]]] = None,
max_jobs: Optional[int] = None,
job_status_counts: Optional[Dict[str, int]] = None,
) -> List[List[str]]:
Expand All @@ -1738,6 +1740,7 @@ def format_job_table(
show_all: bool,
show_user: bool,
return_rows: bool = False,
pool_status: Optional[List[Dict[str, Any]]] = None,
max_jobs: Optional[int] = None,
job_status_counts: Optional[Dict[str, int]] = None,
) -> Union[str, List[List[str]]]:
Expand All @@ -1749,6 +1752,7 @@ def format_job_table(
max_jobs: The maximum number of jobs to show in the table.
return_rows: If True, return the rows as a list of strings instead of
all rows concatenated into a single string.
pool_status: List of pool status dictionaries with replica_info.
job_status_counts: The counts of each job status.

Returns: A formatted string of managed jobs, if not `return_rows`; otherwise
Expand All @@ -1766,6 +1770,30 @@ def get_hash(task):
return (task['user'], task['job_id'])
return task['job_id']

def _get_job_id_to_worker_map(
pool_status: Optional[List[Dict[str, Any]]]) -> Dict[int, int]:
"""Create a mapping from job_id to worker replica_id.

Args:
pool_status: List of pool status dictionaries with replica_info.

Returns:
Dictionary mapping job_id to replica_id (worker ID).
"""
job_to_worker: Dict[int, int] = {}
if pool_status is None:
return job_to_worker
for pool in pool_status:
replica_info = pool.get('replica_info', [])
for replica in replica_info:
used_by = replica.get('used_by')
if used_by is not None:
job_to_worker[used_by] = replica.get('replica_id')
return job_to_worker

# Create mapping from job_id to worker replica_id
job_to_worker = _get_job_id_to_worker_map(pool_status)

for task in tasks:
# The tasks within the same job_id are already sorted
# by the task_id.
Expand Down Expand Up @@ -1909,7 +1937,12 @@ def get_user_column_values(task: Dict[str, Any]) -> List[str]:
if pool is None:
pool = '-'

# Add worker information if job is assigned to a worker
job_id = job_hash[1] if tasks_have_k8s_user else job_hash
# job_id is now always an integer, use it to look up worker
if job_id in job_to_worker and pool != '-':
pool = f'{pool} (worker={job_to_worker[job_id]})'

job_values = [
job_id,
'',
Expand Down Expand Up @@ -1952,6 +1985,12 @@ def get_user_column_values(task: Dict[str, Any]) -> List[str]:
pool = task.get('pool')
if pool is None:
pool = '-'

# Add worker information if task is assigned to a worker
task_job_id = task['job_id']
if task_job_id in job_to_worker and pool != '-':
pool = f'{pool} (worker={job_to_worker[task_job_id]})'

values = [
task['job_id'] if len(job_tasks) == 1 else ' \u21B3',
task['task_id'] if len(job_tasks) > 1 else '-',
Expand Down
37 changes: 37 additions & 0 deletions tests/smoke_tests/test_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,3 +1010,40 @@ def test_pools_num_jobs(generic_cloud: str):
teardown=cancel_jobs_and_teardown_pool(pool_name, timeout=5),
)
smoke_tests_utils.run_one_test(test)


def test_pool_worker_assignment_in_queue(generic_cloud: str):
"""Test that sky jobs queue shows the worker assignment for running jobs."""
timeout = smoke_tests_utils.get_timeout(generic_cloud)
pool_config = basic_pool_conf(num_workers=1, infra=generic_cloud)

job_name = f'{smoke_tests_utils.get_cluster_name()}-job'
job_config = basic_job_conf(
job_name=job_name,
run_cmd='echo "Hello, world!"; sleep infinity',
)
with tempfile.NamedTemporaryFile(delete=True) as pool_yaml:
with tempfile.NamedTemporaryFile(delete=True) as job_yaml:
write_yaml(pool_yaml, pool_config)
write_yaml(job_yaml, job_config)

name = smoke_tests_utils.get_cluster_name()
pool_name = f'{name}-pool'

test = smoke_tests_utils.Test(
'test_pool_worker_assignment_in_queue',
[
_LAUNCH_POOL_AND_CHECK_SUCCESS.format(
pool_name=pool_name, pool_yaml=pool_yaml.name),
wait_until_pool_ready(pool_name, timeout=timeout),
_LAUNCH_JOB_AND_CHECK_SUCCESS.format(
pool_name=pool_name, job_yaml=job_yaml.name),
wait_until_job_status(job_name, ['RUNNING'],
timeout=timeout),
# Check that the worker assignment is shown in the queue output
f's=$(sky jobs queue); echo "$s"; echo; echo; echo "$s" | grep "{job_name}" | grep "{pool_name} (worker=1)"',
],
timeout=timeout,
teardown=cancel_jobs_and_teardown_pool(pool_name, timeout=5),
)
smoke_tests_utils.run_one_test(test)
86 changes: 85 additions & 1 deletion tests/unit_tests/test_sky/test_cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module contains tests for CLI helper functions in sky.client.cli.command.
"""
import traceback
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest import mock

import colorama
Expand Down Expand Up @@ -75,6 +75,7 @@ def test_handle_jobs_queue_request_success_tuple_response():
mock_stream.assert_called_once_with(request_id)
mock_format.assert_called_once_with(
managed_jobs_list,
pool_status=None,
show_all=False,
show_user=False,
max_jobs=10,
Expand Down Expand Up @@ -128,6 +129,89 @@ def test_handle_jobs_queue_request_success_list_response():
mock_stream.assert_called_once_with(request_id)
mock_format.assert_called_once_with(
mock_job_records,
pool_status=None,
show_all=True,
show_user=True,
max_jobs=None,
status_counts=None,
)


def test_handle_jobs_queue_request_success_list_response_with_pool_status():
"""Test _handle_jobs_queue_request with list response (legacy API)."""
# Create mock managed job records as dicts
mock_jobs = [
{
'job_id': 1,
'job_name': 'test-job-1'
},
{
'job_id': 2,
'job_name': 'test-job-2'
},
{
'job_id': 3,
'job_name': 'test-job-3'
},
]

# Mock job records using the model
mock_job_records = [responses.ManagedJobRecord(**job) for job in mock_jobs]

# Mock pool status records using the model
mock_pool_statuses = [
{
'replica_info': [
{
'replica_id': 1,
'used_by': 3,
},
{
'replica_id': 2,
'used_by': 2,
},
{
'replica_id': 3,
'used_by': 1,
},
],
},
]

request_id = server_common.RequestId[List[responses.ManagedJobRecord]](
'test-request-id')

pool_status_request_id = server_common.RequestId[List[Dict[str, Any]]](
'test-pool-status-request-id')

with mock.patch.object(client_sdk,
'stream_and_get',
side_effect=[mock_job_records,
mock_pool_statuses]) as mock_stream:
with mock.patch.object(usage_lib.messages.usage, 'set_internal'):
with mock.patch.object(
table_utils, 'format_job_table',
return_value='formatted table') as mock_format:
num_jobs, msg = command._handle_jobs_queue_request(
request_id=request_id,
show_all=True,
show_user=True,
max_num_jobs_to_show=None,
pool_status_request_id=pool_status_request_id,
is_called_by_user=True,
only_in_progress=False,
)

# Verify the result - should count unique job IDs
assert num_jobs == 3
assert msg == 'formatted table'
mock_stream.assert_has_calls([
mock.call(request_id),
mock.call(pool_status_request_id),
])
mock_format.assert_called_once_with(
mock_job_records,
pool_status=mock_pool_statuses,
show_all=True,
show_user=True,
max_jobs=None,
Expand Down