Skip to content

Commit a91e4d3

Browse files
authored
[db] only query for specific fields of requests (#7658)
* only query for specific fields * fix ut * apply the same trick to individual request gets * fix UT
1 parent 222f220 commit a91e4d3

File tree

10 files changed

+169
-57
lines changed

10 files changed

+169
-57
lines changed

sky/backends/backend_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3128,12 +3128,12 @@ def refresh_cluster_records() -> None:
31283128
# request info in backend_utils.py.
31293129
# Refactor this to use some other info to
31303130
# determine if a launch is in progress.
3131-
requests = requests_lib.get_request_tasks(
3132-
req_filter=requests_lib.RequestTaskFilter(
3133-
status=[requests_lib.RequestStatus.RUNNING],
3134-
include_request_names=['sky.launch']))
31353131
cluster_names_with_launch_request = {
3136-
request.cluster_name for request in requests
3132+
request.cluster_name for request in requests_lib.get_request_tasks(
3133+
req_filter=requests_lib.RequestTaskFilter(
3134+
status=[requests_lib.RequestStatus.RUNNING],
3135+
include_request_names=['sky.launch'],
3136+
fields=['cluster_name']))
31373137
}
31383138
cluster_names_without_launch_request = (cluster_names -
31393139
cluster_names_with_launch_request)
@@ -3356,13 +3356,13 @@ def _refresh_cluster_record(cluster_name):
33563356
# request info in backend_utils.py.
33573357
# Refactor this to use some other info to
33583358
# determine if a launch is in progress.
3359-
requests = requests_lib.get_request_tasks(
3360-
req_filter=requests_lib.RequestTaskFilter(
3361-
status=[requests_lib.RequestStatus.RUNNING],
3362-
include_request_names=['sky.launch'],
3363-
cluster_names=cluster_names))
33643359
cluster_names_with_launch_request = {
3365-
request.cluster_name for request in requests
3360+
request.cluster_name for request in requests_lib.get_request_tasks(
3361+
req_filter=requests_lib.RequestTaskFilter(
3362+
status=[requests_lib.RequestStatus.RUNNING],
3363+
include_request_names=['sky.launch'],
3364+
cluster_names=cluster_names,
3365+
fields=['cluster_name']))
33663366
}
33673367
# Preserve the index of the cluster name as it appears on "records"
33683368
cluster_names_without_launch_request = [

sky/jobs/server/server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,13 @@ async def pool_tail_logs(
201201
request_cluster_name=common.JOB_CONTROLLER_NAME,
202202
)
203203

204-
request_task = api_requests.get_request(request.state.request_id)
204+
request_task = api_requests.get_request(request.state.request_id,
205+
fields=['request_id'])
205206

206207
return stream_utils.stream_response_for_long_request(
207208
request_id=request_task.request_id,
209+
# req.log_path is derived from request_id,
210+
# so it's ok to just grab the request_id in the above query.
208211
logs_path=request_task.log_path,
209212
background_tasks=background_tasks,
210213
)

sky/server/requests/executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,11 @@ def process_request(self, executor: process.BurstableExecutor,
214214
time.sleep(0.1)
215215
return
216216
request_id, ignore_return_value, _ = request_element
217-
request = api_requests.get_request(request_id)
217+
request = api_requests.get_request(request_id, fields=['status'])
218218
assert request is not None, f'Request with ID {request_id} is None'
219219
if request.status == api_requests.RequestStatus.CANCELLED:
220220
return
221+
del request
221222
logger.info(f'[{self}] Submitting request: {request_id}')
222223
# Start additional process to run the request, so that it can be
223224
# cancelled when requested by a user.

sky/server/requests/preconditions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,15 @@ async def _wait(self) -> bool:
9898
return False
9999

100100
# Check if the request has been cancelled
101-
request = await api_requests.get_request_async(self.request_id)
101+
request = await api_requests.get_request_async(self.request_id,
102+
fields=['status'])
102103
if request is None:
103104
logger.error(f'Request {self.request_id} not found')
104105
return False
105106
if request.status == api_requests.RequestStatus.CANCELLED:
106107
logger.debug(f'Request {self.request_id} cancelled')
107108
return False
109+
del request
108110

109111
try:
110112
met, status_msg = await self.check()
@@ -166,7 +168,10 @@ async def check(self) -> Tuple[bool, Optional[str]]:
166168
api_requests.RequestStatus.RUNNING
167169
],
168170
include_request_names=['sky.launch', 'sky.start'],
169-
cluster_names=[self.cluster_name]))
171+
cluster_names=[self.cluster_name],
172+
# Only get the request ID to avoid fetching the whole request.
173+
# We're only interested in the count, not the whole request.
174+
fields=['request_id']))
170175
if len(requests) == 0:
171176
# No running or pending tasks, the start process is done.
172177
return True, None

