Skip to content

Commit

Permalink
convert response to response format and throw
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMarten committed Dec 16, 2024
1 parent df1e308 commit a3990f7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ async def handle_single_request_with_retries(
status_tracker=status_tracker,
)

# Allows us to retry on responses that don't match the response format
self.convert_response_to_response_format(
generic_response.response_message, self.prompt_formatter.response_format
)

# Save response in the base class
await self.append_generic_response(generic_response, save_filepath)

Expand Down
88 changes: 59 additions & 29 deletions src/bespokelabs/curator/request_processor/base_request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,58 @@ def attempt_loading_cached_dataset(
"Deleted file and attempting to regenerate dataset from cached LLM responses."
)

def convert_response_to_response_format(
self, response_message: str | dict, response_format: Optional[BaseModel]
) -> Optional[dict | str]:
"""
Converts a response message to a specified Pydantic model format.
This method takes a response message (either as a string or dict) and validates/converts it
according to the provided Pydantic model format. If the response message is a string,
it first attempts to parse it as JSON. The resulting dict is then used to construct
an instance of the specified Pydantic model.
Args:
response_message (str | dict): The response message to convert, either as a JSON string
or a dictionary.
response_format (Optional[BaseModel]): The Pydantic model class that defines the
expected format of the response.
Returns:
Optional[dict | str]: The validated response message as a Pydantic model instance.
Raises:
json.JSONDecodeError: If the response_message is a string but cannot be parsed as valid JSON.
ValidationError: If the parsed response does not match the schema defined by response_format.
"""
# Response message is a string, which is converted to a dict
# The dict is then used to construct the response_format Pydantic model
try:
# First try to parse the response message as JSON
if isinstance(response_message, str):
try:
response_dict = json.loads(response_message)
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse response message as JSON: {response_message}. "
f"The model likely returned an invalid JSON format."
)
raise e
else:
response_dict = response_message

# Then construct the Pydantic model from the parsed dict
response_message = response_format(**response_dict)
return response_message

except ValidationError as e:
schema_str = json.dumps(response_format.model_json_schema(), indent=2)
logger.warning(
f"Pydantic failed to parse response message {response_message} with `response_format` {schema_str}. "
f"The model likely returned a JSON that does not match the schema of the `response_format`."
)
raise e

def create_dataset_files(
self,
working_dir: str,
Expand Down Expand Up @@ -309,38 +361,16 @@ def create_dataset_files(
continue

if prompt_formatter.response_format:
# Response message is a string, which is converted to a dict
# The dict is then used to construct the response_format Pydantic model
try:
# First try to parse the response message as JSON
if isinstance(response.response_message, str):
try:
response_dict = json.loads(response.response_message)
except json.JSONDecodeError as e:
warning_msg = (
f"Failed to parse response message as JSON: {response.response_message}. "
f"The model likely returned an invalid JSON format. Will skip this response."
)
logger.warning(warning_msg)
failed_responses_count += 1
continue
else:
response_dict = response.response_message

# Then construct the Pydantic model from the parsed dict
response.response_message = prompt_formatter.response_format(
**response_dict
)
except ValidationError as e:
schema_str = json.dumps(
prompt_formatter.response_format.model_json_schema(),
indent=2,
response.response_message = (
self.convert_response_to_response_format(
response.response_message, prompt_formatter.response_format
)
)
warning_msg = (
f"Pydantic failed to parse response message {response.response_message} with `response_format` {schema_str}. "
f"The model likely returned a JSON that does not match the schema of the `response_format`. Will skip this response."
except (json.JSONDecodeError, ValidationError) as e:
logger.warning(
"Skipping response due to error parsing response message into response format"
)
logger.warning(warning_msg)
failed_responses_count += 1
continue

Expand Down

0 comments on commit a3990f7

Please sign in to comment.