Skip to content

Commit bc520c8

Browse files
committed
optimize delete requests
optimization2 consolidate funcs more consolidation more optimizations add index fix unit tests address TODO cancel optimizations 1 revert count testfix
1 parent 33025aa commit bc520c8

File tree

9 files changed

+126
-73
lines changed

9 files changed

+126
-73
lines changed

sky/backends/backend_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3131,7 +3131,8 @@ def refresh_cluster_records() -> None:
31313131
requests = requests_lib.get_request_tasks(
31323132
req_filter=requests_lib.RequestTaskFilter(
31333133
status=[requests_lib.RequestStatus.RUNNING],
3134-
include_request_names=['sky.launch']))
3134+
include_request_names=['sky.launch'],
3135+
fields=['cluster_name']))
31353136
cluster_names_with_launch_request = {
31363137
request.cluster_name for request in requests
31373138
}
@@ -3360,7 +3361,8 @@ def _refresh_cluster_record(cluster_name):
33603361
req_filter=requests_lib.RequestTaskFilter(
33613362
status=[requests_lib.RequestStatus.RUNNING],
33623363
include_request_names=['sky.launch'],
3363-
cluster_names=cluster_names))
3364+
cluster_names=cluster_names,
3365+
fields=['cluster_name']))
33643366
cluster_names_with_launch_request = {
33653367
request.cluster_name for request in requests
33663368
}

sky/server/requests/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def process_request(self, executor: process.BurstableExecutor,
214214
time.sleep(0.1)
215215
return
216216
request_id, ignore_return_value, _ = request_element
217-
request = api_requests.get_request(request_id)
217+
request = api_requests.get_request(request_id, exact_match=True)
218218
assert request is not None, f'Request with ID {request_id} is None'
219219
if request.status == api_requests.RequestStatus.CANCELLED:
220220
return

sky/server/requests/preconditions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ async def check(self) -> Tuple[bool, Optional[str]]:
166166
api_requests.RequestStatus.RUNNING
167167
],
168168
include_request_names=['sky.launch', 'sky.start'],
169-
cluster_names=[self.cluster_name]))
169+
cluster_names=[self.cluster_name],
170+
# Only get the request ID to avoid fetching the whole request.
171+
# We're only interested in the count, not the whole request.
172+
fields=['request_id']))
170173
if len(requests) == 0:
171174
# No running or pending tasks, the start process is done.
172175
return True, None

sky/server/requests/requests.py

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
400400
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
401401
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
402402
exclude_request_names=[exclude_request_name],
403-
cluster_names=[cluster_name]))
403+
cluster_names=[cluster_name],
404+
fields=['request_id']))
404405
]
405406
kill_requests(request_ids)
406407