sky/server/requests/requests.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
400400
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
401401
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
402402
exclude_request_names=[exclude_request_name],
403-
cluster_names=[cluster_name]))
403+
cluster_names=[cluster_name],
404+
fields=['request_id']))
404405
]
405406
kill_requests(request_ids)
406407

@@ -425,7 +426,8 @@ def kill_requests(request_ids: Optional[List[str]] = None,
425426
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
426427
# Avoid cancelling the cancel request itself.
427428
exclude_request_names=['sky.api_cancel'],
428-
user_id=user_id))
429+
user_id=user_id,
430+
fields=['request_id']))
429431
]
430432
cancelled_request_ids = []
431433
for request_id in request_ids:
@@ -604,30 +606,42 @@ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
604606
await _add_or_update_request_no_lock_async(request)
605607

606608

607-
_get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} '
608-
'WHERE request_id LIKE ?')
609-
610-
611-
def _get_request_no_lock(request_id: str) -> Optional[Request]:
609+
def _get_request_no_lock(
610+
request_id: str,
611+
fields: Optional[List[str]] = None) -> Optional[Request]:
612612
"""Get a SkyPilot API request."""
613613
assert _DB is not None
614+
columns_str = ', '.join(REQUEST_COLUMNS)
615+
if fields:
616+
columns_str = ', '.join(fields)
614617
with _DB.conn:
615618
cursor = _DB.conn.cursor()
616-
cursor.execute(_get_request_sql, (request_id + '%',))
619+
cursor.execute((f'SELECT {columns_str} FROM {REQUEST_TABLE} '
620+
'WHERE request_id LIKE ?'), (request_id + '%',))
617621
row = cursor.fetchone()
618622
if row is None:
619623
return None
624+
if fields:
625+
row = _update_request_row_fields(row, fields)
620626
return Request.from_row(row)
621627

622628

623-
async def _get_request_no_lock_async(request_id: str) -> Optional[Request]:
629+
async def _get_request_no_lock_async(
630+
request_id: str,
631+
fields: Optional[List[str]] = None) -> Optional[Request]:
624632
"""Async version of _get_request_no_lock."""
625633
assert _DB is not None
626-
async with _DB.execute_fetchall_async(_get_request_sql,
627-
(request_id + '%',)) as rows:
634+
columns_str = ', '.join(REQUEST_COLUMNS)
635+
if fields:
636+
columns_str = ', '.join(fields)
637+
async with _DB.execute_fetchall_async(
638+
(f'SELECT {columns_str} FROM {REQUEST_TABLE} '
639+
'WHERE request_id LIKE ?'), (request_id + '%',)) as rows:
628640
row = rows[0] if rows else None
629641
if row is None:
630642
return None
643+
if fields:
644+
row = _update_request_row_fields(row, fields)
631645
return Request.from_row(row)
632646

633647

@@ -646,20 +660,23 @@ def get_latest_request_id() -> Optional[str]:
646660

647661
@init_db
648662
@metrics_lib.time_me
649-
def get_request(request_id: str) -> Optional[Request]:
663+
def get_request(request_id: str,
664+
fields: Optional[List[str]] = None) -> Optional[Request]:
650665
"""Get a SkyPilot API request."""
651666
with filelock.FileLock(request_lock_path(request_id)):
652-
return _get_request_no_lock(request_id)
667+
return _get_request_no_lock(request_id, fields)
653668

654669

655670
@init_db_async
656671
@metrics_lib.time_me_async
657672
@asyncio_utils.shield
658-
async def get_request_async(request_id: str) -> Optional[Request]:
673+
async def get_request_async(
674+
request_id: str,
675+
fields: Optional[List[str]] = None) -> Optional[Request]:
659676
"""Async version of get_request."""
660677
# TODO(aylei): figure out how to remove FileLock here to avoid the overhead
661678
async with filelock.AsyncFileLock(request_lock_path(request_id)):
662-
return await _get_request_no_lock_async(request_id)
679+
return await _get_request_no_lock_async(request_id, fields)
663680

