@@ -400,7 +400,8 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
400400 for request_task in get_request_tasks (req_filter = RequestTaskFilter (
401401 status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
402402 exclude_request_names = [exclude_request_name ],
403- cluster_names = [cluster_name ]))
403+ cluster_names = [cluster_name ],
404+ fields = ['request_id' ]))
404405 ]
405406 kill_requests (request_ids )
406407
@@ -425,7 +426,8 @@ def kill_requests(request_ids: Optional[List[str]] = None,
425426 status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
426427 # Avoid cancelling the cancel request itself.
427428 exclude_request_names = ['sky.api_cancel' ],
428- user_id = user_id ))
429+ user_id = user_id ,
430+ fields = ['request_id' ]))
429431 ]
430432 cancelled_request_ids = []
431433 for request_id in request_ids :
@@ -604,30 +606,42 @@ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
604606 await _add_or_update_request_no_lock_async (request )
605607
606608
607- _get_request_sql = (f'SELECT { ", " .join (REQUEST_COLUMNS )} FROM { REQUEST_TABLE } '
608- 'WHERE request_id LIKE ?' )
609-
610-
611- def _get_request_no_lock (request_id : str ) -> Optional [Request ]:
609+ def _get_request_no_lock (
610+ request_id : str ,
611+ fields : Optional [List [str ]] = None ) -> Optional [Request ]:
612612 """Get a SkyPilot API request."""
613613 assert _DB is not None
614+ columns_str = ', ' .join (REQUEST_COLUMNS )
615+ if fields :
616+ columns_str = ', ' .join (fields )
614617 with _DB .conn :
615618 cursor = _DB .conn .cursor ()
616- cursor .execute (_get_request_sql , (request_id + '%' ,))
619+ cursor .execute ((f'SELECT { columns_str } FROM { REQUEST_TABLE } '
620+ 'WHERE request_id LIKE ?' ), (request_id + '%' ,))
617621 row = cursor .fetchone ()
618622 if row is None :
619623 return None
624+ if fields :
625+ row = _update_request_row_fields (row , fields )
620626 return Request .from_row (row )
621627
622628
623- async def _get_request_no_lock_async (request_id : str ) -> Optional [Request ]:
629+ async def _get_request_no_lock_async (
630+ request_id : str ,
631+ fields : Optional [List [str ]] = None ) -> Optional [Request ]:
624632 """Async version of _get_request_no_lock."""
625633 assert _DB is not None
626- async with _DB .execute_fetchall_async (_get_request_sql ,
627- (request_id + '%' ,)) as rows :
634+ columns_str = ', ' .join (REQUEST_COLUMNS )
635+ if fields :
636+ columns_str = ', ' .join (fields )
637+ async with _DB .execute_fetchall_async (
638+ (f'SELECT { columns_str } FROM { REQUEST_TABLE } '
639+ 'WHERE request_id LIKE ?' ), (request_id + '%' ,)) as rows :
628640 row = rows [0 ] if rows else None
629641 if row is None :
630642 return None
643+ if fields :
644+ row = _update_request_row_fields (row , fields )
631645 return Request .from_row (row )
632646
633647
@@ -646,20 +660,23 @@ def get_latest_request_id() -> Optional[str]:
646660
647661@init_db
648662@metrics_lib .time_me
649- def get_request (request_id : str ) -> Optional [Request ]:
663+ def get_request (request_id : str ,
664+ fields : Optional [List [str ]] = None ) -> Optional [Request ]:
650665 """Get a SkyPilot API request."""
651666 with filelock .FileLock (request_lock_path (request_id )):
652- return _get_request_no_lock (request_id )
667+ return _get_request_no_lock (request_id , fields )
653668
654669
655670@init_db_async
656671@metrics_lib .time_me_async
657672@asyncio_utils .shield
658- async def get_request_async (request_id : str ) -> Optional [Request ]:
673+ async def get_request_async (
674+ request_id : str ,
675+ fields : Optional [List [str ]] = None ) -> Optional [Request ]:
659676 """Async version of get_request."""
660677 # TODO(aylei): figure out how to remove FileLock here to avoid the overhead
661678 async with filelock .AsyncFileLock (request_lock_path (request_id )):
662- return await _get_request_no_lock_async (request_id )
679+ return await _get_request_no_lock_async (request_id , fields )
663680
664681
665682class StatusWithMsg (NamedTuple ):
@@ -919,9 +936,9 @@ def set_request_cancelled(request_id: str) -> None:
919936
920937@init_db
921938@metrics_lib .time_me
922- async def _delete_requests (requests : List [Request ]):
939+ async def _delete_requests (request_ids : List [str ]):
923940 """Clean up requests by their IDs."""
924- id_list_str = ',' .join (repr (req . request_id ) for req in requests )
941+ id_list_str = ',' .join (repr (request_id ) for request_id in request_ids )
925942 assert _DB is not None
926943 await _DB .execute_and_commit_async (
927944 f'DELETE FROM { REQUEST_TABLE } WHERE request_id IN ({ id_list_str } )' )
@@ -949,18 +966,21 @@ async def clean_finished_requests_with_retention(retention_seconds: int,
949966 req_filter = RequestTaskFilter (status = RequestStatus .finished_status (),
950967 finished_before = time .time () -
951968 retention_seconds ,
952- limit = batch_size ))
969+ limit = batch_size ,
970+ fields = ['request_id' ]))
953971 if len (reqs ) == 0 :
954972 break
955973 futs = []
956974 for req in reqs :
975+ # req.log_path is derived from request_id,
976+ # so it's ok to just grab the request_id in the above query.
957977 futs .append (
958978 asyncio .create_task (
959979 anyio .Path (
960980 req .log_path .absolute ()).unlink (missing_ok = True )))
961981 await asyncio .gather (* futs )
962982
963- await _delete_requests (reqs )
983+ await _delete_requests ([ req . request_id for req in reqs ] )
964984 total_deleted += len (reqs )
965985 if len (reqs ) < batch_size :
966986 break
0 commit comments