@@ -418,26 +419,36 @@ def kill_requests(request_ids: Optional[List[str]] = None,
418419
Returns:
419420
A list of request IDs that were cancelled.
420421
"""
422+
request_id_is_from_db = False
421423
if request_ids is None:
422424
request_ids = [
423425
request_task.request_id
424426
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
425427
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
426428
# Avoid cancelling the cancel request itself.
427429
exclude_request_names=['sky.api_cancel'],
428-
user_id=user_id))
430+
user_id=user_id,
431+
fields=['request_id']))
429432
]
430-
cancelled_request_ids = []
433+
# Since we got these request IDs from the database, we can assume
434+
# they are exact matches.
435+
request_id_is_from_db = True
436+
internal_request_ids = set(
437+
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS)
438+
request_ids_to_cancel = []
431439
for request_id in request_ids:
432-
with update_request(request_id) as request_record:
440+
if request_id in internal_request_ids:
441+
continue
442+
request_ids_to_cancel.append(request_id)
443+
del request_ids
444+
cancelled_request_ids = []
445+
for request_id in request_ids_to_cancel:
446+
with update_request(
447+
request_id,
448+
exact_match=request_id_is_from_db) as request_record:
433449
if request_record is None:
434450
logger.debug(f'No request ID {request_id}')
435451
continue
436-
# Skip internal requests. The internal requests are scheduled with
437-
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
438-
if request_record.request_id in set(
439-
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
440-
continue
441452
if request_record.status > RequestStatus.RUNNING:
442453
logger.debug(f'Request {request_id} already finished')
443454
continue
@@ -581,12 +592,14 @@ def request_lock_path(request_id: str) -> str:
581592
@contextlib.contextmanager
582593
@init_db
583594
@metrics_lib.time_me
584-
def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
595+
def update_request(
596+
request_id: str,
597+
exact_match: bool = False) -> Generator[Optional[Request], None, None]:
585598
"""Get and update a SkyPilot API request."""
586599
# Acquire the lock to avoid race conditions between multiple request
587600
# operations, e.g. execute and cancel.
588601
with filelock.FileLock(request_lock_path(request_id)):
589-
request = _get_request_no_lock(request_id)
602+
request = _get_request_no_lock(request_id, exact_match=exact_match)
590603
yield request
591604
if request is not None:
592605
_add_or_update_request_no_lock(request)
@@ -604,27 +617,39 @@ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
604617
await _add_or_update_request_no_lock_async(request)
605618

606619

607-
_get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} '
608-
'WHERE request_id LIKE ?')
620+
def _get_request_sql(exact_match: bool = False) -> str:
621+
query = f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE}'
622+
if not exact_match:
623+
query += ' WHERE request_id LIKE ?'
624+
if exact_match:
625+
query += ' WHERE request_id = ?'
626+
return query
609627

610628

611-
def _get_request_no_lock(request_id: str) -> Optional[Request]:
629+
def _get_request_no_lock(request_id: str,
630+
exact_match: bool = False) -> Optional[Request]:
612631
"""Get a SkyPilot API request."""
613632
assert _DB is not None
633+
if not exact_match:
634+
request_id = request_id + '%'
614635
with _DB.conn:
615636
cursor = _DB.conn.cursor()
616-
cursor.execute(_get_request_sql, (request_id + '%',))
637+
cursor.execute(_get_request_sql(exact_match), (request_id,))
617638
row = cursor.fetchone()
618639
if row is None:
619640
return None
620641
return Request.from_row(row)
621642

622643

623-
async def _get_request_no_lock_async(request_id: str) -> Optional[Request]:
644+
async def _get_request_no_lock_async(request_id: str,
645+
exact_match: bool = False
646+
) -> Optional[Request]:
624647
"""Async version of _get_request_no_lock."""
625648
assert _DB is not None
626-
async with _DB.execute_fetchall_async(_get_request_sql,
627-
(request_id + '%',)) as rows:
649+
if not exact_match:
650+
request_id = request_id + '%'
651+
async with _DB.execute_fetchall_async(_get_request_sql(exact_match),
652+
(request_id,)) as rows:
628653
row = rows[0] if rows else None
629654
if row is None:
630655
return None
@@ -646,20 +671,23 @@ def get_latest_request_id() -> Optional[str]:
646671

647672
@init_db
648673
@metrics_lib.time_me
649-
def get_request(request_id: str) -> Optional[Request]:
674+
def get_request(request_id: str,
675+
exact_match: bool = False) -> Optional[Request]:
650676
"""Get a SkyPilot API request."""
651677
with filelock.FileLock(request_lock_path(request_id)):
652-
return _get_request_no_lock(request_id)
678+
return _get_request_no_lock(request_id, exact_match=exact_match)
653679

654680

655681
@init_db_async
656682
@metrics_lib.time_me_async
657683
@asyncio_utils.shield
658-
async def get_request_async(request_id: str) -> Optional[Request]:
684+
async def get_request_async(request_id: str,
685+
exact_match: bool = False) -> Optional[Request]:
659686
"""Async version of get_request."""
660687
# TODO(aylei): figure out how to remove FileLock here to avoid the overhead
661688
async with filelock.AsyncFileLock(request_lock_path(request_id)):
662-
return await _get_request_no_lock_async(request_id)
689+
return await _get_request_no_lock_async(request_id,
690+
exact_match=exact_match)
663691

664692

665693
class StatusWithMsg(NamedTuple):
@@ -672,12 +700,14 @@ class StatusWithMsg(NamedTuple):
672700
async def get_request_status_async(
673701
request_id: str,
674702
include_msg: bool = False,
703+
exact_match: bool = False,
675704
) -> Optional[StatusWithMsg]:
676705
"""Get the status of a request.
677706
678707
Args:
679708
request_id: The ID of the request.
680709
include_msg: Whether to include the status message.
710+
exact_match: Whether to match the request ID exactly.
681711
682712
Returns:
683713
The status of the request. If the request is not found, returns
@@ -687,8 +717,13 @@ async def get_request_status_async(
687717
columns = 'status'
688718
if include_msg:
689719
columns += ', status_msg'
690-
sql = f'SELECT {columns} FROM {REQUEST_TABLE} WHERE request_id LIKE ?'
691-
async with _DB.execute_fetchall_async(sql, (request_id + '%',)) as rows:
720+
sql = f'SELECT {columns} FROM {REQUEST_TABLE}'
721+
if not exact_match:
722+
sql += ' WHERE request_id LIKE ?'
723+
request_id = request_id + '%'
724+
if exact_match:
725+
sql += ' WHERE request_id = ?'
726+
async with _DB.execute_fetchall_async(sql, (request_id,)) as rows:
692727
if rows is None or len(rows) == 0:
693728
return None
694729
status = RequestStatus(rows[0][0])
@@ -701,7 +736,8 @@ async def get_request_status_async(
701736
def create_if_not_exists(request: Request) -> bool:
702737
"""Create a SkyPilot API request if it does not exist."""
703738
with filelock.FileLock(request_lock_path(request.request_id)):
704-
if _get_request_no_lock(request.request_id) is not None:
739+
if _get_request_no_lock(request.request_id,
740+
exact_match=True) is not None:
705741
return False
706742
_add_or_update_request_no_lock(request)
707743
return True
@@ -713,7 +749,8 @@ def create_if_not_exists(request: Request) -> bool:
713749
async def create_if_not_exists_async(request: Request) -> bool:
714750
"""Async version of create_if_not_exists."""
715751
async with filelock.AsyncFileLock(request_lock_path(request.request_id)):
716-
if await _get_request_no_lock_async(request.request_id) is not None:
752+
if await _get_request_no_lock_async(request.request_id,
753+
exact_match=True) is not None:
717754
return False
718755
await _add_or_update_request_no_lock_async(request)
719756
return True
@@ -748,6 +785,7 @@ class RequestTaskFilter:
748785
finished_before: Optional[float] = None
749786
limit: Optional[int] = None
750787
fields: Optional[List[str]] = None
788+
sort: bool = False
751789

752790
def __post_init__(self):
753791
if (self.exclude_request_names is not None and
@@ -792,8 +830,11 @@ def build_query(self) -> Tuple[str, List[Any]]:
792830
columns_str = ', '.join(REQUEST_COLUMNS)
793831
if self.fields:
794832
columns_str = ', '.join(self.fields)
795-
query_str = (f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str} '
796-
'ORDER BY created_at DESC')
833+
sort_str = ''
834+
if self.sort:
835+
sort_str = ' ORDER BY created_at DESC'
836+
query_str = (f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str}'
837+
f'{sort_str}')
797838
if self.limit is not None:
798839
query_str += f' LIMIT {self.limit}'
799840
return query_str, filter_params
@@ -915,9 +956,9 @@ def set_request_cancelled(request_id: str) -> None:
915956

916957
@init_db
917958
@metrics_lib.time_me
918-
async def _delete_requests(requests: List[Request]):
959+
async def _delete_requests(request_ids: List[str]):
919960
"""Clean up requests by their IDs."""
920-
id_list_str = ','.join(repr(req.request_id) for req in requests)
961+
id_list_str = ','.join(repr(request_id) for request_id in request_ids)
921962
assert _DB is not None
922963
await _DB.execute_and_commit_async(
923964
f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
@@ -936,16 +977,19 @@ async def clean_finished_requests_with_retention(retention_seconds: int):
936977
reqs = await get_request_tasks_async(
937978
req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
938979
finished_before=time.time() -
939-
retention_seconds))
980+
retention_seconds,
981+
fields=['request_id']))
940982

