@@ -400,7 +400,8 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
400400 for request_task in get_request_tasks (req_filter = RequestTaskFilter (
401401 status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
402402 exclude_request_names = [exclude_request_name ],
403- cluster_names = [cluster_name ]))
403+ cluster_names = [cluster_name ],
404+ fields = ['request_id' ]))
404405 ]
405406 kill_requests (request_ids )
406407
@@ -418,26 +419,36 @@ def kill_requests(request_ids: Optional[List[str]] = None,
418419 Returns:
419420 A list of request IDs that were cancelled.
420421 """
422+ request_id_is_from_db = False
421423 if request_ids is None :
422424 request_ids = [
423425 request_task .request_id
424426 for request_task in get_request_tasks (req_filter = RequestTaskFilter (
425427 status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
426428 # Avoid cancelling the cancel request itself.
427429 exclude_request_names = ['sky.api_cancel' ],
428- user_id = user_id ))
430+ user_id = user_id ,
431+ fields = ['request_id' ]))
429432 ]
430- cancelled_request_ids = []
433+ # Since we got these request IDs from the database, we can assume
434+ # they are exact matches.
435+ request_id_is_from_db = True
436+ internal_request_ids = set (
437+ event .id for event in daemons .INTERNAL_REQUEST_DAEMONS )
438+ request_ids_to_cancel = []
431439 for request_id in request_ids :
432- with update_request (request_id ) as request_record :
440+ if request_id in internal_request_ids :
441+ continue
442+ request_ids_to_cancel .append (request_id )
443+ del request_ids
444+ cancelled_request_ids = []
445+ for request_id in request_ids_to_cancel :
446+ with update_request (
447+ request_id ,
448+ exact_match = request_id_is_from_db ) as request_record :
433449 if request_record is None :
434450 logger .debug (f'No request ID { request_id } ' )
435451 continue
436- # Skip internal requests. The internal requests are scheduled with
437- # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
438- if request_record .request_id in set (
439- event .id for event in daemons .INTERNAL_REQUEST_DAEMONS ):
440- continue
441452 if request_record .status > RequestStatus .RUNNING :
442453 logger .debug (f'Request { request_id } already finished' )
443454 continue
@@ -581,12 +592,14 @@ def request_lock_path(request_id: str) -> str:
581592@contextlib .contextmanager
582593@init_db
583594@metrics_lib .time_me
584- def update_request (request_id : str ) -> Generator [Optional [Request ], None , None ]:
595+ def update_request (
596+ request_id : str ,
597+ exact_match : bool = False ) -> Generator [Optional [Request ], None , None ]:
585598 """Get and update a SkyPilot API request."""
586599 # Acquire the lock to avoid race conditions between multiple request
587600 # operations, e.g. execute and cancel.
588601 with filelock .FileLock (request_lock_path (request_id )):
589- request = _get_request_no_lock (request_id )
602+ request = _get_request_no_lock (request_id , exact_match = exact_match )
590603 yield request
591604 if request is not None :
592605 _add_or_update_request_no_lock (request )
@@ -604,27 +617,39 @@ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
604617 await _add_or_update_request_no_lock_async (request )
605618
606619
607- _get_request_sql = (f'SELECT { ", " .join (REQUEST_COLUMNS )} FROM { REQUEST_TABLE } '
608- 'WHERE request_id LIKE ?' )
620+ def _get_request_sql (exact_match : bool = False ) -> str :
621+ query = f'SELECT { ", " .join (REQUEST_COLUMNS )} FROM { REQUEST_TABLE } '
622+ if not exact_match :
623+ query += ' WHERE request_id LIKE ?'
624+ if exact_match :
625+ query += ' WHERE request_id = ?'
626+ return query
609627
610628
611- def _get_request_no_lock (request_id : str ) -> Optional [Request ]:
629+ def _get_request_no_lock (request_id : str ,
630+ exact_match : bool = False ) -> Optional [Request ]:
612631 """Get a SkyPilot API request."""
613632 assert _DB is not None
633+ if not exact_match :
634+ request_id = request_id + '%'
614635 with _DB .conn :
615636 cursor = _DB .conn .cursor ()
616- cursor .execute (_get_request_sql , (request_id + '%' ,))
637+ cursor .execute (_get_request_sql ( exact_match ) , (request_id ,))
617638 row = cursor .fetchone ()
618639 if row is None :
619640 return None
620641 return Request .from_row (row )
621642
622643
623- async def _get_request_no_lock_async (request_id : str ) -> Optional [Request ]:
644+ async def _get_request_no_lock_async (request_id : str ,
645+ exact_match : bool = False
646+ ) -> Optional [Request ]:
624647 """Async version of _get_request_no_lock."""
625648 assert _DB is not None
626- async with _DB .execute_fetchall_async (_get_request_sql ,
627- (request_id + '%' ,)) as rows :
649+ if not exact_match :
650+ request_id = request_id + '%'
651+ async with _DB .execute_fetchall_async (_get_request_sql (exact_match ),
652+ (request_id ,)) as rows :
628653 row = rows [0 ] if rows else None
629654 if row is None :
630655 return None
@@ -646,20 +671,23 @@ def get_latest_request_id() -> Optional[str]:
646671
647672@init_db
648673@metrics_lib .time_me
649- def get_request (request_id : str ) -> Optional [Request ]:
674+ def get_request (request_id : str ,
675+ exact_match : bool = False ) -> Optional [Request ]:
650676 """Get a SkyPilot API request."""
651677 with filelock .FileLock (request_lock_path (request_id )):
652- return _get_request_no_lock (request_id )
678+ return _get_request_no_lock (request_id , exact_match = exact_match )
653679
654680
655681@init_db_async
656682@metrics_lib .time_me_async
657683@asyncio_utils .shield
658- async def get_request_async (request_id : str ) -> Optional [Request ]:
684+ async def get_request_async (request_id : str ,
685+ exact_match : bool = False ) -> Optional [Request ]:
659686 """Async version of get_request."""
660687 # TODO(aylei): figure out how to remove FileLock here to avoid the overhead
661688 async with filelock .AsyncFileLock (request_lock_path (request_id )):
662- return await _get_request_no_lock_async (request_id )
689+ return await _get_request_no_lock_async (request_id ,
690+ exact_match = exact_match )
663691
664692
665693class StatusWithMsg (NamedTuple ):
@@ -672,12 +700,14 @@ class StatusWithMsg(NamedTuple):
672700async def get_request_status_async (
673701 request_id : str ,
674702 include_msg : bool = False ,
703+ exact_match : bool = False ,
675704) -> Optional [StatusWithMsg ]:
676705 """Get the status of a request.
677706
678707 Args:
679708 request_id: The ID of the request.
680709 include_msg: Whether to include the status message.
710+ exact_match: Whether to match the request ID exactly.
681711
682712 Returns:
683713 The status of the request. If the request is not found, returns
@@ -687,8 +717,13 @@ async def get_request_status_async(
687717 columns = 'status'
688718 if include_msg :
689719 columns += ', status_msg'
690- sql = f'SELECT { columns } FROM { REQUEST_TABLE } WHERE request_id LIKE ?'
691- async with _DB .execute_fetchall_async (sql , (request_id + '%' ,)) as rows :
720+ sql = f'SELECT { columns } FROM { REQUEST_TABLE } '
721+ if not exact_match :
722+ sql += ' WHERE request_id LIKE ?'
723+ request_id = request_id + '%'
724+ if exact_match :
725+ sql += ' WHERE request_id = ?'
726+ async with _DB .execute_fetchall_async (sql , (request_id ,)) as rows :
692727 if rows is None or len (rows ) == 0 :
693728 return None
694729 status = RequestStatus (rows [0 ][0 ])
@@ -701,7 +736,8 @@ async def get_request_status_async(
701736def create_if_not_exists (request : Request ) -> bool :
702737 """Create a SkyPilot API request if it does not exist."""
703738 with filelock .FileLock (request_lock_path (request .request_id )):
704- if _get_request_no_lock (request .request_id ) is not None :
739+ if _get_request_no_lock (request .request_id ,
740+ exact_match = True ) is not None :
705741 return False
706742 _add_or_update_request_no_lock (request )
707743 return True
@@ -713,7 +749,8 @@ def create_if_not_exists(request: Request) -> bool:
713749async def create_if_not_exists_async (request : Request ) -> bool :
714750 """Async version of create_if_not_exists."""
715751 async with filelock .AsyncFileLock (request_lock_path (request .request_id )):
716- if await _get_request_no_lock_async (request .request_id ) is not None :
752+ if await _get_request_no_lock_async (request .request_id ,
753+ exact_match = True ) is not None :
717754 return False
718755 await _add_or_update_request_no_lock_async (request )
719756 return True
@@ -748,6 +785,7 @@ class RequestTaskFilter:
748785 finished_before : Optional [float ] = None
749786 limit : Optional [int ] = None
750787 fields : Optional [List [str ]] = None
788+ sort : bool = False
751789
752790 def __post_init__ (self ):
753791 if (self .exclude_request_names is not None and
@@ -792,8 +830,11 @@ def build_query(self) -> Tuple[str, List[Any]]:
792830 columns_str = ', ' .join (REQUEST_COLUMNS )
793831 if self .fields :
794832 columns_str = ', ' .join (self .fields )
795- query_str = (f'SELECT { columns_str } FROM { REQUEST_TABLE } { filter_str } '
796- 'ORDER BY created_at DESC' )
833+ sort_str = ''
834+ if self .sort :
835+ sort_str = ' ORDER BY created_at DESC'
836+ query_str = (f'SELECT { columns_str } FROM { REQUEST_TABLE } { filter_str } '
837+ f'{ sort_str } ' )
797838 if self .limit is not None :
798839 query_str += f' LIMIT { self .limit } '
799840 return query_str , filter_params
@@ -915,9 +956,9 @@ def set_request_cancelled(request_id: str) -> None:
915956
916957@init_db
917958@metrics_lib .time_me
918- async def _delete_requests (requests : List [Request ]):
959+ async def _delete_requests (request_ids : List [str ]):
919960 """Clean up requests by their IDs."""
920- id_list_str = ',' .join (repr (req . request_id ) for req in requests )
961+ id_list_str = ',' .join (repr (request_id ) for request_id in request_ids )
921962 assert _DB is not None
922963 await _DB .execute_and_commit_async (
923964 f'DELETE FROM { REQUEST_TABLE } WHERE request_id IN ({ id_list_str } )' )
@@ -936,16 +977,19 @@ async def clean_finished_requests_with_retention(retention_seconds: int):
936977 reqs = await get_request_tasks_async (
937978 req_filter = RequestTaskFilter (status = RequestStatus .finished_status (),
938979 finished_before = time .time () -
939- retention_seconds ))
980+ retention_seconds ,
981+ fields = ['request_id' ]))
940982
941983 futs = []
942984 for req in reqs :
985+ # req.log_path is derived from request_id,
986+ # so it's ok to just grab the request_id in the above query.
943987 futs .append (
944988 asyncio .create_task (
945989 anyio .Path (req .log_path .absolute ()).unlink (missing_ok = True )))
946990 await asyncio .gather (* futs )
947991
948- await _delete_requests (reqs )
992+ await _delete_requests ([ req . request_id for req in reqs ] )
949993
950994 # To avoid leakage of the log file, logs must be deleted before the
951995 # request task in the database.
0 commit comments