Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@
# update the state.
_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 120

# After enabling consolidation mode, we need to restart the API server to get
# the jobs refresh deamon and correct number of executors. We use this file to
# indicate that the API server has been restarted after enabling consolidation
# mode.
_JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE = (
'~/.sky/.jobs_controller_consolidation_reloaded_signal')


class ManagedJobQueueResultType(enum.Enum):
"""The type of the managed job queue result."""
Expand Down Expand Up @@ -203,13 +210,39 @@ def _validate_consolidation_mode_config(
# API Server. Under the hood, we submit the job monitoring logic as processes
# directly in the API Server.
# Use LRU Cache so that the check is only done once.
@annotations.lru_cache(scope='request', maxsize=1)
def is_consolidation_mode() -> bool:
@annotations.lru_cache(scope='request', maxsize=2)
def is_consolidation_mode(on_api_restart: bool = False) -> bool:
if os.environ.get(constants.OVERRIDE_CONSOLIDATION_MODE) is not None:
return True

consolidation_mode = skypilot_config.get_nested(
config_consolidation_mode = skypilot_config.get_nested(
('jobs', 'controller', 'consolidation_mode'), default_value=False)

signal_file = pathlib.Path(
_JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE).expanduser()

restart_signal_file_exists = signal_file.exists()
consolidation_mode = (config_consolidation_mode and
restart_signal_file_exists)

if on_api_restart:
if config_consolidation_mode:
signal_file.touch()
else:
if not restart_signal_file_exists:
if config_consolidation_mode:
logger.warning(f'{colorama.Fore.YELLOW}Consolidation mode for '
'managed jobs is enabled in the server config, '
'but the API server has not been restarted yet. '
'Please restart the API server to enable it.'
f'{colorama.Style.RESET_ALL}')
return False
elif not config_consolidation_mode:
# Cleanup the signal file if the consolidation mode is disabled in
# the config. This allow the user to disable the consolidation mode
# without restarting the API server.
signal_file.unlink()

# We should only do this check on API server, as the controller will not
# have related config and will always seemingly disabled for consolidation
# mode. Check #6611 for more details.
Expand Down
6 changes: 4 additions & 2 deletions sky/server/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,8 @@ def _start_api_server(deploy: bool = False,
# pylint: disable=import-outside-toplevel
import sky.jobs.utils as job_utils
max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
if job_utils.is_consolidation_mode() else
server_constants.MIN_AVAIL_MEM_GB)
if job_utils.is_consolidation_mode(on_api_restart=True)
else server_constants.MIN_AVAIL_MEM_GB)
if avail_mem_size_gb <= max_memory:
logger.warning(
f'{colorama.Fore.YELLOW}Your SkyPilot API server machine only '
Expand All @@ -571,6 +571,8 @@ def _start_api_server(deploy: bool = False,
args += [f'--host={host}']
if metrics_port is not None:
args += [f'--metrics-port={metrics_port}']
# Use this argument to disable the internal signal file check.
args += ['--start-with-python']

if foreground:
# Replaces the current process with the API server
Expand Down
5 changes: 5 additions & 0 deletions sky/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,7 @@ def apply_user_hash(user_hash: str) -> None:
# Serve metrics on a separate port to isolate it from the application APIs:
# metrics port will not be exposed to the public network typically.
parser.add_argument('--metrics-port', default=9090, type=int)
parser.add_argument('--start-with-python', action='store_true')
cmd_args = parser.parse_args()
if cmd_args.port == cmd_args.metrics_port:
logger.error('port and metrics-port cannot be the same, exiting.')
Expand All @@ -1982,6 +1983,10 @@ def apply_user_hash(user_hash: str) -> None:
logger.error(f'Port {cmd_args.port} is not available, exiting.')
raise RuntimeError(f'Port {cmd_args.port} is not available')

if not cmd_args.start_with_python:
# Maybe touch the signal file on API server startup.
managed_job_utils.is_consolidation_mode(on_api_restart=True)

# Show the privacy policy if it is not already shown. We place it here so
# that it is shown only when the API server is started.
usage_lib.maybe_show_privacy_policy()
Expand Down
15 changes: 11 additions & 4 deletions sky/skylet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,17 @@
]
# When overriding the SkyPilot configs on the API server with the client one,
# we skip the following keys because they are meant to be client-side configs.
SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [('api_server',),
('allowed_clouds',),
('workspaces',), ('db',),
('daemons',)]
# Also, we skip the consolidation mode config as those should be only set on
# the API server side.
SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [
('api_server',),
('allowed_clouds',),
('workspaces',),
('db',),
('daemons',),
('jobs', 'controller'),
('serve', 'controller'),
]

# Constants for Azure blob storage
WAIT_FOR_STORAGE_ACCOUNT_CREATION = 60
Expand Down
39 changes: 39 additions & 0 deletions tests/unit_tests/test_jobs_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import pathlib
import tempfile
import time
from unittest import mock

Expand Down Expand Up @@ -88,3 +90,40 @@ def slow_get_job_status(*args, **kwargs):
first_call = mock_logger.info.call_args_list[0][0][0]
assert 'Failed to get job status:' in first_call
assert 'timed out after 30s' in first_call


@mock.patch('sky.jobs.utils.logger')
@mock.patch('sky.jobs.utils.skypilot_config')
def test_consolidation_mode_warning_without_restart(mock_config, mock_logger):
"""Test that a warning is printed when consolidation mode is enabled
but the API server has not been restarted."""
# Clear the LRU cache to ensure fresh test
utils.is_consolidation_mode.cache_clear()

# Mock config to return True for consolidation mode
mock_config.get_nested.return_value = True

# Create a temporary directory to use as the signal file location
with tempfile.TemporaryDirectory() as tmpdir:
signal_file = pathlib.Path(tmpdir) / 'consolidation_signal'

# Ensure signal file does not exist
if signal_file.exists():
signal_file.unlink()

# Mock the signal file path
with mock.patch(
'sky.jobs.utils._JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE',
str(signal_file)):
# Call is_consolidation_mode
result = utils.is_consolidation_mode()

# Should return False because signal file doesn't exist
assert result is False

# Verify warning was logged
assert mock_logger.warning.call_count == 1
warning_message = mock_logger.warning.call_args[0][0]
assert 'Consolidation mode for managed jobs is enabled' in warning_message
assert 'API server has not been restarted yet' in warning_message
assert 'Please restart the API server to enable it' in warning_message