941983
futs = []
942984
for req in reqs:
985+
# req.log_path is derived from request_id,
986+
# so it's ok to just grab the request_id in the above query.
943987
futs.append(
944988
asyncio.create_task(
945989
anyio.Path(req.log_path.absolute()).unlink(missing_ok=True)))
946990
await asyncio.gather(*futs)
947991

948-
await _delete_requests(reqs)
992+
await _delete_requests([req.request_id for req in reqs])
949993

950994
# To avoid leakage of the log file, logs must be deleted before the
951995
# request task in the database.

sky/server/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,7 @@ async def api_status(
16721672
status=statuses,
16731673
limit=limit,
16741674
fields=fields,
1675+
sort=True,
16751676
))
16761677
return requests_lib.encode_requests(request_tasks)
16771678
else:

sky/server/stream_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ async def _tail_log_file(
136136
polling_interval: float = DEFAULT_POLL_INTERVAL
137137
) -> AsyncGenerator[str, None]:
138138
"""Tail the opened log file, buffer the lines and flush in chunks."""
139-
139+
exact_request_id = None
140+
if request_id is not None:
141+
request_task = await requests_lib.get_request_async(request_id)
142+
exact_request_id = request_task.request_id
140143
if tail is not None:
141144
# Find last n lines of the log file. Do not read the whole file into
142145
# memory.
@@ -189,23 +192,22 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
189192
# check the status so that we display the final request status
190193
# if the request is complete.
191194
should_check_status = True
192-
if request_id is not None and should_check_status:
195+
if exact_request_id is not None and should_check_status:
193196
last_status_check_time = current_time
194197
req_status = await requests_lib.get_request_status_async(
195-
request_id)
198+
exact_request_id, exact_match=True)
196199
if req_status.status > requests_lib.RequestStatus.RUNNING:
197200
if (req_status.status ==
198201
requests_lib.RequestStatus.CANCELLED):
199202
request_task = await requests_lib.get_request_async(
200-
request_id)
203+
exact_request_id, exact_match=True)
201204
if request_task.should_retry:
202205
buffer.append(
203206
message_utils.encode_payload(
204207
rich_utils.Control.RETRY.encode('')))
205208
else:
206-
buffer.append(
207-
f'{request_task.name!r} request {request_id}'
208-
' cancelled\n')
209+
buffer.append(f'{request_task.name!r} request '
210+
f'{exact_request_id} cancelled\n')
209211
break
210212
if not follow:
211213
# The below checks (cluster status, heartbeat) are not needed

