Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3131,7 +3131,8 @@ def refresh_cluster_records() -> None:
requests = requests_lib.get_request_tasks(
req_filter=requests_lib.RequestTaskFilter(
status=[requests_lib.RequestStatus.RUNNING],
include_request_names=['sky.launch']))
include_request_names=['sky.launch'],
fields=['cluster_name']))
cluster_names_with_launch_request = {
request.cluster_name for request in requests
}
Expand Down Expand Up @@ -3360,7 +3361,8 @@ def _refresh_cluster_record(cluster_name):
req_filter=requests_lib.RequestTaskFilter(
status=[requests_lib.RequestStatus.RUNNING],
include_request_names=['sky.launch'],
cluster_names=cluster_names))
cluster_names=cluster_names,
fields=['cluster_name']))
cluster_names_with_launch_request = {
request.cluster_name for request in requests
}
Expand Down
2 changes: 1 addition & 1 deletion sky/server/requests/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def process_request(self, executor: process.BurstableExecutor,
time.sleep(0.1)
return
request_id, ignore_return_value, _ = request_element
request = api_requests.get_request(request_id)
request = api_requests.get_request(request_id, exact_match=True)
assert request is not None, f'Request with ID {request_id} is None'
if request.status == api_requests.RequestStatus.CANCELLED:
return
Expand Down
5 changes: 4 additions & 1 deletion sky/server/requests/preconditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ async def check(self) -> Tuple[bool, Optional[str]]:
api_requests.RequestStatus.RUNNING
],
include_request_names=['sky.launch', 'sky.start'],
cluster_names=[self.cluster_name]))
cluster_names=[self.cluster_name],
# Only get the request ID to avoid fetching the whole request.
# We're only interested in the count, not the whole request.
fields=['request_id']))
if len(requests) == 0:
# No running or pending tasks, the start process is done.
return True, None
Expand Down
98 changes: 68 additions & 30 deletions sky/server/requests/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
exclude_request_names=[exclude_request_name],
cluster_names=[cluster_name]))
cluster_names=[cluster_name],
fields=['request_id']))
]
kill_requests(request_ids)

