diff --git a/.gitignore b/.gitignore
index 892e0790..e95f9df8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
.venv
+.DS_Store
__pycache__
.vscode
diff --git a/README.md b/README.md
index e129be08..81224738 100644
--- a/README.md
+++ b/README.md
@@ -24,9 +24,12 @@
+
+
+
-### Overview
+## Overview
Bespoke Curator makes it very easy to create high-quality synthetic data at scale, which you can use to finetune models or use for structured data extraction at scale.
@@ -35,7 +38,7 @@ Bespoke Curator is an open-source project:
* A Curator Viewer which makes it easy to view the datasets, thus aiding in the dataset creation.
* We will also be releasing high-quality datasets that should move the needle on post-training.
-### Key Features
+## Key Features
1. **Programmability and Structured Outputs**: Synthetic data generation is lot more than just using a single prompt -- it involves calling LLMs multiple times and orchestrating control-flow. Curator treats structured outputs as first class citizens and helps you design complex pipelines.
2. **Built-in Performance Optimization**: We often see calling LLMs in loops, or inefficient implementation of multi-threading. We have baked in performance optimizations so that you don't need to worry about those!
@@ -43,48 +46,91 @@ Bespoke Curator is an open-source project:
4. **Native HuggingFace Dataset Integration**: Work directly on HuggingFace Dataset objects throughout your pipeline. Your synthetic data is immediately ready for fine-tuning!
5. **Interactive Curator Viewer**: Improve and iterate on your prompts using our built-in viewer. Inspect LLM requests and responses in real-time, allowing you to iterate and refine your data generation strategy with immediate feedback.
-### Installation
+## Installation
```bash
pip install bespokelabs-curator
```
-### Usage
+## Usage
+To run the examples below, make sure to set your OpenAI API key in
+the environment variable `OPENAI_API_KEY` by running `export OPENAI_API_KEY=sk-...` in your terminal.
+
+### Hello World with `SimpleLLM`: A simple interface for calling LLMs
+
+```python
+from bespokelabs import curator
+llm = curator.SimpleLLM(model_name="gpt-4o-mini")
+poem = llm("Write a poem about the importance of data in AI.")
+print(poem)
+# Or you can pass a list of prompts to generate multiple responses.
+poems = llm(["Write a poem about the importance of data in AI.",
+ "Write a haiku about the importance of data in AI."])
+print(poems)
+```
+Note that retries and caching are enabled by default.
+So now if you run the same prompt again, you will get the same response, pretty much instantly.
+You can delete the cache at `~/.cache/curator`.
+
+#### Use LiteLLM backend for calling other models
+You can use the [LiteLLM](https://docs.litellm.ai/docs/providers) backend for calling other models.
+
+```python
+from bespokelabs import curator
+llm = curator.SimpleLLM(model_name="claude-3-5-sonnet-20240620", backend="litellm")
+poem = llm("Write a poem about the importance of data in AI.")
+print(poem)
+```
+
+### Visualize in Curator Viewer
+Run `curator-viewer` on the command line to see the dataset in the viewer.
+You can click on a run and then click on a specific row to see the LLM request and response.
+![Curator Responses](docs/curator-responses.png)
+More examples below.
+
+### `LLM`: A more powerful interface for synthetic data generation
+
+Let's use structured outputs to generate poems.
```python
from bespokelabs import curator
from datasets import Dataset
from pydantic import BaseModel, Field
from typing import List
-# Create a dataset object for the topics you want to create the poems.
topics = Dataset.from_dict({"topic": [
"Urban loneliness in a bustling city",
"Beauty of Bespoke Labs's Curator library"
]})
+```
-# Define a class to encapsulate a list of poems.
+Define a class to encapsulate a list of poems.
+```python
class Poem(BaseModel):
poem: str = Field(description="A poem.")
class Poems(BaseModel):
poems_list: List[Poem] = Field(description="A list of poems.")
+```
-
-# We define a Prompter that generates poems which gets applied to the topics dataset.
-poet = curator.Prompter(
- # `prompt_func` takes a row of the dataset as input.
- # `row` is a dictionary with a single key 'topic' in this case.
+We define an `LLM` object that generates poems which gets applied to the topics dataset.
+```python
+poet = curator.LLM(
prompt_func=lambda row: f"Write two poems about {row['topic']}.",
model_name="gpt-4o-mini",
response_format=Poems,
- # `row` is the input row, and `poems` is the `Poems` class which
- # is parsed from the structured output from the LLM.
parse_func=lambda row, poems: [
{"topic": row["topic"], "poem": p.poem} for p in poems.poems_list
],
)
+```
+Here:
+* `prompt_func` takes a row of the dataset as input and returns the prompt for the LLM.
+* `response_format` is the structured output class we defined above.
+* `parse_func` takes the input (`row`) and the structured output (`poems`) and converts it to a list of dictionaries. This is so that we can easily convert the output to a HuggingFace Dataset object.
+Now we can apply the `LLM` object to the dataset, which reads very pythonic.
+```python
poem = poet(topics)
print(poem.to_pandas())
# Example output:
@@ -94,14 +140,11 @@ print(poem.to_pandas())
# 2 Beauty of Bespoke Labs's Curator library In whispers of design and crafted grace,\nBesp...
# 3 Beauty of Bespoke Labs's Curator library In the hushed breath of parchment and ink,\nBe...
```
-Note that `topics` can be created with `curator.Prompter` as well,
+Note that `topics` can be created with `curator.LLM` as well,
and we can scale this up to create tens of thousands of diverse poems.
You can see a more detailed example in the [examples/poem.py](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples/poem.py) file,
and other examples in the [examples](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples) directory.
-To run the examples, make sure to set your OpenAI API key in
-the environment variable `OPENAI_API_KEY` by running `export OPENAI_API_KEY=sk-...` in your terminal.
-
See the [docs](https://docs.bespokelabs.ai/) for more details as well as
for troubleshooting information.
@@ -115,6 +158,12 @@ curator-viewer
This will pop up a browser window with the viewer running on `127.0.0.1:3000` by default if you haven't specified a different host and port.
+The dataset viewer shows all the different runs you have made.
+![Curator Runs](docs/curator-runs.png)
+
+You can also see the dataset and the responses from the LLM.
+![Curator Dataset](docs/curator-dataset.png)
+
Optional parameters to run the viewer on a different host and port:
```bash
@@ -152,4 +201,4 @@ npm -v # should print `10.9.0`
```
## Contributing
-Contributions are welcome!
\ No newline at end of file
+Contributions are welcome!
diff --git a/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx b/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx
index f4a74ca6..43f13ff9 100644
--- a/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx
+++ b/bespoke-dataset-viewer/app/dataset/[runHash]/page.tsx
@@ -10,11 +10,6 @@ export default async function DatasetPage({
const { runHash } = await params
const { batchMode } = await searchParams
const isBatchMode = batchMode === '1'
- return (
-
-
-
-
-
- )
+
+ return
}
diff --git a/bespoke-dataset-viewer/app/layout.tsx b/bespoke-dataset-viewer/app/layout.tsx
index 03590f7a..a0b8e263 100644
--- a/bespoke-dataset-viewer/app/layout.tsx
+++ b/bespoke-dataset-viewer/app/layout.tsx
@@ -1,6 +1,6 @@
import type { Metadata } from "next";
import "./globals.css";
-
+import { Toaster } from "@/components/ui/toaster"
export const metadata: Metadata = {
title: "Curator Viewer",
@@ -13,10 +13,11 @@ export default function RootLayout({
children: React.ReactNode
}) {
return (
-
-
+
+
{children}
+
)
-}
+}
\ No newline at end of file
diff --git a/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx b/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx
index c3f673cf..ea36a394 100644
--- a/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx
+++ b/bespoke-dataset-viewer/components/dataset-viewer/DetailsSidebar.tsx
@@ -7,6 +7,7 @@ import { Copy } from "lucide-react"
import { DataItem } from "@/types/dataset"
import { useCallback } from "react"
import { Sheet, SheetContent } from "@/components/ui/sheet"
+import { useToast } from "@/components/ui/use-toast"
interface DetailsSidebarProps {
item: DataItem | null
@@ -14,15 +15,26 @@ interface DetailsSidebarProps {
}
export function DetailsSidebar({ item, onClose }: DetailsSidebarProps) {
+ const { toast } = useToast()
+
const copyToClipboard = useCallback(async (text: string) => {
try {
await navigator.clipboard.writeText(text)
- alert("Copied to clipboard!")
+ toast({
+ title: "Success",
+ description: "Copied to clipboard!",
+ duration: 2000,
+ })
} catch (err) {
console.error("Failed to copy:", err)
- alert("Failed to copy to clipboard")
+ toast({
+ variant: "destructive",
+ title: "Error",
+ description: "Failed to copy to clipboard",
+ duration: 2000,
+ })
}
- }, [])
+ }, [toast])
if (!item) return null
diff --git a/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx b/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx
index d85a42be..299fef0c 100644
--- a/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx
+++ b/bespoke-dataset-viewer/components/dataset-viewer/RunsTable.tsx
@@ -39,8 +39,8 @@ class Poems(BaseModel):
poems_list: List[Poem] = Field(description="A list of poems.")
-# We define a Prompter that generates poems which gets applied to the topics dataset.
-poet = curator.Prompter(
+# We define an LLM object that generates poems which gets applied to the topics dataset.
+poet = curator.LLM(
# prompt_func takes a row of the dataset as input.
# row is a dictionary with a single key 'topic' in this case.
prompt_func=lambda row: f"Write two poems about {row['topic']}.",
diff --git a/bespoke-dataset-viewer/components/ui/use-toast.ts b/bespoke-dataset-viewer/components/ui/use-toast.ts
index c2fbf3f9..05d1fefe 100644
--- a/bespoke-dataset-viewer/components/ui/use-toast.ts
+++ b/bespoke-dataset-viewer/components/ui/use-toast.ts
@@ -6,7 +6,7 @@ import type {
} from "@/components/ui/toast"
const TOAST_LIMIT = 1
-const TOAST_REMOVE_DELAY = 1000000
+const TOAST_REMOVE_DELAY = 3000
type ToasterToast = ToastProps & {
id: string
diff --git a/bespoke-dataset-viewer/package-lock.json b/bespoke-dataset-viewer/package-lock.json
index 97c34801..f9e02c46 100644
--- a/bespoke-dataset-viewer/package-lock.json
+++ b/bespoke-dataset-viewer/package-lock.json
@@ -36,6 +36,7 @@
},
"devDependencies": {
"@types/node": "^20",
+ "@types/prismjs": "^1.26.5",
"@types/react": "^18",
"@types/react-dom": "^18",
"eslint": "^8",
@@ -1860,6 +1861,13 @@
"undici-types": "~6.19.2"
}
},
+ "node_modules/@types/prismjs": {
+ "version": "1.26.5",
+ "resolved": "https://registry.npmjs.org/@types/prismjs/-/prismjs-1.26.5.tgz",
+ "integrity": "sha512-AUZTa7hQ2KY5L7AmtSiqxlhWxb4ina0yd8hNbl4TWuqnv/pFP0nDMb3YrfSBf4hJVGLh2YEIBfKaBW/9UEl6IQ==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/@types/prop-types": {
"version": "15.7.13",
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.13.tgz",
diff --git a/bespoke-dataset-viewer/package.json b/bespoke-dataset-viewer/package.json
index 62643150..dee9be50 100644
--- a/bespoke-dataset-viewer/package.json
+++ b/bespoke-dataset-viewer/package.json
@@ -37,6 +37,7 @@
},
"devDependencies": {
"@types/node": "^20",
+ "@types/prismjs": "^1.26.5",
"@types/react": "^18",
"@types/react-dom": "^18",
"eslint": "^8",
diff --git a/build_pkg.py b/build_pkg.py
index 80de2549..b9a6e57e 100644
--- a/build_pkg.py
+++ b/build_pkg.py
@@ -81,7 +81,7 @@ def nextjs_build():
def run_pytest():
print("Running pytest")
try:
- run_command("pytest", cwd="tests")
+ run_command("pytest")
except subprocess.CalledProcessError:
print("Pytest failed. Aborting build.")
sys.exit(1)
diff --git a/docs/curator-dataset.png b/docs/curator-dataset.png
new file mode 100644
index 00000000..33138ac3
Binary files /dev/null and b/docs/curator-dataset.png differ
diff --git a/docs/curator-responses.png b/docs/curator-responses.png
new file mode 100644
index 00000000..a78277e0
Binary files /dev/null and b/docs/curator-responses.png differ
diff --git a/docs/curator-runs.png b/docs/curator-runs.png
new file mode 100644
index 00000000..d076d9b1
Binary files /dev/null and b/docs/curator-runs.png differ
diff --git a/examples/camel.py b/examples/camel.py
index bffa0507..b9bdfee1 100644
--- a/examples/camel.py
+++ b/examples/camel.py
@@ -22,14 +22,14 @@ class QAs(BaseModel):
qas: List[QA] = Field(description="A list of QAs")
-subject_prompter = curator.Prompter(
+subject_prompter = curator.LLM(
prompt_func=lambda: f"Generate a diverse list of 3 subjects. Keep it high-level (e.g. Math, Science).",
parse_func=lambda _, subjects: [subject for subject in subjects.subjects],
model_name="gpt-4o-mini",
response_format=Subjects,
)
subject_dataset = subject_prompter()
-subsubject_prompter = curator.Prompter(
+subsubject_prompter = curator.LLM(
prompt_func=lambda subject: f"For the given subject {subject}. Generate 3 diverse subsubjects. No explanation.",
parse_func=lambda subject, subsubjects: [
{"subject": subject["subject"], "subsubject": subsubject.subject}
@@ -40,7 +40,7 @@ class QAs(BaseModel):
)
subsubject_dataset = subsubject_prompter(subject_dataset)
-qa_prompter = curator.Prompter(
+qa_prompter = curator.LLM(
prompt_func=lambda subsubject: f"For the given subsubject {subsubject}. Generate 3 diverse questions and answers. No explanation.",
model_name="gpt-4o-mini",
response_format=QAs,
diff --git a/examples/distill.py b/examples/distill.py
index 20cfd53b..b9e7c7bb 100644
--- a/examples/distill.py
+++ b/examples/distill.py
@@ -21,7 +21,7 @@ def parse_func(row, response):
return {"instruction": instruction, "new_response": response}
-distill_prompter = curator.Prompter(
+distill_prompter = curator.LLM(
prompt_func=prompt_func,
parse_func=parse_func,
model_name="gpt-4o-mini",
diff --git a/examples/litellm_recipe_prompting.py b/examples/litellm_recipe_prompting.py
index 87446e01..85449389 100644
--- a/examples/litellm_recipe_prompting.py
+++ b/examples/litellm_recipe_prompting.py
@@ -31,7 +31,7 @@ def main():
# 3. Set environment variable: GEMINI_API_KEY
#############################################
- recipe_prompter = curator.Prompter(
+ recipe_prompter = curator.LLM(
model_name="gemini/gemini-1.5-flash",
prompt_func=lambda row: f"Generate a random {row['cuisine']} recipe. Be creative but keep it realistic.",
parse_func=lambda row, response: {
diff --git a/examples/litellm_recipe_structured_output.py b/examples/litellm_recipe_structured_output.py
index 747411e9..bb9ad12b 100644
--- a/examples/litellm_recipe_structured_output.py
+++ b/examples/litellm_recipe_structured_output.py
@@ -28,7 +28,7 @@ def main():
# 2. Generate an API key or use an existing API key
# 3. Set environment variable: ANTHROPIC_API_KEY
#############################################
- cuisines_generator = curator.Prompter(
+ cuisines_generator = curator.LLM(
prompt_func=lambda: f"Generate 10 diverse cuisines.",
model_name="claude-3-5-haiku-20241022",
response_format=Cuisines,
@@ -44,7 +44,7 @@ def main():
# 2. Generate an API key or use an existing API key
# 3. Set environment variable: GEMINI_API_KEY
#############################################
- recipe_prompter = curator.Prompter(
+ recipe_prompter = curator.LLM(
model_name="gemini/gemini-1.5-flash",
prompt_func=lambda row: f"Generate a random {row['cuisine']} recipe. Be creative but keep it realistic.",
parse_func=lambda row, response: {
diff --git a/examples/persona-hub/synthesize.py b/examples/persona-hub/synthesize.py
index 232d6e1f..a2ef68dd 100644
--- a/examples/persona-hub/synthesize.py
+++ b/examples/persona-hub/synthesize.py
@@ -31,7 +31,7 @@ def get_generator(template):
def prompt_func(row):
return template.format(persona=row["persona"])
- generator = curator.Prompter(
+ generator = curator.LLM(
prompt_func=prompt_func,
model_name="gpt-4o",
temperature=0.7,
diff --git a/examples/poem.py b/examples/poem.py
index e8e50d07..1b59e562 100644
--- a/examples/poem.py
+++ b/examples/poem.py
@@ -17,7 +17,7 @@ class Topics(BaseModel):
# We define a prompter that generates topics.
-topic_generator = curator.Prompter(
+topic_generator = curator.LLM(
prompt_func=lambda: "Generate 10 diverse topics that are suitable for writing poems about.",
model_name="gpt-4o-mini",
response_format=Topics,
@@ -35,8 +35,8 @@ class Poems(BaseModel):
poems_list: List[str] = Field(description="A list of poems.")
-# We define a prompter that generates poems which gets applied to the topics dataset.
-poet = curator.Prompter(
+# We define an `LLM` object that generates poems which gets applied to the topics dataset.
+poet = curator.LLM(
# The prompt_func takes a row of the dataset as input.
# The row is a dictionary with a single key 'topic' in this case.
prompt_func=lambda row: f"Write two poems about {row['topic']}.",
diff --git a/examples/simple_poem.py b/examples/simple_poem.py
new file mode 100644
index 00000000..8b1f5106
--- /dev/null
+++ b/examples/simple_poem.py
@@ -0,0 +1,25 @@
+"""Curator example that uses `SimpleLLM` to generate poems.
+
+Please see the poem.py for more complex use cases.
+"""
+
+from bespokelabs import curator
+
+# Use GPT-4o-mini for this example.
+llm = curator.SimpleLLM(model_name="gpt-4o-mini")
+poem = llm("Write a poem about the importance of data in AI.")
+print(poem)
+
+# Use Claude 3.5 Sonnet for this example.
+llm = curator.SimpleLLM(model_name="claude-3-5-sonnet-20240620", backend="litellm")
+poem = llm("Write a poem about the importance of data in AI.")
+print(poem)
+
+# Note that we can also pass a list of prompts to generate multiple responses.
+poems = llm(
+ [
+ "Write a sonnet about the importance of data in AI.",
+ "Write a haiku about the importance of data in AI.",
+ ]
+)
+print(poems)
diff --git a/pyproject.toml b/pyproject.toml
index 54a6b10d..c594ccd3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bespokelabs-curator"
-version = "0.1.11"
+version = "0.1.12"
description = "Bespoke Labs Curator"
authors = ["Bespoke Labs "]
readme = "README.md"
diff --git a/src/bespokelabs/curator/__init__.py b/src/bespokelabs/curator/__init__.py
index bb0b7aa2..5ef73092 100644
--- a/src/bespokelabs/curator/__init__.py
+++ b/src/bespokelabs/curator/__init__.py
@@ -1,2 +1,3 @@
from .dataset import Dataset
-from .prompter.prompter import Prompter
+from .llm.llm import LLM
+from .llm.simple_llm import SimpleLLM
diff --git a/src/bespokelabs/curator/dataset.py b/src/bespokelabs/curator/dataset.py
index b0abece0..180be4c6 100644
--- a/src/bespokelabs/curator/dataset.py
+++ b/src/bespokelabs/curator/dataset.py
@@ -1,15 +1,13 @@
import glob
-import json
import logging
import os
from typing import Any, Dict, Iterable, Iterator, List, TypeVar
-import pandas as pd
from datasets import Dataset as HFDataset
from datasets.arrow_writer import ArrowWriter, SchemaInferenceError
from pydantic import BaseModel
-from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
+from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.generic_response import GenericResponse
T = TypeVar("T")
diff --git a/src/bespokelabs/curator/file_utilities.py b/src/bespokelabs/curator/file_utilities.py
new file mode 100644
index 00000000..6ee606e7
--- /dev/null
+++ b/src/bespokelabs/curator/file_utilities.py
@@ -0,0 +1,14 @@
+# https://stackoverflow.com/questions/845058/how-to-get-the-line-count-of-a-large-file-cheaply-in-python
+# https://stackoverflow.com/a/68385697
+def _file_gen(reader):
+ b = reader(1024 * 1024)
+ while b:
+ yield b
+ b = reader(1024 * 1024)
+
+
+# Instead of requiring counting lines, we can store metadata file that has the number of requests in each file
+def count_lines(filename):
+ f = open(filename, "rb")
+ f_gen = _file_gen(f.raw.read)
+ return sum(buf.count(b"\n") for buf in f_gen)
diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/llm/llm.py
similarity index 78%
rename from src/bespokelabs/curator/prompter/prompter.py
rename to src/bespokelabs/curator/llm/llm.py
index 61e9e99e..9a53beeb 100644
--- a/src/bespokelabs/curator/prompter/prompter.py
+++ b/src/bespokelabs/curator/llm/llm.py
@@ -7,82 +7,47 @@
from io import BytesIO
from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar, Union
-import dill
from datasets import Dataset
+from datasets.utils._dill import Pickler
from pydantic import BaseModel
from xxhash import xxh64
from bespokelabs.curator.db import MetadataDB
-from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
-from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor
+from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
+from bespokelabs.curator.request_processor.base_request_processor import (
+ BaseRequestProcessor,
+)
+from bespokelabs.curator.request_processor.litellm_online_request_processor import (
+ LiteLLMOnlineRequestProcessor,
+)
from bespokelabs.curator.request_processor.openai_batch_request_processor import (
OpenAIBatchRequestProcessor,
)
from bespokelabs.curator.request_processor.openai_online_request_processor import (
OpenAIOnlineRequestProcessor,
)
-from bespokelabs.curator.request_processor.litellm_online_request_processor import (
- LiteLLMOnlineRequestProcessor,
-)
_CURATOR_DEFAULT_CACHE_DIR = "~/.cache/curator"
T = TypeVar("T")
+_DictOrBaseModel = Union[Dict[str, Any], BaseModel]
logger = logger = logging.getLogger(__name__)
-class Prompter:
+class LLM:
"""Interface for prompting LLMs."""
- @staticmethod
- def _determine_backend(
- model_name: str, response_format: Optional[Type[BaseModel]] = None
- ) -> str:
- """Determine which backend to use based on model name and response format.
-
- Args:
- model_name (str): Name of the model
- response_format (Optional[Type[BaseModel]]): Response format if specified
-
- Returns:
- str: Backend to use ("openai" or "litellm")
- """
- model_name = model_name.lower()
-
- # GPT-4o models with response format should use OpenAI
- if (
- response_format
- and OpenAIOnlineRequestProcessor(model_name).check_structured_output_support()
- ):
- logger.info(f"Requesting structured output from {model_name}, using OpenAI backend")
- return "openai"
-
- # GPT models and O1 models without response format should use OpenAI
- if not response_format and any(x in model_name for x in ["gpt-", "o1-preview", "o1-mini"]):
- logger.info(f"Requesting text output from {model_name}, using OpenAI backend")
- return "openai"
-
- # Default to LiteLLM for all other cases
- logger.info(
- f"Requesting {f'structured' if response_format else 'text'} output from {model_name}, using LiteLLM backend"
- )
- return "litellm"
-
def __init__(
self,
model_name: str,
- prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]],
+ prompt_func: Callable[[_DictOrBaseModel], _DictOrBaseModel],
parse_func: Optional[
- Callable[
- [
- Union[Dict[str, Any], BaseModel],
- Union[Dict[str, Any], BaseModel],
- ],
- T,
- ]
+ Callable[[_DictOrBaseModel, _DictOrBaseModel], _DictOrBaseModel]
] = None,
response_format: Optional[Type[BaseModel]] = None,
backend: Optional[str] = None,
+ max_requests_per_minute: Optional[int] = None,
+ max_tokens_per_minute: Optional[int] = None,
batch: bool = False,
batch_size: Optional[int] = None,
batch_check_interval: Optional[int] = 60,
@@ -92,38 +57,32 @@ def __init__(
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
+ max_retries: Optional[int] = None,
+ require_all_responses: Optional[bool] = True,
):
- """Initialize a Prompter.
+ """Initialize a LLM.
Args:
- model_name (str): The name of the LLM to use
- prompt_func (Callable[[Dict[str, Any]], Union[str, List[Dict[str, Any]]]]): A function that takes a single row
+ model_name: The name of the LLM to use
+ prompt_func: A function that takes a single row
and returns either a string (assumed to be a user prompt) or messages list
- parse_func (Callable[[Dict[str, Any], Any], T]): A function that takes the input row and
+ parse_func: A function that takes the input row and
response object and returns the parsed output
- response_format (Optional[Type[BaseModel]]): A Pydantic model specifying the
+ response_format: A Pydantic model specifying the
response format from the LLM.
- backend (Optional[str]): The backend to use ("openai" or "litellm"). If None, will be auto-determined
- batch (bool): Whether to use batch processing
- batch_size (Optional[int]): The size of the batch to use, only used if batch is True
- temperature (Optional[float]): The temperature to use for the LLM, only used if batch is False
- top_p (Optional[float]): The top_p to use for the LLM, only used if batch is False
- presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False
- frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False
+ backend: The backend to use ("openai" or "litellm"). If None, will be auto-determined
+ batch: Whether to use batch processing
+ batch_size: The size of the batch to use, only used if batch is True
+ batch_check_interval: The interval to check for batch completions, only used if batch is True
+ delete_successful_batch_files: Whether to delete successful batch files, only used if batch is True
+ delete_failed_batch_files: Whether to delete failed batch files, only used if batch is True
+ temperature: The temperature to use for the LLM, only used if batch is False
+ top_p: The top_p to use for the LLM, only used if batch is False
+ presence_penalty: The presence_penalty to use for the LLM, only used if batch is False
+ frequency_penalty: The frequency_penalty to use for the LLM, only used if batch is False
+ max_retries: The maximum number of retries to use for the LLM
+ require_all_responses: Whether to require all responses
"""
- prompt_sig = inspect.signature(prompt_func)
- if len(prompt_sig.parameters) > 1:
- raise ValueError(
- f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}"
- )
-
- if parse_func is not None:
- parse_sig = inspect.signature(parse_func)
- if len(parse_sig.parameters) != 2:
- raise ValueError(
- f"parse_func must take exactly 2 arguments, got {len(parse_sig.parameters)}"
- )
-
self.prompt_formatter = PromptFormatter(
model_name, prompt_func, parse_func, response_format
)
@@ -144,6 +103,10 @@ def __init__(
logger.info(
f"batch=True but no batch_size provided, using default batch_size of {batch_size:,}"
)
+ if max_requests_per_minute is not None or max_tokens_per_minute is not None:
+ logger.warning(
+ "max_requests_per_minute and max_tokens_per_minute not supported with batch mode"
+ )
self._request_processor = OpenAIBatchRequestProcessor(
model=model_name,
batch_size=batch_size,
@@ -154,11 +117,13 @@ def __init__(
frequency_penalty=frequency_penalty,
delete_successful_batch_files=delete_successful_batch_files,
delete_failed_batch_files=delete_failed_batch_files,
+ max_retries=max_retries,
+ require_all_responses=require_all_responses,
)
else:
if batch_size is not None:
logger.warning(
- f"Prompter argument `batch_size` {batch_size} is ignored because `batch` is False"
+ f"LLM argument `batch_size` {batch_size} is ignored because `batch` is False"
)
self._request_processor = OpenAIOnlineRequestProcessor(
model=model_name,
@@ -166,6 +131,10 @@ def __init__(
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
+ max_requests_per_minute=max_requests_per_minute,
+ max_tokens_per_minute=max_tokens_per_minute,
+ max_retries=max_retries,
+ require_all_responses=require_all_responses,
)
elif self.backend == "litellm":
if batch:
@@ -178,10 +147,48 @@ def __init__(
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
+ max_requests_per_minute=max_requests_per_minute,
+ max_tokens_per_minute=max_tokens_per_minute,
+ max_retries=max_retries,
+ require_all_responses=require_all_responses,
)
else:
raise ValueError(f"Unknown backend: {self.backend}")
+ @staticmethod
+ def _determine_backend(
+ model_name: str, response_format: Optional[Type[BaseModel]] = None
+ ) -> str:
+ """Determine which backend to use based on model name and response format.
+
+ Args:
+ model_name (str): Name of the model
+ response_format (Optional[Type[BaseModel]]): Response format if specified
+
+ Returns:
+ str: Backend to use ("openai" or "litellm")
+ """
+ model_name = model_name.lower()
+
+ # GPT-4o models with response format should use OpenAI
+ if (
+ response_format
+ and OpenAIOnlineRequestProcessor(model_name).check_structured_output_support()
+ ):
+ logger.info(f"Requesting structured output from {model_name}, using OpenAI backend")
+ return "openai"
+
+ # GPT models and O1 models without response format should use OpenAI
+ if not response_format and any(x in model_name for x in ["gpt-", "o1-preview", "o1-mini"]):
+ logger.info(f"Requesting text output from {model_name}, using OpenAI backend")
+ return "openai"
+
+ # Default to LiteLLM for all other cases
+ logger.info(
+ f"Requesting {f'structured' if response_format else 'text'} output from {model_name}, using LiteLLM backend"
+ )
+ return "litellm"
+
def __call__(
self,
dataset: Optional[Iterable] = None,
@@ -211,7 +218,7 @@ def _completions(
Args:
dataset (Iterable): A dataset consisting of a list of items to apply completions
- prompter (Prompter): A Prompter that contains the logic for formatting each
+ prompter (LLM): A LLM that contains the logic for formatting each
item in the dataset
working_dir (str): The working directory to save the requests.jsonl, responses.jsonl, and dataset.arrow files.
@@ -223,7 +230,7 @@ def _completions(
dataset = Dataset.from_generator(dataset)
if self is None:
- raise ValueError("Prompter must be provided")
+ raise ValueError("LLM must be provided")
if working_dir is None:
curator_cache_dir = os.environ.get(
@@ -311,7 +318,7 @@ def _get_function_hash(func) -> str:
return xxh64("").hexdigest()
file = BytesIO()
- dill.Pickler(file, recurse=True).dump(func)
+ Pickler(file, recurse=True).dump(func)
return xxh64(file.getvalue()).hexdigest()
diff --git a/src/bespokelabs/curator/llm/prompt_formatter.py b/src/bespokelabs/curator/llm/prompt_formatter.py
new file mode 100644
index 00000000..4dae93ce
--- /dev/null
+++ b/src/bespokelabs/curator/llm/prompt_formatter.py
@@ -0,0 +1,140 @@
+import dataclasses
+import inspect
+import json
+import logging
+from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union
+
+from pydantic import BaseModel, ValidationError
+
+from bespokelabs.curator.request_processor.generic_request import GenericRequest
+
+T = TypeVar("T")
+_DictOrBaseModel = Union[Dict[str, Any], BaseModel]
+logger = logging.getLogger(__name__)
+
+
+def _validate_messages(messages: list[dict]) -> None:
+ """Validates that messages conform to the expected chat format.
+
+ Args:
+ messages: A list of message dictionaries to validate.
+
+ Raises:
+ ValueError: If messages don't meet the required format:
+ - Must be a list of dictionaries
+ - Each message must have 'role' and 'content' keys
+ - Role must be one of: 'system', 'user', 'assistant'
+ """
+ valid_roles = {"system", "user", "assistant"}
+
+ for msg in messages:
+ if not isinstance(msg, dict):
+ raise ValueError(
+ "In the return value (a list) of the prompt_func, each "
+ "message must be a dictionary"
+ )
+
+ if "role" not in msg or "content" not in msg:
+ raise ValueError(
+ "In the return value (a list) of the prompt_func, each "
+ "message must contain 'role' and 'content' keys"
+ )
+
+ if msg["role"] not in valid_roles:
+ raise ValueError(
+ f"In the return value (a list) of the prompt_func, "
+ f"each message role must be one of: {', '.join(sorted(valid_roles))}"
+ )
+
+
+@dataclasses.dataclass
+class PromptFormatter:
+ model_name: str
+ prompt_func: Callable[[_DictOrBaseModel], Dict[str, str]]
+ parse_func: Optional[Callable[[_DictOrBaseModel, _DictOrBaseModel], T]] = None
+ response_format: Optional[Type[BaseModel]] = None
+
+ def create_generic_request(self, row: _DictOrBaseModel, idx: int) -> GenericRequest:
+ """Format the request object based off of `LLM` attributes."""
+ sig = inspect.signature(self.prompt_func)
+ if len(sig.parameters) == 0:
+ prompts = self.prompt_func()
+ elif len(sig.parameters) == 1:
+ prompts = self.prompt_func(row)
+ else:
+ raise ValueError(f"Prompting function {self.prompt_func} must have 0 or 1 arguments.")
+
+ if isinstance(prompts, str):
+ messages = [{"role": "user", "content": prompts}]
+ elif isinstance(prompts, list):
+ _validate_messages(prompts)
+ messages = prompts
+ else:
+ raise ValueError("The return value of the prompt_func must be a list of dictionaries.")
+
+ # Convert BaseModel to dict for serialization
+ if isinstance(row, BaseModel):
+ row = row.model_dump()
+
+ return GenericRequest(
+ model=self.model_name,
+ messages=messages,
+ original_row=row,
+ original_row_idx=idx,
+ response_format=(
+ self.response_format.model_json_schema() if self.response_format else None
+ ),
+ )
+
+ def response_to_response_format(self, response_message: str | dict) -> Optional[dict | str]:
+ """
+ Converts a response message to a specified Pydantic model format.
+
+ This method takes a response message (either as a string or dict) and validates/converts it
+ according to the provided Pydantic model format. If the response message is a string,
+ it first attempts to parse it as JSON. The resulting dict is then used to construct
+ an instance of the specified Pydantic model.
+
+ Args:
+ response_message (str | dict): The response message to convert, either as a JSON string
+ or a dictionary.
+ response_format (Optional[BaseModel]): The Pydantic model class that defines the
+ expected format of the response.
+
+ Returns:
+ Optional[dict | str]: The validated response message as a Pydantic model instance.
+
+ Raises:
+ json.JSONDecodeError: If the response_message is a string but cannot be parsed as valid JSON.
+ ValidationError: If the parsed response does not match the schema defined by response_format.
+ """
+ # Response message is a string, which is converted to a dict
+ # The dict is then used to construct the response_format Pydantic model
+ if self.response_format is None:
+ return response_message
+
+ try:
+ # First try to parse the response message as JSON
+ if isinstance(response_message, str):
+ try:
+ response_dict = json.loads(response_message)
+ except json.JSONDecodeError as e:
+ logger.warning(
+ f"Failed to parse response message as JSON: {response_message}. "
+ f"The model likely returned an invalid JSON format."
+ )
+ raise e
+ else:
+ response_dict = response_message
+
+ # Then construct the Pydantic model from the parsed dict
+ response_message = self.response_format(**response_dict)
+ return response_message
+
+ except ValidationError as e:
+ schema_str = json.dumps(self.response_format.model_json_schema(), indent=2)
+ logger.warning(
+ f"Pydantic failed to parse response message {response_message} with `response_format` {schema_str}. "
+ f"The model likely returned a JSON that does not match the schema of the `response_format`."
+ )
+ raise e
diff --git a/src/bespokelabs/curator/llm/prompt_formatter_test.py b/src/bespokelabs/curator/llm/prompt_formatter_test.py
new file mode 100644
index 00000000..b39b2226
--- /dev/null
+++ b/src/bespokelabs/curator/llm/prompt_formatter_test.py
@@ -0,0 +1,79 @@
+import pytest
+from pydantic import BaseModel
+
+from bespokelabs.curator.llm.prompt_formatter import PromptFormatter, _validate_messages
+
+
+def test_validate_messages_valid():
+ """Tests that valid message formats pass validation."""
+ valid_messages = [
+ {"role": "system", "content": "You are a helpful assistant"},
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ # Should not raise any exceptions
+ _validate_messages(valid_messages)
+
+
+def test_validate_messages_invalid_format():
+ """Tests that invalid message formats raise appropriate errors."""
+ # Test non-dict message
+ with pytest.raises(ValueError, match="must be a dictionary"):
+ _validate_messages([["role", "content"]])
+
+ # Test missing required keys
+ with pytest.raises(ValueError, match="must contain 'role' and 'content' keys"):
+ _validate_messages([{"role": "user"}])
+
+ # Test invalid role
+ with pytest.raises(ValueError, match="must be one of: assistant, system, user"):
+ _validate_messages([{"role": "invalid", "content": "test"}])
+
+
+class TestResponse(BaseModel):
+ text: str
+
+
+def test_prompt_formatter_create_generic_request():
+ """Tests that PromptFormatter correctly creates GenericRequest objects."""
+ # Test with string prompt
+ formatter = PromptFormatter(
+ model_name="test-model", prompt_func=lambda x: "Hello", response_format=TestResponse
+ )
+ request = formatter.create_generic_request({"input": "test"}, 0)
+
+ assert request.model == "test-model"
+ assert request.messages == [{"role": "user", "content": "Hello"}]
+ assert request.original_row == {"input": "test"}
+ assert request.original_row_idx == 0
+ assert request.response_format is not None
+
+ # Test with message list prompt
+ formatter = PromptFormatter(
+ model_name="test-model",
+ prompt_func=lambda x: [
+ {"role": "system", "content": "You are helpful"},
+ {"role": "user", "content": "Hi"},
+ ],
+ )
+ request = formatter.create_generic_request({"input": "test"}, 1)
+
+ assert len(request.messages) == 2
+ assert request.messages[0]["role"] == "system"
+ assert request.messages[1]["role"] == "user"
+ assert request.original_row_idx == 1
+
+
+def test_prompt_formatter_invalid_prompt_func():
+ """Tests that PromptFormatter raises errors for invalid prompt functions."""
+ # Test prompt function with too many parameters
+ with pytest.raises(ValueError, match="must have 0 or 1 arguments"):
+ PromptFormatter(model_name="test", prompt_func=lambda x, y: "test").create_generic_request(
+ {}, 0
+ )
+
+ # Test invalid prompt function return type
+ with pytest.raises(ValueError, match="must be a list of dictionaries"):
+ PromptFormatter(
+ model_name="test", prompt_func=lambda x: {"invalid": "format"}
+ ).create_generic_request({}, 0)
diff --git a/src/bespokelabs/curator/llm/simple_llm.py b/src/bespokelabs/curator/llm/simple_llm.py
new file mode 100644
index 00000000..7cc62dd0
--- /dev/null
+++ b/src/bespokelabs/curator/llm/simple_llm.py
@@ -0,0 +1,33 @@
+from bespokelabs.curator.llm.llm import LLM
+from datasets import Dataset
+from typing import Union, List
+
+
+class SimpleLLM:
+ """A simpler interface for the LLM class.
+
+ Usage:
+ llm = SimpleLLM(model_name="gpt-4o-mini")
+ llm("Do you know about the bitter lesson?")
+ llm(["What is the capital of France?", "What is the capital of Germany?"])
+ For more complex use cases (e.g. structured outputs and custom prompt functions), see the LLM class.
+ """
+
+ def __init__(self, model_name: str, backend: str = "openai"):
+ self._model_name = model_name
+ self._backend = backend
+
+ def __call__(self, prompt: Union[str, List[str]]) -> Union[str, List[str]]:
+ prompt_list = [prompt] if isinstance(prompt, str) else prompt
+ dataset: Dataset = Dataset.from_dict({"prompt": prompt_list})
+
+ llm = LLM(
+ prompt_func=lambda row: row["prompt"],
+ model_name=self._model_name,
+ response_format=None,
+ backend=self._backend,
+ )
+ response = llm(dataset)
+ if isinstance(prompt, str):
+ return response["response"][0]
+ return response["response"]
diff --git a/src/bespokelabs/curator/prompter/prompt_formatter.py b/src/bespokelabs/curator/prompter/prompt_formatter.py
deleted file mode 100644
index 5682c978..00000000
--- a/src/bespokelabs/curator/prompter/prompt_formatter.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import inspect
-from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union
-
-from pydantic import BaseModel
-
-from bespokelabs.curator.request_processor.generic_request import GenericRequest
-
-T = TypeVar("T")
-
-
-class PromptFormatter:
- model_name: str
- prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]]
- parse_func: Optional[
- Callable[
- [
- Union[Dict[str, Any], BaseModel],
- Union[Dict[str, Any], BaseModel],
- ],
- T,
- ]
- ] = None
- response_format: Optional[Type[BaseModel]] = None
-
- def __init__(
- self,
- model_name: str,
- prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]],
- parse_func: Optional[
- Callable[
- [
- Union[Dict[str, Any], BaseModel],
- Union[Dict[str, Any], BaseModel],
- ],
- T,
- ]
- ] = None,
- response_format: Optional[Type[BaseModel]] = None,
- ):
- self.model_name = model_name
- self.prompt_func = prompt_func
- self.parse_func = parse_func
- self.response_format = response_format
-
- def create_generic_request(self, row: Dict[str, Any] | BaseModel, idx: int) -> GenericRequest:
- """Format the request object based off Prompter attributes."""
- sig = inspect.signature(self.prompt_func)
- if len(sig.parameters) == 0:
- prompts = self.prompt_func()
- elif len(sig.parameters) == 1:
- prompts = self.prompt_func(row)
- else:
- raise ValueError(f"Prompting function {self.prompt_func} must have 0 or 1 arguments.")
-
- if isinstance(prompts, str):
- messages = [{"role": "user", "content": prompts}]
- else:
- # TODO(Ryan): Add validation here
- messages = prompts
-
- # Convert BaseModel to dict for serialization
- if isinstance(row, BaseModel):
- row = row.model_dump()
-
- return GenericRequest(
- model=self.model_name,
- messages=messages,
- original_row=row,
- original_row_idx=idx,
- response_format=(
- self.response_format.model_json_schema() if self.response_format else None
- ),
- )
diff --git a/src/bespokelabs/curator/request_processor/base_online_request_processor.py b/src/bespokelabs/curator/request_processor/base_online_request_processor.py
index 7e95cbc0..51537125 100644
--- a/src/bespokelabs/curator/request_processor/base_online_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/base_online_request_processor.py
@@ -13,7 +13,7 @@
from bespokelabs.curator.dataset import Dataset
from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor
-from bespokelabs.curator.prompter.prompter import PromptFormatter
+from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.generic_request import GenericRequest
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
from bespokelabs.curator.request_processor.generic_response import GenericResponse
@@ -22,6 +22,12 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
+DEFAULT_MAX_REQUESTS_PER_MINUTE = 100
+DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000
+DEFAULT_MAX_RETRIES = 10
+SECONDS_TO_PAUSE_ON_RATE_LIMIT = 10
+DEFAULT_REQUEST_TIMEOUT = 10 * 60 # 10 minutes
+
@dataclass
class StatusTracker:
@@ -42,7 +48,9 @@ class StatusTracker:
max_tokens_per_minute: int = 0
pbar: tqdm = field(default=None)
response_cost: float = 0
- time_of_last_rate_limit_error: float = field(default=None)
+ time_of_last_rate_limit_error: float = field(
+ default=time.time() - SECONDS_TO_PAUSE_ON_RATE_LIMIT
+ )
def __str__(self):
return (
@@ -119,14 +127,61 @@ def __init__(
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
+ max_requests_per_minute: Optional[int] = None,
+ max_tokens_per_minute: Optional[int] = None,
+ require_all_responses: bool = None,
+ max_retries: Optional[int] = None,
):
- super().__init__(batch_size=None)
+ super().__init__(batch_size=None, require_all_responses=require_all_responses)
self.model: str = model
self.temperature: float | None = temperature
self.top_p: float | None = top_p
self.presence_penalty: float | None = presence_penalty
self.frequency_penalty: float | None = frequency_penalty
self.prompt_formatter: Optional[PromptFormatter] = None
+ self.manual_max_requests_per_minute: Optional[int] = max_requests_per_minute
+ self.manual_max_tokens_per_minute: Optional[int] = max_tokens_per_minute
+ if max_retries is None:
+ self.max_retries = DEFAULT_MAX_RETRIES
+ else:
+ self.max_retries = max_retries
+ self.timeout = DEFAULT_REQUEST_TIMEOUT
+
+ @property
+ def max_requests_per_minute(self) -> int:
+ if self.manual_max_requests_per_minute:
+ logger.info(
+ f"Manually set max_requests_per_minute to {self.manual_max_requests_per_minute}"
+ )
+ return self.manual_max_requests_per_minute
+ elif self.header_based_max_requests_per_minute:
+ logger.info(
+ f"Automatically set max_requests_per_minute to {self.header_based_max_requests_per_minute}"
+ )
+ return self.header_based_max_requests_per_minute
+ else:
+ logger.warning(
+ f"No manual max_requests_per_minute set, and headers based detection failed, using default value of {DEFAULT_MAX_REQUESTS_PER_MINUTE}"
+ )
+ return DEFAULT_MAX_REQUESTS_PER_MINUTE
+
+ @property
+ def max_tokens_per_minute(self) -> int:
+ if self.manual_max_tokens_per_minute:
+ logger.info(
+ f"Manually set max_tokens_per_minute to {self.manual_max_tokens_per_minute}"
+ )
+ return self.manual_max_tokens_per_minute
+ elif self.header_based_max_tokens_per_minute:
+ logger.info(
+ f"Automatically set max_tokens_per_minute to {self.header_based_max_tokens_per_minute}"
+ )
+ return self.header_based_max_tokens_per_minute
+ else:
+ logger.warning(
+ f"No manual max_tokens_per_minute set, and headers based detection failed, using default value of {DEFAULT_MAX_TOKENS_PER_MINUTE}"
+ )
+ return DEFAULT_MAX_TOKENS_PER_MINUTE
@abstractmethod
def estimate_total_tokens(self, messages: list) -> int:
@@ -149,6 +204,11 @@ def run(
parse_func_hash: str,
prompt_formatter: PromptFormatter,
) -> Dataset:
+ # load from already completed dataset
+ output_dataset = self.attempt_loading_cached_dataset(working_dir, parse_func_hash)
+ if output_dataset is not None:
+ return output_dataset
+
"""Run completions using the online API with async processing."""
logger.info(f"Running {self.__class__.__name__} completions with model: {self.model}")
@@ -169,7 +229,6 @@ def run(
self.process_requests_from_file(
generic_request_filepath=request_file,
save_filepath=response_file,
- max_attempts=5,
resume=True,
)
)
@@ -180,7 +239,6 @@ async def process_requests_from_file(
self,
generic_request_filepath: str,
save_filepath: str,
- max_attempts: int,
resume: bool,
resume_no_retry: bool = False,
) -> None:
@@ -191,10 +249,8 @@ async def process_requests_from_file(
status_tracker = StatusTracker()
# Get rate limits
- rate_limits = self.get_rate_limits()
- status_tracker.max_requests_per_minute = rate_limits["max_requests_per_minute"]
- status_tracker.max_tokens_per_minute = rate_limits["max_tokens_per_minute"]
- rpm = rate_limits["max_requests_per_minute"]
+ status_tracker.max_requests_per_minute = self.max_requests_per_minute
+ status_tracker.max_tokens_per_minute = self.max_tokens_per_minute
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(
@@ -206,7 +262,7 @@ async def process_requests_from_file(
completed_request_ids = set()
if os.path.exists(save_filepath):
if resume:
- logger.debug(f"Resuming progress from existing file: {save_filepath}")
+ logger.info(f"Resuming progress by reading existing file: {save_filepath}")
logger.debug(
f"Removing all failed requests from {save_filepath} so they can be retried"
)
@@ -224,6 +280,11 @@ async def process_requests_from_file(
f"{response.response_errors}, removing from output and will retry"
)
num_previously_failed_requests += 1
+ if response.response_message is None:
+ logger.debug(
+ f"Request {response.generic_request.original_row_idx} previously failed due to no response, removing from output and will retry"
+ )
+ num_previously_failed_requests += 1
else:
completed_request_ids.add(response.generic_request.original_row_idx)
output_file.write(line)
@@ -279,7 +340,7 @@ async def process_requests_from_file(
)
# Use higher connector limit for better throughput
- connector = aiohttp.TCPConnector(limit=10 * rpm)
+ connector = aiohttp.TCPConnector(limit=10 * status_tracker.max_requests_per_minute)
async with aiohttp.ClientSession(
connector=connector
) as session: # Initialize ClientSession here
@@ -297,7 +358,7 @@ async def process_requests_from_file(
task_id=status_tracker.num_tasks_started,
generic_request=generic_request,
api_specific_request=self.create_api_specific_request(generic_request),
- attempts_left=max_attempts,
+ attempts_left=self.max_retries,
prompt_formatter=self.prompt_formatter,
)
@@ -307,6 +368,19 @@ async def process_requests_from_file(
while not status_tracker.has_capacity(token_estimate):
await asyncio.sleep(0.1)
+ # Wait for rate limits cool down if needed
+ seconds_since_rate_limit_error = (
+ time.time() - status_tracker.time_of_last_rate_limit_error
+ )
+ if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_ON_RATE_LIMIT:
+ remaining_seconds_to_pause = (
+ SECONDS_TO_PAUSE_ON_RATE_LIMIT - seconds_since_rate_limit_error
+ )
+ await asyncio.sleep(remaining_seconds_to_pause)
+ logger.warn(
+ f"Pausing to cool down for {int(remaining_seconds_to_pause)} seconds"
+ )
+
# Consume capacity before making request
status_tracker.consume_capacity(token_estimate)
@@ -337,10 +411,10 @@ async def process_requests_from_file(
token_estimate = self.estimate_total_tokens(
retry_request.generic_request.messages
)
- attempt_number = 6 - retry_request.attempts_left
- logger.info(
- f"Processing retry for request {retry_request.task_id} "
- f"(attempt #{attempt_number} of 5). "
+ attempt_number = self.max_retries - retry_request.attempts_left
+ logger.debug(
+ f"Retrying request {retry_request.task_id} "
+ f"(attempt #{attempt_number} of {self.max_retries})"
f"Previous errors: {retry_request.result}"
)
@@ -405,6 +479,9 @@ async def handle_single_request_with_retries(
status_tracker=status_tracker,
)
+ # Allows us to retry on responses that don't match the response format
+ self.prompt_formatter.response_to_response_format(generic_response.response_message)
+
# Save response in the base class
await self.append_generic_response(generic_response, save_filepath)
@@ -413,23 +490,20 @@ async def handle_single_request_with_retries(
status_tracker.pbar.update(1)
except Exception as e:
- logger.warning(
- f"Request {request.task_id} failed with Exception {e}, attempts left {request.attempts_left}"
- )
status_tracker.num_other_errors += 1
request.result.append(e)
if request.attempts_left > 0:
request.attempts_left -= 1
- # Add retry queue logging
- logger.info(
- f"Adding request {request.task_id} to retry queue. Will retry in next available slot. "
- f"Attempts remaining: {request.attempts_left}"
+ logger.warning(
+ f"Encountered '{e.__class__.__name__}: {e}' during attempt "
+ f"{self.max_retries - request.attempts_left} of {self.max_retries} "
+ f"while processing request {request.task_id}"
)
retry_queue.put_nowait(request)
else:
logger.error(
- f"Request {request.task_id} failed permanently after exhausting all 5 retry attempts. "
+ f"Request {request.task_id} failed permanently after exhausting all {self.max_retries} retry attempts. "
f"Errors: {[str(e) for e in request.result]}"
)
generic_response = GenericResponse(
diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py
index a19fbf8f..6a5b2a30 100644
--- a/src/bespokelabs/curator/request_processor/base_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/base_request_processor.py
@@ -6,7 +6,8 @@
import resource
from abc import ABC, abstractmethod
from math import ceil
-from typing import Optional
+from pathlib import Path
+from typing import Optional, List
import aiofiles
import pyarrow
@@ -14,7 +15,8 @@
from datasets.arrow_writer import ArrowWriter
from pydantic import BaseModel, ValidationError
-from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
+from bespokelabs.curator.file_utilities import count_lines
+from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
from bespokelabs.curator.request_processor.generic_request import GenericRequest
from bespokelabs.curator.request_processor.generic_response import GenericResponse
@@ -29,8 +31,9 @@ class BaseRequestProcessor(ABC):
Base class for all request processors.
"""
- def __init__(self, batch_size: Optional[int] = None):
+ def __init__(self, batch_size: Optional[int] = None, require_all_responses: bool = True):
self.batch_size = batch_size
+ self.require_all_responses = require_all_responses
# Increase the number of open file descriptors to avoid "Too many open files" errors
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
desired_limit = min(10_000_000, hard)
@@ -39,16 +42,6 @@ def __init__(self, batch_size: Optional[int] = None):
)
resource.setrlimit(resource.RLIMIT_NOFILE, (desired_limit, hard))
- @abstractmethod
- def get_rate_limits(self) -> dict:
- """
- Returns the rate limits for the API.
-
- Returns:
- dict: A dictionary containing the rate limit information.
- """
- pass
-
@abstractmethod
def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
"""
@@ -84,6 +77,64 @@ def run(
"""
pass
+ def _verify_existing_request_files(
+ self, working_dir: str, dataset: Optional[Dataset]
+ ) -> List[int]:
+ """
+ Verify integrity of the cache (each request file has associated metadata, and the number of rows is correct),
+ and return the indices of request files that need to be regenerated (so that no work is repeated).
+
+ Args:
+ working_dir (str): Working directory where cache files are expected to be (requests.jsonl, metadata.json)
+ dataset (Optional[Dataset]): The dataset that we want to create requests from
+
+ Returns:
+ List[int]: Indices of missing files
+ """
+
+ if self.batch_size is not None and dataset is not None:
+ expected_num_files = ceil(len(dataset) / self.batch_size)
+ else:
+ expected_num_files = 1
+
+ try:
+ incomplete_files = []
+ for i in range(expected_num_files):
+ req_f = os.path.join(working_dir, f"requests_{i}.jsonl")
+ meta_f = os.path.join(working_dir, f"metadata_{i}.json")
+
+ if not os.path.exists(req_f):
+ incomplete_files.append(i)
+ continue
+
+ if not os.path.exists(meta_f):
+ logger.warning(f"Cache missing metadata file {meta_f} for request file {req_f}")
+ incomplete_files.append(i)
+ continue
+
+ with open(req_f, "r") as f:
+ data = f.read()
+ num_jobs = len(data.splitlines())
+
+ with open(meta_f, "r") as f:
+ metadata = json.load(f)
+
+ expected_num_jobs = metadata["num_jobs"]
+ if num_jobs != expected_num_jobs:
+ logger.warning(
+ f"Request file {req_f} has {num_jobs} jobs, but metadata file {meta_f} has {expected_num_jobs} jobs"
+ )
+ incomplete_files.append(i)
+
+ return incomplete_files
+
+ except Exception as e:
+ logger.warning(
+ f"Cache verification failed due to {e} - regenerating all request files."
+ )
+ incomplete_files = list(range(expected_num_files))
+ return incomplete_files
+
def create_request_files(
self,
dataset: Optional[Dataset],
@@ -104,7 +155,9 @@ def create_request_files(
request_files = glob.glob(f"{working_dir}/requests_*.jsonl")
# By default use existing requests in working_dir
- if len(request_files) > 0:
+ incomplete_files = self._verify_existing_request_files(working_dir, dataset)
+
+ if len(incomplete_files) == 0:
logger.info(f"Using cached requests. {CACHE_MSG}")
# count existing jobs in file and print first job
with open(request_files[0], "r") as f:
@@ -124,18 +177,27 @@ def create_request_files(
return request_files
# Create new requests file
+ logger.info(f"Preparing request file(s) in {working_dir}")
request_file = f"{working_dir}/requests_0.jsonl"
request_files = [request_file]
+ metadata_file = f"{working_dir}/metadata_0.json"
+ metadata_files = [metadata_file]
+
if dataset is None:
with open(request_file, "w") as f:
generic_request = prompt_formatter.create_generic_request(dict(), 0)
f.write(json.dumps(generic_request.model_dump(), default=str) + "\n")
+
+ metadata_dict = {"num_jobs": 1}
+ with open(metadata_file, "w") as f:
+ f.write(json.dumps(metadata_dict, indent=4) + "\n")
return request_files
if self.batch_size:
num_batches = ceil(len(dataset) / self.batch_size)
request_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)]
+ metadata_files = [f"{working_dir}/metadata_{i}.json" for i in range(num_batches)]
async def create_all_request_files():
tasks = [
@@ -143,15 +205,19 @@ async def create_all_request_files():
dataset,
prompt_formatter,
request_files[i],
+ metadata_files[i],
start_idx=i * self.batch_size,
)
for i in range(num_batches)
+ if i in incomplete_files
]
await asyncio.gather(*tasks)
run_in_event_loop(create_all_request_files())
else:
- run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file))
+ run_in_event_loop(
+ self.acreate_request_file(dataset, prompt_formatter, request_file, metadata_file)
+ )
return request_files
@@ -161,8 +227,9 @@ async def acreate_request_file(
dataset: Dataset,
prompt_formatter: PromptFormatter,
request_file: str,
+ metadata_file: str,
start_idx: int = 0,
- ) -> str:
+ ) -> None:
if self.batch_size is not None:
end_idx = min(start_idx + self.batch_size, len(dataset))
dataset = dataset.select(range(start_idx, end_idx))
@@ -176,7 +243,13 @@ async def acreate_request_file(
# Get the generic request from the map function
request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx)
await f.write(json.dumps(request.model_dump(), default=str) + "\n")
- logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.")
+
+ num_requests = end_idx - start_idx
+ metadata_dict = {"num_jobs": num_requests}
+ async with aiofiles.open(metadata_file, "w") as f:
+ await f.write(json.dumps(metadata_dict, indent=4) + "\n")
+
+ logger.info(f"Wrote {num_requests} requests to {request_file}.")
def attempt_loading_cached_dataset(
self, working_dir: str, parse_func_hash: str
@@ -216,9 +289,6 @@ def create_dataset_files(
Returns:
Dataset: Completed dataset
"""
- total_responses_count = 0
- failed_responses_count = 0
-
responses_files = glob.glob(f"{working_dir}/responses_*.jsonl")
if len(responses_files) == 0:
raise ValueError(f"No responses files found in {working_dir}")
@@ -230,6 +300,8 @@ def create_dataset_files(
)
# Process all response files
+ total_responses_count = 0
+ failed_responses_count = 0
dataset_file = f"{working_dir}/{parse_func_hash}.arrow"
with ArrowWriter(path=dataset_file) as writer:
for responses_file in responses_files:
@@ -243,41 +315,18 @@ def create_dataset_files(
failed_responses_count += 1
continue
- if prompt_formatter.response_format:
- # Response message is a string, which is converted to a dict
- # The dict is then used to construct the response_format Pydantic model
- try:
- # First try to parse the response message as JSON
- if isinstance(response.response_message, str):
- try:
- response_dict = json.loads(response.response_message)
- except json.JSONDecodeError as e:
- warning_msg = (
- f"Failed to parse response message as JSON: {response.response_message}. "
- f"The model likely returned an invalid JSON format. Will skip this response."
- )
- logger.warning(warning_msg)
- failed_responses_count += 1
- continue
- else:
- response_dict = response.response_message
-
- # Then construct the Pydantic model from the parsed dict
- response.response_message = prompt_formatter.response_format(
- **response_dict
- )
- except ValidationError as e:
- schema_str = json.dumps(
- prompt_formatter.response_format.model_json_schema(),
- indent=2,
+ try:
+ response.response_message = (
+ self.prompt_formatter.response_to_response_format(
+ response.response_message
)
- warning_msg = (
- f"Pydantic failed to parse response message {response.response_message} with `response_format` {schema_str}. "
- f"The model likely returned a JSON that does not match the schema of the `response_format`. Will skip this response."
- )
- logger.warning(warning_msg)
- failed_responses_count += 1
- continue
+ )
+ except (json.JSONDecodeError, ValidationError) as e:
+ logger.warning(
+ "Skipping response due to error parsing response message into response format"
+ )
+ failed_responses_count += 1
+ continue
# parse_func can return a single row or a list of rows
if prompt_formatter.parse_func:
@@ -293,7 +342,13 @@ def create_dataset_files(
if not isinstance(dataset_rows, list):
dataset_rows = [dataset_rows]
else:
- dataset_rows = [{"response": response.response_message}]
+ # Convert response to dict before adding to dataset
+ response_value = response.response_message
+ if hasattr(response_value, "model_dump"):
+ response_value = response_value.model_dump()
+ elif hasattr(response_value, "__dict__"):
+ response_value = response_value.__dict__
+ dataset_rows = [{"response": response_value}]
for row in dataset_rows:
if isinstance(row, BaseModel):
@@ -313,14 +368,35 @@ def create_dataset_files(
writer.write(row)
- logger.info(f"Read {total_responses_count} responses, {failed_responses_count} failed")
+ logger.info("Finalizing writer")
+ writer.finalize()
+
+ logger.info(f"Read {total_responses_count} responses.")
if failed_responses_count == total_responses_count:
os.remove(dataset_file)
raise ValueError("All requests failed")
- logger.info("Finalizing writer")
+ if failed_responses_count > 0:
+ logger.warning(f"{failed_responses_count} requests failed.")
+ if self.require_all_responses:
+ os.remove(dataset_file)
+ raise ValueError(f"Some requests failed and require_all_responses is True")
- writer.finalize()
+ # number of responses matches number of requests
+ request_files = glob.glob(f"{working_dir}/requests_*.jsonl")
+ n_requests = 0
+ for request_file in request_files:
+ n_requests += count_lines(request_file)
+
+ if n_requests != total_responses_count:
+ logger.warning(
+ f"{n_requests - total_responses_count} requests do not have responses. n_requests is {n_requests} and n_responses is {total_responses_count}"
+ )
+ if self.require_all_responses:
+ os.remove(dataset_file)
+ raise ValueError(
+ f"Some requests do not have responses and require_all_responses is True."
+ )
return Dataset.from_file(dataset_file)
diff --git a/src/bespokelabs/curator/request_processor/event_loop.py b/src/bespokelabs/curator/request_processor/event_loop.py
index 92120e7d..6bc8bda7 100644
--- a/src/bespokelabs/curator/request_processor/event_loop.py
+++ b/src/bespokelabs/curator/request_processor/event_loop.py
@@ -1,5 +1,4 @@
import asyncio
-from time import sleep
import nest_asyncio
diff --git a/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py b/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py
index 4b346fcf..28c888e8 100644
--- a/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/litellm_online_request_processor.py
@@ -1,6 +1,5 @@
import logging
from typing import Optional
-import asyncio
import aiohttp
import litellm
from litellm import get_supported_openai_params
@@ -14,7 +13,7 @@
from bespokelabs.curator.request_processor.generic_request import GenericRequest
from bespokelabs.curator.request_processor.generic_response import TokenUsage, GenericResponse
from pydantic import BaseModel
-from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
+import time
logger = logging.getLogger(__name__)
@@ -49,6 +48,10 @@ def __init__(
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
+ max_requests_per_minute: Optional[int] = None,
+ max_tokens_per_minute: Optional[int] = None,
+ require_all_responses: Optional[bool] = None,
+ max_retries: Optional[int] = None,
):
super().__init__(
model=model,
@@ -56,8 +59,15 @@ def __init__(
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
+ max_requests_per_minute=max_requests_per_minute,
+ max_tokens_per_minute=max_tokens_per_minute,
+ require_all_responses=require_all_responses,
+ max_retries=max_retries,
)
self.client = instructor.from_litellm(litellm.acompletion)
+ self.header_based_max_requests_per_minute, self.header_based_max_tokens_per_minute = (
+ self.get_header_based_rate_limits()
+ )
def check_structured_output_support(self):
"""Verify if the model supports structured output via instructor.
@@ -134,20 +144,7 @@ def estimate_total_tokens(self, messages: list) -> int:
output_tokens = self.estimate_output_tokens()
return input_tokens + output_tokens
- def get_rate_limits(self) -> dict:
- """Retrieve rate limits from the LLM provider via LiteLLM.
-
- Makes a test request to get rate limit information from response headers.
-
- Returns:
- dict: Contains 'max_requests_per_minute' and 'max_tokens_per_minute'
-
- Note:
- - Falls back to default values if headers are missing
- - Some providers (e.g., Claude) require non-empty messages
- """
- logger.info(f"Getting rate limits for model: {self.model}")
-
+ def test_call(self):
completion = litellm.completion(
model=self.model,
messages=[
@@ -155,15 +152,33 @@ def get_rate_limits(self) -> dict:
], # Some models (e.g. Claude) require an non-empty message to get rate limits.
)
+ # Try the method of caculating cost
+ try:
+ litellm.completion_cost(completion_response=completion.model_dump())
+ except litellm.NotFoundError as e:
+ logger.warning(f"LiteLLM does not support cost estimation for model {self.model}: {e}")
+
headers = completion._hidden_params.get("additional_headers", {})
- logger.info(f"Rate limit headers: {headers}")
+ logger.info(f"Test call headers: {headers}")
+ return headers
+
+ def get_header_based_rate_limits(self) -> tuple[int, int]:
+ """Retrieve rate limits from the LLM provider via LiteLLM.
+
+ Returns:
+ tuple[int, int]: Contains 'max_requests_per_minute' and 'max_tokens_per_minute'
- rpm = int(headers.get("x-ratelimit-limit-requests", 3000))
- tpm = int(headers.get("x-ratelimit-limit-tokens", 150_000))
+ Note:
+ - Makes a test request to get rate limit information from response headers.
+ - Some providers (e.g., Claude) require non-empty messages
+ """
+ logger.info(f"Getting rate limits for model: {self.model}")
- logger.info(f"Rate limits - Requests/min: {rpm}, Tokens/min: {tpm}")
+ headers = self.test_call()
+ rpm = int(headers.get("x-ratelimit-limit-requests", 0))
+ tpm = int(headers.get("x-ratelimit-limit-tokens", 0))
- return {"max_requests_per_minute": rpm, "max_tokens_per_minute": tpm}
+ return rpm, tpm
def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
"""Convert a generic request into a LiteLLM-compatible format.
@@ -200,6 +215,31 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
if "frequency_penalty" in supported_params and self.frequency_penalty is not None:
request["frequency_penalty"] = self.frequency_penalty
+ # Add safety settings for Gemini models
+ if "gemini" in generic_request.model.lower():
+ request["safety_settings"] = [
+ {
+ "category": "HARM_CATEGORY_HARASSMENT",
+ "threshold": "BLOCK_NONE",
+ },
+ {
+ "category": "HARM_CATEGORY_HATE_SPEECH",
+ "threshold": "BLOCK_NONE",
+ },
+ {
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+ "threshold": "BLOCK_NONE",
+ },
+ {
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
+ "threshold": "BLOCK_NONE",
+ },
+ {
+ "category": "HARM_CATEGORY_CIVIC_INTEGRITY",
+ "threshold": "BLOCK_NONE",
+ },
+ ]
+
return request
async def call_single_request(
@@ -222,18 +262,29 @@ async def call_single_request(
GenericResponse: The response from LiteLLM
"""
# Get response directly without extra logging
- if request.generic_request.response_format:
- response, completion_obj = await self.client.chat.completions.create_with_completion(
- **request.api_specific_request,
- response_model=request.prompt_formatter.response_format,
- timeout=60.0,
- )
- response_message = (
- response.model_dump() if hasattr(response, "model_dump") else response
- )
- else:
- completion_obj = await litellm.acompletion(**request.api_specific_request, timeout=60.0)
- response_message = completion_obj["choices"][0]["message"]["content"]
+ try:
+ if request.generic_request.response_format:
+ response, completion_obj = (
+ await self.client.chat.completions.create_with_completion(
+ **request.api_specific_request,
+ response_model=request.prompt_formatter.response_format,
+ timeout=self.timeout,
+ )
+ )
+ response_message = (
+ response.model_dump() if hasattr(response, "model_dump") else response
+ )
+ else:
+ completion_obj = await litellm.acompletion(
+ **request.api_specific_request, timeout=self.timeout
+ )
+ response_message = completion_obj["choices"][0]["message"]["content"]
+ except litellm.RateLimitError as e:
+ status_tracker.time_of_last_rate_limit_error = time.time()
+ status_tracker.num_rate_limit_errors += 1
+ # because handle_single_request_with_retries will double count otherwise
+ status_tracker.num_api_errors -= 1
+ raise e
# Extract token usage
usage = completion_obj.usage if hasattr(completion_obj, "usage") else {}
@@ -247,9 +298,21 @@ async def call_single_request(
try:
cost = litellm.completion_cost(completion_response=completion_obj.model_dump())
except litellm.NotFoundError as e:
- logger.info(f"LiteLLM does not support cost estimation for model {self.model}: {e}")
cost = 0
+ finish_reason = completion_obj.choices[0].finish_reason
+ invalid_finish_reasons = ["length", "content_filter"]
+ if finish_reason in invalid_finish_reasons:
+ logger.debug(
+ f"Invalid finish_reason {finish_reason}. Raw response {completion_obj.model_dump()} for request {request.generic_request.messages}"
+ )
+ raise ValueError(f"finish_reason was {finish_reason}")
+
+ if response_message is None:
+ raise ValueError(
+ f"response_message was None with raw response {completion_obj.model_dump()}"
+ )
+
# Create and return response
return GenericResponse(
response_message=response_message,
diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py
index 6312a2d7..1aaf27e3 100644
--- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py
@@ -3,18 +3,18 @@
import glob
import json
import logging
+import os
from dataclasses import dataclass, field
-from typing import Callable
+from typing import Callable, Optional
-import glob
-import os
import litellm
-from openai import AsyncOpenAI
+from openai import AsyncOpenAI, NotFoundError
from openai.types import Batch
from tqdm import tqdm
+from bespokelabs.curator.file_utilities import count_lines
from bespokelabs.curator.dataset import Dataset
-from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
+from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.base_request_processor import (
BaseRequestProcessor,
GenericRequest,
@@ -48,6 +48,8 @@ def __init__(
url: str = "https://api.openai.com/v1/chat/completions",
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
+ require_all_responses: bool = None,
+ max_retries: Optional[int] = None,
):
if batch_size > MAX_REQUESTS_PER_BATCH:
raise ValueError(
@@ -55,7 +57,7 @@ def __init__(
f"{MAX_REQUESTS_PER_BATCH:,} requests per batch that OpenAI supports. "
f"Please set your batch_size to be less than or equal to {MAX_REQUESTS_PER_BATCH:,}."
)
- super().__init__(batch_size)
+ super().__init__(batch_size, require_all_responses=require_all_responses)
self.model = model
self.url: str = url
self.check_interval: int = batch_check_interval
@@ -65,48 +67,10 @@ def __init__(
self.frequency_penalty: float | None = frequency_penalty
self.delete_successful_batch_files: bool = delete_successful_batch_files
self.delete_failed_batch_files: bool = delete_failed_batch_files
-
- def get_rate_limits(self) -> dict:
- """
- Function to get rate limits for a given annotator. Not available via response headers, so
- the following is based on tier 5 limits on Nov 6th, 2024.
-
- These rate limits vary per model
- and are determined by your organization's usage tier. View the following:
- https://platform.openai.com/docs/guides/rate-limits/usage-tiers
- https://platform.openai.com/settings/organization/limits
-
- Args:
- model (str): The model for which to get the rate limits.
- request_url (str): The request URL for which to get the rate limits.
-
- Returns:
- tuple[int, int]: A tuple containing the maximum number of requests and tokens per minute.
- """
- model_tpd = {
- "gpt-3.5-turbo": 5_000_000_000,
- "gpt-3.5-turbo-0125": 5_000_000_000,
- "gpt-3.5-turbo-1106": 5_000_000_000,
- "gpt-3.5-turbo-16k": 5_000_000_000,
- "gpt-3.5-turbo-instruct": 200_000,
- "gpt-3.5-turbo-instruct-0914": 200_000,
- "gpt-4": 150_000_000,
- "gpt-4-0613": 150_000_000,
- "gpt-4-turbo": 300_000_000,
- "gpt-4o": 10_000_000_000,
- "gpt-4o-mini": 15_000_000_000,
- }
-
- if self.model not in model_tpd:
- tpd = 1_000_000_000
+ if max_retries is None:
+ self.max_retries = MAX_RETRIES_PER_OPERATION
else:
- tpd = model_tpd[self.model]
-
- logger.info(f"Automatically set max_tokens_per_day to {tpd}, model: {self.model} ")
-
- rate_limits = {"max_tokens_per_day": tpd}
-
- return rate_limits
+ self.max_retries = max_retries
def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
"""
@@ -221,7 +185,14 @@ def generic_response_file_from_responses(
for raw_response in responses.text.splitlines():
raw_response = json.loads(raw_response)
request_idx = int(raw_response["custom_id"])
- generic_request = generic_request_map[request_idx]
+
+ if request_idx not in generic_request_map:
+ logger.warning(
+ f"Request {request_idx} not found in generic_request_map. response_file: {response_file}, "
+ f"request_file: {request_file}. The request files might have been incomplete. Will skip "
+ f"this response."
+ )
+ continue
if raw_response["response"]["status_code"] != 200:
logger.warning(
@@ -229,9 +200,7 @@ def generic_response_file_from_responses(
)
generic_response = GenericResponse(
response_message=None,
- response_errors=[
- f"Request {generic_request} failed with status code {raw_response['response']['status_code']}"
- ],
+ response_errors=[raw_response["response"]["status_code"]],
raw_response=raw_response,
raw_request=None,
generic_request=generic_request,
@@ -329,6 +298,7 @@ def run(
prompt_formatter,
delete_successful_batch_files=self.delete_successful_batch_files,
delete_failed_batch_files=self.delete_failed_batch_files,
+ max_retries=self.max_retries,
)
run_in_event_loop(self.run_batch_operations(batch_manager, request_files))
@@ -347,6 +317,7 @@ def cancel_batches(self, working_dir: str) -> Dataset:
self.check_interval,
delete_successful_batch_files=self.delete_successful_batch_files,
delete_failed_batch_files=self.delete_failed_batch_files,
+ max_retries=self.max_retries,
)
run_in_event_loop(batch_manager.cancel_batches())
@@ -504,6 +475,7 @@ def __init__(
prompt_formatter: PromptFormatter | None = None,
delete_successful_batch_files: bool = False,
delete_failed_batch_files: bool = False,
+ max_retries: Optional[int] = None,
) -> None:
"""Initialize BatchManager to handle OpenAI batch processing operations.
@@ -517,7 +489,7 @@ def __init__(
delete_failed_batch_files (bool): Whether to delete input/error files from OpenAI
after batch failure.
"""
- self.client = AsyncOpenAI()
+ self.client = AsyncOpenAI(max_retries=max_retries)
self.check_interval = check_interval
self.working_dir = working_dir
self.tracker = BatchStatusTracker()
@@ -677,7 +649,6 @@ async def retrieve_batch(self, batch_id: str) -> Batch:
try:
batch_object = await self.client.batches.retrieve(batch_id)
except Exception as e:
- logger.error(f"Error checking previously submitted batch: {e}")
raise e
return batch_object
@@ -721,12 +692,22 @@ async def submit_batch_from_request_file(
async def track_already_submitted_batches(self):
"""
Tracks previously submitted batches from the submitted batch objects file.
+ We need to check all submitted batch objects files because we might be looking at a cancelled batch
+ or a batch from another key but same project.
Side Effects:
- Updates tracker with previously submitted batch statuses
"""
- if os.path.exists(self.submitted_batch_objects_file):
- with open(self.submitted_batch_objects_file, "r") as f:
+ all_submitted_batches_files = set(
+ glob.glob(f"{self.working_dir}/batch_objects_submitted_*.jsonl")
+ )
+
+ existing_submitted_batches = {}
+ for submitted_batch_objects_file in all_submitted_batches_files:
+ logger.info(
+ f"Processing submitted batch objects file: {submitted_batch_objects_file} Your API key is ***{self.client.api_key[-4:]}."
+ )
+ with open(submitted_batch_objects_file, "r") as f:
for line in f:
batch_object = Batch.model_validate(json.loads(line))
request_file_name = batch_object.metadata["request_file_name"]
@@ -734,27 +715,79 @@ async def track_already_submitted_batches(self):
f"Already submitted batch {batch_object.id} for request file {request_file_name}. "
f"Getting batch object to update tracker."
)
- batch_object = await self.retrieve_batch(batch_object.id)
+ try:
+ batch_object = await self.retrieve_batch(batch_object.id)
+ except NotFoundError:
+ logger.warning(
+ f"Already submitted batch object {batch_object.id} not found. This might be fine since we might be "
+ "looking at a batch object submitted by another project. Will ignore this batch object..."
+ )
+ continue
+
+ if not self._validate_batch_status(batch_object.status):
+ logger.warning(
+ f"Already submitted batch {batch_object.id} has an invalid status {batch_object.status}. "
+ f"Will ignore this batch object..."
+ )
+ continue
+
+ # We skip the batch if it has a status that means it can no longer be used.
+ if batch_object.status in ["expired", "cancelling", "cancelled"]:
+ logger.info(
+ f"Batch {batch_object.id} has status {batch_object.status}, which means it can "
+ "no longer be used. Will ignore this batch object..."
+ )
+ continue
# Edge case where the batch is still validating, and we need to know the total number of requests
if batch_object.status == "validating":
- n_requests = len(open(request_file_name, "r").readlines())
- batch_object.request_counts.total = n_requests
+ batch_object.request_counts.total = count_lines(request_file_name)
else:
n_requests = batch_object.request_counts.total
- if request_file_name in self.tracker.unsubmitted_request_files:
- self.tracker.mark_as_submitted(request_file_name, batch_object, n_requests)
- else:
- # batch objects if not unsubmitted, should be downloaded
- assert batch_object.id in self.tracker.downloaded_batches
+ # For each request file, we only want to keep the latest batch object.
+ if (
+ request_file_name not in existing_submitted_batches
+ or existing_submitted_batches[request_file_name].created_at
+ < batch_object.created_at
+ ):
+ existing_submitted_batches[request_file_name] = batch_object
+
+ for request_file_name, batch_object in existing_submitted_batches.items():
+
+ output_file_id = batch_object.output_file_id
+ if output_file_id is not None:
+ try:
+ await self.client.files.retrieve(output_file_id)
+ except NotFoundError:
+ logger.warning(
+ f"Output file {output_file_id} exists in batch object but cannot be found "
+ "in OpenAI storage. The file may have been deleted. Will resubmit this batch..."
+ )
+ continue
+
+ if request_file_name in self.tracker.unsubmitted_request_files:
+ self.tracker.mark_as_submitted(request_file_name, batch_object, n_requests)
+ else:
+ response_file = request_file_to_response_file(request_file_name, self.working_dir)
+ if not os.path.exists(response_file):
+ raise ValueError(
+ f"While processing {batch_object.id}, we found that its corresponding request_file_name {request_file_name} is "
+ f"not in tracker.unsubmitted_request_files, but its corresponding response_file {response_file} does not exist. "
+ f"This is an invalid state. \n"
+ f"batch_object: {batch_object} \n"
+ f"request_file_name: {request_file_name} \n"
+ f"tracker.unsubmitted_request_files: {self.tracker.unsubmitted_request_files} \n"
+ f"tracker.submitted_batches: {self.tracker.submitted_batches} \n"
+ f"tracker.downloaded_batches: {self.tracker.downloaded_batches} \n"
+ )
if self.tracker.n_submitted_batches > 0:
logger.info(
f"{self.tracker.n_submitted_batches:,} out of {self.tracker.n_total_batches - self.tracker.n_downloaded_batches:,} remaining batches are already submitted."
)
- def track_already_downloaded_batches(self):
+ async def track_already_downloaded_batches(self):
"""
Tracks previously downloaded batches from the downloaded batch objects files.
@@ -765,13 +798,24 @@ def track_already_downloaded_batches(self):
glob.glob(f"{self.working_dir}/batch_objects_downloaded_*.jsonl")
)
for downloaded_batch_object_file in downloaded_batch_object_files:
+ logger.info(
+ f"Processing downloaded batch objects file: {downloaded_batch_object_file} Your API key is ***{self.client.api_key[-4:]}."
+ )
with open(downloaded_batch_object_file, "r") as f:
for line in f:
batch_object = Batch.model_validate(json.loads(line))
request_file = batch_object.metadata["request_file_name"]
response_file = request_file_to_response_file(request_file, self.working_dir)
- assert request_file in self.tracker.unsubmitted_request_files
- assert os.path.exists(response_file)
+ assert (
+ request_file in self.tracker.unsubmitted_request_files
+ ), f"request_file {request_file} not in unsubmitted_request_files: {self.tracker.unsubmitted_request_files}"
+ if not os.path.exists(response_file):
+ logger.warning(
+ f"Downloaded batch object {batch_object.id} has a response_file {response_file} that does not exist. "
+ "Will resubmit this batch..."
+ )
+ continue
+
self.tracker.mark_as_submitted(
request_file, batch_object, batch_object.request_counts.total
)
@@ -800,7 +844,7 @@ async def submit_batches_from_request_files(
- Creates and updates batch submission progress bar
"""
self.tracker.unsubmitted_request_files = request_files
- self.track_already_downloaded_batches()
+ await self.track_already_downloaded_batches()
await self.track_already_submitted_batches()
# exit early
if self.tracker.n_unsubmitted_request_files == 0:
@@ -853,9 +897,8 @@ async def check_batch_status(self, batch_id: str) -> Batch | None:
)
finished_statuses = ["completed", "failed", "expired", "cancelled"]
- in_progress_statuses = ["validating", "finalizing", "cancelling", "in_progress"]
batch_returned = batch.status in finished_statuses
- if batch.status not in in_progress_statuses + finished_statuses:
+ if not self._validate_batch_status(batch.status):
logger.warning(f"Unknown batch status: {batch.status}")
if batch_returned:
@@ -902,7 +945,7 @@ async def poll_and_process_batches(
batches_to_download = await asyncio.gather(*status_tasks)
batches_to_download = filter(None, batches_to_download)
- # update progress bar
+ # update progress bari
self.request_pbar.n = self.tracker.n_finished_or_downloaded_requests
self.request_pbar.refresh()
@@ -937,11 +980,15 @@ async def delete_file(self, file_id: str, semaphore: asyncio.Semaphore):
semaphore (asyncio.Semaphore): Semaphore to limit concurrent operations
"""
async with semaphore:
- delete_response = await self.client.files.delete(file_id)
- if delete_response.deleted:
- logger.debug(f"Deleted file {file_id}")
- else:
- logger.warning(f"Failed to delete file {file_id}")
+ try:
+ delete_response = await self.client.files.delete(file_id)
+ if delete_response.deleted:
+ logger.debug(f"Deleted file {file_id}")
+ else:
+ logger.warning(f"Failed to delete file {file_id}")
+ except NotFoundError:
+ # This is fine, the file may have been deleted already. Deletion should be best-effort.
+ logger.warning(f"Trying to delete file {file_id} but it was not found.")
async def download_batch(self, batch: Batch) -> str | None:
file_content = None
@@ -1027,3 +1074,18 @@ async def download_batch_to_response_file(
self.tracker.mark_as_downloaded(batch)
return response_file
+
+ @staticmethod
+ def _validate_batch_status(status: str) -> bool:
+ # See https://github.com/openai/openai-python/blob/995cce048f9427bba4f7ac1e5fc60abbf1f8f0b7/src/openai/types/batch.py#L40C1-L41C1
+ # for all possible batch statuses
+ return status in [
+ "completed",
+ "failed",
+ "expired",
+ "cancelled",
+ "validating",
+ "finalizing",
+ "cancelling",
+ "in_progress",
+ ]
diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py
index 14fc27ea..a8416906 100644
--- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py
@@ -79,6 +79,10 @@ def __init__(
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
+ max_requests_per_minute: Optional[int] = None,
+ max_tokens_per_minute: Optional[int] = None,
+ require_all_responses: bool = None,
+ max_retries: Optional[int] = None,
):
super().__init__(
model=model,
@@ -86,43 +90,41 @@ def __init__(
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
+ max_requests_per_minute=max_requests_per_minute,
+ max_tokens_per_minute=max_tokens_per_minute,
+ require_all_responses=require_all_responses,
+ max_retries=max_retries,
)
self.url = url
self.api_key = api_key
self.token_encoding = tiktoken.get_encoding(get_token_encoding_name(model))
+ self.header_based_max_requests_per_minute, self.header_based_max_tokens_per_minute = (
+ self.get_header_based_rate_limits()
+ )
- def get_rate_limits(self) -> dict:
+ def get_header_based_rate_limits(self) -> tuple[int, int]:
"""Get rate limits from OpenAI API headers.
Returns:
- dict: Contains 'max_requests_per_minute' and 'max_tokens_per_minute'
+ tuple[int, int]: Contains 'max_requests_per_minute' and 'max_tokens_per_minute'
Note:
- Makes a dummy request to get actual rate limits
- - Falls back to default values if headers are missing
- - Supports both OpenAI and Azure endpoints
"""
+ if not self.api_key:
+ raise ValueError(
+ "Missing OpenAI API Key - Please set OPENAI_API_KEY in your environment vars"
+ )
+
response = requests.post(
self.url,
headers={"Authorization": f"Bearer {self.api_key}"},
json={"model": self.model, "messages": []},
)
-
rpm = int(response.headers.get("x-ratelimit-limit-requests", 0))
tpm = int(response.headers.get("x-ratelimit-limit-tokens", 0))
- if not rpm or not tpm:
- logger.warning("Failed to get rate limits from OpenAI API, using default values")
- rpm = 30_000
- tpm = 150_000_000
-
- logger.info(f"Automatically set max_requests_per_minute to {rpm}")
- logger.info(f"Automatically set max_tokens_per_minute to {tpm}")
-
- return {
- "max_requests_per_minute": rpm,
- "max_tokens_per_minute": tpm,
- }
+ return rpm, tpm
def estimate_output_tokens(self) -> int:
"""Estimate number of tokens in the response.
@@ -270,7 +272,7 @@ async def call_single_request(
self.url,
headers=request_header,
json=request.api_specific_request,
- timeout=60.0,
+ timeout=self.timeout,
) as response_obj:
response = await response_obj.json()
@@ -281,6 +283,8 @@ async def call_single_request(
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
status_tracker.num_api_errors -= 1
+ # because handle_single_request_with_retries will double count otherwise
+ status_tracker.num_other_errors -= 1
raise Exception(f"API error: {error}")
if response_obj.status != 200:
diff --git a/tests/batch/simple_batch.py b/tests/batch/simple_batch.py
index 68fbd38c..251296ae 100644
--- a/tests/batch/simple_batch.py
+++ b/tests/batch/simple_batch.py
@@ -1,4 +1,4 @@
-from bespokelabs.curator import Prompter
+from bespokelabs.curator import LLM
from datasets import Dataset
import logging
import argparse
@@ -13,7 +13,7 @@ def main(args):
dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * args.n_requests})
- prompter = Prompter(
+ prompter = LLM(
prompt_func=lambda row: row["prompt"],
model_name="gpt-4o-mini",
response_format=None,
diff --git a/tests/batch/test_resume.py b/tests/batch/test_resume.py
index 0248da20..9ac0c906 100644
--- a/tests/batch/test_resume.py
+++ b/tests/batch/test_resume.py
@@ -10,6 +10,7 @@
"""
+@pytest.mark.skip(reason="Temporarily disabled, need to add mocking")
@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-resume"))
@pytest.mark.usefixtures("prepare_test_cache")
def test_batch_resume():
diff --git a/tests/batch/test_switch_keys.py b/tests/batch/test_switch_keys.py
index 80eb3984..f1d9fc8b 100644
--- a/tests/batch/test_switch_keys.py
+++ b/tests/batch/test_switch_keys.py
@@ -10,6 +10,7 @@
"""
+@pytest.mark.skip(reason="Temporarily disabled, need to add mocking")
@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-switch-keys"))
@pytest.mark.usefixtures("prepare_test_cache")
def test_batch_switch_keys():
@@ -46,4 +47,4 @@ def test_batch_switch_keys():
print(output2)
# checks
- assert "1 out of 1 batches already downloaded." in output2
+ assert "1 out of 2 batches already downloaded." in output2
diff --git a/tests/cache/different_files/one.py b/tests/cache/different_files/one.py
new file mode 100644
index 00000000..e5667add
--- /dev/null
+++ b/tests/cache/different_files/one.py
@@ -0,0 +1,18 @@
+from bespokelabs.curator import LLM
+from datasets import Dataset
+import logging
+
+logger = logging.getLogger("bespokelabs.curator")
+logger.setLevel(logging.INFO)
+
+
+dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})
+
+prompter = LLM(
+ prompt_func=lambda row: row["prompt"],
+ model_name="gpt-4o-mini",
+ response_format=None,
+)
+
+dataset = prompter(dataset)
+print(dataset.to_pandas())
diff --git a/tests/cache/different_files/two.py b/tests/cache/different_files/two.py
new file mode 100644
index 00000000..e5667add
--- /dev/null
+++ b/tests/cache/different_files/two.py
@@ -0,0 +1,18 @@
+from bespokelabs.curator import LLM
+from datasets import Dataset
+import logging
+
+logger = logging.getLogger("bespokelabs.curator")
+logger.setLevel(logging.INFO)
+
+
+dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})
+
+prompter = LLM(
+ prompt_func=lambda row: row["prompt"],
+ model_name="gpt-4o-mini",
+ response_format=None,
+)
+
+dataset = prompter(dataset)
+print(dataset.to_pandas())
diff --git a/tests/cache/one.py b/tests/cache/one.py
index 090b5b44..10ff74d4 100644
--- a/tests/cache/one.py
+++ b/tests/cache/one.py
@@ -1,4 +1,4 @@
-from bespokelabs.curator import Prompter
+from bespokelabs.curator import LLM
from datasets import Dataset
import logging
import argparse
@@ -10,7 +10,7 @@
def main(delete_cache: bool = False):
dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})
- prompter = Prompter(
+ prompter = LLM(
prompt_func=lambda row: row["prompt"],
model_name="gpt-4o-mini",
response_format=None,
diff --git a/tests/cache/test_different_files.py b/tests/cache/test_different_files.py
index 6b18de07..31fe866b 100644
--- a/tests/cache/test_different_files.py
+++ b/tests/cache/test_different_files.py
@@ -16,17 +16,14 @@ def test_cache_behavior():
# Run one.py twice and check for cache behavior
print("RUNNING ONE.PY")
- output1, _ = run_script(["python", "tests/cache_tests/different_files/one.py"])
- print(output1)
+ output1, _ = run_script(["python", "tests/cache/different_files/one.py"])
assert cache_hit_log not in output1, "First run of one.py should not hit cache"
print("RUNNING ONE.PY AGAIN")
- output2, _ = run_script(["python", "tests/cache_tests/different_files/one.py"])
- print(output2)
+ output2, _ = run_script(["python", "tests/cache/different_files/one.py"])
assert cache_hit_log in output2, "Second run of one.py should hit cache"
# Run two.py and check for cache behavior
print("RUNNING TWO.PY")
- output3, _ = run_script(["python", "tests/cache_tests/different_files/two.py"])
- print(output3)
+ output3, _ = run_script(["python", "tests/cache/different_files/two.py"])
assert cache_hit_log in output3, "First run of two.py should hit cache"
diff --git a/tests/cache/two.py b/tests/cache/two.py
index 090b5b44..10ff74d4 100644
--- a/tests/cache/two.py
+++ b/tests/cache/two.py
@@ -1,4 +1,4 @@
-from bespokelabs.curator import Prompter
+from bespokelabs.curator import LLM
from datasets import Dataset
import logging
import argparse
@@ -10,7 +10,7 @@
def main(delete_cache: bool = False):
dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})
- prompter = Prompter(
+ prompter = LLM(
prompt_func=lambda row: row["prompt"],
model_name="gpt-4o-mini",
response_format=None,
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..012b8dc6
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,5 @@
+import pytest
+
+
+def pytest_configure(config):
+ config.addinivalue_line("markers", "cache_dir(path): mark test to use specific cache directory")
diff --git a/tests/litellm/__init__.py b/tests/litellm/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/tests/litellm/test_models.py b/tests/litellm/test_models.py
deleted file mode 100644
index 05bb9b7c..00000000
--- a/tests/litellm/test_models.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import pytest
-import os
-import logging
-from datasets import Dataset
-from bespokelabs.curator import Prompter
-from tests.helpers import prepare_test_cache
-
-"""
-USAGE:
-pytest -s tests/litellm/test_models.py
-"""
-
-
-@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-models"))
-@pytest.mark.usefixtures("prepare_test_cache")
-def test_litellm_models():
-
- env = os.environ.copy()
- assert "ANTHROPIC_API_KEY" in env, "ANTHROPIC_API_KEY must be set"
- assert "OPENAI_API_KEY" in env, "OPENAI_API_KEY must be set"
- assert "GEMINI_API_KEY" in env, "GEMINI_API_KEY must be set"
- assert "TOGETHER_API_KEY" in env, "TOGETHER_API_KEY must be set"
-
- models_list = [
- "claude-3-5-sonnet-20240620", # https://docs.litellm.ai/docs/providers/anthropic # anthropic has a different hidden param tokens structure.
- "claude-3-5-haiku-20241022",
- "claude-3-haiku-20240307",
- "claude-3-opus-20240229",
- "claude-3-sonnet-20240229",
- "gpt-4o-mini", # https://docs.litellm.ai/docs/providers/openai
- "gpt-4o-2024-08-06",
- "gpt-4-0125-preview",
- "gpt-3.5-turbo-1106",
- "gemini/gemini-1.5-flash", # https://docs.litellm.ai/docs/providers/gemini; https://ai.google.dev/gemini-api/docs/models # 20-30 iter/s
- "gemini/gemini-1.5-pro", # 20-30 iter/s
- "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", # https://docs.together.ai/docs/serverless-models
- "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
- ]
-
- for model in models_list:
- print(f"\n\n========== TESTING {model} ==========\n\n")
- logger = logging.getLogger("bespokelabs.curator")
- logger.setLevel(logging.DEBUG)
-
- dataset = Dataset.from_dict({"prompt": ["just say 'hi'"]})
-
- prompter = Prompter(
- prompt_func=lambda row: row["prompt"],
- model_name=model,
- response_format=None,
- backend="litellm",
- )
-
- dataset = prompter(dataset)
- print(dataset.to_pandas())
diff --git a/tests/simple_online.py b/tests/simple_online.py
new file mode 100644
index 00000000..4d5f90df
--- /dev/null
+++ b/tests/simple_online.py
@@ -0,0 +1,54 @@
+from bespokelabs.curator import LLM
+from datasets import Dataset
+import logging
+import argparse
+
+# python tests/simple_online.py --log-level DEBUG --model claude-3-5-haiku-20241022
+
+
+def main(args):
+ if args.log_level is not None:
+ logger = logging.getLogger("bespokelabs.curator")
+ logger.setLevel(args.log_level)
+
+ dataset = Dataset.from_dict({"prompt": ["write me a poem"] * args.n_requests})
+
+ prompter = LLM(
+ prompt_func=lambda row: row["prompt"],
+ model_name=args.model,
+ max_requests_per_minute=args.max_requests_per_minute,
+ max_tokens_per_minute=args.max_tokens_per_minute,
+ max_retries=args.max_retries,
+ require_all_responses=not args.partial_responses,
+ )
+
+ dataset = prompter(dataset, batch_cancel=args.cancel)
+ print(dataset.to_pandas())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Simple batch test bed")
+ parser.add_argument("--cancel", action="store_true", default=False, help="Cancel the batches")
+ parser.add_argument("--n-requests", type=int, help="Number of requests to process", default=3)
+ parser.add_argument(
+ "--log-level",
+ type=lambda x: getattr(logging, x.upper()),
+ default=None,
+ help="Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)",
+ )
+ parser.add_argument("--model", type=str, help="Model to use", default="gemini/gemini-1.5-flash")
+ parser.add_argument(
+ "--max-requests-per-minute", type=int, help="Max requests per minute", default=None
+ )
+ parser.add_argument(
+ "--max-tokens-per-minute", type=int, help="Max tokens per minute", default=None
+ )
+ parser.add_argument("--max-retries", type=int, help="Max retries", default=None)
+ parser.add_argument(
+ "--partial-responses",
+ action="store_true",
+ default=False,
+ help="Require all responses",
+ )
+ args = parser.parse_args()
+ main(args)
diff --git a/tests/test_caching.py b/tests/test_caching.py
index 73803465..15c3ebd6 100644
--- a/tests/test_caching.py
+++ b/tests/test_caching.py
@@ -1,6 +1,6 @@
from datasets import Dataset
-from bespokelabs.curator import Prompter
+from bespokelabs import curator
def test_same_value_caching(tmp_path):
@@ -13,7 +13,7 @@ def test_same_value_caching(tmp_path):
def prompt_func():
return f"Say '1'. Do not explain."
- prompter = Prompter(
+ prompter = curator.LLM(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
@@ -36,7 +36,7 @@ def test_different_values_caching(tmp_path):
def prompt_func():
return f"Say '{x}'. Do not explain."
- prompter = Prompter(
+ prompter = curator.LLM(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
@@ -52,7 +52,7 @@ def prompt_func():
def test_same_dataset_caching(tmp_path):
"""Test that using the same dataset multiple times uses cache."""
dataset = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}])
- prompter = Prompter(
+ prompter = curator.LLM(
prompt_func=lambda x: x["instruction"],
model_name="gpt-4o-mini",
)
@@ -72,7 +72,7 @@ def test_different_dataset_caching(tmp_path):
"""Test that using different datasets creates different cache entries."""
dataset1 = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}])
dataset2 = Dataset.from_list([{"instruction": "Say '2'. Do not explain."}])
- prompter = Prompter(
+ prompter = curator.LLM(
prompt_func=lambda x: x["instruction"],
model_name="gpt-4o-mini",
)
@@ -97,7 +97,7 @@ def value_generator():
def prompt_func():
return f"Say '{value_generator()}'. Do not explain."
- prompter = Prompter(
+ prompter = curator.LLM(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
@@ -113,3 +113,95 @@ def value_generator():
# Count cache directories, excluding metadata.db
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}"
+
+
+def test_function_hash_dir_change():
+ """Test that identical functions in different directories but same base filename produce the same hash."""
+ import logging
+ import os
+ import sys
+ import tempfile
+ from pathlib import Path
+
+ from bespokelabs.curator.llm.llm import _get_function_hash
+
+ # Set up logging to write to a file in the current directory
+ debug_log = Path("function_debug.log")
+ logging.basicConfig(
+ level=logging.DEBUG, format="%(message)s", filename=str(debug_log), filemode="w"
+ )
+ logger = logging.getLogger(__name__)
+
+ def dump_function_details(func, prefix):
+ """Helper to dump all function details."""
+ print(f"\n{prefix} details:") # Print to stdout as well
+ logger.debug(f"\n{prefix} details:")
+ # Basic attributes
+ details = {
+ "__name__": func.__name__,
+ "__module__": func.__module__,
+ "__qualname__": func.__qualname__,
+ "__code__.co_filename": func.__code__.co_filename,
+ "__code__.co_name": func.__code__.co_name,
+ "__code__.co_firstlineno": func.__code__.co_firstlineno,
+ "__code__.co_consts": func.__code__.co_consts,
+ "__code__.co_names": func.__code__.co_names,
+ "__code__.co_varnames": func.__code__.co_varnames,
+ "__code__.co_code": func.__code__.co_code.hex(),
+ "__code__.co_flags": func.__code__.co_flags,
+ "__code__.co_stacksize": func.__code__.co_stacksize,
+ "__code__.co_freevars": func.__code__.co_freevars,
+ "__code__.co_cellvars": func.__code__.co_cellvars,
+ "__globals__ keys": sorted(func.__globals__.keys()),
+ "__closure__": func.__closure__,
+ "__defaults__": func.__defaults__,
+ "__kwdefaults__": func.__kwdefaults__,
+ }
+
+ for key, value in details.items():
+ msg = f" {key}: {value}"
+ print(msg) # Print to stdout
+ logger.debug(msg) # Log to file
+
+ def create_function(name, tmp_path):
+ # Create a temporary file with a function definition
+ path = tmp_path / f"{name}.py"
+ with open(path, "w") as f:
+ f.write(
+ """
+def test_func():
+ x = 42 # Add a constant
+ y = "Hello" # Add a string constant
+ z = [1, 2, 3] # Add a list constant
+ return f"{y}, {x}! {z}" # Use all constants
+"""
+ )
+
+ # Import the function from the file
+ import importlib.util
+
+ spec = importlib.util.spec_from_file_location(name, path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module.test_func
+
+ # Create two identical functions in different files
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ func1 = create_function("module1", Path(tmp_dir))
+ func2 = create_function("module1", Path(tmp_dir))
+
+ # Dump detailed information about both functions
+ dump_function_details(func1, "Function 1")
+ dump_function_details(func2, "Function 2")
+
+ # Both should produce the same hash
+ hash1 = _get_function_hash(func1)
+ hash2 = _get_function_hash(func2)
+ print(f"\nHash comparison:") # Print to stdout
+ print(f" hash1: {hash1}")
+ print(f" hash2: {hash2}")
+ logger.debug(f"\nHash comparison:")
+ logger.debug(f" hash1: {hash1}")
+ logger.debug(f" hash2: {hash2}")
+
+ assert hash1 == hash2, "Identical functions should produce the same hash"
diff --git a/tests/test_litellm_models.py b/tests/test_litellm_models.py
new file mode 100644
index 00000000..972848c9
--- /dev/null
+++ b/tests/test_litellm_models.py
@@ -0,0 +1,64 @@
+import pytest
+import os
+import logging
+from datasets import Dataset
+from bespokelabs.curator import LLM
+from tests.helpers import prepare_test_cache
+
+"""
+USAGE:
+pytest -s tests/test_litellm_models.py
+"""
+
+
+@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-models"))
+@pytest.mark.usefixtures("prepare_test_cache")
+class TestLiteLLMModels:
+ @pytest.fixture(autouse=True)
+ def check_environment(self):
+ env = os.environ.copy()
+ required_keys = [
+ "ANTHROPIC_API_KEY",
+ "OPENAI_API_KEY",
+ "GEMINI_API_KEY",
+ "TOGETHER_API_KEY",
+ ]
+ for key in required_keys:
+ assert key in env, f"{key} must be set"
+
+ @pytest.mark.parametrize(
+ "model",
+ [
+ pytest.param("claude-3-5-sonnet-20240620", id="claude-3-5-sonnet"),
+ pytest.param("claude-3-5-haiku-20241022", id="claude-3-5-haiku"),
+ pytest.param("claude-3-haiku-20240307", id="claude-3-haiku"),
+ pytest.param("claude-3-opus-20240229", id="claude-3-opus"),
+ pytest.param("claude-3-sonnet-20240229", id="claude-3-sonnet"),
+ pytest.param("gpt-4o-mini", id="gpt-4-mini"),
+ pytest.param("gpt-4o-2024-08-06", id="gpt-4"),
+ pytest.param("gpt-4-0125-preview", id="gpt-4-preview"),
+ pytest.param("gpt-3.5-turbo-1106", id="gpt-3.5"),
+ pytest.param("gemini/gemini-1.5-flash", id="gemini-flash"),
+ pytest.param("gemini/gemini-1.5-pro", id="gemini-pro"),
+ pytest.param("together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", id="llama-8b"),
+ pytest.param(
+ "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", id="llama-70b"
+ ),
+ ],
+ )
+ def test_model(self, model):
+ print(f"\n\n========== TESTING {model} ==========\n\n")
+ logger = logging.getLogger("bespokelabs.curator")
+ logger.setLevel(logging.DEBUG)
+
+ dataset = Dataset.from_dict({"prompt": ["just say 'hi'"]})
+
+ prompter = LLM(
+ prompt_func=lambda row: row["prompt"],
+ model_name=model,
+ response_format=None,
+ backend="litellm",
+ )
+
+ dataset = prompter(dataset)
+ print(dataset.to_pandas())
diff --git a/tests/test_prompt.py b/tests/test_prompt.py
index f1c327cd..84c2d640 100644
--- a/tests/test_prompt.py
+++ b/tests/test_prompt.py
@@ -1,11 +1,12 @@
import os
from typing import Optional
+from unittest.mock import patch, MagicMock
import pytest
from datasets import Dataset
from pydantic import BaseModel
-from bespokelabs.curator import Prompter
+from bespokelabs.curator import LLM
class MockResponseFormat(BaseModel):
@@ -16,7 +17,7 @@ class MockResponseFormat(BaseModel):
@pytest.fixture
-def prompter() -> Prompter:
+def prompter() -> LLM:
"""Create a Prompter instance for testing.
Returns:
@@ -24,12 +25,18 @@ def prompter() -> Prompter:
"""
def prompt_func(row):
- return {
- "user_prompt": f"Context: {row['context']} Answer this question: {row['question']}",
- "system_prompt": "You are a helpful assistant.",
- }
+ return [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "role": "user",
+ "content": f"Context: {row['context']} Answer this question: {row['question']}",
+ },
+ ]
- return Prompter(
+ return LLM(
model_name="gpt-4o-mini",
prompt_func=prompt_func,
response_format=MockResponseFormat,
@@ -37,7 +44,7 @@ def prompt_func(row):
@pytest.mark.test
-def test_completions(prompter: Prompter, tmp_path):
+def test_completions(prompter: LLM, tmp_path):
"""Test that completions processes a dataset correctly.
Args:
@@ -54,17 +61,31 @@ def test_completions(prompter: Prompter, tmp_path):
# Set up temporary cache directory
os.environ["BELLA_CACHE_DIR"] = str(tmp_path)
- result_dataset = prompter(dataset)
- result_dataset = result_dataset.to_huggingface()
+ # Mock OpenAI API response
+ mock_response = {
+ "choices": [{"message": {"content": "1 + 1 equals 2."}, "finish_reason": "stop"}]
+ }
+
+ with patch("openai.resources.chat.completions.Completions.create", return_value=mock_response):
+ # Process dataset and get responses
+ result_dataset = prompter(dataset)
- # Assertions
- assert len(result_dataset) == len(dataset)
- assert "message" in result_dataset.column_names
- assert "confidence" in result_dataset.column_names
+ # Verify the dataset structure
+ assert len(result_dataset) == len(dataset)
+ assert "response" in result_dataset.column_names
+ # Check that each response has the required fields
+ for row in result_dataset:
+ response = row["response"]
+ if isinstance(response, dict):
+ assert "message" in response
+ assert "confidence" in response
+ else:
+ assert hasattr(response, "message")
+ assert hasattr(response, "confidence")
@pytest.mark.test
-def test_single_completion_batch(prompter: Prompter):
+def test_single_completion_batch(prompter: LLM):
"""Test that a single completion works with batch=True.
Args:
@@ -84,24 +105,36 @@ def simple_prompt_func():
},
]
- batch_prompter = Prompter(
+ batch_prompter = LLM(
model_name="gpt-4o-mini",
prompt_func=simple_prompt_func,
response_format=MockResponseFormat,
batch=True,
)
- # Get single completion
- result = batch_prompter()
+ # Mock response data
+ mock_dataset = Dataset.from_list(
+ [{"response": {"message": "This is a test message.", "confidence": 0.9}}]
+ )
+
+ # Mock the run method of OpenAIBatchRequestProcessor
+ with patch(
+ "bespokelabs.curator.request_processor.openai_batch_request_processor.OpenAIBatchRequestProcessor.run",
+ return_value=mock_dataset,
+ ):
+ # Get single completion
+ result = batch_prompter()
- # Assertions
- assert isinstance(result, MockResponseFormat)
- assert hasattr(result, "message")
- assert hasattr(result, "confidence")
+ # Assertions
+ assert isinstance(result, Dataset)
+ assert len(result) == 1
+ assert isinstance(result[0]["response"], dict)
+ assert result[0]["response"]["message"] == "This is a test message."
+ assert result[0]["response"]["confidence"] == 0.9
@pytest.mark.test
-def test_single_completion_no_batch(prompter: Prompter):
+def test_single_completion_no_batch(prompter: LLM):
"""Test that a single completion works without batch parameter.
Args:
@@ -121,16 +154,28 @@ def simple_prompt_func():
},
]
- non_batch_prompter = Prompter(
+ non_batch_prompter = LLM(
model_name="gpt-4o-mini",
prompt_func=simple_prompt_func,
response_format=MockResponseFormat,
)
- # Get single completion
- result = non_batch_prompter()
+ # Mock response data
+ mock_dataset = Dataset.from_list(
+ [{"response": {"message": "This is a test message.", "confidence": 0.9}}]
+ )
- # Assertions
- assert isinstance(result, MockResponseFormat)
- assert hasattr(result, "message")
- assert hasattr(result, "confidence")
+ # Mock the run method of OpenAIOnlineRequestProcessor
+ with patch(
+ "bespokelabs.curator.request_processor.openai_online_request_processor.OpenAIOnlineRequestProcessor.run",
+ return_value=mock_dataset,
+ ):
+ # Get single completion
+ result = non_batch_prompter()
+
+ # Assertions
+ assert isinstance(result, Dataset)
+ assert len(result) == 1
+ assert isinstance(result[0]["response"], dict)
+ assert result[0]["response"]["message"] == "This is a test message."
+ assert result[0]["response"]["confidence"] == 0.9