sky/server/uvicorn.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646

4747
# TODO(aylei): use decorator to register requests that need to be proactively
4848
# cancelled instead of hardcoding here.
49-
_RETRIABLE_REQUEST_NAMES = [
49+
_RETRIABLE_REQUEST_NAMES = {
5050
'sky.logs',
5151
'sky.jobs.logs',
5252
'sky.serve.logs',
53-
]
53+
}
5454

5555

5656
def add_timestamp_prefix_for_server_logs() -> None:
@@ -152,16 +152,17 @@ def _wait_requests(self) -> None:
152152
requests_lib.RequestStatus.RUNNING,
153153
]
154154
reqs = requests_lib.get_request_tasks(
155-
req_filter=requests_lib.RequestTaskFilter(status=statuses))
155+
req_filter=requests_lib.RequestTaskFilter(
156+
status=statuses, fields=['request_id', 'name']))
156157
if not reqs:
157158
break
158159
logger.info(f'{len(reqs)} on-going requests '
159160
'found, waiting for them to finish...')
160161
# Proactively cancel internal requests and logs requests since
161162
# they can run for infinite time.
162-
internal_request_ids = [
163+
internal_request_ids = {
163164
d.id for d in daemons.INTERNAL_REQUEST_DAEMONS
164-
]
165+
}
165166
if time.time() - start_time > _WAIT_REQUESTS_TIMEOUT_SECONDS:
166167
logger.warning('Timeout waiting for on-going requests to '
167168
'finish, cancelling all on-going requests.')

tests/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self):
6161
async def mock_get_request(request_id):
6262
return MockRequest()
6363

64-
async def mock_get_request_status(request_id):
64+
async def mock_get_request_status(request_id, exact_match=False):
6565
return requests_lib.StatusWithMsg(MockRequest().status,
6666
MockRequest().status_msg)
6767

@@ -153,7 +153,7 @@ def __init__(self):
153153
async def mock_get_request(request_id):
154154
return MockRequest()
155155

156-
async def mock_get_request_status(request_id):
156+
async def mock_get_request_status(request_id, exact_match=False):
157157
return requests_lib.StatusWithMsg(MockRequest().status,
158158
MockRequest().status_msg)
159159

0 commit comments

Comments
 (0)