Expand All @@ -418,26 +419,36 @@ def kill_requests(request_ids: Optional[List[str]] = None,
Returns:
A list of request IDs that were cancelled.
"""
request_id_is_from_db = False
if request_ids is None:
request_ids = [
request_task.request_id
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
# Avoid cancelling the cancel request itself.
exclude_request_names=['sky.api_cancel'],
user_id=user_id))
user_id=user_id,
fields=['request_id']))
]
cancelled_request_ids = []
# Since we got these request IDs from the database, we can assume
# they are exact matches.
request_id_is_from_db = True
internal_request_ids = set(
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS)
request_ids_to_cancel = []
for request_id in request_ids:
with update_request(request_id) as request_record:
if request_id in internal_request_ids:
continue
request_ids_to_cancel.append(request_id)
del request_ids
cancelled_request_ids = []
for request_id in request_ids_to_cancel:
with update_request(
request_id,
exact_match=request_id_is_from_db) as request_record:
if request_record is None:
logger.debug(f'No request ID {request_id}')
continue
# Skip internal requests. The internal requests are scheduled with
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
if request_record.request_id in set(
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
continue
if request_record.status > RequestStatus.RUNNING:
logger.debug(f'Request {request_id} already finished')
continue
Expand Down Expand Up @@ -581,12 +592,14 @@ def request_lock_path(request_id: str) -> str:
@contextlib.contextmanager
@init_db
@metrics_lib.time_me
def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
def update_request(
request_id: str,
exact_match: bool = False) -> Generator[Optional[Request], None, None]:
"""Get and update a SkyPilot API request."""
# Acquire the lock to avoid race conditions between multiple request
# operations, e.g. execute and cancel.
with filelock.FileLock(request_lock_path(request_id)):
request = _get_request_no_lock(request_id)
request = _get_request_no_lock(request_id, exact_match=exact_match)
yield request
if request is not None:
_add_or_update_request_no_lock(request)
Expand All @@ -604,27 +617,39 @@ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
await _add_or_update_request_no_lock_async(request)


_get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} '
'WHERE request_id LIKE ?')
def _get_request_sql(exact_match: bool = False) -> str:
query = f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE}'
if not exact_match:
query += ' WHERE request_id LIKE ?'
if exact_match:
query += ' WHERE request_id = ?'
return query


def _get_request_no_lock(request_id: str) -> Optional[Request]:
def _get_request_no_lock(request_id: str,
exact_match: bool = False) -> Optional[Request]:
"""Get a SkyPilot API request."""
assert _DB is not None
if not exact_match:
request_id = request_id + '%'
with _DB.conn:
cursor = _DB.conn.cursor()
cursor.execute(_get_request_sql, (request_id + '%',))
cursor.execute(_get_request_sql(exact_match), (request_id,))
row = cursor.fetchone()
if row is None:
return None
return Request.from_row(row)


async def _get_request_no_lock_async(request_id: str) -> Optional[Request]:
async def _get_request_no_lock_async(request_id: str,
exact_match: bool = False
) -> Optional[Request]:
"""Async version of _get_request_no_lock."""
assert _DB is not None
async with _DB.execute_fetchall_async(_get_request_sql,
(request_id + '%',)) as rows:
if not exact_match:
request_id = request_id + '%'
async with _DB.execute_fetchall_async(_get_request_sql(exact_match),
(request_id,)) as rows:
row = rows[0] if rows else None
if row is None:
return None
Expand All @@ -646,20 +671,23 @@ def get_latest_request_id() -> Optional[str]:

@init_db
@metrics_lib.time_me
def get_request(request_id: str) -> Optional[Request]:
def get_request(request_id: str,
exact_match: bool = False) -> Optional[Request]:
"""Get a SkyPilot API request."""
with filelock.FileLock(request_lock_path(request_id)):
return _get_request_no_lock(request_id)
return _get_request_no_lock(request_id, exact_match=exact_match)


@init_db_async
@metrics_lib.time_me_async
@asyncio_utils.shield
async def get_request_async(request_id: str) -> Optional[Request]:
async def get_request_async(request_id: str,
exact_match: bool = False) -> Optional[Request]:
"""Async version of get_request."""
# TODO(aylei): figure out how to remove FileLock here to avoid the overhead
async with filelock.AsyncFileLock(request_lock_path(request_id)):
return await _get_request_no_lock_async(request_id)
return await _get_request_no_lock_async(request_id,
exact_match=exact_match)


class StatusWithMsg(NamedTuple):
Expand All @@ -672,12 +700,14 @@ class StatusWithMsg(NamedTuple):
async def get_request_status_async(
request_id: str,
include_msg: bool = False,
exact_match: bool = False,
) -> Optional[StatusWithMsg]:
"""Get the status of a request.

Args:
request_id: The ID of the request.
include_msg: Whether to include the status message.
exact_match: Whether to match the request ID exactly.

Returns:
The status of the request. If the request is not found, returns
Expand All @@ -687,8 +717,13 @@ async def get_request_status_async(
columns = 'status'
if include_msg:
columns += ', status_msg'
sql = f'SELECT {columns} FROM {REQUEST_TABLE} WHERE request_id LIKE ?'
async with _DB.execute_fetchall_async(sql, (request_id + '%',)) as rows:
sql = f'SELECT {columns} FROM {REQUEST_TABLE}'
if not exact_match:
sql += ' WHERE request_id LIKE ?'
request_id = request_id + '%'
if exact_match:
sql += ' WHERE request_id = ?'
async with _DB.execute_fetchall_async(sql, (request_id,)) as rows:
if rows is None or len(rows) == 0:
return None
status = RequestStatus(rows[0][0])
Expand All @@ -701,7 +736,8 @@ async def get_request_status_async(
def create_if_not_exists(request: Request) -> bool:
"""Create a SkyPilot API request if it does not exist."""
with filelock.FileLock(request_lock_path(request.request_id)):
if _get_request_no_lock(request.request_id) is not None:
if _get_request_no_lock(request.request_id,
exact_match=True) is not None:
return False
_add_or_update_request_no_lock(request)
return True
Expand All @@ -713,7 +749,8 @@ def create_if_not_exists(request: Request) -> bool:
async def create_if_not_exists_async(request: Request) -> bool:
"""Async version of create_if_not_exists."""
async with filelock.AsyncFileLock(request_lock_path(request.request_id)):
if await _get_request_no_lock_async(request.request_id) is not None:
if await _get_request_no_lock_async(request.request_id,
exact_match=True) is not None:
return False
await _add_or_update_request_no_lock_async(request)
return True
Expand Down Expand Up @@ -919,9 +956,9 @@ def set_request_cancelled(request_id: str) -> None:

@init_db
@metrics_lib.time_me
async def _delete_requests(requests: List[Request]):
async def _delete_requests(request_ids: List[str]):
"""Clean up requests by their IDs."""
id_list_str = ','.join(repr(req.request_id) for req in requests)
id_list_str = ','.join(repr(request_id) for request_id in request_ids)
assert _DB is not None
await _DB.execute_and_commit_async(
f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
Expand Down Expand Up @@ -949,7 +986,8 @@ async def clean_finished_requests_with_retention(retention_seconds: int,
req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
finished_before=time.time() -
retention_seconds,
limit=batch_size))
limit=batch_size,
fields=['request_id']))
if len(reqs) == 0:
break
futs = []
Expand All @@ -960,7 +998,7 @@ async def clean_finished_requests_with_retention(retention_seconds: int,
req.log_path.absolute()).unlink(missing_ok=True)))
await asyncio.gather(*futs)

