diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 232c5de1b9f..aa75fdbd0c2 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -3131,7 +3131,8 @@ def refresh_cluster_records() -> None: requests = requests_lib.get_request_tasks( req_filter=requests_lib.RequestTaskFilter( status=[requests_lib.RequestStatus.RUNNING], - include_request_names=['sky.launch'])) + include_request_names=['sky.launch'], + fields=['cluster_name'])) cluster_names_with_launch_request = { request.cluster_name for request in requests } @@ -3360,7 +3361,8 @@ def _refresh_cluster_record(cluster_name): req_filter=requests_lib.RequestTaskFilter( status=[requests_lib.RequestStatus.RUNNING], include_request_names=['sky.launch'], - cluster_names=cluster_names)) + cluster_names=cluster_names, + fields=['cluster_name'])) cluster_names_with_launch_request = { request.cluster_name for request in requests } diff --git a/sky/server/requests/executor.py b/sky/server/requests/executor.py index 6c0402faf4a..aa3c8274579 100644 --- a/sky/server/requests/executor.py +++ b/sky/server/requests/executor.py @@ -214,7 +214,7 @@ def process_request(self, executor: process.BurstableExecutor, time.sleep(0.1) return request_id, ignore_return_value, _ = request_element - request = api_requests.get_request(request_id) + request = api_requests.get_request(request_id, exact_match=True) assert request is not None, f'Request with ID {request_id} is None' if request.status == api_requests.RequestStatus.CANCELLED: return diff --git a/sky/server/requests/preconditions.py b/sky/server/requests/preconditions.py index 0298d83f310..13a6a160173 100644 --- a/sky/server/requests/preconditions.py +++ b/sky/server/requests/preconditions.py @@ -166,7 +166,10 @@ async def check(self) -> Tuple[bool, Optional[str]]: api_requests.RequestStatus.RUNNING ], include_request_names=['sky.launch', 'sky.start'], - cluster_names=[self.cluster_name])) + cluster_names=[self.cluster_name], + # Only get the request ID to avoid fetching the whole request. + # We're only interested in the count, not the whole request. + fields=['request_id'])) if len(requests) == 0: # No running or pending tasks, the start process is done. return True, None diff --git a/sky/server/requests/requests.py b/sky/server/requests/requests.py index df86bef861d..4732f8589d6 100644 --- a/sky/server/requests/requests.py +++ b/sky/server/requests/requests.py @@ -400,7 +400,8 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str): for request_task in get_request_tasks(req_filter=RequestTaskFilter( status=[RequestStatus.PENDING, RequestStatus.RUNNING], exclude_request_names=[exclude_request_name], - cluster_names=[cluster_name])) + cluster_names=[cluster_name], + fields=['request_id'])) ] kill_requests(request_ids) @@ -418,6 +419,7 @@ def kill_requests(request_ids: Optional[List[str]] = None, Returns: A list of request IDs that were cancelled. """ + request_id_is_from_db = False if request_ids is None: request_ids = [ request_task.request_id @@ -425,19 +427,28 @@ def kill_requests(request_ids: Optional[List[str]] = None, status=[RequestStatus.PENDING, RequestStatus.RUNNING], # Avoid cancelling the cancel request itself. exclude_request_names=['sky.api_cancel'], - user_id=user_id)) + user_id=user_id, + fields=['request_id'])) ] - cancelled_request_ids = [] + # Since we got these request IDs from the database, we can assume + # they are exact matches. + request_id_is_from_db = True + internal_request_ids = set( + event.id for event in daemons.INTERNAL_REQUEST_DAEMONS) + request_ids_to_cancel = [] for request_id in request_ids: - with update_request(request_id) as request_record: + if request_id in internal_request_ids: + continue + request_ids_to_cancel.append(request_id) + del request_ids + cancelled_request_ids = [] + for request_id in request_ids_to_cancel: + with update_request( + request_id, + exact_match=request_id_is_from_db) as request_record: if request_record is None: logger.debug(f'No request ID {request_id}') continue - # Skip internal requests. The internal requests are scheduled with - # request_id in range(len(INTERNAL_REQUEST_EVENTS)). - if request_record.request_id in set( - event.id for event in daemons.INTERNAL_REQUEST_DAEMONS): - continue if request_record.status > RequestStatus.RUNNING: logger.debug(f'Request {request_id} already finished') continue @@ -581,12 +592,14 @@ def request_lock_path(request_id: str) -> str: @contextlib.contextmanager @init_db @metrics_lib.time_me -def update_request(request_id: str) -> Generator[Optional[Request], None, None]: +def update_request( + request_id: str, + exact_match: bool = False) -> Generator[Optional[Request], None, None]: """Get and update a SkyPilot API request.""" # Acquire the lock to avoid race conditions between multiple request # operations, e.g. execute and cancel. with filelock.FileLock(request_lock_path(request_id)): - request = _get_request_no_lock(request_id) + request = _get_request_no_lock(request_id, exact_match=exact_match) yield request if request is not None: _add_or_update_request_no_lock(request) @@ -604,27 +617,39 @@ async def update_status_msg_async(request_id: str, status_msg: str) -> None: await _add_or_update_request_no_lock_async(request) -_get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} ' - 'WHERE request_id LIKE ?') +def _get_request_sql(exact_match: bool = False) -> str: + query = f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE}' + if not exact_match: + query += ' WHERE request_id LIKE ?' + if exact_match: + query += ' WHERE request_id = ?' + return query -def _get_request_no_lock(request_id: str) -> Optional[Request]: +def _get_request_no_lock(request_id: str, + exact_match: bool = False) -> Optional[Request]: """Get a SkyPilot API request.""" assert _DB is not None + if not exact_match: + request_id = request_id + '%' with _DB.conn: cursor = _DB.conn.cursor() - cursor.execute(_get_request_sql, (request_id + '%',)) + cursor.execute(_get_request_sql(exact_match), (request_id,)) row = cursor.fetchone() if row is None: return None return Request.from_row(row) -async def _get_request_no_lock_async(request_id: str) -> Optional[Request]: +async def _get_request_no_lock_async(request_id: str, + exact_match: bool = False + ) -> Optional[Request]: """Async version of _get_request_no_lock.""" assert _DB is not None - async with _DB.execute_fetchall_async(_get_request_sql, - (request_id + '%',)) as rows: + if not exact_match: + request_id = request_id + '%' + async with _DB.execute_fetchall_async(_get_request_sql(exact_match), + (request_id,)) as rows: row = rows[0] if rows else None if row is None: return None @@ -646,20 +671,23 @@ def get_latest_request_id() -> Optional[str]: @init_db @metrics_lib.time_me -def get_request(request_id: str) -> Optional[Request]: +def get_request(request_id: str, + exact_match: bool = False) -> Optional[Request]: """Get a SkyPilot API request.""" with filelock.FileLock(request_lock_path(request_id)): - return _get_request_no_lock(request_id) + return _get_request_no_lock(request_id, exact_match=exact_match) @init_db_async @metrics_lib.time_me_async @asyncio_utils.shield -async def get_request_async(request_id: str) -> Optional[Request]: +async def get_request_async(request_id: str, + exact_match: bool = False) -> Optional[Request]: """Async version of get_request.""" # TODO(aylei): figure out how to remove FileLock here to avoid the overhead async with filelock.AsyncFileLock(request_lock_path(request_id)): - return await _get_request_no_lock_async(request_id) + return await _get_request_no_lock_async(request_id, + exact_match=exact_match) class StatusWithMsg(NamedTuple): @@ -672,12 +700,14 @@ class StatusWithMsg(NamedTuple): async def get_request_status_async( request_id: str, include_msg: bool = False, + exact_match: bool = False, ) -> Optional[StatusWithMsg]: """Get the status of a request. Args: request_id: The ID of the request. include_msg: Whether to include the status message. + exact_match: Whether to match the request ID exactly. Returns: The status of the request. If the request is not found, returns @@ -687,8 +717,13 @@ async def get_request_status_async( columns = 'status' if include_msg: columns += ', status_msg' - sql = f'SELECT {columns} FROM {REQUEST_TABLE} WHERE request_id LIKE ?' - async with _DB.execute_fetchall_async(sql, (request_id + '%',)) as rows: + sql = f'SELECT {columns} FROM {REQUEST_TABLE}' + if not exact_match: + sql += ' WHERE request_id LIKE ?' + request_id = request_id + '%' + if exact_match: + sql += ' WHERE request_id = ?' + async with _DB.execute_fetchall_async(sql, (request_id,)) as rows: if rows is None or len(rows) == 0: return None status = RequestStatus(rows[0][0]) @@ -701,7 +736,8 @@ async def get_request_status_async( def create_if_not_exists(request: Request) -> bool: """Create a SkyPilot API request if it does not exist.""" with filelock.FileLock(request_lock_path(request.request_id)): - if _get_request_no_lock(request.request_id) is not None: + if _get_request_no_lock(request.request_id, + exact_match=True) is not None: return False _add_or_update_request_no_lock(request) return True @@ -713,7 +749,8 @@ def create_if_not_exists(request: Request) -> bool: async def create_if_not_exists_async(request: Request) -> bool: """Async version of create_if_not_exists.""" async with filelock.AsyncFileLock(request_lock_path(request.request_id)): - if await _get_request_no_lock_async(request.request_id) is not None: + if await _get_request_no_lock_async(request.request_id, + exact_match=True) is not None: return False await _add_or_update_request_no_lock_async(request) return True @@ -919,9 +956,9 @@ def set_request_cancelled(request_id: str) -> None: @init_db @metrics_lib.time_me -async def _delete_requests(requests: List[Request]): +async def _delete_requests(request_ids: List[str]): """Clean up requests by their IDs.""" - id_list_str = ','.join(repr(req.request_id) for req in requests) + id_list_str = ','.join(repr(request_id) for request_id in request_ids) assert _DB is not None await _DB.execute_and_commit_async( f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})') @@ -949,7 +986,8 @@ async def clean_finished_requests_with_retention(retention_seconds: int, req_filter=RequestTaskFilter(status=RequestStatus.finished_status(), finished_before=time.time() - retention_seconds, - limit=batch_size)) + limit=batch_size, + fields=['request_id'])) if len(reqs) == 0: break futs = [] @@ -960,7 +998,7 @@ async def clean_finished_requests_with_retention(retention_seconds: int, req.log_path.absolute()).unlink(missing_ok=True))) await asyncio.gather(*futs) - await _delete_requests(reqs) + await _delete_requests([req.request_id for req in reqs]) total_deleted += len(reqs) if len(reqs) < batch_size: break diff --git a/sky/server/stream_utils.py b/sky/server/stream_utils.py index 6e8485fb625..9559e72c1fd 100644 --- a/sky/server/stream_utils.py +++ b/sky/server/stream_utils.py @@ -136,7 +136,10 @@ async def _tail_log_file( polling_interval: float = DEFAULT_POLL_INTERVAL ) -> AsyncGenerator[str, None]: """Tail the opened log file, buffer the lines and flush in chunks.""" - + exact_request_id = None + if request_id is not None: + request_task = await requests_lib.get_request_async(request_id) + exact_request_id = request_task.request_id if tail is not None: # Find last n lines of the log file. Do not read the whole file into # memory. @@ -189,23 +192,22 @@ async def flush_buffer() -> AsyncGenerator[str, None]: # check the status so that we display the final request status # if the request is complete. should_check_status = True - if request_id is not None and should_check_status: + if exact_request_id is not None and should_check_status: last_status_check_time = current_time req_status = await requests_lib.get_request_status_async( - request_id) + exact_request_id, exact_match=True) if req_status.status > requests_lib.RequestStatus.RUNNING: if (req_status.status == requests_lib.RequestStatus.CANCELLED): request_task = await requests_lib.get_request_async( - request_id) + exact_request_id, exact_match=True) if request_task.should_retry: buffer.append( message_utils.encode_payload( rich_utils.Control.RETRY.encode(''))) else: - buffer.append( - f'{request_task.name!r} request {request_id}' - ' cancelled\n') + buffer.append(f'{request_task.name!r} request ' + f'{exact_request_id} cancelled\n') break if not follow: # The below checks (cluster status, heartbeat) are not needed diff --git a/sky/server/uvicorn.py b/sky/server/uvicorn.py index 3c447867f77..74912ede063 100644 --- a/sky/server/uvicorn.py +++ b/sky/server/uvicorn.py @@ -46,11 +46,11 @@ # TODO(aylei): use decorator to register requests that need to be proactively # cancelled instead of hardcoding here. -_RETRIABLE_REQUEST_NAMES = [ +_RETRIABLE_REQUEST_NAMES = { 'sky.logs', 'sky.jobs.logs', 'sky.serve.logs', -] +} def add_timestamp_prefix_for_server_logs() -> None: @@ -152,16 +152,17 @@ def _wait_requests(self) -> None: requests_lib.RequestStatus.RUNNING, ] reqs = requests_lib.get_request_tasks( - req_filter=requests_lib.RequestTaskFilter(status=statuses)) + req_filter=requests_lib.RequestTaskFilter( + status=statuses, fields=['request_id', 'name'])) if not reqs: break logger.info(f'{len(reqs)} on-going requests ' 'found, waiting for them to finish...') # Proactively cancel internal requests and logs requests since # they can run for infinite time. - internal_request_ids = [ + internal_request_ids = { d.id for d in daemons.INTERNAL_REQUEST_DAEMONS - ] + } if time.time() - start_time > _WAIT_REQUESTS_TIMEOUT_SECONDS: logger.warning('Timeout waiting for on-going requests to ' 'finish, cancelling all on-going requests.') diff --git a/tests/test_api.py b/tests/test_api.py index d3bdc81657e..563e599bb9a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -61,7 +61,7 @@ def __init__(self): async def mock_get_request(request_id): return MockRequest() - async def mock_get_request_status(request_id): + async def mock_get_request_status(request_id, exact_match=False): return requests_lib.StatusWithMsg(MockRequest().status, MockRequest().status_msg) @@ -153,7 +153,7 @@ def __init__(self): async def mock_get_request(request_id): return MockRequest() - async def mock_get_request_status(request_id): + async def mock_get_request_status(request_id, exact_match=False): return requests_lib.StatusWithMsg(MockRequest().status, MockRequest().status_msg) diff --git a/tests/unit_tests/test_sky/server/requests/test_requests.py b/tests/unit_tests/test_sky/server/requests/test_requests.py index 12852e1cded..96dd2dab819 100644 --- a/tests/unit_tests/test_sky/server/requests/test_requests.py +++ b/tests/unit_tests/test_sky/server/requests/test_requests.py @@ -1404,7 +1404,7 @@ def test_update_request_row_fields_maintains_order(): @pytest.mark.asyncio async def test_cancel_get_request_async(): - async def mock_get_request_async_no_lock(id: str): + async def mock_get_request_async_no_lock(id: str, exact_match: bool = False): await asyncio.sleep(1) return None