@@ -292,9 +292,7 @@ def decode(cls, payload: payloads.RequestPayload) -> 'Request':
292292 raise
293293
294294
295- def encode_requests (
296- requests : List [Request ],
297- fields : Optional [List [str ]] = None ) -> List [payloads .RequestPayload ]:
295+ def encode_requests (requests : List [Request ]) -> List [payloads .RequestPayload ]:
298296 """Serialize the SkyPilot API request for display purposes.
299297
300298 This function should be called on the server side to serialize the
@@ -312,15 +310,18 @@ def encode_requests(
312310 all_users = global_user_state .get_all_users ()
313311 all_users_map = {user .id : user .name for user in all_users }
314312 for request in requests :
315- assert isinstance (request .request_body ,
316- payloads .RequestBody ), (request .name ,
317- request .request_body )
313+ if request .request_body is not None :
314+ assert isinstance (request .request_body ,
315+ payloads .RequestBody ), (request .name ,
316+ request .request_body )
318317 user_name = all_users_map .get (request .user_id )
319318 payload = payloads .RequestPayload (
320319 request_id = request .request_id ,
321320 name = request .name ,
322- entrypoint = request .entrypoint .__name__ ,
323- request_body = request .request_body .model_dump_json (),
321+ entrypoint = request .entrypoint .__name__
322+ if request .entrypoint is not None else '' ,
323+ request_body = request .request_body .model_dump_json ()
324+ if request .request_body is not None else json .dumps (None ),
324325 status = request .status .value ,
325326 return_value = json .dumps (None ),
326327 error = json .dumps (None ),
@@ -334,51 +335,59 @@ def encode_requests(
334335 should_retry = request .should_retry ,
335336 finished_at = request .finished_at ,
336337 )
337- payload = _update_request_payload_fields (payload , fields )
338338 encoded_requests .append (payload )
339339 return encoded_requests
340340
341341
342- def _update_request_payload_fields (
343- payload : payloads . RequestPayload ,
344- fields : Optional [List [str ]] = None ) -> payloads . RequestPayload :
345- """Update the request payload fields."""
342+ def _update_request_row_fields (
343+ row : Tuple [ Any , ...] ,
344+ fields : Optional [List [str ]] = None ) -> Tuple [ Any , ...] :
345+ """Update the request row fields."""
346346 if not fields :
347- return payload
347+ return row
348+
349+ # Convert tuple to dictionary for easier manipulation
350+ content = dict (zip (fields , row ))
351+
352+ # Valid empty values for pickled fields (base64-encoded pickled None)
353+ # base64.b64encode(pickle.dumps(None)).decode('utf-8')
354+ empty_pickled_value = 'gAROLg=='
355+
348356 # Required fields in RequestPayload
349357 if 'request_id' not in fields :
350- payload . request_id = ''
358+ content [ ' request_id' ] = ''
351359 if 'name' not in fields :
352- payload . name = ''
360+ content [ ' name' ] = ''
353361 if 'entrypoint' not in fields :
354- payload . entrypoint = ''
362+ content [ ' entrypoint' ] = empty_pickled_value
355363 if 'request_body' not in fields :
356- payload . request_body = json . dumps ( None )
364+ content [ ' request_body' ] = empty_pickled_value
357365 if 'status' not in fields :
358- payload . status = ''
366+ content [ ' status' ] = RequestStatus . PENDING . value
359367 if 'created_at' not in fields :
360- payload . created_at = 0
368+ content [ ' created_at' ] = 0
361369 if 'user_id' not in fields :
362- payload .user_id = ''
363- payload .user_name = None
370+ content ['user_id' ] = ''
364371 if 'return_value' not in fields :
365- payload . return_value = json .dumps (None )
372+ content [ ' return_value' ] = json .dumps (None )
366373 if 'error' not in fields :
367- payload . error = json .dumps (None )
374+ content [ ' error' ] = json .dumps (None )
368375 if 'schedule_type' not in fields :
369- payload . schedule_type = ''
376+ content [ ' schedule_type' ] = ScheduleType . SHORT . value
370377 # Optional fields in RequestPayload
371378 if 'pid' not in fields :
372- payload . pid = None
379+ content [ ' pid' ] = None
373380 if 'cluster_name' not in fields :
374- payload . cluster_name = None
381+ content [ ' cluster_name' ] = None
375382 if 'status_msg' not in fields :
376- payload . status_msg = None
383+ content [ ' status_msg' ] = None
377384 if 'should_retry' not in fields :
378- payload . should_retry = False
385+ content [ ' should_retry' ] = False
379386 if 'finished_at' not in fields :
380- payload .finished_at = None
381- return payload
387+ content ['finished_at' ] = None
388+
389+ # Convert back to tuple in the same order as REQUEST_COLUMNS
390+ return tuple (content [col ] for col in REQUEST_COLUMNS )
382391
383392
384393def kill_cluster_requests (cluster_name : str , exclude_request_name : str ):
@@ -736,6 +745,7 @@ class RequestTaskFilter:
736745 include_request_names : Optional [List [str ]] = None
737746 finished_before : Optional [float ] = None
738747 request_limit : int = 0
748+ fields : Optional [List [str ]] = None
739749
740750 def __post_init__ (self ):
741751 if (self .exclude_request_names is not None and
@@ -778,6 +788,8 @@ def build_query(self) -> Tuple[str, List[Any]]:
778788 if filter_str :
779789 filter_str = f' WHERE { filter_str } '
780790 columns_str = ', ' .join (REQUEST_COLUMNS )
791+ if self .fields :
792+ columns_str = ', ' .join (self .fields )
781793 query_str = (f'SELECT { columns_str } FROM { REQUEST_TABLE } { filter_str } '
782794 'ORDER BY created_at DESC' )
783795 if self .request_limit > 0 :
@@ -816,6 +828,21 @@ async def get_request_tasks_async(
816828 return [Request .from_row (row ) for row in rows ]
817829
818830
831+ @init_db_async
832+ @metrics_lib .time_me_async
833+ async def get_request_tasks_with_fields_async (
834+ req_filter : RequestTaskFilter ,
835+ fields : Optional [List [str ]] = None ,
836+ ) -> List [Request ]:
837+ """Async version of get_request_tasks."""
838+ assert _DB is not None
839+ async with _DB .execute_fetchall_async (* req_filter .build_query ()) as rows :
840+ if not rows :
841+ return []
842+ rows = [_update_request_row_fields (row , fields ) for row in rows ]
843+ return [Request .from_row (row ) for row in rows ]
844+
845+
819846@init_db_async
820847@metrics_lib .time_me_async
821848async def get_api_request_ids_start_with (incomplete : str ) -> List [str ]:
0 commit comments