diff --git a/.gitignore b/.gitignore index 892e0790..e95f9df8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .venv +.DS_Store __pycache__ .vscode diff --git a/README.md b/README.md index e129be08..81224738 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,12 @@ Discord + + Code style: black +

-### Overview +## Overview Bespoke Curator makes it very easy to create high-quality synthetic data at scale, which you can use to finetune models or use for structured data extraction at scale. @@ -35,7 +38,7 @@ Bespoke Curator is an open-source project: * A Curator Viewer which makes it easy to view the datasets, thus aiding in the dataset creation. * We will also be releasing high-quality datasets that should move the needle on post-training. -### Key Features +## Key Features 1. **Programmability and Structured Outputs**: Synthetic data generation is lot more than just using a single prompt -- it involves calling LLMs multiple times and orchestrating control-flow. Curator treats structured outputs as first class citizens and helps you design complex pipelines. 2. **Built-in Performance Optimization**: We often see calling LLMs in loops, or inefficient implementation of multi-threading. We have baked in performance optimizations so that you don't need to worry about those! @@ -43,48 +46,91 @@ Bespoke Curator is an open-source project: 4. **Native HuggingFace Dataset Integration**: Work directly on HuggingFace Dataset objects throughout your pipeline. Your synthetic data is immediately ready for fine-tuning! 5. **Interactive Curator Viewer**: Improve and iterate on your prompts using our built-in viewer. Inspect LLM requests and responses in real-time, allowing you to iterate and refine your data generation strategy with immediate feedback. -### Installation +## Installation ```bash pip install bespokelabs-curator ``` -### Usage +## Usage +To run the examples below, make sure to set your OpenAI API key in +the environment variable `OPENAI_API_KEY` by running `export OPENAI_API_KEY=sk-...` in your terminal. + +### Hello World with `SimpleLLM`: A simple interface for calling LLMs + +```python +from bespokelabs import curator +llm = curator.SimpleLLM(model_name="gpt-4o-mini") +poem = llm("Write a poem about the importance of data in AI.") +print(poem) +# Or you can pass a list of prompts to generate multiple responses. +poems = llm(["Write a poem about the importance of data in AI.", + "Write a haiku about the importance of data in AI."]) +print(poems) +``` +Note that retries and caching are enabled by default. +So now if you run the same prompt again, you will get the same response, pretty much instantly. +You can delete the cache at `~/.cache/curator`. + +#### Use LiteLLM backend for calling other models +You can use the [LiteLLM](https://docs.litellm.ai/docs/providers) backend for calling other models. + +```python +from bespokelabs import curator +llm = curator.SimpleLLM(model_name="claude-3-5-sonnet-20240620", backend="litellm") +poem = llm("Write a poem about the importance of data in AI.") +print(poem) +``` + +### Visualize in Curator Viewer +Run `curator-viewer` on the command line to see the dataset in the viewer. +You can click on a run and then click on a specific row to see the LLM request and response. +![Curator Responses](docs/curator-responses.png) +More examples below. + +### `LLM`: A more powerful interface for synthetic data generation + +Let's use structured outputs to generate poems. ```python from bespokelabs import curator from datasets import Dataset from pydantic import BaseModel, Field from typing import List -# Create a dataset object for the topics you want to create the poems. topics = Dataset.from_dict({"topic": [ "Urban loneliness in a bustling city", "Beauty of Bespoke Labs's Curator library" ]}) +``` -# Define a class to encapsulate a list of poems. +Define a class to encapsulate a list of poems. +```python class Poem(BaseModel): poem: str = Field(description="A poem.") class Poems(BaseModel): poems_list: List[Poem] = Field(description="A list of poems.") +``` - -# We define a Prompter that generates poems which gets applied to the topics dataset. -poet = curator.Prompter( - # `prompt_func` takes a row of the dataset as input. - # `row` is a dictionary with a single key 'topic' in this case. +We define an `LLM` object that generates poems which gets applied to the topics dataset. +```python +poet = curator.LLM( prompt_func=lambda row: f"Write two poems about {row['topic']}.", model_name="gpt-4o-mini", response_format=Poems, - # `row` is the input row, and `poems` is the `Poems` class which - # is parsed from the structured output from the LLM. parse_func=lambda row, poems: [ {"topic": row["topic"], "poem": p.poem} for p in poems.poems_list ], ) +``` +Here: +* `prompt_func` takes a row of the dataset as input and returns the prompt for the LLM. +* `response_format` is the structured output class we defined above. +* `parse_func` takes the input (`row`) and the structured output (`poems`) and converts it to a list of dictionaries. This is so that we can easily convert the output to a HuggingFace Dataset object. +Now we can apply the `LLM` object to the dataset, which reads very pythonic. +```python poem = poet(topics) print(poem.to_pandas()) # Example output: @@ -94,14 +140,11 @@ print(poem.to_pandas()) # 2 Beauty of Bespoke Labs's Curator library In whispers of design and crafted grace,\nBesp... # 3 Beauty of Bespoke Labs's Curator library In the hushed breath of parchment and ink,\nBe... ``` -Note that `topics` can be created with `curator.Prompter` as well, +Note that `topics` can be created with `curator.LLM` as well, and we can scale this up to create tens of thousands of diverse poems. You can see a more detailed example in the [examples/poem.py](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples/poem.py) file, and other examples in the [examples](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples) directory. -To run the examples, make sure to set your OpenAI API key in -the environment variable `OPENAI_API_KEY` by running `export OPENAI_API_KEY=sk-...` in your terminal. - See the [docs](https://docs.bespokelabs.ai/) for more details as well as for troubleshooting information. @@ -115,6 +158,12 @@ curator-viewer This will pop up a browser window with the viewer running on `127.0.0.1:3000` by default if you haven't specified a different host and port. +The dataset viewer shows all the different runs you have made. +![Curator Runs](docs/curator-runs.png) + +You can also see the dataset and the responses from the LLM. +![Curator Dataset](docs/curator-dataset.png) + Optional parameters to run the viewer on a different host and port: ```bash @@ -152,4 +201,4 @@ npm -v # should print `10.9.0` ``` ## Contributing -Contributions are welcome! \ No newline at end of file +Contributions are welcome! diff --git a/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx b/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx index f4a74ca6..43f13ff9 100644 --- a/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx +++ b/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx @@ -10,11 +10,6 @@ export default async function DatasetPage({ const { runHash } = await params const { batchMode } = await searchParams const isBatchMode = batchMode === '1' - return ( - - - - - - ) + + return } diff --git a/bespoke-dataset-viewer/app/layout.tsx b/bespoke-dataset-viewer/app/layout.tsx index 03590f7a..a0b8e263 100644 --- a/bespoke-dataset-viewer/app/layout.tsx +++ b/bespoke-dataset-viewer/app/layout.tsx @@ -1,6 +1,6 @@ import type { Metadata } from "next"; import "./globals.css"; - +import { Toaster } from "@/components/ui/toaster" export const metadata: Metadata = { title: "Curator Viewer", @@ -13,10 +13,11 @@ export default function RootLayout({ children: React.ReactNode }) { return ( - - + + {children} + ) -} +} \ No newline at end of file diff --git a/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx b/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx index c3f673cf..ea36a394 100644 --- a/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx +++ b/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx @@ -7,6 +7,7 @@ import { Copy } from "lucide-react" import { DataItem } from "@/types/dataset" import { useCallback } from "react" import { Sheet, SheetContent } from "@/components/ui/sheet" +import { useToast } from "@/components/ui/use-toast" interface DetailsSidebarProps { item: DataItem | null @@ -14,15 +15,26 @@ interface DetailsSidebarProps { } export function DetailsSidebar({ item, onClose }: DetailsSidebarProps) { + const { toast } = useToast() + const copyToClipboard = useCallback(async (text: string) => { try { await navigator.clipboard.writeText(text) - alert("Copied to clipboard!") + toast({ + title: "Success", + description: "Copied to clipboard!", + duration: 2000, + }) } catch (err) { console.error("Failed to copy:", err) - alert("Failed to copy to clipboard") + toast({ + variant: "destructive", + title: "Error", + description: "Failed to copy to clipboard", + duration: 2000, + }) } - }, []) + }, [toast]) if (!item) return null diff --git a/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx b/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx index d85a42be..299fef0c 100644 --- a/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx +++ b/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx @@ -39,8 +39,8 @@ class Poems(BaseModel): poems_list: List[Poem] = Field(description="A list of poems.") -# We define a Prompter that generates poems which gets applied to the topics dataset. -poet = curator.Prompter( +# We define an LLM object that generates poems which gets applied to the topics dataset. +poet = curator.LLM( # prompt_func takes a row of the dataset as input. # row is a dictionary with a single key 'topic' in this case. prompt_func=lambda row: f"Write two poems about {row['topic']}.", diff --git a/bespoke-dataset-viewer/components/ui/use-toast.ts b/bespoke-dataset-viewer/components/ui/use-toast.ts index c2fbf3f9..05d1fefe 100644 --- a/bespoke-dataset-viewer/components/ui/use-toast.ts +++ b/bespoke-dataset-viewer/components/ui/use-toast.ts @@ -6,7 +6,7 @@ import type { } from "@/components/ui/toast" const TOAST_LIMIT = 1 -const TOAST_REMOVE_DELAY = 1000000 +const TOAST_REMOVE_DELAY = 3000 type ToasterToast = ToastProps & { id: string diff --git a/bespoke-dataset-viewer/package-lock.json b/bespoke-dataset-viewer/package-lock.json index 97c34801..f9e02c46 100644 --- a/bespoke-dataset-viewer/package-lock.json +++ b/bespoke-dataset-viewer/package-lock.json @@ -36,6 +36,7 @@ }, "devDependencies": { "@types/node": "^20", + "@types/prismjs": "^1.26.5", "@types/react": "^18", "@types/react-dom": "^18", "eslint": "^8", @@ -1860,6 +1861,13 @@ "undici-types": "~6.19.2" } }, + "node_modules/@types/prismjs": { + "version": "1.26.5", + "resolved": "https://registry.npmjs.org/@types/prismjs/-/prismjs-1.26.5.tgz", + "integrity": "sha512-AUZTa7hQ2KY5L7AmtSiqxlhWxb4ina0yd8hNbl4TWuqnv/pFP0nDMb3YrfSBf4hJVGLh2YEIBfKaBW/9UEl6IQ==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/prop-types": { "version": "15.7.13", "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.13.tgz", diff --git a/bespoke-dataset-viewer/package.json b/bespoke-dataset-viewer/package.json index 62643150..dee9be50 100644 --- a/bespoke-dataset-viewer/package.json +++ b/bespoke-dataset-viewer/package.json @@ -37,6 +37,7 @@ }, "devDependencies": { "@types/node": "^20", + "@types/prismjs": "^1.26.5", "@types/react": "^18", "@types/react-dom": "^18", "eslint": "^8", diff --git a/build_pkg.py b/build_pkg.py index 80de2549..b9a6e57e 100644 --- a/build_pkg.py +++ b/build_pkg.py @@ -81,7 +81,7 @@ def nextjs_build(): def run_pytest(): print("Running pytest") try: - run_command("pytest", cwd="tests") + run_command("pytest") except subprocess.CalledProcessError: print("Pytest failed. Aborting build.") sys.exit(1) diff --git a/docs/curator-dataset.png b/docs/curator-dataset.png new file mode 100644 index 00000000..33138ac3 Binary files /dev/null and b/docs/curator-dataset.png differ diff --git a/docs/curator-responses.png b/docs/curator-responses.png new file mode 100644 index 00000000..a78277e0 Binary files /dev/null and b/docs/curator-responses.png differ diff --git a/docs/curator-runs.png b/docs/curator-runs.png new file mode 100644 index 00000000..d076d9b1 Binary files /dev/null and b/docs/curator-runs.png differ diff --git a/examples/camel.py b/examples/camel.py index bffa0507..b9bdfee1 100644 --- a/examples/camel.py +++ b/examples/camel.py @@ -22,14 +22,14 @@ class QAs(BaseModel): qas: List[QA] = Field(description="A list of QAs") -subject_prompter = curator.Prompter( +subject_prompter = curator.LLM( prompt_func=lambda: f"Generate a diverse list of 3 subjects. Keep it high-level (e.g. Math, Science).", parse_func=lambda _, subjects: [subject for subject in subjects.subjects], model_name="gpt-4o-mini", response_format=Subjects, ) subject_dataset = subject_prompter() -subsubject_prompter = curator.Prompter( +subsubject_prompter = curator.LLM( prompt_func=lambda subject: f"For the given subject {subject}. Generate 3 diverse subsubjects. No explanation.", parse_func=lambda subject, subsubjects: [ {"subject": subject["subject"], "subsubject": subsubject.subject} @@ -40,7 +40,7 @@ class QAs(BaseModel): ) subsubject_dataset = subsubject_prompter(subject_dataset) -qa_prompter = curator.Prompter( +qa_prompter = curator.LLM( prompt_func=lambda subsubject: f"For the given subsubject {subsubject}. Generate 3 diverse questions and answers. No explanation.", model_name="gpt-4o-mini", response_format=QAs, diff --git a/examples/distill.py b/examples/distill.py index 20cfd53b..b9e7c7bb 100644 --- a/examples/distill.py +++ b/examples/distill.py @@ -21,7 +21,7 @@ def parse_func(row, response): return {"instruction": instruction, "new_response": response} -distill_prompter = curator.Prompter( +distill_prompter = curator.LLM( prompt_func=prompt_func, parse_func=parse_func, model_name="gpt-4o-mini", diff --git a/examples/litellm_recipe_prompting.py b/examples/litellm_recipe_prompting.py index 87446e01..85449389 100644 --- a/examples/litellm_recipe_prompting.py +++ b/examples/litellm_recipe_prompting.py @@ -31,7 +31,7 @@ def main(): # 3. Set environment variable: GEMINI_API_KEY ############################################# - recipe_prompter = curator.Prompter( + recipe_prompter = curator.LLM( model_name="gemini/gemini-1.5-flash", prompt_func=lambda row: f"Generate a random {row['cuisine']} recipe. Be creative but keep it realistic.", parse_func=lambda row, response: { diff --git a/examples/litellm_recipe_structured_output.py b/examples/litellm_recipe_structured_output.py index 747411e9..bb9ad12b 100644 --- a/examples/litellm_recipe_structured_output.py +++ b/examples/litellm_recipe_structured_output.py @@ -28,7 +28,7 @@ def main(): # 2. Generate an API key or use an existing API key # 3. Set environment variable: ANTHROPIC_API_KEY ############################################# - cuisines_generator = curator.Prompter( + cuisines_generator = curator.LLM( prompt_func=lambda: f"Generate 10 diverse cuisines.", model_name="claude-3-5-haiku-20241022", response_format=Cuisines, @@ -44,7 +44,7 @@ def main(): # 2. Generate an API key or use an existing API key # 3. Set environment variable: GEMINI_API_KEY ############################################# - recipe_prompter = curator.Prompter( + recipe_prompter = curator.LLM( model_name="gemini/gemini-1.5-flash", prompt_func=lambda row: f"Generate a random {row['cuisine']} recipe. Be creative but keep it realistic.", parse_func=lambda row, response: { diff --git a/examples/persona-hub/synthesize.py b/examples/persona-hub/synthesize.py index 232d6e1f..a2ef68dd 100644 --- a/examples/persona-hub/synthesize.py +++ b/examples/persona-hub/synthesize.py @@ -31,7 +31,7 @@ def get_generator(template): def prompt_func(row): return template.format(persona=row["persona"]) - generator = curator.Prompter( + generator = curator.LLM( prompt_func=prompt_func, model_name="gpt-4o", temperature=0.7, diff --git a/examples/poem.py b/examples/poem.py index e8e50d07..1b59e562 100644 --- a/examples/poem.py +++ b/examples/poem.py @@ -17,7 +17,7 @@ class Topics(BaseModel): # We define a prompter that generates topics. -topic_generator = curator.Prompter( +topic_generator = curator.LLM( prompt_func=lambda: "Generate 10 diverse topics that are suitable for writing poems about.", model_name="gpt-4o-mini", response_format=Topics, @@ -35,8 +35,8 @@ class Poems(BaseModel): poems_list: List[str] = Field(description="A list of poems.") -# We define a prompter that generates poems which gets applied to the topics dataset. -poet = curator.Prompter( +# We define an `LLM` object that generates poems which gets applied to the topics dataset. +poet = curator.LLM( # The prompt_func takes a row of the dataset as input. # The row is a dictionary with a single key 'topic' in this case. prompt_func=lambda row: f"Write two poems about {row['topic']}.", diff --git a/examples/simple_poem.py b/examples/simple_poem.py new file mode 100644 index 00000000..8b1f5106 --- /dev/null +++ b/examples/simple_poem.py @@ -0,0 +1,25 @@ +"""Curator example that uses `SimpleLLM` to generate poems. + +Please see the poem.py for more complex use cases. +""" + +from bespokelabs import curator + +# Use GPT-4o-mini for this example. +llm = curator.SimpleLLM(model_name="gpt-4o-mini") +poem = llm("Write a poem about the importance of data in AI.") +print(poem) + +# Use Claude 3.5 Sonnet for this example. +llm = curator.SimpleLLM(model_name="claude-3-5-sonnet-20240620", backend="litellm") +poem = llm("Write a poem about the importance of data in AI.") +print(poem) + +# Note that we can also pass a list of prompts to generate multiple responses. +poems = llm( + [ + "Write a sonnet about the importance of data in AI.", + "Write a haiku about the importance of data in AI.", + ] +) +print(poems) diff --git a/pyproject.toml b/pyproject.toml index 54a6b10d..c594ccd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "bespokelabs-curator" -version = "0.1.11" +version = "0.1.12" description = "Bespoke Labs Curator" authors = ["Bespoke Labs "] readme = "README.md" diff --git a/src/bespokelabs/curator/__init__.py b/src/bespokelabs/curator/__init__.py index bb0b7aa2..5ef73092 100644 --- a/src/bespokelabs/curator/__init__.py +++ b/src/bespokelabs/curator/__init__.py @@ -1,2 +1,3 @@ from .dataset import Dataset -from .prompter.prompter import Prompter +from .llm.llm import LLM +from .llm.simple_llm import SimpleLLM diff --git a/src/bespokelabs/curator/dataset.py b/src/bespokelabs/curator/dataset.py index b0abece0..180be4c6 100644 --- a/src/bespokelabs/curator/dataset.py +++ b/src/bespokelabs/curator/dataset.py @@ -1,15 +1,13 @@ import glob -import json import logging import os from typing import Any, Dict, Iterable, Iterator, List, TypeVar -import pandas as pd from datasets import Dataset as HFDataset from datasets.arrow_writer import ArrowWriter, SchemaInferenceError from pydantic import BaseModel -from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter +from bespokelabs.curator.llm.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_response import GenericResponse T = TypeVar("T") diff --git a/src/bespokelabs/curator/file_utilities.py b/src/bespokelabs/curator/file_utilities.py new file mode 100644 index 00000000..6ee606e7 --- /dev/null +++ b/src/bespokelabs/curator/file_utilities.py @@ -0,0 +1,14 @@ +# https://stackoverflow.com/questions/845058/how-to-get-the-line-count-of-a-large-file-cheaply-in-python +# https://stackoverflow.com/a/68385697 +def _file_gen(reader): + b = reader(1024 * 1024) + while b: + yield b + b = reader(1024 * 1024) + + +# Instead of requiring counting lines, we can store metadata file that has the number of requests in each file +def count_lines(filename): + f = open(filename, "rb") + f_gen = _file_gen(f.raw.read) + return sum(buf.count(b"\n") for buf in f_gen) diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/llm/llm.py similarity index 78% rename from src/bespokelabs/curator/prompter/prompter.py rename to src/bespokelabs/curator/llm/llm.py index 61e9e99e..9a53beeb 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/llm/llm.py @@ -7,82 +7,47 @@ from io import BytesIO from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar, Union -import dill from datasets import Dataset +from datasets.utils._dill import Pickler from pydantic import BaseModel from xxhash import xxh64 from bespokelabs.curator.db import MetadataDB -from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter -from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor +from bespokelabs.curator.llm.prompt_formatter import PromptFormatter +from bespokelabs.curator.request_processor.base_request_processor import ( + BaseRequestProcessor, +) +from bespokelabs.curator.request_processor.litellm_online_request_processor import ( + LiteLLMOnlineRequestProcessor, +) from bespokelabs.curator.request_processor.openai_batch_request_processor import ( OpenAIBatchRequestProcessor, ) from bespokelabs.curator.request_processor.openai_online_request_processor import ( OpenAIOnlineRequestProcessor, ) -from bespokelabs.curator.request_processor.litellm_online_request_processor import ( - LiteLLMOnlineRequestProcessor, -) _CURATOR_DEFAULT_CACHE_DIR = "~/.cache/curator" T = TypeVar("T") +_DictOrBaseModel = Union[Dict[str, Any], BaseModel] logger = logger = logging.getLogger(__name__) -class Prompter: +class LLM: """Interface for prompting LLMs.""" - @staticmethod - def _determine_backend( - model_name: str, response_format: Optional[Type[BaseModel]] = None - ) -> str: - """Determine which backend to use based on model name and response format. - - Args: - model_name (str): Name of the model - response_format (Optional[Type[BaseModel]]): Response format if specified - - Returns: - str: Backend to use ("openai" or "litellm") - """ - model_name = model_name.lower() - - # GPT-4o models with response format should use OpenAI - if ( - response_format - and OpenAIOnlineRequestProcessor(model_name).check_structured_output_support() - ): - logger.info(f"Requesting structured output from {model_name}, using OpenAI backend") - return "openai" - - # GPT models and O1 models without response format should use OpenAI - if not response_format and any(x in model_name for x in ["gpt-", "o1-preview", "o1-mini"]): - logger.info(f"Requesting text output from {model_name}, using OpenAI backend") - return "openai" - - # Default to LiteLLM for all other cases - logger.info( - f"Requesting {f'structured' if response_format else 'text'} output from {model_name}, using LiteLLM backend" - ) - return "litellm" - def __init__( self, model_name: str, - prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]], + prompt_func: Callable[[_DictOrBaseModel], _DictOrBaseModel], parse_func: Optional[ - Callable[ - [ - Union[Dict[str, Any], BaseModel], - Union[Dict[str, Any], BaseModel], - ], - T, - ] + Callable[[_DictOrBaseModel, _DictOrBaseModel], _DictOrBaseModel] ] = None, response_format: Optional[Type[BaseModel]] = None, backend: Optional[str] = None, + max_requests_per_minute: Optional[int] = None, + max_tokens_per_minute: Optional[int] = None, batch: bool = False, batch_size: Optional[int] = None, batch_check_interval: Optional[int] = 60, @@ -92,38 +57,32 @@ def __init__( top_p: Optional[float] = None, presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, + max_retries: Optional[int] = None, + require_all_responses: Optional[bool] = True, ): - """Initialize a Prompter. + """Initialize a LLM. Args: - model_name (str): The name of the LLM to use - prompt_func (Callable[[Dict[str, Any]], Union[str, List[Dict[str, Any]]]]): A function that takes a single row + model_name: The name of the LLM to use + prompt_func: A function that takes a single row and returns either a string (assumed to be a user prompt) or messages list - parse_func (Callable[[Dict[str, Any], Any], T]): A function that takes the input row and + parse_func: A function that takes the input row and response object and returns the parsed output - response_format (Optional[Type[BaseModel]]): A Pydantic model specifying the + response_format: A Pydantic model specifying the response format from the LLM. - backend (Optional[str]): The backend to use ("openai" or "litellm"). If None, will be auto-determined - batch (bool): Whether to use batch processing - batch_size (Optional[int]): The size of the batch to use, only used if batch is True - temperature (Optional[float]): The temperature to use for the LLM, only used if batch is False - top_p (Optional[float]): The top_p to use for the LLM, only used if batch is False - presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False - frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False + backend: The backend to use ("openai" or "litellm"). If None, will be auto-determined + batch: Whether to use batch processing + batch_size: The size of the batch to use, only used if batch is True + batch_check_interval: The interval to check for batch completions, only used if batch is True + delete_successful_batch_files: Whether to delete successful batch files, only used if batch is True + delete_failed_batch_files: Whether to delete failed batch files, only used if batch is True + temperature: The temperature to use for the LLM, only used if batch is False + top_p: The top_p to use for the LLM, only used if batch is False + presence_penalty: The presence_penalty to use for the LLM, only used if batch is False + frequency_penalty: The frequency_penalty to use for the LLM, only used if batch is False + max_retries: The maximum number of retries to use for the LLM + require_all_responses: Whether to require all responses """ - prompt_sig = inspect.signature(prompt_func) - if len(prompt_sig.parameters) > 1: - raise ValueError( - f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}" - ) - - if parse_func is not None: - parse_sig = inspect.signature(parse_func) - if len(parse_sig.parameters) != 2: - raise ValueError( - f"parse_func must take exactly 2 arguments, got {len(parse_sig.parameters)}" - ) - self.prompt_formatter = PromptFormatter( model_name, prompt_func, parse_func, response_format ) @@ -144,6 +103,10 @@ def __init__( logger.info( f"batch=True but no batch_size provided, using default batch_size of {batch_size:,}" ) + if max_requests_per_minute is not None or max_tokens_per_minute is not None: + logger.warning( + "max_requests_per_minute and max_tokens_per_minute not supported with batch mode" + ) self._request_processor = OpenAIBatchRequestProcessor( model=model_name, batch_size=batch_size, @@ -154,11 +117,13 @@ def __init__( frequency_penalty=frequency_penalty, delete_successful_batch_files=delete_successful_batch_files, delete_failed_batch_files=delete_failed_batch_files, + max_retries=max_retries, + require_all_responses=require_all_responses, ) else: if batch_size is not None: logger.warning( - f"Prompter argument `batch_size` {batch_size} is ignored because `batch` is False" + f"LLM argument `batch_size` {batch_size} is ignored because `batch` is False" ) self._request_processor = OpenAIOnlineRequestProcessor( model=model_name, @@ -166,6 +131,10 @@ def __init__( top_p=top_p, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, + max_requests_per_minute=max_requests_per_minute, + max_tokens_per_minute=max_tokens_per_minute, + max_retries=max_retries, + require_all_responses=require_all_responses, ) elif self.backend == "litellm": if batch: @@ -178,10 +147,48 @@ def __init__( top_p=top_p, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, + max_requests_per_minute=max_requests_per_minute, + max_tokens_per_minute=max_tokens_per_minute, + max_retries=max_retries, + require_all_responses=require_all_responses, ) else: raise ValueError(f"Unknown backend: {self.backend}") + @staticmethod + def _determine_backend( + model_name: str, response_format: Optional[Type[BaseModel]] = None + ) -> str: + """Determine which backend to use based on model name and response format. + + Args: + model_name (str): Name of the model + response_format (Optional[Type[BaseModel]]): Response format if specified + + Returns: + str: Backend to use ("openai" or "litellm") + """ + model_name = model_name.lower() + + # GPT-4o models with response format should use OpenAI + if ( + response_format + and OpenAIOnlineRequestProcessor(model_name).check_structured_output_support() + ): + logger.info(f"Requesting structured output from {model_name}, using OpenAI backend") + return "openai" + + # GPT models and O1 models without response format should use OpenAI + if not response_format and any(x in model_name for x in ["gpt-", "o1-preview", "o1-mini"]): + logger.info(f"Requesting text output from {model_name}, using OpenAI backend") + return "openai" + + # Default to LiteLLM for all other cases + logger.info( + f"Requesting {f'structured' if response_format else 'text'} output from {model_name}, using LiteLLM backend" + ) + return "litellm" + def __call__( self, dataset: Optional[Iterable] = None, @@ -211,7 +218,7 @@ def _completions( Args: dataset (Iterable): A dataset consisting of a list of items to apply completions - prompter (Prompter): A Prompter that contains the logic for formatting each + prompter (LLM): A LLM that contains the logic for formatting each item in the dataset working_dir (str): The working directory to save the requests.jsonl, responses.jsonl, and dataset.arrow files. @@ -223,7 +230,7 @@ def _completions( dataset = Dataset.from_generator(dataset) if self is None: - raise ValueError("Prompter must be provided") + raise ValueError("LLM must be provided") if working_dir is None: curator_cache_dir = os.environ.get( @@ -311,7 +318,7 @@ def _get_function_hash(func) -> str: return xxh64("").hexdigest() file = BytesIO() - dill.Pickler(file, recurse=True).dump(func) + Pickler(file, recurse=True).dump(func) return xxh64(file.getvalue()).hexdigest() diff --git a/src/bespokelabs/curator/llm/prompt_formatter.py b/src/bespokelabs/curator/llm/prompt_formatter.py new file mode 100644 index 00000000..4dae93ce --- /dev/null +++ b/src/bespokelabs/curator/llm/prompt_formatter.py @@ -0,0 +1,140 @@ +import dataclasses +import inspect +import json +import logging +from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union + +from pydantic import BaseModel, ValidationError + +from bespokelabs.curator.request_processor.generic_request import GenericRequest + +T = TypeVar("T") +_DictOrBaseModel = Union[Dict[str, Any], BaseModel] +logger = logging.getLogger(__name__) + + +def _validate_messages(messages: list[dict]) -> None: + """Validates that messages conform to the expected chat format. + + Args: + messages: A list of message dictionaries to validate. + + Raises: + ValueError: If messages don't meet the required format: + - Must be a list of dictionaries + - Each message must have 'role' and 'content' keys + - Role must be one of: 'system', 'user', 'assistant' + """ + valid_roles = {"system", "user", "assistant"} + + for msg in messages: + if not isinstance(msg, dict): + raise ValueError( + "In the return value (a list) of the prompt_func, each " + "message must be a dictionary" + ) + + if "role" not in msg or "content" not in msg: + raise ValueError( + "In the return value (a list) of the prompt_func, each " + "message must contain 'role' and 'content' keys" + ) + + if msg["role"] not in valid_roles: + raise ValueError( + f"In the return value (a list) of the prompt_func, " + f"each message role must be one of: {', '.join(sorted(valid_roles))}" + ) + + +@dataclasses.dataclass +class PromptFormatter: + model_name: str + prompt_func: Callable[[_DictOrBaseModel], Dict[str, str]] + parse_func: Optional[Callable[[_DictOrBaseModel, _DictOrBaseModel], T]] = None + response_format: Optional[Type[BaseModel]] = None + + def create_generic_request(self, row: _DictOrBaseModel, idx: int) -> GenericRequest: + """Format the request object based off of `LLM` attributes.""" + sig = inspect.signature(self.prompt_func) + if len(sig.parameters) == 0: + prompts = self.prompt_func() + elif len(sig.parameters) == 1: + prompts = self.prompt_func(row) + else: + raise ValueError(f"Prompting function {self.prompt_func} must have 0 or 1 arguments.") + + if isinstance(prompts, str): + messages = [{"role": "user", "content": prompts}] + elif isinstance(prompts, list): + _validate_messages(prompts) + messages = prompts + else: + raise ValueError("The return value of the prompt_func must be a list of dictionaries.") + + # Convert BaseModel to dict for serialization + if isinstance(row, BaseModel): + row = row.model_dump() + + return GenericRequest( + model=self.model_name, + messages=messages, + original_row=row, + original_row_idx=idx, + response_format=( + self.response_format.model_json_schema() if self.response_format else None + ), + ) + + def response_to_response_format(self, response_message: str | dict) -> Optional[dict | str]: + """ + Converts a response message to a specified Pydantic model format. + + This method takes a response message (either as a string or dict) and validates/converts it + according to the provided Pydantic model format. If the response message is a string, + it first attempts to parse it as JSON. The resulting dict is then used to construct + an instance of the specified Pydantic model. + + Args: + response_message (str | dict): The response message to convert, either as a JSON string + or a dictionary. + response_format (Optional[BaseModel]): The Pydantic model class that defines the + expected format of the response. + + Returns: + Optional[dict | str]: The validated response message as a Pydantic model instance. + + Raises: + json.JSONDecodeError: If the response_message is a string but cannot be parsed as valid JSON. + ValidationError: If the parsed response does not match the schema defined by response_format. + """ + # Response message is a string, which is converted to a dict + # The dict is then used to construct the response_format Pydantic model + if self.response_format is None: + return response_message + + try: + # First try to parse the response message as JSON + if isinstance(response_message, str): + try: + response_dict = json.loads(response_message) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse response message as JSON: {response_message}. " + f"The model likely returned an invalid JSON format." + ) + raise e + else: + response_dict = response_message + + # Then construct the Pydantic model from the parsed dict + response_message = self.response_format(**response_dict) + return response_message + + except ValidationError as e: + schema_str = json.dumps(self.response_format.model_json_schema(), indent=2) + logger.warning( + f"Pydantic failed to parse response message {response_message} with `response_format` {schema_str}. " + f"The model likely returned a JSON that does not match the schema of the `response_format`." + ) + raise e diff --git a/src/bespokelabs/curator/llm/prompt_formatter_test.py b/src/bespokelabs/curator/llm/prompt_formatter_test.py new file mode 100644 index 00000000..b39b2226 --- /dev/null +++ b/src/bespokelabs/curator/llm/prompt_formatter_test.py @@ -0,0 +1,79 @@ +import pytest +from pydantic import BaseModel + +from bespokelabs.curator.llm.prompt_formatter import PromptFormatter, _validate_messages + + +def test_validate_messages_valid(): + """Tests that valid message formats pass validation.""" + valid_messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + # Should not raise any exceptions + _validate_messages(valid_messages) + + +def test_validate_messages_invalid_format(): + """Tests that invalid message formats raise appropriate errors.""" + # Test non-dict message + with pytest.raises(ValueError, match="must be a dictionary"): + _validate_messages([["role", "content"]]) + + # Test missing required keys + with pytest.raises(ValueError, match="must contain 'role' and 'content' keys"): + _validate_messages([{"role": "user"}]) + + # Test invalid role + with pytest.raises(ValueError, match="must be one of: assistant, system, user"): + _validate_messages([{"role": "invalid", "content": "test"}]) + + +class TestResponse(BaseModel): + text: str + + +def test_prompt_formatter_create_generic_request(): + """Tests that PromptFormatter correctly creates GenericRequest objects.""" + # Test with string prompt + formatter = PromptFormatter( + model_name="test-model", prompt_func=lambda x: "Hello", response_format=TestResponse + ) + request = formatter.create_generic_request({"input": "test"}, 0) + + assert request.model == "test-model" + assert request.messages == [{"role": "user", "content": "Hello"}] + assert request.original_row == {"input": "test"} + assert request.original_row_idx == 0 + assert request.response_format is not None + + # Test with message list prompt + formatter = PromptFormatter( + model_name="test-model", + prompt_func=lambda x: [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"}, + ], + ) + request = formatter.create_generic_request({"input": "test"}, 1) + + assert len(request.messages) == 2 + assert request.messages[0]["role"] == "system" + assert request.messages[1]["role"] == "user" + assert request.original_row_idx == 1 + + +def test_prompt_formatter_invalid_prompt_func(): + """Tests that PromptFormatter raises errors for invalid prompt functions.""" + # Test prompt function with too many parameters + with pytest.raises(ValueError, match="must have 0 or 1 arguments"): + PromptFormatter(model_name="test", prompt_func=lambda x, y: "test").create_generic_request( + {}, 0 + ) + + # Test invalid prompt function return type + with pytest.raises(ValueError, match="must be a list of dictionaries"): + PromptFormatter( + model_name="test", prompt_func=lambda x: {"invalid": "format"} + ).create_generic_request({}, 0) diff --git a/src/bespokelabs/curator/llm/simple_llm.py b/src/bespokelabs/curator/llm/simple_llm.py new file mode 100644 index 00000000..7cc62dd0 --- /dev/null +++ b/src/bespokelabs/curator/llm/simple_llm.py @@ -0,0 +1,33 @@ +from bespokelabs.curator.llm.llm import LLM +from datasets import Dataset +from typing import Union, List + + +class SimpleLLM: + """A simpler interface for the LLM class. + + Usage: + llm = SimpleLLM(model_name="gpt-4o-mini") + llm("Do you know about the bitter lesson?") + llm(["What is the capital of France?", "What is the capital of Germany?"]) + For more complex use cases (e.g. structured outputs and custom prompt functions), see the LLM class. + """ + + def __init__(self, model_name: str, backend: str = "openai"): + self._model_name = model_name + self._backend = backend + + def __call__(self, prompt: Union[str, List[str]]) -> Union[str, List[str]]: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + dataset: Dataset = Dataset.from_dict({"prompt": prompt_list}) + + llm = LLM( + prompt_func=lambda row: row["prompt"], + model_name=self._model_name, + response_format=None, + backend=self._backend, + ) + response = llm(dataset) + if isinstance(prompt, str): + return response["response"][0] + return response["response"] diff --git a/src/bespokelabs/curator/prompter/prompt_formatter.py b/src/bespokelabs/curator/prompter/prompt_formatter.py deleted file mode 100644 index 5682c978..00000000 --- a/src/bespokelabs/curator/prompter/prompt_formatter.py +++ /dev/null @@ -1,73 +0,0 @@ -import inspect -from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union - -from pydantic import BaseModel - -from bespokelabs.curator.request_processor.generic_request import GenericRequest - -T = TypeVar("T") - - -class PromptFormatter: - model_name: str - prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]] - parse_func: Optional[ - Callable[ - [ - Union[Dict[str, Any], BaseModel], - Union[Dict[str, Any], BaseModel], - ], - T, - ] - ] = None - response_format: Optional[Type[BaseModel]] = None - - def __init__( - self, - model_name: str, - prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]], - parse_func: Optional[ - Callable[ - [ - Union[Dict[str, Any], BaseModel], - Union[Dict[str, Any], BaseModel], - ], - T, - ] - ] = None, - response_format: Optional[Type[BaseModel]] = None, - ): - self.model_name = model_name - self.prompt_func = prompt_func - self.parse_func = parse_func - self.response_format = response_format - - def create_generic_request(self, row: Dict[str, Any] | BaseModel, idx: int) -> GenericRequest: - """Format the request object based off Prompter attributes.""" - sig = inspect.signature(self.prompt_func) - if len(sig.parameters) == 0: - prompts = self.prompt_func() - elif len(sig.parameters) == 1: - prompts = self.prompt_func(row) - else: - raise ValueError(f"Prompting function {self.prompt_func} must have 0 or 1 arguments.") - - if isinstance(prompts, str): - messages = [{"role": "user", "content": prompts}] - else: - # TODO(Ryan): Add validation here - messages = prompts - - # Convert BaseModel to dict for serialization - if isinstance(row, BaseModel): - row = row.model_dump() - - return GenericRequest( - model=self.model_name, - messages=messages, - original_row=row, - original_row_idx=idx, - response_format=( - self.response_format.model_json_schema() if self.response_format else None - ), - ) diff --git a/src/bespokelabs/curator/request_processor/base_online_request_processor.py b/src/bespokelabs/curator/request_processor/base_online_request_processor.py index 7e95cbc0..51537125 100644 --- a/src/bespokelabs/curator/request_processor/base_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_online_request_processor.py @@ -13,7 +13,7 @@ from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor -from bespokelabs.curator.prompter.prompter import PromptFormatter +from bespokelabs.curator.llm.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.event_loop import run_in_event_loop from bespokelabs.curator.request_processor.generic_response import GenericResponse @@ -22,6 +22,12 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +DEFAULT_MAX_REQUESTS_PER_MINUTE = 100 +DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000 +DEFAULT_MAX_RETRIES = 10 +SECONDS_TO_PAUSE_ON_RATE_LIMIT = 10 +DEFAULT_REQUEST_TIMEOUT = 10 * 60 # 10 minutes + @dataclass class StatusTracker: @@ -42,7 +48,9 @@ class StatusTracker: max_tokens_per_minute: int = 0 pbar: tqdm = field(default=None) response_cost: float = 0 - time_of_last_rate_limit_error: float = field(default=None) + time_of_last_rate_limit_error: float = field( + default=time.time() - SECONDS_TO_PAUSE_ON_RATE_LIMIT + ) def __str__(self): return ( @@ -119,14 +127,61 @@ def __init__( top_p: Optional[float] = None, presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, + max_requests_per_minute: Optional[int] = None, + max_tokens_per_minute: Optional[int] = None, + require_all_responses: bool = None, + max_retries: Optional[int] = None, ): - super().__init__(batch_size=None) + super().__init__(batch_size=None, require_all_responses=require_all_responses) self.model: str = model self.temperature: float | None = temperature self.top_p: float | None = top_p self.presence_penalty: float | None = presence_penalty self.frequency_penalty: float | None = frequency_penalty self.prompt_formatter: Optional[PromptFormatter] = None + self.manual_max_requests_per_minute: Optional[int] = max_requests_per_minute + self.manual_max_tokens_per_minute: Optional[int] = max_tokens_per_minute + if max_retries is None: + self.max_retries = DEFAULT_MAX_RETRIES + else: + self.max_retries = max_retries + self.timeout = DEFAULT_REQUEST_TIMEOUT + + @property + def max_requests_per_minute(self) -> int: + if self.manual_max_requests_per_minute: + logger.info( + f"Manually set max_requests_per_minute to {self.manual_max_requests_per_minute}" + ) + return self.manual_max_requests_per_minute + elif self.header_based_max_requests_per_minute: + logger.info( + f"Automatically set max_requests_per_minute to {self.header_based_max_requests_per_minute}" + ) + return self.header_based_max_requests_per_minute + else: + logger.warning( + f"No manual max_requests_per_minute set, and headers based detection failed, using default value of {DEFAULT_MAX_REQUESTS_PER_MINUTE}" + ) + return DEFAULT_MAX_REQUESTS_PER_MINUTE + + @property + def max_tokens_per_minute(self) -> int: + if self.manual_max_tokens_per_minute: + logger.info( + f"Manually set max_tokens_per_minute to {self.manual_max_tokens_per_minute}" + ) + return self.manual_max_tokens_per_minute + elif self.header_based_max_tokens_per_minute: + logger.info( + f"Automatically set max_tokens_per_minute to {self.header_based_max_tokens_per_minute}" + ) + return self.header_based_max_tokens_per_minute + else: + logger.warning( + f"No manual max_tokens_per_minute set, and headers based detection failed, using default value of {DEFAULT_MAX_TOKENS_PER_MINUTE}" + ) + return DEFAULT_MAX_TOKENS_PER_MINUTE @abstractmethod def estimate_total_tokens(self, messages: list) -> int: @@ -149,6 +204,11 @@ def run( parse_func_hash: str, prompt_formatter: PromptFormatter, ) -> Dataset: + # load from already completed dataset + output_dataset = self.attempt_loading_cached_dataset(working_dir, parse_func_hash) + if output_dataset is not None: + return output_dataset + """Run completions using the online API with async processing.""" logger.info(f"Running {self.__class__.__name__} completions with model: {self.model}") @@ -169,7 +229,6 @@ def run( self.process_requests_from_file( generic_request_filepath=request_file, save_filepath=response_file, - max_attempts=5, resume=True, ) ) @@ -180,7 +239,6 @@ async def process_requests_from_file( self, generic_request_filepath: str, save_filepath: str, - max_attempts: int, resume: bool, resume_no_retry: bool = False, ) -> None: @@ -191,10 +249,8 @@ async def process_requests_from_file( status_tracker = StatusTracker() # Get rate limits - rate_limits = self.get_rate_limits() - status_tracker.max_requests_per_minute = rate_limits["max_requests_per_minute"] - status_tracker.max_tokens_per_minute = rate_limits["max_tokens_per_minute"] - rpm = rate_limits["max_requests_per_minute"] + status_tracker.max_requests_per_minute = self.max_requests_per_minute + status_tracker.max_tokens_per_minute = self.max_tokens_per_minute soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit( @@ -206,7 +262,7 @@ async def process_requests_from_file( completed_request_ids = set() if os.path.exists(save_filepath): if resume: - logger.debug(f"Resuming progress from existing file: {save_filepath}") + logger.info(f"Resuming progress by reading existing file: {save_filepath}") logger.debug( f"Removing all failed requests from {save_filepath} so they can be retried" ) @@ -224,6 +280,11 @@ async def process_requests_from_file( f"{response.response_errors}, removing from output and will retry" ) num_previously_failed_requests += 1 + if response.response_message is None: + logger.debug( + f"Request {response.generic_request.original_row_idx} previously failed due to no response, removing from output and will retry" + ) + num_previously_failed_requests += 1 else: completed_request_ids.add(response.generic_request.original_row_idx) output_file.write(line) @@ -279,7 +340,7 @@ async def process_requests_from_file( ) # Use higher connector limit for better throughput - connector = aiohttp.TCPConnector(limit=10 * rpm) + connector = aiohttp.TCPConnector(limit=10 * status_tracker.max_requests_per_minute) async with aiohttp.ClientSession( connector=connector ) as session: # Initialize ClientSession here @@ -297,7 +358,7 @@ async def process_requests_from_file( task_id=status_tracker.num_tasks_started, generic_request=generic_request, api_specific_request=self.create_api_specific_request(generic_request), - attempts_left=max_attempts, + attempts_left=self.max_retries, prompt_formatter=self.prompt_formatter, ) @@ -307,6 +368,19 @@ async def process_requests_from_file( while not status_tracker.has_capacity(token_estimate): await asyncio.sleep(0.1) + # Wait for rate limits cool down if needed + seconds_since_rate_limit_error = ( + time.time() - status_tracker.time_of_last_rate_limit_error + ) + if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_ON_RATE_LIMIT: + remaining_seconds_to_pause = ( + SECONDS_TO_PAUSE_ON_RATE_LIMIT - seconds_since_rate_limit_error + ) + await asyncio.sleep(remaining_seconds_to_pause) + logger.warn( + f"Pausing to cool down for {int(remaining_seconds_to_pause)} seconds" + ) + # Consume capacity before making request status_tracker.consume_capacity(token_estimate) @@ -337,10 +411,10 @@ async def process_requests_from_file( token_estimate = self.estimate_total_tokens( retry_request.generic_request.messages ) - attempt_number = 6 - retry_request.attempts_left - logger.info( - f"Processing retry for request {retry_request.task_id} " - f"(attempt #{attempt_number} of 5). " + attempt_number = self.max_retries - retry_request.attempts_left + logger.debug( + f"Retrying request {retry_request.task_id} " + f"(attempt #{attempt_number} of {self.max_retries})" f"Previous errors: {retry_request.result}" ) @@ -405,6 +479,9 @@ async def handle_single_request_with_retries( status_tracker=status_tracker, ) + # Allows us to retry on responses that don't match the response format + self.prompt_formatter.response_to_response_format(generic_response.response_message) + # Save response in the base class await self.append_generic_response(generic_response, save_filepath) @@ -413,23 +490,20 @@ async def handle_single_request_with_retries( status_tracker.pbar.update(1) except Exception as e: - logger.warning( - f"Request {request.task_id} failed with Exception {e}, attempts left {request.attempts_left}" - ) status_tracker.num_other_errors += 1 request.result.append(e) if request.attempts_left > 0: request.attempts_left -= 1 - # Add retry queue logging - logger.info( - f"Adding request {request.task_id} to retry queue. Will retry in next available slot. " - f"Attempts remaining: {request.attempts_left}" + logger.warning( + f"Encountered '{e.__class__.__name__}: {e}' during attempt " + f"{self.max_retries - request.attempts_left} of {self.max_retries} " + f"while processing request {request.task_id}" ) retry_queue.put_nowait(request) else: logger.error( - f"Request {request.task_id} failed permanently after exhausting all 5 retry attempts. " + f"Request {request.task_id} failed permanently after exhausting all {self.max_retries} retry attempts. " f"Errors: {[str(e) for e in request.result]}" ) generic_response = GenericResponse( diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index a19fbf8f..6a5b2a30 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -6,7 +6,8 @@ import resource from abc import ABC, abstractmethod from math import ceil -from typing import Optional +from pathlib import Path +from typing import Optional, List import aiofiles import pyarrow @@ -14,7 +15,8 @@ from datasets.arrow_writer import ArrowWriter from pydantic import BaseModel, ValidationError -from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter +from bespokelabs.curator.file_utilities import count_lines +from bespokelabs.curator.llm.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.event_loop import run_in_event_loop from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.generic_response import GenericResponse @@ -29,8 +31,9 @@ class BaseRequestProcessor(ABC): Base class for all request processors. """ - def __init__(self, batch_size: Optional[int] = None): + def __init__(self, batch_size: Optional[int] = None, require_all_responses: bool = True): self.batch_size = batch_size + self.require_all_responses = require_all_responses # Increase the number of open file descriptors to avoid "Too many open files" errors soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) desired_limit = min(10_000_000, hard) @@ -39,16 +42,6 @@ def __init__(self, batch_size: Optional[int] = None): ) resource.setrlimit(resource.RLIMIT_NOFILE, (desired_limit, hard)) - @abstractmethod - def get_rate_limits(self) -> dict: - """ - Returns the rate limits for the API. - - Returns: - dict: A dictionary containing the rate limit information. - """ - pass - @abstractmethod def create_api_specific_request(self, generic_request: GenericRequest) -> dict: """ @@ -84,6 +77,64 @@ def run( """ pass + def _verify_existing_request_files( + self, working_dir: str, dataset: Optional[Dataset] + ) -> List[int]: + """ + Verify integrity of the cache (each request file has associated metadata, and the number of rows is correct), + and return the indices of request files that need to be regenerated (so that no work is repeated). + + Args: + working_dir (str): Working directory where cache files are expected to be (requests.jsonl, metadata.json) + dataset (Optional[Dataset]): The dataset that we want to create requests from + + Returns: + List[int]: Indices of missing files + """ + + if self.batch_size is not None and dataset is not None: + expected_num_files = ceil(len(dataset) / self.batch_size) + else: + expected_num_files = 1 + + try: + incomplete_files = [] + for i in range(expected_num_files): + req_f = os.path.join(working_dir, f"requests_{i}.jsonl") + meta_f = os.path.join(working_dir, f"metadata_{i}.json") + + if not os.path.exists(req_f): + incomplete_files.append(i) + continue + + if not os.path.exists(meta_f): + logger.warning(f"Cache missing metadata file {meta_f} for request file {req_f}") + incomplete_files.append(i) + continue + + with open(req_f, "r") as f: + data = f.read() + num_jobs = len(data.splitlines()) + + with open(meta_f, "r") as f: + metadata = json.load(f) + + expected_num_jobs = metadata["num_jobs"] + if num_jobs != expected_num_jobs: + logger.warning( + f"Request file {req_f} has {num_jobs} jobs, but metadata file {meta_f} has {expected_num_jobs} jobs" + ) + incomplete_files.append(i) + + return incomplete_files + + except Exception as e: + logger.warning( + f"Cache verification failed due to {e} - regenerating all request files." + ) + incomplete_files = list(range(expected_num_files)) + return incomplete_files + def create_request_files( self, dataset: Optional[Dataset], @@ -104,7 +155,9 @@ def create_request_files( request_files = glob.glob(f"{working_dir}/requests_*.jsonl") # By default use existing requests in working_dir - if len(request_files) > 0: + incomplete_files = self._verify_existing_request_files(working_dir, dataset) + + if len(incomplete_files) == 0: logger.info(f"Using cached requests. {CACHE_MSG}") # count existing jobs in file and print first job with open(request_files[0], "r") as f: @@ -124,18 +177,27 @@ def create_request_files( return request_files # Create new requests file + logger.info(f"Preparing request file(s) in {working_dir}") request_file = f"{working_dir}/requests_0.jsonl" request_files = [request_file] + metadata_file = f"{working_dir}/metadata_0.json" + metadata_files = [metadata_file] + if dataset is None: with open(request_file, "w") as f: generic_request = prompt_formatter.create_generic_request(dict(), 0) f.write(json.dumps(generic_request.model_dump(), default=str) + "\n") + + metadata_dict = {"num_jobs": 1} + with open(metadata_file, "w") as f: + f.write(json.dumps(metadata_dict, indent=4) + "\n") return request_files if self.batch_size: num_batches = ceil(len(dataset) / self.batch_size) request_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)] + metadata_files = [f"{working_dir}/metadata_{i}.json" for i in range(num_batches)] async def create_all_request_files(): tasks = [ @@ -143,15 +205,19 @@ async def create_all_request_files(): dataset, prompt_formatter, request_files[i], + metadata_files[i], start_idx=i * self.batch_size, ) for i in range(num_batches) + if i in incomplete_files ] await asyncio.gather(*tasks) run_in_event_loop(create_all_request_files()) else: - run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file)) + run_in_event_loop( + self.acreate_request_file(dataset, prompt_formatter, request_file, metadata_file) + ) return request_files @@ -161,8 +227,9 @@ async def acreate_request_file( dataset: Dataset, prompt_formatter: PromptFormatter, request_file: str, + metadata_file: str, start_idx: int = 0, - ) -> str: + ) -> None: if self.batch_size is not None: end_idx = min(start_idx + self.batch_size, len(dataset)) dataset = dataset.select(range(start_idx, end_idx)) @@ -176,7 +243,13 @@ async def acreate_request_file( # Get the generic request from the map function request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx) await f.write(json.dumps(request.model_dump(), default=str) + "\n") - logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.") + + num_requests = end_idx - start_idx + metadata_dict = {"num_jobs": num_requests} + async with aiofiles.open(metadata_file, "w") as f: + await f.write(json.dumps(metadata_dict, indent=4) + "\n") + + logger.info(f"Wrote {num_requests} requests to {request_file}.") def attempt_loading_cached_dataset( self, working_dir: str, parse_func_hash: str @@ -216,9 +289,6 @@ def create_dataset_files( Returns: Dataset: Completed dataset """ - total_responses_count = 0 - failed_responses_count = 0 - responses_files = glob.glob(f"{working_dir}/responses_*.jsonl") if len(responses_files) == 0: raise ValueError(f"No responses files found in {working_dir}") @@ -230,6 +300,8 @@ def create_dataset_files( ) # Process all response files + total_responses_count = 0 + failed_responses_count = 0 dataset_file = f"{working_dir}/{parse_func_hash}.arrow" with ArrowWriter(path=dataset_file) as writer: for responses_file in responses_files: @@ -243,41 +315,18 @@ def create_dataset_files( failed_responses_count += 1 continue - if prompt_formatter.response_format: - # Response message is a string, which is converted to a dict - # The dict is then used to construct the response_format Pydantic model - try: - # First try to parse the response message as JSON - if isinstance(response.response_message, str): - try: - response_dict = json.loads(response.response_message) - except json.JSONDecodeError as e: - warning_msg = ( - f"Failed to parse response message as JSON: {response.response_message}. " - f"The model likely returned an invalid JSON format. Will skip this response." - ) - logger.warning(warning_msg) - failed_responses_count += 1 - continue - else: - response_dict = response.response_message - - # Then construct the Pydantic model from the parsed dict - response.response_message = prompt_formatter.response_format( - **response_dict - ) - except ValidationError as e: - schema_str = json.dumps( - prompt_formatter.response_format.model_json_schema(), - indent=2, + try: + response.response_message = ( + self.prompt_formatter.response_to_response_format( + response.response_message ) - warning_msg = ( - f"Pydantic failed to parse response message {response.response_message} with `response_format` {schema_str}. " - f"The model likely returned a JSON that does not match the schema of the `response_format`. Will skip this response." - ) - logger.warning(warning_msg) - failed_responses_count += 1 - continue + ) + except (json.JSONDecodeError, ValidationError) as e: + logger.warning( + "Skipping response due to error parsing response message into response format" + ) + failed_responses_count += 1 + continue # parse_func can return a single row or a list of rows if prompt_formatter.parse_func: @@ -293,7 +342,13 @@ def create_dataset_files( if not isinstance(dataset_rows, list): dataset_rows = [dataset_rows] else: - dataset_rows = [{"response": response.response_message}] + # Convert response to dict before adding to dataset + response_value = response.response_message + if hasattr(response_value, "model_dump"): + response_value = response_value.model_dump() + elif hasattr(response_value, "__dict__"): + response_value = response_value.__dict__ + dataset_rows = [{"response": response_value}] for row in dataset_rows: if isinstance(row, BaseModel): @@ -313,14 +368,35 @@ def create_dataset_files( writer.write(row) - logger.info(f"Read {total_responses_count} responses, {failed_responses_count} failed") + logger.info("Finalizing writer") + writer.finalize() + + logger.info(f"Read {total_responses_count} responses.") if failed_responses_count == total_responses_count: os.remove(dataset_file) raise ValueError("All requests failed") - logger.info("Finalizing writer") + if failed_responses_count > 0: + logger.warning(f"{failed_responses_count} requests failed.") + if self.require_all_responses: + os.remove(dataset_file) + raise ValueError(f"Some requests failed and require_all_responses is True") - writer.finalize() + # number of responses matches number of requests + request_files = glob.glob(f"{working_dir}/requests_*.jsonl") + n_requests = 0 + for request_file in request_files: + n_requests += count_lines(request_file) + + if n_requests != total_responses_count: + logger.warning( + f"{n_requests - total_responses_count} requests do not have responses. n_requests is {n_requests} and n_responses is {total_responses_count}" + ) + if self.require_all_responses: + os.remove(dataset_file) + raise ValueError( + f"Some requests do not have responses and require_all_responses is True." + ) return Dataset.from_file(dataset_file) diff --git a/src/bespokelabs/curator/request_processor/event_loop.py b/src/bespokelabs/curator/request_processor/event_loop.py index 92120e7d..6bc8bda7 100644 --- a/src/bespokelabs/curator/request_processor/event_loop.py +++ b/src/bespokelabs/curator/request_processor/event_loop.py @@ -1,5 +1,4 @@ import asyncio -from time import sleep import nest_asyncio diff --git a/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py b/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py index 4b346fcf..28c888e8 100644 --- a/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py @@ -1,6 +1,5 @@ import logging from typing import Optional -import asyncio import aiohttp import litellm from litellm import get_supported_openai_params @@ -14,7 +13,7 @@ from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.generic_response import TokenUsage, GenericResponse from pydantic import BaseModel -from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter +import time logger = logging.getLogger(__name__) @@ -49,6 +48,10 @@ def __init__( top_p: Optional[float] = None, presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, + max_requests_per_minute: Optional[int] = None, + max_tokens_per_minute: Optional[int] = None, + require_all_responses: Optional[bool] = None, + max_retries: Optional[int] = None, ): super().__init__( model=model, @@ -56,8 +59,15 @@ def __init__( top_p=top_p, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, + max_requests_per_minute=max_requests_per_minute, + max_tokens_per_minute=max_tokens_per_minute, + require_all_responses=require_all_responses, + max_retries=max_retries, ) self.client = instructor.from_litellm(litellm.acompletion) + self.header_based_max_requests_per_minute, self.header_based_max_tokens_per_minute = ( + self.get_header_based_rate_limits() + ) def check_structured_output_support(self): """Verify if the model supports structured output via instructor. @@ -134,20 +144,7 @@ def estimate_total_tokens(self, messages: list) -> int: output_tokens = self.estimate_output_tokens() return input_tokens + output_tokens - def get_rate_limits(self) -> dict: - """Retrieve rate limits from the LLM provider via LiteLLM. - - Makes a test request to get rate limit information from response headers. - - Returns: - dict: Contains 'max_requests_per_minute' and 'max_tokens_per_minute' - - Note: - - Falls back to default values if headers are missing - - Some providers (e.g., Claude) require non-empty messages - """ - logger.info(f"Getting rate limits for model: {self.model}") - + def test_call(self): completion = litellm.completion( model=self.model, messages=[ @@ -155,15 +152,33 @@ def get_rate_limits(self) -> dict: ], # Some models (e.g. Claude) require an non-empty message to get rate limits. ) + # Try the method of caculating cost + try: + litellm.completion_cost(completion_response=completion.model_dump()) + except litellm.NotFoundError as e: + logger.warning(f"LiteLLM does not support cost estimation for model {self.model}: {e}") + headers = completion._hidden_params.get("additional_headers", {}) - logger.info(f"Rate limit headers: {headers}") + logger.info(f"Test call headers: {headers}") + return headers + + def get_header_based_rate_limits(self) -> tuple[int, int]: + """Retrieve rate limits from the LLM provider via LiteLLM. + + Returns: + tuple[int, int]: Contains 'max_requests_per_minute' and 'max_tokens_per_minute' - rpm = int(headers.get("x-ratelimit-limit-requests", 3000)) - tpm = int(headers.get("x-ratelimit-limit-tokens", 150_000)) + Note: + - Makes a test request to get rate limit information from response headers. + - Some providers (e.g., Claude) require non-empty messages + """ + logger.info(f"Getting rate limits for model: {self.model}") - logger.info(f"Rate limits - Requests/min: {rpm}, Tokens/min: {tpm}") + headers = self.test_call() + rpm = int(headers.get("x-ratelimit-limit-requests", 0)) + tpm = int(headers.get("x-ratelimit-limit-tokens", 0)) - return {"max_requests_per_minute": rpm, "max_tokens_per_minute": tpm} + return rpm, tpm def create_api_specific_request(self, generic_request: GenericRequest) -> dict: """Convert a generic request into a LiteLLM-compatible format. @@ -200,6 +215,31 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: if "frequency_penalty" in supported_params and self.frequency_penalty is not None: request["frequency_penalty"] = self.frequency_penalty + # Add safety settings for Gemini models + if "gemini" in generic_request.model.lower(): + request["safety_settings"] = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_CIVIC_INTEGRITY", + "threshold": "BLOCK_NONE", + }, + ] + return request async def call_single_request( @@ -222,18 +262,29 @@ async def call_single_request( GenericResponse: The response from LiteLLM """ # Get response directly without extra logging - if request.generic_request.response_format: - response, completion_obj = await self.client.chat.completions.create_with_completion( - **request.api_specific_request, - response_model=request.prompt_formatter.response_format, - timeout=60.0, - ) - response_message = ( - response.model_dump() if hasattr(response, "model_dump") else response - ) - else: - completion_obj = await litellm.acompletion(**request.api_specific_request, timeout=60.0) - response_message = completion_obj["choices"][0]["message"]["content"] + try: + if request.generic_request.response_format: + response, completion_obj = ( + await self.client.chat.completions.create_with_completion( + **request.api_specific_request, + response_model=request.prompt_formatter.response_format, + timeout=self.timeout, + ) + ) + response_message = ( + response.model_dump() if hasattr(response, "model_dump") else response + ) + else: + completion_obj = await litellm.acompletion( + **request.api_specific_request, timeout=self.timeout + ) + response_message = completion_obj["choices"][0]["message"]["content"] + except litellm.RateLimitError as e: + status_tracker.time_of_last_rate_limit_error = time.time() + status_tracker.num_rate_limit_errors += 1 + # because handle_single_request_with_retries will double count otherwise + status_tracker.num_api_errors -= 1 + raise e # Extract token usage usage = completion_obj.usage if hasattr(completion_obj, "usage") else {} @@ -247,9 +298,21 @@ async def call_single_request( try: cost = litellm.completion_cost(completion_response=completion_obj.model_dump()) except litellm.NotFoundError as e: - logger.info(f"LiteLLM does not support cost estimation for model {self.model}: {e}") cost = 0 + finish_reason = completion_obj.choices[0].finish_reason + invalid_finish_reasons = ["length", "content_filter"] + if finish_reason in invalid_finish_reasons: + logger.debug( + f"Invalid finish_reason {finish_reason}. Raw response {completion_obj.model_dump()} for request {request.generic_request.messages}" + ) + raise ValueError(f"finish_reason was {finish_reason}") + + if response_message is None: + raise ValueError( + f"response_message was None with raw response {completion_obj.model_dump()}" + ) + # Create and return response return GenericResponse( response_message=response_message, diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 6312a2d7..1aaf27e3 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -3,18 +3,18 @@ import glob import json import logging +import os from dataclasses import dataclass, field -from typing import Callable +from typing import Callable, Optional -import glob -import os import litellm -from openai import AsyncOpenAI +from openai import AsyncOpenAI, NotFoundError from openai.types import Batch from tqdm import tqdm +from bespokelabs.curator.file_utilities import count_lines from bespokelabs.curator.dataset import Dataset -from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter +from bespokelabs.curator.llm.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.base_request_processor import ( BaseRequestProcessor, GenericRequest, @@ -48,6 +48,8 @@ def __init__( url: str = "https://api.openai.com/v1/chat/completions", presence_penalty: float | None = None, frequency_penalty: float | None = None, + require_all_responses: bool = None, + max_retries: Optional[int] = None, ): if batch_size > MAX_REQUESTS_PER_BATCH: raise ValueError( @@ -55,7 +57,7 @@ def __init__( f"{MAX_REQUESTS_PER_BATCH:,} requests per batch that OpenAI supports. " f"Please set your batch_size to be less than or equal to {MAX_REQUESTS_PER_BATCH:,}." ) - super().__init__(batch_size) + super().__init__(batch_size, require_all_responses=require_all_responses) self.model = model self.url: str = url self.check_interval: int = batch_check_interval @@ -65,48 +67,10 @@ def __init__( self.frequency_penalty: float | None = frequency_penalty self.delete_successful_batch_files: bool = delete_successful_batch_files self.delete_failed_batch_files: bool = delete_failed_batch_files - - def get_rate_limits(self) -> dict: - """ - Function to get rate limits for a given annotator. Not available via response headers, so - the following is based on tier 5 limits on Nov 6th, 2024. - - These rate limits vary per model - and are determined by your organization's usage tier. View the following: - https://platform.openai.com/docs/guides/rate-limits/usage-tiers - https://platform.openai.com/settings/organization/limits - - Args: - model (str): The model for which to get the rate limits. - request_url (str): The request URL for which to get the rate limits. - - Returns: - tuple[int, int]: A tuple containing the maximum number of requests and tokens per minute. - """ - model_tpd = { - "gpt-3.5-turbo": 5_000_000_000, - "gpt-3.5-turbo-0125": 5_000_000_000, - "gpt-3.5-turbo-1106": 5_000_000_000, - "gpt-3.5-turbo-16k": 5_000_000_000, - "gpt-3.5-turbo-instruct": 200_000, - "gpt-3.5-turbo-instruct-0914": 200_000, - "gpt-4": 150_000_000, - "gpt-4-0613": 150_000_000, - "gpt-4-turbo": 300_000_000, - "gpt-4o": 10_000_000_000, - "gpt-4o-mini": 15_000_000_000, - } - - if self.model not in model_tpd: - tpd = 1_000_000_000 + if max_retries is None: + self.max_retries = MAX_RETRIES_PER_OPERATION else: - tpd = model_tpd[self.model] - - logger.info(f"Automatically set max_tokens_per_day to {tpd}, model: {self.model} ") - - rate_limits = {"max_tokens_per_day": tpd} - - return rate_limits + self.max_retries = max_retries def create_api_specific_request(self, generic_request: GenericRequest) -> dict: """ @@ -221,7 +185,14 @@ def generic_response_file_from_responses( for raw_response in responses.text.splitlines(): raw_response = json.loads(raw_response) request_idx = int(raw_response["custom_id"]) - generic_request = generic_request_map[request_idx] + + if request_idx not in generic_request_map: + logger.warning( + f"Request {request_idx} not found in generic_request_map. response_file: {response_file}, " + f"request_file: {request_file}. The request files might have been incomplete. Will skip " + f"this response." + ) + continue if raw_response["response"]["status_code"] != 200: logger.warning( @@ -229,9 +200,7 @@ def generic_response_file_from_responses( ) generic_response = GenericResponse( response_message=None, - response_errors=[ - f"Request {generic_request} failed with status code {raw_response['response']['status_code']}" - ], + response_errors=[raw_response["response"]["status_code"]], raw_response=raw_response, raw_request=None, generic_request=generic_request, @@ -329,6 +298,7 @@ def run( prompt_formatter, delete_successful_batch_files=self.delete_successful_batch_files, delete_failed_batch_files=self.delete_failed_batch_files, + max_retries=self.max_retries, ) run_in_event_loop(self.run_batch_operations(batch_manager, request_files)) @@ -347,6 +317,7 @@ def cancel_batches(self, working_dir: str) -> Dataset: self.check_interval, delete_successful_batch_files=self.delete_successful_batch_files, delete_failed_batch_files=self.delete_failed_batch_files, + max_retries=self.max_retries, ) run_in_event_loop(batch_manager.cancel_batches()) @@ -504,6 +475,7 @@ def __init__( prompt_formatter: PromptFormatter | None = None, delete_successful_batch_files: bool = False, delete_failed_batch_files: bool = False, + max_retries: Optional[int] = None, ) -> None: """Initialize BatchManager to handle OpenAI batch processing operations. @@ -517,7 +489,7 @@ def __init__( delete_failed_batch_files (bool): Whether to delete input/error files from OpenAI after batch failure. """ - self.client = AsyncOpenAI() + self.client = AsyncOpenAI(max_retries=max_retries) self.check_interval = check_interval self.working_dir = working_dir self.tracker = BatchStatusTracker() @@ -677,7 +649,6 @@ async def retrieve_batch(self, batch_id: str) -> Batch: try: batch_object = await self.client.batches.retrieve(batch_id) except Exception as e: - logger.error(f"Error checking previously submitted batch: {e}") raise e return batch_object @@ -721,12 +692,22 @@ async def submit_batch_from_request_file( async def track_already_submitted_batches(self): """ Tracks previously submitted batches from the submitted batch objects file. + We need to check all submitted batch objects files because we might be looking at a cancelled batch + or a batch from another key but same project. Side Effects: - Updates tracker with previously submitted batch statuses """ - if os.path.exists(self.submitted_batch_objects_file): - with open(self.submitted_batch_objects_file, "r") as f: + all_submitted_batches_files = set( + glob.glob(f"{self.working_dir}/batch_objects_submitted_*.jsonl") + ) + + existing_submitted_batches = {} + for submitted_batch_objects_file in all_submitted_batches_files: + logger.info( + f"Processing submitted batch objects file: {submitted_batch_objects_file} Your API key is ***{self.client.api_key[-4:]}." + ) + with open(submitted_batch_objects_file, "r") as f: for line in f: batch_object = Batch.model_validate(json.loads(line)) request_file_name = batch_object.metadata["request_file_name"] @@ -734,27 +715,79 @@ async def track_already_submitted_batches(self): f"Already submitted batch {batch_object.id} for request file {request_file_name}. " f"Getting batch object to update tracker." ) - batch_object = await self.retrieve_batch(batch_object.id) + try: + batch_object = await self.retrieve_batch(batch_object.id) + except NotFoundError: + logger.warning( + f"Already submitted batch object {batch_object.id} not found. This might be fine since we might be " + "looking at a batch object submitted by another project. Will ignore this batch object..." + ) + continue + + if not self._validate_batch_status(batch_object.status): + logger.warning( + f"Already submitted batch {batch_object.id} has an invalid status {batch_object.status}. " + f"Will ignore this batch object..." + ) + continue + + # We skip the batch if it has a status that means it can no longer be used. + if batch_object.status in ["expired", "cancelling", "cancelled"]: + logger.info( + f"Batch {batch_object.id} has status {batch_object.status}, which means it can " + "no longer be used. Will ignore this batch object..." + ) + continue # Edge case where the batch is still validating, and we need to know the total number of requests if batch_object.status == "validating": - n_requests = len(open(request_file_name, "r").readlines()) - batch_object.request_counts.total = n_requests + batch_object.request_counts.total = count_lines(request_file_name) else: n_requests = batch_object.request_counts.total - if request_file_name in self.tracker.unsubmitted_request_files: - self.tracker.mark_as_submitted(request_file_name, batch_object, n_requests) - else: - # batch objects if not unsubmitted, should be downloaded - assert batch_object.id in self.tracker.downloaded_batches + # For each request file, we only want to keep the latest batch object. + if ( + request_file_name not in existing_submitted_batches + or existing_submitted_batches[request_file_name].created_at + < batch_object.created_at + ): + existing_submitted_batches[request_file_name] = batch_object + + for request_file_name, batch_object in existing_submitted_batches.items(): + + output_file_id = batch_object.output_file_id + if output_file_id is not None: + try: + await self.client.files.retrieve(output_file_id) + except NotFoundError: + logger.warning( + f"Output file {output_file_id} exists in batch object but cannot be found " + "in OpenAI storage. The file may have been deleted. Will resubmit this batch..." + ) + continue + + if request_file_name in self.tracker.unsubmitted_request_files: + self.tracker.mark_as_submitted(request_file_name, batch_object, n_requests) + else: + response_file = request_file_to_response_file(request_file_name, self.working_dir) + if not os.path.exists(response_file): + raise ValueError( + f"While processing {batch_object.id}, we found that its corresponding request_file_name {request_file_name} is " + f"not in tracker.unsubmitted_request_files, but its corresponding response_file {response_file} does not exist. " + f"This is an invalid state. \n" + f"batch_object: {batch_object} \n" + f"request_file_name: {request_file_name} \n" + f"tracker.unsubmitted_request_files: {self.tracker.unsubmitted_request_files} \n" + f"tracker.submitted_batches: {self.tracker.submitted_batches} \n" + f"tracker.downloaded_batches: {self.tracker.downloaded_batches} \n" + ) if self.tracker.n_submitted_batches > 0: logger.info( f"{self.tracker.n_submitted_batches:,} out of {self.tracker.n_total_batches - self.tracker.n_downloaded_batches:,} remaining batches are already submitted." ) - def track_already_downloaded_batches(self): + async def track_already_downloaded_batches(self): """ Tracks previously downloaded batches from the downloaded batch objects files. @@ -765,13 +798,24 @@ def track_already_downloaded_batches(self): glob.glob(f"{self.working_dir}/batch_objects_downloaded_*.jsonl") ) for downloaded_batch_object_file in downloaded_batch_object_files: + logger.info( + f"Processing downloaded batch objects file: {downloaded_batch_object_file} Your API key is ***{self.client.api_key[-4:]}." + ) with open(downloaded_batch_object_file, "r") as f: for line in f: batch_object = Batch.model_validate(json.loads(line)) request_file = batch_object.metadata["request_file_name"] response_file = request_file_to_response_file(request_file, self.working_dir) - assert request_file in self.tracker.unsubmitted_request_files - assert os.path.exists(response_file) + assert ( + request_file in self.tracker.unsubmitted_request_files + ), f"request_file {request_file} not in unsubmitted_request_files: {self.tracker.unsubmitted_request_files}" + if not os.path.exists(response_file): + logger.warning( + f"Downloaded batch object {batch_object.id} has a response_file {response_file} that does not exist. " + "Will resubmit this batch..." + ) + continue + self.tracker.mark_as_submitted( request_file, batch_object, batch_object.request_counts.total ) @@ -800,7 +844,7 @@ async def submit_batches_from_request_files( - Creates and updates batch submission progress bar """ self.tracker.unsubmitted_request_files = request_files - self.track_already_downloaded_batches() + await self.track_already_downloaded_batches() await self.track_already_submitted_batches() # exit early if self.tracker.n_unsubmitted_request_files == 0: @@ -853,9 +897,8 @@ async def check_batch_status(self, batch_id: str) -> Batch | None: ) finished_statuses = ["completed", "failed", "expired", "cancelled"] - in_progress_statuses = ["validating", "finalizing", "cancelling", "in_progress"] batch_returned = batch.status in finished_statuses - if batch.status not in in_progress_statuses + finished_statuses: + if not self._validate_batch_status(batch.status): logger.warning(f"Unknown batch status: {batch.status}") if batch_returned: @@ -902,7 +945,7 @@ async def poll_and_process_batches( batches_to_download = await asyncio.gather(*status_tasks) batches_to_download = filter(None, batches_to_download) - # update progress bar + # update progress bari self.request_pbar.n = self.tracker.n_finished_or_downloaded_requests self.request_pbar.refresh() @@ -937,11 +980,15 @@ async def delete_file(self, file_id: str, semaphore: asyncio.Semaphore): semaphore (asyncio.Semaphore): Semaphore to limit concurrent operations """ async with semaphore: - delete_response = await self.client.files.delete(file_id) - if delete_response.deleted: - logger.debug(f"Deleted file {file_id}") - else: - logger.warning(f"Failed to delete file {file_id}") + try: + delete_response = await self.client.files.delete(file_id) + if delete_response.deleted: + logger.debug(f"Deleted file {file_id}") + else: + logger.warning(f"Failed to delete file {file_id}") + except NotFoundError: + # This is fine, the file may have been deleted already. Deletion should be best-effort. + logger.warning(f"Trying to delete file {file_id} but it was not found.") async def download_batch(self, batch: Batch) -> str | None: file_content = None @@ -1027,3 +1074,18 @@ async def download_batch_to_response_file( self.tracker.mark_as_downloaded(batch) return response_file + + @staticmethod + def _validate_batch_status(status: str) -> bool: + # See https://github.com/openai/openai-python/blob/995cce048f9427bba4f7ac1e5fc60abbf1f8f0b7/src/openai/types/batch.py#L40C1-L41C1 + # for all possible batch statuses + return status in [ + "completed", + "failed", + "expired", + "cancelled", + "validating", + "finalizing", + "cancelling", + "in_progress", + ] diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 14fc27ea..a8416906 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -79,6 +79,10 @@ def __init__( top_p: Optional[float] = None, presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, + max_requests_per_minute: Optional[int] = None, + max_tokens_per_minute: Optional[int] = None, + require_all_responses: bool = None, + max_retries: Optional[int] = None, ): super().__init__( model=model, @@ -86,43 +90,41 @@ def __init__( top_p=top_p, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, + max_requests_per_minute=max_requests_per_minute, + max_tokens_per_minute=max_tokens_per_minute, + require_all_responses=require_all_responses, + max_retries=max_retries, ) self.url = url self.api_key = api_key self.token_encoding = tiktoken.get_encoding(get_token_encoding_name(model)) + self.header_based_max_requests_per_minute, self.header_based_max_tokens_per_minute = ( + self.get_header_based_rate_limits() + ) - def get_rate_limits(self) -> dict: + def get_header_based_rate_limits(self) -> tuple[int, int]: """Get rate limits from OpenAI API headers. Returns: - dict: Contains 'max_requests_per_minute' and 'max_tokens_per_minute' + tuple[int, int]: Contains 'max_requests_per_minute' and 'max_tokens_per_minute' Note: - Makes a dummy request to get actual rate limits - - Falls back to default values if headers are missing - - Supports both OpenAI and Azure endpoints """ + if not self.api_key: + raise ValueError( + "Missing OpenAI API Key - Please set OPENAI_API_KEY in your environment vars" + ) + response = requests.post( self.url, headers={"Authorization": f"Bearer {self.api_key}"}, json={"model": self.model, "messages": []}, ) - rpm = int(response.headers.get("x-ratelimit-limit-requests", 0)) tpm = int(response.headers.get("x-ratelimit-limit-tokens", 0)) - if not rpm or not tpm: - logger.warning("Failed to get rate limits from OpenAI API, using default values") - rpm = 30_000 - tpm = 150_000_000 - - logger.info(f"Automatically set max_requests_per_minute to {rpm}") - logger.info(f"Automatically set max_tokens_per_minute to {tpm}") - - return { - "max_requests_per_minute": rpm, - "max_tokens_per_minute": tpm, - } + return rpm, tpm def estimate_output_tokens(self) -> int: """Estimate number of tokens in the response. @@ -270,7 +272,7 @@ async def call_single_request( self.url, headers=request_header, json=request.api_specific_request, - timeout=60.0, + timeout=self.timeout, ) as response_obj: response = await response_obj.json() @@ -281,6 +283,8 @@ async def call_single_request( status_tracker.time_of_last_rate_limit_error = time.time() status_tracker.num_rate_limit_errors += 1 status_tracker.num_api_errors -= 1 + # because handle_single_request_with_retries will double count otherwise + status_tracker.num_other_errors -= 1 raise Exception(f"API error: {error}") if response_obj.status != 200: diff --git a/tests/batch/simple_batch.py b/tests/batch/simple_batch.py index 68fbd38c..251296ae 100644 --- a/tests/batch/simple_batch.py +++ b/tests/batch/simple_batch.py @@ -1,4 +1,4 @@ -from bespokelabs.curator import Prompter +from bespokelabs.curator import LLM from datasets import Dataset import logging import argparse @@ -13,7 +13,7 @@ def main(args): dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * args.n_requests}) - prompter = Prompter( + prompter = LLM( prompt_func=lambda row: row["prompt"], model_name="gpt-4o-mini", response_format=None, diff --git a/tests/batch/test_resume.py b/tests/batch/test_resume.py index 0248da20..9ac0c906 100644 --- a/tests/batch/test_resume.py +++ b/tests/batch/test_resume.py @@ -10,6 +10,7 @@ """ +@pytest.mark.skip(reason="Temporarily disabled, need to add mocking") @pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-resume")) @pytest.mark.usefixtures("prepare_test_cache") def test_batch_resume(): diff --git a/tests/batch/test_switch_keys.py b/tests/batch/test_switch_keys.py index 80eb3984..f1d9fc8b 100644 --- a/tests/batch/test_switch_keys.py +++ b/tests/batch/test_switch_keys.py @@ -10,6 +10,7 @@ """ +@pytest.mark.skip(reason="Temporarily disabled, need to add mocking") @pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-switch-keys")) @pytest.mark.usefixtures("prepare_test_cache") def test_batch_switch_keys(): @@ -46,4 +47,4 @@ def test_batch_switch_keys(): print(output2) # checks - assert "1 out of 1 batches already downloaded." in output2 + assert "1 out of 2 batches already downloaded." in output2 diff --git a/tests/cache/different_files/one.py b/tests/cache/different_files/one.py new file mode 100644 index 00000000..e5667add --- /dev/null +++ b/tests/cache/different_files/one.py @@ -0,0 +1,18 @@ +from bespokelabs.curator import LLM +from datasets import Dataset +import logging + +logger = logging.getLogger("bespokelabs.curator") +logger.setLevel(logging.INFO) + + +dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) + +prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name="gpt-4o-mini", + response_format=None, +) + +dataset = prompter(dataset) +print(dataset.to_pandas()) diff --git a/tests/cache/different_files/two.py b/tests/cache/different_files/two.py new file mode 100644 index 00000000..e5667add --- /dev/null +++ b/tests/cache/different_files/two.py @@ -0,0 +1,18 @@ +from bespokelabs.curator import LLM +from datasets import Dataset +import logging + +logger = logging.getLogger("bespokelabs.curator") +logger.setLevel(logging.INFO) + + +dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) + +prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name="gpt-4o-mini", + response_format=None, +) + +dataset = prompter(dataset) +print(dataset.to_pandas()) diff --git a/tests/cache/one.py b/tests/cache/one.py index 090b5b44..10ff74d4 100644 --- a/tests/cache/one.py +++ b/tests/cache/one.py @@ -1,4 +1,4 @@ -from bespokelabs.curator import Prompter +from bespokelabs.curator import LLM from datasets import Dataset import logging import argparse @@ -10,7 +10,7 @@ def main(delete_cache: bool = False): dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) - prompter = Prompter( + prompter = LLM( prompt_func=lambda row: row["prompt"], model_name="gpt-4o-mini", response_format=None, diff --git a/tests/cache/test_different_files.py b/tests/cache/test_different_files.py index 6b18de07..31fe866b 100644 --- a/tests/cache/test_different_files.py +++ b/tests/cache/test_different_files.py @@ -16,17 +16,14 @@ def test_cache_behavior(): # Run one.py twice and check for cache behavior print("RUNNING ONE.PY") - output1, _ = run_script(["python", "tests/cache_tests/different_files/one.py"]) - print(output1) + output1, _ = run_script(["python", "tests/cache/different_files/one.py"]) assert cache_hit_log not in output1, "First run of one.py should not hit cache" print("RUNNING ONE.PY AGAIN") - output2, _ = run_script(["python", "tests/cache_tests/different_files/one.py"]) - print(output2) + output2, _ = run_script(["python", "tests/cache/different_files/one.py"]) assert cache_hit_log in output2, "Second run of one.py should hit cache" # Run two.py and check for cache behavior print("RUNNING TWO.PY") - output3, _ = run_script(["python", "tests/cache_tests/different_files/two.py"]) - print(output3) + output3, _ = run_script(["python", "tests/cache/different_files/two.py"]) assert cache_hit_log in output3, "First run of two.py should hit cache" diff --git a/tests/cache/two.py b/tests/cache/two.py index 090b5b44..10ff74d4 100644 --- a/tests/cache/two.py +++ b/tests/cache/two.py @@ -1,4 +1,4 @@ -from bespokelabs.curator import Prompter +from bespokelabs.curator import LLM from datasets import Dataset import logging import argparse @@ -10,7 +10,7 @@ def main(delete_cache: bool = False): dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3}) - prompter = Prompter( + prompter = LLM( prompt_func=lambda row: row["prompt"], model_name="gpt-4o-mini", response_format=None, diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..012b8dc6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "cache_dir(path): mark test to use specific cache directory") diff --git a/tests/litellm/__init__.py b/tests/litellm/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/litellm/test_models.py b/tests/litellm/test_models.py deleted file mode 100644 index 05bb9b7c..00000000 --- a/tests/litellm/test_models.py +++ /dev/null @@ -1,55 +0,0 @@ -import pytest -import os -import logging -from datasets import Dataset -from bespokelabs.curator import Prompter -from tests.helpers import prepare_test_cache - -""" -USAGE: -pytest -s tests/litellm/test_models.py -""" - - -@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-models")) -@pytest.mark.usefixtures("prepare_test_cache") -def test_litellm_models(): - - env = os.environ.copy() - assert "ANTHROPIC_API_KEY" in env, "ANTHROPIC_API_KEY must be set" - assert "OPENAI_API_KEY" in env, "OPENAI_API_KEY must be set" - assert "GEMINI_API_KEY" in env, "GEMINI_API_KEY must be set" - assert "TOGETHER_API_KEY" in env, "TOGETHER_API_KEY must be set" - - models_list = [ - "claude-3-5-sonnet-20240620", # https://docs.litellm.ai/docs/providers/anthropic # anthropic has a different hidden param tokens structure. - "claude-3-5-haiku-20241022", - "claude-3-haiku-20240307", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "gpt-4o-mini", # https://docs.litellm.ai/docs/providers/openai - "gpt-4o-2024-08-06", - "gpt-4-0125-preview", - "gpt-3.5-turbo-1106", - "gemini/gemini-1.5-flash", # https://docs.litellm.ai/docs/providers/gemini; https://ai.google.dev/gemini-api/docs/models # 20-30 iter/s - "gemini/gemini-1.5-pro", # 20-30 iter/s - "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", # https://docs.together.ai/docs/serverless-models - "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - ] - - for model in models_list: - print(f"\n\n========== TESTING {model} ==========\n\n") - logger = logging.getLogger("bespokelabs.curator") - logger.setLevel(logging.DEBUG) - - dataset = Dataset.from_dict({"prompt": ["just say 'hi'"]}) - - prompter = Prompter( - prompt_func=lambda row: row["prompt"], - model_name=model, - response_format=None, - backend="litellm", - ) - - dataset = prompter(dataset) - print(dataset.to_pandas()) diff --git a/tests/simple_online.py b/tests/simple_online.py new file mode 100644 index 00000000..4d5f90df --- /dev/null +++ b/tests/simple_online.py @@ -0,0 +1,54 @@ +from bespokelabs.curator import LLM +from datasets import Dataset +import logging +import argparse + +# python tests/simple_online.py --log-level DEBUG --model claude-3-5-haiku-20241022 + + +def main(args): + if args.log_level is not None: + logger = logging.getLogger("bespokelabs.curator") + logger.setLevel(args.log_level) + + dataset = Dataset.from_dict({"prompt": ["write me a poem"] * args.n_requests}) + + prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name=args.model, + max_requests_per_minute=args.max_requests_per_minute, + max_tokens_per_minute=args.max_tokens_per_minute, + max_retries=args.max_retries, + require_all_responses=not args.partial_responses, + ) + + dataset = prompter(dataset, batch_cancel=args.cancel) + print(dataset.to_pandas()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple batch test bed") + parser.add_argument("--cancel", action="store_true", default=False, help="Cancel the batches") + parser.add_argument("--n-requests", type=int, help="Number of requests to process", default=3) + parser.add_argument( + "--log-level", + type=lambda x: getattr(logging, x.upper()), + default=None, + help="Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)", + ) + parser.add_argument("--model", type=str, help="Model to use", default="gemini/gemini-1.5-flash") + parser.add_argument( + "--max-requests-per-minute", type=int, help="Max requests per minute", default=None + ) + parser.add_argument( + "--max-tokens-per-minute", type=int, help="Max tokens per minute", default=None + ) + parser.add_argument("--max-retries", type=int, help="Max retries", default=None) + parser.add_argument( + "--partial-responses", + action="store_true", + default=False, + help="Require all responses", + ) + args = parser.parse_args() + main(args) diff --git a/tests/test_caching.py b/tests/test_caching.py index 73803465..15c3ebd6 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -1,6 +1,6 @@ from datasets import Dataset -from bespokelabs.curator import Prompter +from bespokelabs import curator def test_same_value_caching(tmp_path): @@ -13,7 +13,7 @@ def test_same_value_caching(tmp_path): def prompt_func(): return f"Say '1'. Do not explain." - prompter = Prompter( + prompter = curator.LLM( prompt_func=prompt_func, model_name="gpt-4o-mini", ) @@ -36,7 +36,7 @@ def test_different_values_caching(tmp_path): def prompt_func(): return f"Say '{x}'. Do not explain." - prompter = Prompter( + prompter = curator.LLM( prompt_func=prompt_func, model_name="gpt-4o-mini", ) @@ -52,7 +52,7 @@ def prompt_func(): def test_same_dataset_caching(tmp_path): """Test that using the same dataset multiple times uses cache.""" dataset = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}]) - prompter = Prompter( + prompter = curator.LLM( prompt_func=lambda x: x["instruction"], model_name="gpt-4o-mini", ) @@ -72,7 +72,7 @@ def test_different_dataset_caching(tmp_path): """Test that using different datasets creates different cache entries.""" dataset1 = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}]) dataset2 = Dataset.from_list([{"instruction": "Say '2'. Do not explain."}]) - prompter = Prompter( + prompter = curator.LLM( prompt_func=lambda x: x["instruction"], model_name="gpt-4o-mini", ) @@ -97,7 +97,7 @@ def value_generator(): def prompt_func(): return f"Say '{value_generator()}'. Do not explain." - prompter = Prompter( + prompter = curator.LLM( prompt_func=prompt_func, model_name="gpt-4o-mini", ) @@ -113,3 +113,95 @@ def value_generator(): # Count cache directories, excluding metadata.db cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"] assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}" + + +def test_function_hash_dir_change(): + """Test that identical functions in different directories but same base filename produce the same hash.""" + import logging + import os + import sys + import tempfile + from pathlib import Path + + from bespokelabs.curator.llm.llm import _get_function_hash + + # Set up logging to write to a file in the current directory + debug_log = Path("function_debug.log") + logging.basicConfig( + level=logging.DEBUG, format="%(message)s", filename=str(debug_log), filemode="w" + ) + logger = logging.getLogger(__name__) + + def dump_function_details(func, prefix): + """Helper to dump all function details.""" + print(f"\n{prefix} details:") # Print to stdout as well + logger.debug(f"\n{prefix} details:") + # Basic attributes + details = { + "__name__": func.__name__, + "__module__": func.__module__, + "__qualname__": func.__qualname__, + "__code__.co_filename": func.__code__.co_filename, + "__code__.co_name": func.__code__.co_name, + "__code__.co_firstlineno": func.__code__.co_firstlineno, + "__code__.co_consts": func.__code__.co_consts, + "__code__.co_names": func.__code__.co_names, + "__code__.co_varnames": func.__code__.co_varnames, + "__code__.co_code": func.__code__.co_code.hex(), + "__code__.co_flags": func.__code__.co_flags, + "__code__.co_stacksize": func.__code__.co_stacksize, + "__code__.co_freevars": func.__code__.co_freevars, + "__code__.co_cellvars": func.__code__.co_cellvars, + "__globals__ keys": sorted(func.__globals__.keys()), + "__closure__": func.__closure__, + "__defaults__": func.__defaults__, + "__kwdefaults__": func.__kwdefaults__, + } + + for key, value in details.items(): + msg = f" {key}: {value}" + print(msg) # Print to stdout + logger.debug(msg) # Log to file + + def create_function(name, tmp_path): + # Create a temporary file with a function definition + path = tmp_path / f"{name}.py" + with open(path, "w") as f: + f.write( + """ +def test_func(): + x = 42 # Add a constant + y = "Hello" # Add a string constant + z = [1, 2, 3] # Add a list constant + return f"{y}, {x}! {z}" # Use all constants +""" + ) + + # Import the function from the file + import importlib.util + + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.test_func + + # Create two identical functions in different files + with tempfile.TemporaryDirectory() as tmp_dir: + func1 = create_function("module1", Path(tmp_dir)) + func2 = create_function("module1", Path(tmp_dir)) + + # Dump detailed information about both functions + dump_function_details(func1, "Function 1") + dump_function_details(func2, "Function 2") + + # Both should produce the same hash + hash1 = _get_function_hash(func1) + hash2 = _get_function_hash(func2) + print(f"\nHash comparison:") # Print to stdout + print(f" hash1: {hash1}") + print(f" hash2: {hash2}") + logger.debug(f"\nHash comparison:") + logger.debug(f" hash1: {hash1}") + logger.debug(f" hash2: {hash2}") + + assert hash1 == hash2, "Identical functions should produce the same hash" diff --git a/tests/test_litellm_models.py b/tests/test_litellm_models.py new file mode 100644 index 00000000..972848c9 --- /dev/null +++ b/tests/test_litellm_models.py @@ -0,0 +1,64 @@ +import pytest +import os +import logging +from datasets import Dataset +from bespokelabs.curator import LLM +from tests.helpers import prepare_test_cache + +""" +USAGE: +pytest -s tests/test_litellm_models.py +""" + + +@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-models")) +@pytest.mark.usefixtures("prepare_test_cache") +class TestLiteLLMModels: + @pytest.fixture(autouse=True) + def check_environment(self): + env = os.environ.copy() + required_keys = [ + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "GEMINI_API_KEY", + "TOGETHER_API_KEY", + ] + for key in required_keys: + assert key in env, f"{key} must be set" + + @pytest.mark.parametrize( + "model", + [ + pytest.param("claude-3-5-sonnet-20240620", id="claude-3-5-sonnet"), + pytest.param("claude-3-5-haiku-20241022", id="claude-3-5-haiku"), + pytest.param("claude-3-haiku-20240307", id="claude-3-haiku"), + pytest.param("claude-3-opus-20240229", id="claude-3-opus"), + pytest.param("claude-3-sonnet-20240229", id="claude-3-sonnet"), + pytest.param("gpt-4o-mini", id="gpt-4-mini"), + pytest.param("gpt-4o-2024-08-06", id="gpt-4"), + pytest.param("gpt-4-0125-preview", id="gpt-4-preview"), + pytest.param("gpt-3.5-turbo-1106", id="gpt-3.5"), + pytest.param("gemini/gemini-1.5-flash", id="gemini-flash"), + pytest.param("gemini/gemini-1.5-pro", id="gemini-pro"), + pytest.param("together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", id="llama-8b"), + pytest.param( + "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", id="llama-70b" + ), + ], + ) + def test_model(self, model): + print(f"\n\n========== TESTING {model} ==========\n\n") + logger = logging.getLogger("bespokelabs.curator") + logger.setLevel(logging.DEBUG) + + dataset = Dataset.from_dict({"prompt": ["just say 'hi'"]}) + + prompter = LLM( + prompt_func=lambda row: row["prompt"], + model_name=model, + response_format=None, + backend="litellm", + ) + + dataset = prompter(dataset) + print(dataset.to_pandas()) diff --git a/tests/test_prompt.py b/tests/test_prompt.py index f1c327cd..84c2d640 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1,11 +1,12 @@ import os from typing import Optional +from unittest.mock import patch, MagicMock import pytest from datasets import Dataset from pydantic import BaseModel -from bespokelabs.curator import Prompter +from bespokelabs.curator import LLM class MockResponseFormat(BaseModel): @@ -16,7 +17,7 @@ class MockResponseFormat(BaseModel): @pytest.fixture -def prompter() -> Prompter: +def prompter() -> LLM: """Create a Prompter instance for testing. Returns: @@ -24,12 +25,18 @@ def prompter() -> Prompter: """ def prompt_func(row): - return { - "user_prompt": f"Context: {row['context']} Answer this question: {row['question']}", - "system_prompt": "You are a helpful assistant.", - } + return [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": f"Context: {row['context']} Answer this question: {row['question']}", + }, + ] - return Prompter( + return LLM( model_name="gpt-4o-mini", prompt_func=prompt_func, response_format=MockResponseFormat, @@ -37,7 +44,7 @@ def prompt_func(row): @pytest.mark.test -def test_completions(prompter: Prompter, tmp_path): +def test_completions(prompter: LLM, tmp_path): """Test that completions processes a dataset correctly. Args: @@ -54,17 +61,31 @@ def test_completions(prompter: Prompter, tmp_path): # Set up temporary cache directory os.environ["BELLA_CACHE_DIR"] = str(tmp_path) - result_dataset = prompter(dataset) - result_dataset = result_dataset.to_huggingface() + # Mock OpenAI API response + mock_response = { + "choices": [{"message": {"content": "1 + 1 equals 2."}, "finish_reason": "stop"}] + } + + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_response): + # Process dataset and get responses + result_dataset = prompter(dataset) - # Assertions - assert len(result_dataset) == len(dataset) - assert "message" in result_dataset.column_names - assert "confidence" in result_dataset.column_names + # Verify the dataset structure + assert len(result_dataset) == len(dataset) + assert "response" in result_dataset.column_names + # Check that each response has the required fields + for row in result_dataset: + response = row["response"] + if isinstance(response, dict): + assert "message" in response + assert "confidence" in response + else: + assert hasattr(response, "message") + assert hasattr(response, "confidence") @pytest.mark.test -def test_single_completion_batch(prompter: Prompter): +def test_single_completion_batch(prompter: LLM): """Test that a single completion works with batch=True. Args: @@ -84,24 +105,36 @@ def simple_prompt_func(): }, ] - batch_prompter = Prompter( + batch_prompter = LLM( model_name="gpt-4o-mini", prompt_func=simple_prompt_func, response_format=MockResponseFormat, batch=True, ) - # Get single completion - result = batch_prompter() + # Mock response data + mock_dataset = Dataset.from_list( + [{"response": {"message": "This is a test message.", "confidence": 0.9}}] + ) + + # Mock the run method of OpenAIBatchRequestProcessor + with patch( + "bespokelabs.curator.request_processor.openai_batch_request_processor.OpenAIBatchRequestProcessor.run", + return_value=mock_dataset, + ): + # Get single completion + result = batch_prompter() - # Assertions - assert isinstance(result, MockResponseFormat) - assert hasattr(result, "message") - assert hasattr(result, "confidence") + # Assertions + assert isinstance(result, Dataset) + assert len(result) == 1 + assert isinstance(result[0]["response"], dict) + assert result[0]["response"]["message"] == "This is a test message." + assert result[0]["response"]["confidence"] == 0.9 @pytest.mark.test -def test_single_completion_no_batch(prompter: Prompter): +def test_single_completion_no_batch(prompter: LLM): """Test that a single completion works without batch parameter. Args: @@ -121,16 +154,28 @@ def simple_prompt_func(): }, ] - non_batch_prompter = Prompter( + non_batch_prompter = LLM( model_name="gpt-4o-mini", prompt_func=simple_prompt_func, response_format=MockResponseFormat, ) - # Get single completion - result = non_batch_prompter() + # Mock response data + mock_dataset = Dataset.from_list( + [{"response": {"message": "This is a test message.", "confidence": 0.9}}] + ) - # Assertions - assert isinstance(result, MockResponseFormat) - assert hasattr(result, "message") - assert hasattr(result, "confidence") + # Mock the run method of OpenAIOnlineRequestProcessor + with patch( + "bespokelabs.curator.request_processor.openai_online_request_processor.OpenAIOnlineRequestProcessor.run", + return_value=mock_dataset, + ): + # Get single completion + result = non_batch_prompter() + + # Assertions + assert isinstance(result, Dataset) + assert len(result) == 1 + assert isinstance(result[0]["response"], dict) + assert result[0]["response"]["message"] == "This is a test message." + assert result[0]["response"]["confidence"] == 0.9