diff --git a/bespoke-dataset-viewer/package-lock.json b/bespoke-dataset-viewer/package-lock.json index 97c34801..f9e02c46 100644 --- a/bespoke-dataset-viewer/package-lock.json +++ b/bespoke-dataset-viewer/package-lock.json @@ -36,6 +36,7 @@ }, "devDependencies": { "@types/node": "^20", + "@types/prismjs": "^1.26.5", "@types/react": "^18", "@types/react-dom": "^18", "eslint": "^8", @@ -1860,6 +1861,13 @@ "undici-types": "~6.19.2" } }, + "node_modules/@types/prismjs": { + "version": "1.26.5", + "resolved": "https://registry.npmjs.org/@types/prismjs/-/prismjs-1.26.5.tgz", + "integrity": "sha512-AUZTa7hQ2KY5L7AmtSiqxlhWxb4ina0yd8hNbl4TWuqnv/pFP0nDMb3YrfSBf4hJVGLh2YEIBfKaBW/9UEl6IQ==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/prop-types": { "version": "15.7.13", "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.13.tgz", diff --git a/bespoke-dataset-viewer/package.json b/bespoke-dataset-viewer/package.json index 62643150..dee9be50 100644 --- a/bespoke-dataset-viewer/package.json +++ b/bespoke-dataset-viewer/package.json @@ -37,6 +37,7 @@ }, "devDependencies": { "@types/node": "^20", + "@types/prismjs": "^1.26.5", "@types/react": "^18", "@types/react-dom": "^18", "eslint": "^8", diff --git a/build_pkg.py b/build_pkg.py index 80de2549..b9a6e57e 100644 --- a/build_pkg.py +++ b/build_pkg.py @@ -81,7 +81,7 @@ def nextjs_build(): def run_pytest(): print("Running pytest") try: - run_command("pytest", cwd="tests") + run_command("pytest") except subprocess.CalledProcessError: print("Pytest failed. Aborting build.") sys.exit(1) diff --git a/package-lock.json b/package-lock.json deleted file mode 100644 index c1d2f741..00000000 --- a/package-lock.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "bella", - "lockfileVersion": 3, - "requires": true, - "packages": {} -} diff --git a/pyproject.toml b/pyproject.toml index d849e450..9691e3e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "bespokelabs-curator" -version = "0.1.11" +version = "0.1.12" description = "Bespoke Labs Curator" authors = ["Bespoke Labs "] readme = "README.md" diff --git a/src/bespokelabs/curator/batch_manager/openai_batch_manager.py b/src/bespokelabs/curator/batch_manager/openai_batch_manager.py index 9fc73356..e36550fb 100644 --- a/src/bespokelabs/curator/batch_manager/openai_batch_manager.py +++ b/src/bespokelabs/curator/batch_manager/openai_batch_manager.py @@ -200,9 +200,7 @@ def generic_response_file_from_responses( ) generic_response = GenericResponse( response_message=None, - response_errors=[ - f"Request {generic_request} failed with status code {raw_response['response']['status_code']}" - ], + response_errors=[raw_response["response"]["status_code"]], raw_response=raw_response, raw_request=None, generic_request=generic_request, diff --git a/src/bespokelabs/curator/llm/llm.py b/src/bespokelabs/curator/llm/llm.py index 57e3032a..fa6475b9 100644 --- a/src/bespokelabs/curator/llm/llm.py +++ b/src/bespokelabs/curator/llm/llm.py @@ -58,7 +58,7 @@ def __init__( presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, max_retries: Optional[int] = None, - require_all_responses: Optional[bool] = None, + require_all_responses: Optional[bool] = True, ): """Initialize a LLM. diff --git a/src/bespokelabs/curator/llm/prompt_formatter.py b/src/bespokelabs/curator/llm/prompt_formatter.py index 29b05b82..4dae93ce 100644 --- a/src/bespokelabs/curator/llm/prompt_formatter.py +++ b/src/bespokelabs/curator/llm/prompt_formatter.py @@ -1,13 +1,16 @@ import dataclasses import inspect +import json +import logging from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from bespokelabs.curator.request_processor.generic_request import GenericRequest T = TypeVar("T") _DictOrBaseModel = Union[Dict[str, Any], BaseModel] +logger = logging.getLogger(__name__) def _validate_messages(messages: list[dict]) -> None: @@ -82,3 +85,56 @@ def create_generic_request(self, row: _DictOrBaseModel, idx: int) -> GenericRequ self.response_format.model_json_schema() if self.response_format else None ), ) + + def response_to_response_format(self, response_message: str | dict) -> Optional[dict | str]: + """ + Converts a response message to a specified Pydantic model format. + + This method takes a response message (either as a string or dict) and validates/converts it + according to the provided Pydantic model format. If the response message is a string, + it first attempts to parse it as JSON. The resulting dict is then used to construct + an instance of the specified Pydantic model. + + Args: + response_message (str | dict): The response message to convert, either as a JSON string + or a dictionary. + response_format (Optional[BaseModel]): The Pydantic model class that defines the + expected format of the response. + + Returns: + Optional[dict | str]: The validated response message as a Pydantic model instance. + + Raises: + json.JSONDecodeError: If the response_message is a string but cannot be parsed as valid JSON. + ValidationError: If the parsed response does not match the schema defined by response_format. + """ + # Response message is a string, which is converted to a dict + # The dict is then used to construct the response_format Pydantic model + if self.response_format is None: + return response_message + + try: + # First try to parse the response message as JSON + if isinstance(response_message, str): + try: + response_dict = json.loads(response_message) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse response message as JSON: {response_message}. " + f"The model likely returned an invalid JSON format." + ) + raise e + else: + response_dict = response_message + + # Then construct the Pydantic model from the parsed dict + response_message = self.response_format(**response_dict) + return response_message + + except ValidationError as e: + schema_str = json.dumps(self.response_format.model_json_schema(), indent=2) + logger.warning( + f"Pydantic failed to parse response message {response_message} with `response_format` {schema_str}. " + f"The model likely returned a JSON that does not match the schema of the `response_format`." + ) + raise e diff --git a/src/bespokelabs/curator/request_processor/base_online_request_processor.py b/src/bespokelabs/curator/request_processor/base_online_request_processor.py index bc042054..bdcff800 100644 --- a/src/bespokelabs/curator/request_processor/base_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_online_request_processor.py @@ -26,6 +26,8 @@ DEFAULT_MAX_REQUESTS_PER_MINUTE = 100 DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000 DEFAULT_MAX_RETRIES = 10 +SECONDS_TO_PAUSE_ON_RATE_LIMIT = 10 +DEFAULT_REQUEST_TIMEOUT = 10 * 60 # 10 minutes @dataclass @@ -69,6 +71,7 @@ def __init__( self.max_retries = DEFAULT_MAX_RETRIES else: self.max_retries = max_retries + self.timeout = DEFAULT_REQUEST_TIMEOUT @property def max_requests_per_minute(self) -> int: @@ -127,6 +130,11 @@ def run( parse_func_hash: str, prompt_formatter: PromptFormatter, ) -> Dataset: + # load from already completed dataset + output_dataset = self.attempt_loading_cached_dataset(working_dir, parse_func_hash) + if output_dataset is not None: + return output_dataset + """Run completions using the online API with async processing.""" logger.info(f"Running {self.__class__.__name__} completions with model: {self.model}") @@ -147,7 +155,6 @@ def run( self.process_requests_from_file( generic_request_filepath=request_file, save_filepath=response_file, - max_attempts=self.max_retries, resume=True, ) ) @@ -158,7 +165,6 @@ async def process_requests_from_file( self, generic_request_filepath: str, save_filepath: str, - max_attempts: int, resume: bool, resume_no_retry: bool = False, ) -> None: @@ -182,7 +188,7 @@ async def process_requests_from_file( completed_request_ids = set() if os.path.exists(save_filepath): if resume: - logger.debug(f"Resuming progress from existing file: {save_filepath}") + logger.info(f"Resuming progress by reading existing file: {save_filepath}") logger.debug( f"Removing all failed requests from {save_filepath} so they can be retried" ) @@ -200,6 +206,11 @@ async def process_requests_from_file( f"{response.response_errors}, removing from output and will retry" ) num_previously_failed_requests += 1 + if response.response_message is None: + logger.debug( + f"Request {response.generic_request.original_row_idx} previously failed due to no response, removing from output and will retry" + ) + num_previously_failed_requests += 1 else: completed_request_ids.add(response.generic_request.original_row_idx) output_file.write(line) @@ -273,7 +284,7 @@ async def process_requests_from_file( task_id=status_tracker.num_tasks_started, generic_request=generic_request, api_specific_request=self.create_api_specific_request(generic_request), - attempts_left=max_attempts, + attempts_left=self.max_retries, prompt_formatter=self.prompt_formatter, ) @@ -283,6 +294,19 @@ async def process_requests_from_file( while not status_tracker.has_capacity(token_estimate): await asyncio.sleep(0.1) + # Wait for rate limits cool down if needed + seconds_since_rate_limit_error = ( + time.time() - status_tracker.time_of_last_rate_limit_error + ) + if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_ON_RATE_LIMIT: + remaining_seconds_to_pause = ( + SECONDS_TO_PAUSE_ON_RATE_LIMIT - seconds_since_rate_limit_error + ) + await asyncio.sleep(remaining_seconds_to_pause) + logger.warn( + f"Pausing to cool down for {int(remaining_seconds_to_pause)} seconds" + ) + # Consume capacity before making request status_tracker.consume_capacity(token_estimate) @@ -313,10 +337,10 @@ async def process_requests_from_file( token_estimate = self.estimate_total_tokens( retry_request.generic_request.messages ) - attempt_number = 1 + self.max_retries - retry_request.attempts_left - logger.info( - f"Processing retry for request {retry_request.task_id} " - f"(attempt #{attempt_number} of {self.max_retries}). " + attempt_number = self.max_retries - retry_request.attempts_left + logger.debug( + f"Retrying request {retry_request.task_id} " + f"(attempt #{attempt_number} of {self.max_retries})" f"Previous errors: {retry_request.result}" ) @@ -381,6 +405,9 @@ async def handle_single_request_with_retries( status_tracker=status_tracker, ) + # Allows us to retry on responses that don't match the response format + self.prompt_formatter.response_to_response_format(generic_response.response_message) + # Save response in the base class await self.append_generic_response(generic_response, save_filepath) @@ -389,18 +416,15 @@ async def handle_single_request_with_retries( status_tracker.pbar.update(1) except Exception as e: - logger.warning( - f"Request {request.task_id} failed with Exception {e}, attempts left {request.attempts_left}" - ) status_tracker.num_other_errors += 1 request.result.append(e) if request.attempts_left > 0: request.attempts_left -= 1 - # Add retry queue logging - logger.info( - f"Adding request {request.task_id} to retry queue. Will retry in next available slot. " - f"Attempts remaining: {request.attempts_left}" + logger.warning( + f"Encountered '{e.__class__.__name__}: {e}' during attempt " + f"{self.max_retries - request.attempts_left} of {self.max_retries} " + f"while processing request {request.task_id}" ) retry_queue.put_nowait(request) else: diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 52e82e41..2c6acc41 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -6,7 +6,8 @@ import resource from abc import ABC, abstractmethod from math import ceil -from typing import Optional +from pathlib import Path +from typing import Optional, List import aiofiles import pyarrow @@ -76,6 +77,64 @@ def run( """ pass + def _verify_existing_request_files( + self, working_dir: str, dataset: Optional[Dataset] + ) -> List[int]: + """ + Verify integrity of the cache (each request file has associated metadata, and the number of rows is correct), + and return the indices of request files that need to be regenerated (so that no work is repeated). + + Args: + working_dir (str): Working directory where cache files are expected to be (requests.jsonl, metadata.json) + dataset (Optional[Dataset]): The dataset that we want to create requests from + + Returns: + List[int]: Indices of missing files + """ + + if self.batch_size is not None and dataset is not None: + expected_num_files = ceil(len(dataset) / self.batch_size) + else: + expected_num_files = 1 + + try: + incomplete_files = [] + for i in range(expected_num_files): + req_f = os.path.join(working_dir, f"requests_{i}.jsonl") + meta_f = os.path.join(working_dir, f"metadata_{i}.json") + + if not os.path.exists(req_f): + incomplete_files.append(i) + continue + + if not os.path.exists(meta_f): + logger.warning(f"Cache missing metadata file {meta_f} for request file {req_f}") + incomplete_files.append(i) + continue + + with open(req_f, "r") as f: + data = f.read() + num_jobs = len(data.splitlines()) + + with open(meta_f, "r") as f: + metadata = json.load(f) + + expected_num_jobs = metadata["num_jobs"] + if num_jobs != expected_num_jobs: + logger.warning( + f"Request file {req_f} has {num_jobs} jobs, but metadata file {meta_f} has {expected_num_jobs} jobs" + ) + incomplete_files.append(i) + + return incomplete_files + + except Exception as e: + logger.warning( + f"Cache verification failed due to {e} - regenerating all request files." + ) + incomplete_files = list(range(expected_num_files)) + return incomplete_files + def create_request_files( self, dataset: Optional[Dataset], @@ -96,7 +155,9 @@ def create_request_files( request_files = glob.glob(f"{working_dir}/requests_*.jsonl") # By default use existing requests in working_dir - if len(request_files) > 0: + incomplete_files = self._verify_existing_request_files(working_dir, dataset) + + if len(incomplete_files) == 0: logger.info(f"Using cached requests. {CACHE_MSG}") # count existing jobs in file and print first job with open(request_files[0], "r") as f: @@ -116,18 +177,27 @@ def create_request_files( return request_files # Create new requests file + logger.info(f"Preparing request file(s) in {working_dir}") request_file = f"{working_dir}/requests_0.jsonl" request_files = [request_file] + metadata_file = f"{working_dir}/metadata_0.json" + metadata_files = [metadata_file] + if dataset is None: with open(request_file, "w") as f: generic_request = prompt_formatter.create_generic_request(dict(), 0) f.write(json.dumps(generic_request.model_dump(), default=str) + "\n") + + metadata_dict = {"num_jobs": 1} + with open(metadata_file, "w") as f: + f.write(json.dumps(metadata_dict, indent=4) + "\n") return request_files if self.batch_size: num_batches = ceil(len(dataset) / self.batch_size) request_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)] + metadata_files = [f"{working_dir}/metadata_{i}.json" for i in range(num_batches)] async def create_all_request_files(): tasks = [ @@ -135,15 +205,19 @@ async def create_all_request_files(): dataset, prompt_formatter, request_files[i], + metadata_files[i], start_idx=i * self.batch_size, ) for i in range(num_batches) + if i in incomplete_files ] await asyncio.gather(*tasks) run_in_event_loop(create_all_request_files()) else: - run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file)) + run_in_event_loop( + self.acreate_request_file(dataset, prompt_formatter, request_file, metadata_file) + ) return request_files @@ -153,8 +227,9 @@ async def acreate_request_file( dataset: Dataset, prompt_formatter: PromptFormatter, request_file: str, + metadata_file: str, start_idx: int = 0, - ) -> str: + ) -> None: if self.batch_size is not None: end_idx = min(start_idx + self.batch_size, len(dataset)) dataset = dataset.select(range(start_idx, end_idx)) @@ -168,7 +243,13 @@ async def acreate_request_file( # Get the generic request from the map function request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx) await f.write(json.dumps(request.model_dump(), default=str) + "\n") - logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.") + + num_requests = end_idx - start_idx + metadata_dict = {"num_jobs": num_requests} + async with aiofiles.open(metadata_file, "w") as f: + await f.write(json.dumps(metadata_dict, indent=4) + "\n") + + logger.info(f"Wrote {num_requests} requests to {request_file}.") def attempt_loading_cached_dataset( self, working_dir: str, parse_func_hash: str @@ -234,41 +315,18 @@ def create_dataset_files( failed_responses_count += 1 continue - if prompt_formatter.response_format: - # Response message is a string, which is converted to a dict - # The dict is then used to construct the response_format Pydantic model - try: - # First try to parse the response message as JSON - if isinstance(response.response_message, str): - try: - response_dict = json.loads(response.response_message) - except json.JSONDecodeError as e: - warning_msg = ( - f"Failed to parse response message as JSON: {response.response_message}. " - f"The model likely returned an invalid JSON format. Will skip this response." - ) - logger.warning(warning_msg) - failed_responses_count += 1 - continue - else: - response_dict = response.response_message - - # Then construct the Pydantic model from the parsed dict - response.response_message = prompt_formatter.response_format( - **response_dict - ) - except ValidationError as e: - schema_str = json.dumps( - prompt_formatter.response_format.model_json_schema(), - indent=2, + try: + response.response_message = ( + self.prompt_formatter.response_to_response_format( + response.response_message ) - warning_msg = ( - f"Pydantic failed to parse response message {response.response_message} with `response_format` {schema_str}. " - f"The model likely returned a JSON that does not match the schema of the `response_format`. Will skip this response." - ) - logger.warning(warning_msg) - failed_responses_count += 1 - continue + ) + except (json.JSONDecodeError, ValidationError) as e: + logger.warning( + "Skipping response due to error parsing response message into response format" + ) + failed_responses_count += 1 + continue # parse_func can return a single row or a list of rows if prompt_formatter.parse_func: diff --git a/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py b/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py index 42cc79fc..44ab4443 100644 --- a/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py @@ -13,6 +13,7 @@ from bespokelabs.curator.types.generic_request import GenericRequest from bespokelabs.curator.request_processor.generic_response import TokenUsage, GenericResponse from pydantic import BaseModel +import time logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ def __init__( frequency_penalty: Optional[float] = None, max_requests_per_minute: Optional[int] = None, max_tokens_per_minute: Optional[int] = None, - require_all_responses: bool = False, + require_all_responses: Optional[bool] = None, max_retries: Optional[int] = None, ): super().__init__( @@ -214,6 +215,31 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: if "frequency_penalty" in supported_params and self.frequency_penalty is not None: request["frequency_penalty"] = self.frequency_penalty + # Add safety settings for Gemini models + if "gemini" in generic_request.model.lower(): + request["safety_settings"] = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_CIVIC_INTEGRITY", + "threshold": "BLOCK_NONE", + }, + ] + return request async def call_single_request( @@ -236,18 +262,29 @@ async def call_single_request( GenericResponse: The response from LiteLLM """ # Get response directly without extra logging - if request.generic_request.response_format: - response, completion_obj = await self.client.chat.completions.create_with_completion( - **request.api_specific_request, - response_model=request.prompt_formatter.response_format, - timeout=60.0, - ) - response_message = ( - response.model_dump() if hasattr(response, "model_dump") else response - ) - else: - completion_obj = await litellm.acompletion(**request.api_specific_request, timeout=60.0) - response_message = completion_obj["choices"][0]["message"]["content"] + try: + if request.generic_request.response_format: + response, completion_obj = ( + await self.client.chat.completions.create_with_completion( + **request.api_specific_request, + response_model=request.prompt_formatter.response_format, + timeout=self.timeout, + ) + ) + response_message = ( + response.model_dump() if hasattr(response, "model_dump") else response + ) + else: + completion_obj = await litellm.acompletion( + **request.api_specific_request, timeout=self.timeout + ) + response_message = completion_obj["choices"][0]["message"]["content"] + except litellm.RateLimitError as e: + status_tracker.time_of_last_rate_limit_error = time.time() + status_tracker.num_rate_limit_errors += 1 + # because handle_single_request_with_retries will double count otherwise + status_tracker.num_api_errors -= 1 + raise e # Extract token usage usage = completion_obj.usage if hasattr(completion_obj, "usage") else {} @@ -263,6 +300,19 @@ async def call_single_request( except litellm.NotFoundError as e: cost = 0 + finish_reason = completion_obj.choices[0].finish_reason + invalid_finish_reasons = ["length", "content_filter"] + if finish_reason in invalid_finish_reasons: + logger.debug( + f"Invalid finish_reason {finish_reason}. Raw response {completion_obj.model_dump()} for request {request.generic_request.messages}" + ) + raise ValueError(f"finish_reason was {finish_reason}") + + if response_message is None: + raise ValueError( + f"response_message was None with raw response {completion_obj.model_dump()}" + ) + # Create and return response return GenericResponse( response_message=response_message, diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 113db762..0794685a 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -271,7 +271,7 @@ async def call_single_request( self.url, headers=request_header, json=request.api_specific_request, - timeout=60.0, + timeout=self.timeout, ) as response_obj: response = await response_obj.json() @@ -282,6 +282,8 @@ async def call_single_request( status_tracker.time_of_last_rate_limit_error = time.time() status_tracker.num_rate_limit_errors += 1 status_tracker.num_api_errors -= 1 + # because handle_single_request_with_retries will double count otherwise + status_tracker.num_other_errors -= 1 raise Exception(f"API error: {error}") if response_obj.status != 200: diff --git a/tests/batch/test_resume.py b/tests/batch/test_resume.py index 0248da20..9ac0c906 100644 --- a/tests/batch/test_resume.py +++ b/tests/batch/test_resume.py @@ -10,6 +10,7 @@ """ +@pytest.mark.skip(reason="Temporarily disabled, need to add mocking") @pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-resume")) @pytest.mark.usefixtures("prepare_test_cache") def test_batch_resume(): diff --git a/tests/batch/test_switch_keys.py b/tests/batch/test_switch_keys.py index e9026577..f1d9fc8b 100644 --- a/tests/batch/test_switch_keys.py +++ b/tests/batch/test_switch_keys.py @@ -10,6 +10,7 @@ """ +@pytest.mark.skip(reason="Temporarily disabled, need to add mocking") @pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-switch-keys")) @pytest.mark.usefixtures("prepare_test_cache") def test_batch_switch_keys(): diff --git a/tests/cache/different_files/one.py b/tests/cache/different_files/one.py index 10ff74d4..e5667add 100644 --- a/tests/cache/different_files/one.py +++ b/tests/cache/different_files/one.py @@ -1,32 +1,18 @@ from bespokelabs.curator import LLM from datasets import Dataset import logging -import argparse logger = logging.getLogger("bespokelabs.curator") logger.setLevel(logging.INFO) -def main(delete_cache: bool = False): - dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) +dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) - prompter = LLM( - prompt_func=lambda row: row["prompt"], - model_name="gpt-4o-mini", - response_format=None, - delete_cache=delete_cache, - ) +prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name="gpt-4o-mini", + response_format=None, +) - dataset = prompter(dataset) - print(dataset.to_pandas()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run prompter with cache control") - parser.add_argument( - "--delete-cache", - action="store_true", - help="Delete the cache before running", - ) - args = parser.parse_args() - main(delete_cache=args.delete_cache) +dataset = prompter(dataset) +print(dataset.to_pandas()) diff --git a/tests/cache/different_files/two.py b/tests/cache/different_files/two.py index 10ff74d4..e5667add 100644 --- a/tests/cache/different_files/two.py +++ b/tests/cache/different_files/two.py @@ -1,32 +1,18 @@ from bespokelabs.curator import LLM from datasets import Dataset import logging -import argparse logger = logging.getLogger("bespokelabs.curator") logger.setLevel(logging.INFO) -def main(delete_cache: bool = False): - dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) +dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) - prompter = LLM( - prompt_func=lambda row: row["prompt"], - model_name="gpt-4o-mini", - response_format=None, - delete_cache=delete_cache, - ) +prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name="gpt-4o-mini", + response_format=None, +) - dataset = prompter(dataset) - print(dataset.to_pandas()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run prompter with cache control") - parser.add_argument( - "--delete-cache", - action="store_true", - help="Delete the cache before running", - ) - args = parser.parse_args() - main(delete_cache=args.delete_cache) +dataset = prompter(dataset) +print(dataset.to_pandas()) diff --git a/tests/cache/test_different_files.py b/tests/cache/test_different_files.py index 6b18de07..31fe866b 100644 --- a/tests/cache/test_different_files.py +++ b/tests/cache/test_different_files.py @@ -16,17 +16,14 @@ def test_cache_behavior(): # Run one.py twice and check for cache behavior print("RUNNING ONE.PY") - output1, _ = run_script(["python", "tests/cache_tests/different_files/one.py"]) - print(output1) + output1, _ = run_script(["python", "tests/cache/different_files/one.py"]) assert cache_hit_log not in output1, "First run of one.py should not hit cache" print("RUNNING ONE.PY AGAIN") - output2, _ = run_script(["python", "tests/cache_tests/different_files/one.py"]) - print(output2) + output2, _ = run_script(["python", "tests/cache/different_files/one.py"]) assert cache_hit_log in output2, "Second run of one.py should hit cache" # Run two.py and check for cache behavior print("RUNNING TWO.PY") - output3, _ = run_script(["python", "tests/cache_tests/different_files/two.py"]) - print(output3) + output3, _ = run_script(["python", "tests/cache/different_files/two.py"]) assert cache_hit_log in output3, "First run of two.py should hit cache" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..012b8dc6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "cache_dir(path): mark test to use specific cache directory") diff --git a/tests/test_caching.py b/tests/test_caching.py index 25f7426e..15c3ebd6 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -123,7 +123,7 @@ def test_function_hash_dir_change(): import tempfile from pathlib import Path - from bespokelabs.curator.prompter.llm import _get_function_hash + from bespokelabs.curator.llm.llm import _get_function_hash # Set up logging to write to a file in the current directory debug_log = Path("function_debug.log") diff --git a/tests/test_litellm_models.py b/tests/test_litellm_models.py index d8b36cdd..972848c9 100644 --- a/tests/test_litellm_models.py +++ b/tests/test_litellm_models.py @@ -13,31 +13,40 @@ @pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-models")) @pytest.mark.usefixtures("prepare_test_cache") -def test_litellm_models(): +class TestLiteLLMModels: + @pytest.fixture(autouse=True) + def check_environment(self): + env = os.environ.copy() + required_keys = [ + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "GEMINI_API_KEY", + "TOGETHER_API_KEY", + ] + for key in required_keys: + assert key in env, f"{key} must be set" - env = os.environ.copy() - assert "ANTHROPIC_API_KEY" in env, "ANTHROPIC_API_KEY must be set" - assert "OPENAI_API_KEY" in env, "OPENAI_API_KEY must be set" - assert "GEMINI_API_KEY" in env, "GEMINI_API_KEY must be set" - assert "TOGETHER_API_KEY" in env, "TOGETHER_API_KEY must be set" - - models_list = [ - "claude-3-5-sonnet-20240620", # https://docs.litellm.ai/docs/providers/anthropic # anthropic has a different hidden param tokens structure. - "claude-3-5-haiku-20241022", - "claude-3-haiku-20240307", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "gpt-4o-mini", # https://docs.litellm.ai/docs/providers/openai - "gpt-4o-2024-08-06", - "gpt-4-0125-preview", - "gpt-3.5-turbo-1106", - "gemini/gemini-1.5-flash", # https://docs.litellm.ai/docs/providers/gemini; https://ai.google.dev/gemini-api/docs/models # 20-30 iter/s - "gemini/gemini-1.5-pro", # 20-30 iter/s - "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", # https://docs.together.ai/docs/serverless-models - "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - ] - - for model in models_list: + @pytest.mark.parametrize( + "model", + [ + pytest.param("claude-3-5-sonnet-20240620", id="claude-3-5-sonnet"), + pytest.param("claude-3-5-haiku-20241022", id="claude-3-5-haiku"), + pytest.param("claude-3-haiku-20240307", id="claude-3-haiku"), + pytest.param("claude-3-opus-20240229", id="claude-3-opus"), + pytest.param("claude-3-sonnet-20240229", id="claude-3-sonnet"), + pytest.param("gpt-4o-mini", id="gpt-4-mini"), + pytest.param("gpt-4o-2024-08-06", id="gpt-4"), + pytest.param("gpt-4-0125-preview", id="gpt-4-preview"), + pytest.param("gpt-3.5-turbo-1106", id="gpt-3.5"), + pytest.param("gemini/gemini-1.5-flash", id="gemini-flash"), + pytest.param("gemini/gemini-1.5-pro", id="gemini-pro"), + pytest.param("together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", id="llama-8b"), + pytest.param( + "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", id="llama-70b" + ), + ], + ) + def test_model(self, model): print(f"\n\n========== TESTING {model} ==========\n\n") logger = logging.getLogger("bespokelabs.curator") logger.setLevel(logging.DEBUG)