From 973d710365b27f97109bd9d37df94417f92be4a2 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Mon, 16 Dec 2024 21:27:42 -0800 Subject: [PATCH] fix test different files --- .../curator/llm/prompt_formatter.py | 3 ++ .../base_online_request_processor.py | 5 ++++ .../base_request_processor.py | 23 +++++++------- tests/cache/different_files/one.py | 30 +++++-------------- tests/cache/different_files/two.py | 30 +++++-------------- tests/cache/test_different_files.py | 9 ++---- 6 files changed, 38 insertions(+), 62 deletions(-) diff --git a/src/bespokelabs/curator/llm/prompt_formatter.py b/src/bespokelabs/curator/llm/prompt_formatter.py index 826398d2..4dae93ce 100644 --- a/src/bespokelabs/curator/llm/prompt_formatter.py +++ b/src/bespokelabs/curator/llm/prompt_formatter.py @@ -110,6 +110,9 @@ def response_to_response_format(self, response_message: str | dict) -> Optional[ """ # Response message is a string, which is converted to a dict # The dict is then used to construct the response_format Pydantic model + if self.response_format is None: + return response_message + try: # First try to parse the response message as JSON if isinstance(response_message, str): diff --git a/src/bespokelabs/curator/request_processor/base_online_request_processor.py b/src/bespokelabs/curator/request_processor/base_online_request_processor.py index 7e8d922f..51537125 100644 --- a/src/bespokelabs/curator/request_processor/base_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_online_request_processor.py @@ -204,6 +204,11 @@ def run( parse_func_hash: str, prompt_formatter: PromptFormatter, ) -> Dataset: + # load from already completed dataset + output_dataset = self.attempt_loading_cached_dataset(working_dir, parse_func_hash) + if output_dataset is not None: + return output_dataset + """Run completions using the online API with async processing.""" logger.info(f"Running {self.__class__.__name__} completions with model: {self.model}") diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 841223c0..6a5b2a30 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -315,19 +315,18 @@ def create_dataset_files( failed_responses_count += 1 continue - if prompt_formatter.response_format: - try: - response.response_message = ( - self.prompt_formatter.response_to_response_format( - response.response_message - ) - ) - except (json.JSONDecodeError, ValidationError) as e: - logger.warning( - "Skipping response due to error parsing response message into response format" + try: + response.response_message = ( + self.prompt_formatter.response_to_response_format( + response.response_message ) - failed_responses_count += 1 - continue + ) + except (json.JSONDecodeError, ValidationError) as e: + logger.warning( + "Skipping response due to error parsing response message into response format" + ) + failed_responses_count += 1 + continue # parse_func can return a single row or a list of rows if prompt_formatter.parse_func: diff --git a/tests/cache/different_files/one.py b/tests/cache/different_files/one.py index 10ff74d4..e5667add 100644 --- a/tests/cache/different_files/one.py +++ b/tests/cache/different_files/one.py @@ -1,32 +1,18 @@ from bespokelabs.curator import LLM from datasets import Dataset import logging -import argparse logger = logging.getLogger("bespokelabs.curator") logger.setLevel(logging.INFO) -def main(delete_cache: bool = False): - dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) +dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) - prompter = LLM( - prompt_func=lambda row: row["prompt"], - model_name="gpt-4o-mini", - response_format=None, - delete_cache=delete_cache, - ) +prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name="gpt-4o-mini", + response_format=None, +) - dataset = prompter(dataset) - print(dataset.to_pandas()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run prompter with cache control") - parser.add_argument( - "--delete-cache", - action="store_true", - help="Delete the cache before running", - ) - args = parser.parse_args() - main(delete_cache=args.delete_cache) +dataset = prompter(dataset) +print(dataset.to_pandas()) diff --git a/tests/cache/different_files/two.py b/tests/cache/different_files/two.py index 10ff74d4..e5667add 100644 --- a/tests/cache/different_files/two.py +++ b/tests/cache/different_files/two.py @@ -1,32 +1,18 @@ from bespokelabs.curator import LLM from datasets import Dataset import logging -import argparse logger = logging.getLogger("bespokelabs.curator") logger.setLevel(logging.INFO) -def main(delete_cache: bool = False): - dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) +dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) - prompter = LLM( - prompt_func=lambda row: row["prompt"], - model_name="gpt-4o-mini", - response_format=None, - delete_cache=delete_cache, - ) +prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name="gpt-4o-mini", + response_format=None, +) - dataset = prompter(dataset) - print(dataset.to_pandas()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run prompter with cache control") - parser.add_argument( - "--delete-cache", - action="store_true", - help="Delete the cache before running", - ) - args = parser.parse_args() - main(delete_cache=args.delete_cache) +dataset = prompter(dataset) +print(dataset.to_pandas()) diff --git a/tests/cache/test_different_files.py b/tests/cache/test_different_files.py index 6b18de07..31fe866b 100644 --- a/tests/cache/test_different_files.py +++ b/tests/cache/test_different_files.py @@ -16,17 +16,14 @@ def test_cache_behavior(): # Run one.py twice and check for cache behavior print("RUNNING ONE.PY") - output1, _ = run_script(["python", "tests/cache_tests/different_files/one.py"]) - print(output1) + output1, _ = run_script(["python", "tests/cache/different_files/one.py"]) assert cache_hit_log not in output1, "First run of one.py should not hit cache" print("RUNNING ONE.PY AGAIN") - output2, _ = run_script(["python", "tests/cache_tests/different_files/one.py"]) - print(output2) + output2, _ = run_script(["python", "tests/cache/different_files/one.py"]) assert cache_hit_log in output2, "Second run of one.py should hit cache" # Run two.py and check for cache behavior print("RUNNING TWO.PY") - output3, _ = run_script(["python", "tests/cache_tests/different_files/two.py"]) - print(output3) + output3, _ = run_script(["python", "tests/cache/different_files/two.py"]) assert cache_hit_log in output3, "First run of two.py should hit cache"