await _delete_requests(reqs)
await _delete_requests([req.request_id for req in reqs])
total_deleted += len(reqs)
if len(reqs) < batch_size:
break
Expand Down
16 changes: 9 additions & 7 deletions sky/server/stream_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ async def _tail_log_file(
polling_interval: float = DEFAULT_POLL_INTERVAL
) -> AsyncGenerator[str, None]:
"""Tail the opened log file, buffer the lines and flush in chunks."""

exact_request_id = None
if request_id is not None:
request_task = await requests_lib.get_request_async(request_id)
exact_request_id = request_task.request_id
if tail is not None:
# Find last n lines of the log file. Do not read the whole file into
# memory.
Expand Down Expand Up @@ -189,23 +192,22 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
# check the status so that we display the final request status
# if the request is complete.
should_check_status = True
if request_id is not None and should_check_status:
if exact_request_id is not None and should_check_status:
last_status_check_time = current_time
req_status = await requests_lib.get_request_status_async(
request_id)
exact_request_id, exact_match=True)
if req_status.status > requests_lib.RequestStatus.RUNNING:
if (req_status.status ==
requests_lib.RequestStatus.CANCELLED):
request_task = await requests_lib.get_request_async(
request_id)
exact_request_id, exact_match=True)
if request_task.should_retry:
buffer.append(
message_utils.encode_payload(
rich_utils.Control.RETRY.encode('')))
else:
buffer.append(
f'{request_task.name!r} request {request_id}'
' cancelled\n')
buffer.append(f'{request_task.name!r} request '
f'{exact_request_id} cancelled\n')
break
if not follow:
# The below checks (cluster status, heartbeat) are not needed
Expand Down
11 changes: 6 additions & 5 deletions sky/server/uvicorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@

# TODO(aylei): use decorator to register requests that need to be proactively
# cancelled instead of hardcoding here.
_RETRIABLE_REQUEST_NAMES = [
_RETRIABLE_REQUEST_NAMES = {
'sky.logs',
'sky.jobs.logs',
'sky.serve.logs',
]
}


def add_timestamp_prefix_for_server_logs() -> None:
Expand Down Expand Up @@ -152,16 +152,17 @@ def _wait_requests(self) -> None:
requests_lib.RequestStatus.RUNNING,
]
reqs = requests_lib.get_request_tasks(
req_filter=requests_lib.RequestTaskFilter(status=statuses))
req_filter=requests_lib.RequestTaskFilter(
status=statuses, fields=['request_id', 'name']))
if not reqs:
break
logger.info(f'{len(reqs)} on-going requests '
'found, waiting for them to finish...')
# Proactively cancel internal requests and logs requests since
# they can run for infinite time.
internal_request_ids = [
internal_request_ids = {
d.id for d in daemons.INTERNAL_REQUEST_DAEMONS
]
}
if time.time() - start_time > _WAIT_REQUESTS_TIMEOUT_SECONDS:
logger.warning('Timeout waiting for on-going requests to '
'finish, cancelling all on-going requests.')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self):
async def mock_get_request(request_id):
return MockRequest()

async def mock_get_request_status(request_id):
async def mock_get_request_status(request_id, exact_match=False):
return requests_lib.StatusWithMsg(MockRequest().status,
MockRequest().status_msg)

Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(self):
async def mock_get_request(request_id):
return MockRequest()

async def mock_get_request_status(request_id):
async def mock_get_request_status(request_id, exact_match=False):
return requests_lib.StatusWithMsg(MockRequest().status,
MockRequest().status_msg)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_sky/server/requests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,7 @@ def test_update_request_row_fields_maintains_order():
@pytest.mark.asyncio
async def test_cancel_get_request_async():

async def mock_get_request_async_no_lock(id: str):
async def mock_get_request_async_no_lock(id: str, exact_match: bool = False):
await asyncio.sleep(1)
return None

Expand Down
Loading