664681

665682
class StatusWithMsg(NamedTuple):
@@ -919,9 +936,9 @@ def set_request_cancelled(request_id: str) -> None:
919936

920937
@init_db
921938
@metrics_lib.time_me
922-
async def _delete_requests(requests: List[Request]):
939+
async def _delete_requests(request_ids: List[str]):
923940
"""Clean up requests by their IDs."""
924-
id_list_str = ','.join(repr(req.request_id) for req in requests)
941+
id_list_str = ','.join(repr(request_id) for request_id in request_ids)
925942
assert _DB is not None
926943
await _DB.execute_and_commit_async(
927944
f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
@@ -949,18 +966,21 @@ async def clean_finished_requests_with_retention(retention_seconds: int,
949966
req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
950967
finished_before=time.time() -
951968
retention_seconds,
952-
limit=batch_size))
969+
limit=batch_size,
970+
fields=['request_id']))
953971
if len(reqs) == 0:
954972
break
955973
futs = []
956974
for req in reqs:
975+
# req.log_path is derived from request_id,
976+
# so it's ok to just grab the request_id in the above query.
957977
futs.append(
958978
asyncio.create_task(
959979
anyio.Path(
960980
req.log_path.absolute()).unlink(missing_ok=True)))
961981
await asyncio.gather(*futs)
962982

963-
await _delete_requests(reqs)
983+
await _delete_requests([req.request_id for req in reqs])
964984
total_deleted += len(reqs)
965985
if len(reqs) < batch_size:
966986
break

