Skip to content

Commit 6236ce8

Browse files
committed
retrieve required fields only from db
1 parent a1acf43 commit 6236ce8

File tree

3 files changed

+330
-397
lines changed

3 files changed

+330
-397
lines changed

sky/server/requests/requests.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,7 @@ def decode(cls, payload: payloads.RequestPayload) -> 'Request':
292292
raise
293293

294294

295-
def encode_requests(
296-
requests: List[Request],
297-
fields: Optional[List[str]] = None) -> List[payloads.RequestPayload]:
295+
def encode_requests(requests: List[Request]) -> List[payloads.RequestPayload]:
298296
"""Serialize the SkyPilot API request for display purposes.
299297
300298
This function should be called on the server side to serialize the
@@ -312,15 +310,18 @@ def encode_requests(
312310
all_users = global_user_state.get_all_users()
313311
all_users_map = {user.id: user.name for user in all_users}
314312
for request in requests:
315-
assert isinstance(request.request_body,
316-
payloads.RequestBody), (request.name,
317-
request.request_body)
313+
if request.request_body is not None:
314+
assert isinstance(request.request_body,
315+
payloads.RequestBody), (request.name,
316+
request.request_body)
318317
user_name = all_users_map.get(request.user_id)
319318
payload = payloads.RequestPayload(
320319
request_id=request.request_id,
321320
name=request.name,
322-
entrypoint=request.entrypoint.__name__,
323-
request_body=request.request_body.model_dump_json(),
321+
entrypoint=request.entrypoint.__name__
322+
if request.entrypoint is not None else '',
323+
request_body=request.request_body.model_dump_json()
324+
if request.request_body is not None else json.dumps(None),
324325
status=request.status.value,
325326
return_value=json.dumps(None),
326327
error=json.dumps(None),
@@ -334,51 +335,59 @@ def encode_requests(
334335
should_retry=request.should_retry,
335336
finished_at=request.finished_at,
336337
)
337-
payload = _update_request_payload_fields(payload, fields)
338338
encoded_requests.append(payload)
339339
return encoded_requests
340340

341341

342-
def _update_request_payload_fields(
343-
payload: payloads.RequestPayload,
344-
fields: Optional[List[str]] = None) -> payloads.RequestPayload:
345-
"""Update the request payload fields."""
342+
def _update_request_row_fields(
343+
row: Tuple[Any, ...],
344+
fields: Optional[List[str]] = None) -> Tuple[Any, ...]:
345+
"""Update the request row fields."""
346346
if not fields:
347-
return payload
347+
return row
348+
349+
# Convert tuple to dictionary for easier manipulation
350+
content = dict(zip(fields, row))
351+
352+
# Valid empty values for pickled fields (base64-encoded pickled None)
353+
# base64.b64encode(pickle.dumps(None)).decode('utf-8')
354+
empty_pickled_value = 'gAROLg=='
355+
348356
# Required fields in RequestPayload
349357
if 'request_id' not in fields:
350-
payload.request_id = ''
358+
content['request_id'] = ''
351359
if 'name' not in fields:
352-
payload.name = ''
360+
content['name'] = ''
353361
if 'entrypoint' not in fields:
354-
payload.entrypoint = ''
362+
content['entrypoint'] = empty_pickled_value
355363
if 'request_body' not in fields:
356-
payload.request_body = json.dumps(None)
364+
content['request_body'] = empty_pickled_value
357365
if 'status' not in fields:
358-
payload.status = ''
366+
content['status'] = RequestStatus.PENDING.value
359367
if 'created_at' not in fields:
360-
payload.created_at = 0
368+
content['created_at'] = 0
361369
if 'user_id' not in fields:
362-
payload.user_id = ''
363-
payload.user_name = None
370+
content['user_id'] = ''
364371
if 'return_value' not in fields:
365-
payload.return_value = json.dumps(None)
372+
content['return_value'] = json.dumps(None)
366373
if 'error' not in fields:
367-
payload.error = json.dumps(None)
374+
content['error'] = json.dumps(None)
368375
if 'schedule_type' not in fields:
369-
payload.schedule_type = ''
376+
content['schedule_type'] = ScheduleType.SHORT.value
370377
# Optional fields in RequestPayload
371378
if 'pid' not in fields:
372-
payload.pid = None
379+
content['pid'] = None
373380
if 'cluster_name' not in fields:
374-
payload.cluster_name = None
381+
content['cluster_name'] = None
375382
if 'status_msg' not in fields:
376-
payload.status_msg = None
383+
content['status_msg'] = None
377384
if 'should_retry' not in fields:
378-
payload.should_retry = False
385+
content['should_retry'] = False
379386
if 'finished_at' not in fields:
380-
payload.finished_at = None
381-
return payload
387+
content['finished_at'] = None
388+
389+
# Convert back to tuple in the same order as REQUEST_COLUMNS
390+
return tuple(content[col] for col in REQUEST_COLUMNS)
382391

383392

384393
def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
@@ -736,6 +745,7 @@ class RequestTaskFilter:
736745
include_request_names: Optional[List[str]] = None
737746
finished_before: Optional[float] = None
738747
request_limit: int = 0
748+
fields: Optional[List[str]] = None
739749

740750
def __post_init__(self):
741751
if (self.exclude_request_names is not None and
@@ -778,6 +788,8 @@ def build_query(self) -> Tuple[str, List[Any]]:
778788
if filter_str:
779789
filter_str = f' WHERE {filter_str}'
780790
columns_str = ', '.join(REQUEST_COLUMNS)
791+
if self.fields:
792+
columns_str = ', '.join(self.fields)
781793
query_str = (f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str} '
782794
'ORDER BY created_at DESC')
783795
if self.request_limit > 0:
@@ -816,6 +828,21 @@ async def get_request_tasks_async(
816828
return [Request.from_row(row) for row in rows]
817829

818830

831+
@init_db_async
832+
@metrics_lib.time_me_async
833+
async def get_request_tasks_with_fields_async(
834+
req_filter: RequestTaskFilter,
835+
fields: Optional[List[str]] = None,
836+
) -> List[Request]:
837+
"""Async version of get_request_tasks."""
838+
assert _DB is not None
839+
async with _DB.execute_fetchall_async(*req_filter.build_query()) as rows:
840+
if not rows:
841+
return []
842+
rows = [_update_request_row_fields(row, fields) for row in rows]
843+
return [Request.from_row(row) for row in rows]
844+
845+
819846
@init_db_async
820847
@metrics_lib.time_me_async
821848
async def get_api_request_ids_start_with(incomplete: str) -> List[str]:

sky/server/server.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,12 +1638,15 @@ async def api_status(
16381638
requests_lib.RequestStatus.PENDING,
16391639
requests_lib.RequestStatus.RUNNING,
16401640
]
1641-
request_tasks = await requests_lib.get_request_tasks_async(
1641+
request_tasks = await requests_lib.get_request_tasks_with_fields_async(
16421642
req_filter=requests_lib.RequestTaskFilter(
16431643
status=statuses,
16441644
request_limit=request_limit,
1645-
))
1646-
return requests_lib.encode_requests(request_tasks, fields=fields)
1645+
fields=fields,
1646+
),
1647+
fields=fields,
1648+
)
1649+
return requests_lib.encode_requests(request_tasks)
16471650
else:
16481651
encoded_request_tasks = []
16491652
for request_id in request_ids:

0 commit comments

Comments
 (0)