diff --git a/src/bespokelabs/curator/request_processor/base_online_request_processor.py b/src/bespokelabs/curator/request_processor/base_online_request_processor.py index 3d8f9af0..2fd5c218 100644 --- a/src/bespokelabs/curator/request_processor/base_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_online_request_processor.py @@ -474,6 +474,11 @@ async def handle_single_request_with_retries( status_tracker=status_tracker, ) + # Allows us to retry on responses that don't match the response format + self.convert_response_to_response_format( + generic_response.response_message, self.prompt_formatter.response_format + ) + # Save response in the base class await self.append_generic_response(generic_response, save_filepath) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 08743532..f02b5ab7 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -262,6 +262,58 @@ def attempt_loading_cached_dataset( "Deleted file and attempting to regenerate dataset from cached LLM responses." ) + def convert_response_to_response_format( + self, response_message: str | dict, response_format: Optional[BaseModel] + ) -> Optional[dict | str]: + """ + Converts a response message to a specified Pydantic model format. + + This method takes a response message (either as a string or dict) and validates/converts it + according to the provided Pydantic model format. If the response message is a string, + it first attempts to parse it as JSON. The resulting dict is then used to construct + an instance of the specified Pydantic model. + + Args: + response_message (str | dict): The response message to convert, either as a JSON string + or a dictionary. + response_format (Optional[BaseModel]): The Pydantic model class that defines the + expected format of the response. + + Returns: + Optional[dict | str]: The validated response message as a Pydantic model instance. + + Raises: + json.JSONDecodeError: If the response_message is a string but cannot be parsed as valid JSON. + ValidationError: If the parsed response does not match the schema defined by response_format. + """ + # Response message is a string, which is converted to a dict + # The dict is then used to construct the response_format Pydantic model + try: + # First try to parse the response message as JSON + if isinstance(response_message, str): + try: + response_dict = json.loads(response_message) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse response message as JSON: {response_message}. " + f"The model likely returned an invalid JSON format." + ) + raise e + else: + response_dict = response_message + + # Then construct the Pydantic model from the parsed dict + response_message = response_format(**response_dict) + return response_message + + except ValidationError as e: + schema_str = json.dumps(response_format.model_json_schema(), indent=2) + logger.warning( + f"Pydantic failed to parse response message {response_message} with `response_format` {schema_str}. " + f"The model likely returned a JSON that does not match the schema of the `response_format`." + ) + raise e + def create_dataset_files( self, working_dir: str, @@ -309,38 +361,16 @@ def create_dataset_files( continue if prompt_formatter.response_format: - # Response message is a string, which is converted to a dict - # The dict is then used to construct the response_format Pydantic model try: - # First try to parse the response message as JSON - if isinstance(response.response_message, str): - try: - response_dict = json.loads(response.response_message) - except json.JSONDecodeError as e: - warning_msg = ( - f"Failed to parse response message as JSON: {response.response_message}. " - f"The model likely returned an invalid JSON format. Will skip this response." - ) - logger.warning(warning_msg) - failed_responses_count += 1 - continue - else: - response_dict = response.response_message - - # Then construct the Pydantic model from the parsed dict - response.response_message = prompt_formatter.response_format( - **response_dict - ) - except ValidationError as e: - schema_str = json.dumps( - prompt_formatter.response_format.model_json_schema(), - indent=2, + response.response_message = ( + self.convert_response_to_response_format( + response.response_message, prompt_formatter.response_format + ) ) - warning_msg = ( - f"Pydantic failed to parse response message {response.response_message} with `response_format` {schema_str}. " - f"The model likely returned a JSON that does not match the schema of the `response_format`. Will skip this response." + except (json.JSONDecodeError, ValidationError) as e: + logger.warning( + "Skipping response due to error parsing response message into response format" ) - logger.warning(warning_msg) failed_responses_count += 1 continue