sky/server/server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1567,11 +1567,14 @@ async def stream(
15671567
polling_interval = stream_utils.DEFAULT_POLL_INTERVAL
15681568
# Original plain text streaming logic
15691569
if request_id is not None:
1570-
request_task = await requests_lib.get_request_async(request_id)
1570+
request_task = await requests_lib.get_request_async(
1571+
request_id, fields=['request_id', 'schedule_type'])
15711572
if request_task is None:
15721573
print(f'No task with request ID {request_id}')
15731574
raise fastapi.HTTPException(
15741575
status_code=404, detail=f'Request {request_id!r} not found')
1576+
# req.log_path is derived from request_id,
1577+
# so it's ok to just grab the request_id in the above query.
15751578
log_path_to_stream = request_task.log_path
15761579
if not log_path_to_stream.exists():
15771580
# The log file might be deleted by the request GC daemon but the
@@ -1581,6 +1584,7 @@ async def stream(
15811584
detail=f'Log of request {request_id!r} has been deleted')
15821585
if request_task.schedule_type == requests_lib.ScheduleType.LONG:
15831586
polling_interval = stream_utils.LONG_REQUEST_POLL_INTERVAL
1587+
del request_task
15841588
else:
15851589
assert log_path is not None, (request_id, log_path)
15861590
if log_path == constants.API_SERVER_LOGS:

sky/server/stream_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,14 @@ async def log_streamer(
6868
if request_id is not None:
6969
status_msg = rich_utils.EncodedStatusMessage(
7070
f'[dim]Checking request: {request_id}[/dim]')
71-
request_task = await requests_lib.get_request_async(request_id)
71+
request_task = await requests_lib.get_request_async(request_id,
72+
fields=[
73+
'request_id',
74+
'name',
75+
'schedule_type',
76+
'status',
77+
'status_msg'
78+
])
7279

7380
if request_task is None:
7481
raise fastapi.HTTPException(
@@ -89,14 +96,15 @@ async def log_streamer(
8996
f'scheduled: {request_id}')
9097
req_status = request_task.status
9198
req_msg = request_task.status_msg
99+
del request_task
92100
# Slowly back off the database polling up to every 1 second, to avoid
93101
# overloading the CPU and DB.
94102
backoff = common_utils.Backoff(initial_backoff=polling_interval,
95103
max_backoff_factor=10,
96104
multiplier=1.2)
97105
while req_status < requests_lib.RequestStatus.RUNNING:
98106
if req_msg is not None:
99-
waiting_msg = request_task.status_msg
107+
waiting_msg = req_msg
100108
if show_request_waiting_spinner:
101109
yield status_msg.update(f'[dim]{waiting_msg}[/dim]')
102110
elif plain_logs and waiting_msg != last_waiting_msg:
@@ -197,7 +205,7 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
197205
if (req_status.status ==
198206
requests_lib.RequestStatus.CANCELLED):
199207
request_task = await requests_lib.get_request_async(
200-
request_id)
208+
request_id, fields=['name', 'should_retry'])
201209
if request_task.should_retry:
202210
buffer.append(
203211
message_utils.encode_payload(
@@ -206,6 +214,7 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
206214
buffer.append(
207215
f'{request_task.name!r} request {request_id}'
208216
' cancelled\n')
217+
del request_task
209218
break
210219
if not follow:
211220
# The below checks (cluster status, heartbeat) are not needed

sky/server/uvicorn.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646

4747
# TODO(aylei): use decorator to register requests that need to be proactively
4848
# cancelled instead of hardcoding here.
49-
_RETRIABLE_REQUEST_NAMES = [
49+
_RETRIABLE_REQUEST_NAMES = {
5050
'sky.logs',
5151
'sky.jobs.logs',
5252
'sky.serve.logs',
53-
]
53+
}
5454

5555

5656
def add_timestamp_prefix_for_server_logs() -> None:
@@ -151,37 +151,38 @@ def _wait_requests(self) -> None:
151151
requests_lib.RequestStatus.PENDING,
152152
requests_lib.RequestStatus.RUNNING,
153153
]
154-
reqs = requests_lib.get_request_tasks(
155-
req_filter=requests_lib.RequestTaskFilter(status=statuses))
156-
if not reqs:
154+
requests = [(request_task.request_id, request_task.name)
155+
for request_task in requests_lib.get_request_tasks(
156+
req_filter=requests_lib.RequestTaskFilter(
157+
status=statuses, fields=['request_id', 'name']))
158+
]
159+
if not requests:
157160
break
158-
logger.info(f'{len(reqs)} on-going requests '
161+
logger.info(f'{len(requests)} on-going requests '
159162
'found, waiting for them to finish...')
160163
# Proactively cancel internal requests and logs requests since
161164
# they can run for infinite time.
162-
internal_request_ids = [
165+
internal_request_ids = {
163166
d.id for d in daemons.INTERNAL_REQUEST_DAEMONS
164-
]
167+
}
165168
if time.time() - start_time > _WAIT_REQUESTS_TIMEOUT_SECONDS:
166169
logger.warning('Timeout waiting for on-going requests to '
167170
'finish, cancelling all on-going requests.')
168-
for req in reqs:
169-
self.interrupt_request_for_retry(req.request_id)
171+
for request_id, _ in requests:
172+
self.interrupt_request_for_retry(request_id)
170173
break
171174
interrupted = 0
172-
for req in reqs:
173-
if req.request_id in internal_request_ids:
174-
self.interrupt_request_for_retry(req.request_id)
175-
interrupted += 1
176-
elif req.name in _RETRIABLE_REQUEST_NAMES:
177-
self.interrupt_request_for_retry(req.request_id)
175+
for request_id, name in requests:
176+
if (name in _RETRIABLE_REQUEST_NAMES or
177+
request_id in internal_request_ids):
178+
self.interrupt_request_for_retry(request_id)
178179
interrupted += 1
179180
# TODO(aylei): interrupt pending requests to accelerate the
180181
# shutdown.
181182
# If some requests are not interrupted, wait for them to finish,
182183
# otherwise we just check again immediately to accelerate the
183184
# shutdown process.
184-
if interrupted < len(reqs):
185+
if interrupted < len(requests):
185186
time.sleep(_WAIT_REQUESTS_INTERVAL_SECONDS)
186187

187188
def interrupt_request_for_retry(self, request_id: str) -> None:

tests/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self):
5858
self.schedule_type = requests_lib.ScheduleType.LONG
5959
self.status_msg = None
6060

61-
async def mock_get_request(request_id):
61+
async def mock_get_request(request_id, fields):
6262
return MockRequest()
6363

6464
async def mock_get_request_status(request_id):
@@ -150,7 +150,7 @@ def __init__(self):
150150
self.schedule_type = requests_lib.ScheduleType.LONG
151151
self.status_msg = None
152152

153-
async def mock_get_request(request_id):
153+
async def mock_get_request(request_id, fields):
154154
return MockRequest()
155155

156156
async def mock_get_request_status(request_id):

0 commit comments

Comments
 (0)