diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..45bb9615 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,21 @@ +name: Python Linting + +on: [push, pull_request] + +jobs: + PythonLinting: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install dependencies + run: | + pip install poetry + poetry install + - name: Run black + run: | + poetry run black --check . diff --git a/examples/distill.py b/examples/distill.py index 7fc785cb..20cfd53b 100644 --- a/examples/distill.py +++ b/examples/distill.py @@ -1,7 +1,9 @@ -from bespokelabs import curator -from datasets import load_dataset import logging +from datasets import load_dataset + +from bespokelabs import curator + dataset = load_dataset("allenai/WildChat", split="train") dataset = dataset.select(range(3_000)) diff --git a/examples/poem.py b/examples/poem.py index 5697e5e2..ffb8c5a5 100644 --- a/examples/poem.py +++ b/examples/poem.py @@ -2,10 +2,12 @@ We generate 10 diverse topics and then generate 2 poems for each topic.""" -from bespokelabs import curator +from typing import List + from datasets import Dataset from pydantic import BaseModel, Field -from typing import List + +from bespokelabs import curator # We use Pydantic and structured outputs to define the format of the response. @@ -41,9 +43,7 @@ class Poems(BaseModel): model_name="gpt-4o-mini", response_format=Poems, # `row` is the input row, and `poems` is the Poems class which is parsed from the structured output from the LLM. - parse_func=lambda row, poems: [ - {"topic": row["topic"], "poem": p} for p in poems.poems_list - ], + parse_func=lambda row, poems: [{"topic": row["topic"], "poem": p} for p in poems.poems_list], ) # We apply the prompter to the topics dataset. diff --git a/poetry.lock b/poetry.lock index f726bb06..13174123 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -1135,6 +1135,20 @@ mistralai = ["mistralai (>=1.0.3,<2.0.0)"] test-docs = ["anthropic (>=0.36.2,<0.38.0)", "cohere (>=5.1.8,<6.0.0)", "diskcache (>=5.6.3,<6.0.0)", "fastapi (>=0.109.2,<0.116.0)", "groq (>=0.4.2,<0.12.0)", "litellm (>=1.35.31,<2.0.0)", "mistralai (>=1.0.3,<2.0.0)", "pandas (>=2.2.0,<3.0.0)", "pydantic_extra_types (>=2.6.0,<3.0.0)", "redis (>=5.0.1,<6.0.0)", "tabulate (>=0.9.0,<0.10.0)"] vertexai = ["google-cloud-aiplatform (>=1.53.0,<2.0.0)", "jsonref (>=1.1.0,<2.0.0)"] +[[package]] +name = "isort" +version = "5.13.2" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, + {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, +] + +[package.extras] +colors = ["colorama (>=0.4.6)"] + [[package]] name = "jaraco-classes" version = "3.4.0" @@ -3583,4 +3597,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "f6b5a294e6105fa990fee6139aee98bd03335063a2932f71e152f5de2b599074" +content-hash = "3604f19ac9d9dd28454528f2623f2b638bbd985d12810f4d99934d2bd11a3294" diff --git a/pyproject.toml b/pyproject.toml index 6f5d597a..0e622361 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ tiktoken = "^0.8.0" nest-asyncio = "^1.6.0" rich = "^13.7.0" litellm = "^1.52.11" +isort = "^5.13.2" [tool.poetry.group.dev.dependencies] black = "^24.2.0" @@ -47,4 +48,4 @@ build-backend = "poetry.core.masonry.api" curator-viewer = "bespokelabs.curator.viewer.__main__:main" [tool.black] -line-length = 80 +line-length = 100 diff --git a/src/bespokelabs/__init__.py b/src/bespokelabs/__init__.py index e89e45ee..f7b99017 100644 --- a/src/bespokelabs/__init__.py +++ b/src/bespokelabs/__init__.py @@ -3,9 +3,7 @@ logger = logging.getLogger("bespokelabs.curator") handler = logging.StreamHandler() -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.WARNING) diff --git a/src/bespokelabs/curator/__init__.py b/src/bespokelabs/curator/__init__.py index 37ec5dbb..bb0b7aa2 100644 --- a/src/bespokelabs/curator/__init__.py +++ b/src/bespokelabs/curator/__init__.py @@ -1,2 +1,2 @@ -from .prompter.prompter import Prompter from .dataset import Dataset +from .prompter.prompter import Prompter diff --git a/src/bespokelabs/curator/dataset.py b/src/bespokelabs/curator/dataset.py index 56b6c63d..b0abece0 100644 --- a/src/bespokelabs/curator/dataset.py +++ b/src/bespokelabs/curator/dataset.py @@ -1,19 +1,16 @@ +import glob import json import logging import os -import glob +from typing import Any, Dict, Iterable, Iterator, List, TypeVar import pandas as pd - -from pydantic import BaseModel from datasets import Dataset as HFDataset from datasets.arrow_writer import ArrowWriter, SchemaInferenceError -from typing import Any, Dict, Iterable, Iterator, List, TypeVar +from pydantic import BaseModel from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter -from bespokelabs.curator.request_processor.generic_response import ( - GenericResponse, -) +from bespokelabs.curator.request_processor.generic_response import GenericResponse T = TypeVar("T") @@ -33,9 +30,7 @@ def from_iterable(iterable: Iterable[Dict[str, Any] | BaseModel]): return Dataset(iterable=iterable) def from_working_dir(working_dir: str, prompt_formatter: PromptFormatter): - return Dataset( - working_dir=working_dir, prompt_formatter=prompt_formatter - ) + return Dataset(working_dir=working_dir, prompt_formatter=prompt_formatter) def __iter__(self) -> Iterator[Dict[str, Any] | BaseModel]: if self.iterable is not None: @@ -48,13 +43,9 @@ def __iter__(self) -> Iterator[Dict[str, Any] | BaseModel]: for line in open(response_file, "r"): response = GenericResponse.model_validate_json(line) if self.prompt_formatter.response_format: - response.response = self.prompt_formatter.response_format( - **response.response - ) + response.response = self.prompt_formatter.response_format(**response.response) if self.prompt_formatter.parse_func: - response = self.prompt_formatter.parse_func( - response.row, response.response - ) + response = self.prompt_formatter.parse_func(response.row, response.response) else: response = [response.response] @@ -97,10 +88,8 @@ def to_huggingface(self, in_memory: bool = False) -> None: total_responses_count += 1 response = GenericResponse.model_validate_json(line) if self.prompt_formatter.response_format: - response.response = ( - self.prompt_formatter.response_format( - **response.response - ) + response.response = self.prompt_formatter.response_format( + **response.response ) if response is None: @@ -119,9 +108,7 @@ def to_huggingface(self, in_memory: bool = False) -> None: row = row.model_dump() writer.write(row) - logging.info( - f"Read {total_responses_count} responses, {failed_responses_count} failed" - ) + logging.info(f"Read {total_responses_count} responses, {failed_responses_count} failed") logging.info("Finalizing writer") if failed_responses_count == total_responses_count: diff --git a/src/bespokelabs/curator/install_ui.py b/src/bespokelabs/curator/install_ui.py index b526ed60..746e67f7 100644 --- a/src/bespokelabs/curator/install_ui.py +++ b/src/bespokelabs/curator/install_ui.py @@ -4,22 +4,23 @@ It includes progress tracking, status updates, and a polished success message. """ -import sys import subprocess -from typing import Optional, Tuple +import sys from dataclasses import dataclass from enum import Enum +from typing import Optional, Tuple from rich.console import Console -from rich.text import Text from rich.live import Live -from rich.spinner import Spinner from rich.panel import Panel from rich.progress import ProgressBar +from rich.spinner import Spinner +from rich.text import Text class InstallationStage(Enum): """Enum representing different stages of the installation process.""" + PREPARING = ("Preparing your environment...", 0.0) COLLECTING = ("Downloading packages...", 0.2) DOWNLOADING = ("Downloading packages...", 0.4) @@ -35,9 +36,10 @@ def __init__(self, message: str, progress: float): @dataclass class InstallationUI: """Class to manage the installation UI components and styling.""" + package_name: str console: Console = Console() - + def create_progress_bar(self, completed: float = 0) -> Text: """Create a stylish progress bar with the given completion percentage.""" width = 40 @@ -65,25 +67,33 @@ def create_loading_text(self, stage: InstallationStage, progress: float) -> Text ("Your synthetic data journey begins in moments", "dim white"), self.create_progress_bar(progress), ("\n ", ""), - (stage.message, "italic dim white") + (stage.message, "italic dim white"), ) def create_success_text(self) -> Text: """Create the success message with links.""" text = Text() text.append("✨ Curator installed successfully!\n\n", style="bold green") - text.append("Start building production-ready synthetic data pipelines:\n\n", style="dim white") + text.append( + "Start building production-ready synthetic data pipelines:\n\n", style="dim white" + ) text.append(" 📚 ", style="") text.append("docs.bespokelabs.ai", style="dim cyan link https://docs.bespokelabs.ai") text.append("\n 📦 ", style="") - text.append("github.com/bespokelabsai/curator", style="dim cyan link https://github.com/bespokelabsai/curator") + text.append( + "github.com/bespokelabsai/curator", + style="dim cyan link https://github.com/bespokelabsai/curator", + ) text.append("\n 💬 ", style="") - text.append("discord.gg/KqpXvpzVBS", style="dim cyan link https://discord.com/invite/KqpXvpzVBS") + text.append( + "discord.gg/KqpXvpzVBS", style="dim cyan link https://discord.com/invite/KqpXvpzVBS" + ) return text class PackageInstaller: """Class to handle the package installation process.""" + def __init__(self, package_name: str, version: Optional[str] = None): self.package_spec = f"{package_name}=={version}" if version else package_name self.ui = InstallationUI(package_name) @@ -96,13 +106,13 @@ def run_pip_install(self) -> subprocess.Popen: stderr=subprocess.PIPE, text=True, bufsize=1, - universal_newlines=True + universal_newlines=True, ) def parse_pip_output(self, line: str) -> Tuple[InstallationStage, float]: """Parse pip output to determine installation stage and progress.""" line = line.strip().lower() - + if "collecting" in line: return InstallationStage.COLLECTING, InstallationStage.COLLECTING.progress elif "downloading" in line: @@ -118,32 +128,30 @@ def parse_pip_output(self, line: str) -> Tuple[InstallationStage, float]: return InstallationStage.INSTALLING, InstallationStage.INSTALLING.progress elif "successfully installed" in line: return InstallationStage.FINALIZING, InstallationStage.FINALIZING.progress - + return InstallationStage.PREPARING, InstallationStage.PREPARING.progress def install(self) -> None: """Execute the installation with progress tracking and UI updates.""" - spinner = Spinner("dots2", text=self.ui.create_loading_text(InstallationStage.PREPARING, 0), style="green") - - with Live( - spinner, - console=self.ui.console, - refresh_per_second=30 - ) as live: + spinner = Spinner( + "dots2", text=self.ui.create_loading_text(InstallationStage.PREPARING, 0), style="green" + ) + + with Live(spinner, console=self.ui.console, refresh_per_second=30) as live: try: process = self.run_pip_install() - + while True: output_line = process.stdout.readline() - if output_line == '' and process.poll() is not None: + if output_line == "" and process.poll() is not None: break - + stage, progress = self.parse_pip_output(output_line) spinner.text = self.ui.create_loading_text(stage, progress) - + # Show completion spinner.text = self.ui.create_loading_text(InstallationStage.COMPLETE, 1.0) - + if process.poll() == 0: live.update(self.ui.create_success_text()) else: @@ -151,19 +159,19 @@ def install(self) -> None: error_text = Text(error, style="red") live.update(error_text) sys.exit(1) - + except Exception as e: error_text = Text(f"Error: {str(e)}", style="red") live.update(error_text) sys.exit(1) - + self.ui.console.print() def enhanced_install(package_name: str, version: Optional[str] = None) -> None: """ Enhance pip installation with a professional progress UI. - + Args: package_name: Name of the package to install version: Optional specific version to install diff --git a/src/bespokelabs/curator/prompter/prompt_formatter.py b/src/bespokelabs/curator/prompter/prompt_formatter.py index 40b26e2a..5682c978 100644 --- a/src/bespokelabs/curator/prompter/prompt_formatter.py +++ b/src/bespokelabs/curator/prompter/prompt_formatter.py @@ -25,9 +25,7 @@ class PromptFormatter: def __init__( self, model_name: str, - prompt_func: Callable[ - [Union[Dict[str, Any], BaseModel]], Dict[str, str] - ], + prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]], parse_func: Optional[ Callable[ [ @@ -44,9 +42,7 @@ def __init__( self.parse_func = parse_func self.response_format = response_format - def create_generic_request( - self, row: Dict[str, Any] | BaseModel, idx: int - ) -> GenericRequest: + def create_generic_request(self, row: Dict[str, Any] | BaseModel, idx: int) -> GenericRequest: """Format the request object based off Prompter attributes.""" sig = inspect.signature(self.prompt_func) if len(sig.parameters) == 0: @@ -54,9 +50,7 @@ def create_generic_request( elif len(sig.parameters) == 1: prompts = self.prompt_func(row) else: - raise ValueError( - f"Prompting function {self.prompt_func} must have 0 or 1 arguments." - ) + raise ValueError(f"Prompting function {self.prompt_func} must have 0 or 1 arguments.") if isinstance(prompts, str): messages = [{"role": "user", "content": prompts}] @@ -74,8 +68,6 @@ def create_generic_request( original_row=row, original_row_idx=idx, response_format=( - self.response_format.model_json_schema() - if self.response_format - else None + self.response_format.model_json_schema() if self.response_format else None ), ) diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 4f3d7e46..9a3c705e 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -1,20 +1,20 @@ """Curator: Bespoke Labs Synthetic Data Generation Library.""" import inspect +import logging import os from datetime import datetime +from io import BytesIO from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar, Union +import dill from datasets import Dataset from pydantic import BaseModel from xxhash import xxh64 -import logging from bespokelabs.curator.db import MetadataDB from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter -from bespokelabs.curator.request_processor.base_request_processor import ( - BaseRequestProcessor, -) +from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor from bespokelabs.curator.request_processor.openai_batch_request_processor import ( OpenAIBatchRequestProcessor, ) @@ -35,9 +35,7 @@ class Prompter: def __init__( self, model_name: str, - prompt_func: Callable[ - [Union[Dict[str, Any], BaseModel]], Dict[str, str] - ], + prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]], parse_func: Optional[ Callable[ [ @@ -133,9 +131,7 @@ def __init__( else: raise ValueError(f"Unknown backend: {backend}") - def __call__( - self, dataset: Optional[Iterable] = None, working_dir: str = None - ) -> Dataset: + def __call__(self, dataset: Optional[Iterable] = None, working_dir: str = None) -> Dataset: """ Run completions on a dataset. @@ -179,11 +175,7 @@ def _completions( else: curator_cache_dir = working_dir - dataset_hash = ( - dataset._fingerprint - if dataset is not None - else xxh64("").hexdigest() - ) + dataset_hash = dataset._fingerprint if dataset is not None else xxh64("").hexdigest() prompt_func_hash = _get_function_hash(self.prompt_formatter.prompt_func) @@ -211,13 +203,9 @@ def _completions( metadata_db = MetadataDB(metadata_db_path) # Get the source code of the prompt function - prompt_func_source = _get_function_source( - self.prompt_formatter.prompt_func - ) + prompt_func_source = _get_function_source(self.prompt_formatter.prompt_func) if self.prompt_formatter.parse_func is not None: - parse_func_source = _get_function_source( - self.prompt_formatter.parse_func - ) + parse_func_source = _get_function_source(self.prompt_formatter.parse_func) else: parse_func_source = "" @@ -252,7 +240,9 @@ def _get_function_hash(func) -> str: if func is None: return xxh64("").hexdigest() - return xxh64(_get_function_source(func)).hexdigest() + file = BytesIO() + dill.Pickler(file, recurse=True).dump(func) + return xxh64(file.getvalue()).hexdigest() def _get_function_source(func) -> str: diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index dcc344b7..d1f0b4e9 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -16,9 +16,7 @@ from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.event_loop import run_in_event_loop from bespokelabs.curator.request_processor.generic_request import GenericRequest -from bespokelabs.curator.request_processor.generic_response import ( - GenericResponse, -) +from bespokelabs.curator.request_processor.generic_response import GenericResponse logger = logging.getLogger(__name__) @@ -42,9 +40,7 @@ def get_rate_limits(self) -> dict: pass @abstractmethod - def create_api_specific_request( - self, generic_request: GenericRequest - ) -> dict: + def create_api_specific_request(self, generic_request: GenericRequest) -> dict: """ Creates a API-specific request body from a GenericRequest. @@ -115,9 +111,7 @@ def create_request_files( num_jobs = i + 1 if num_jobs > 0: - logger.info( - f"There are {num_jobs} existing requests in {requests_files[0]}" - ) + logger.info(f"There are {num_jobs} existing requests in {requests_files[0]}") logger.info( f"Example request in {requests_files[0]}:\n{json.dumps(first_job, default=str, indent=2)}" ) @@ -129,19 +123,13 @@ def create_request_files( if dataset is None: with open(requests_file, "w") as f: - generic_request = prompt_formatter.create_generic_request( - dict(), 0 - ) - f.write( - json.dumps(generic_request.model_dump(), default=str) + "\n" - ) + generic_request = prompt_formatter.create_generic_request(dict(), 0) + f.write(json.dumps(generic_request.model_dump(), default=str) + "\n") return requests_files if self.batch_size: num_batches = ceil(len(dataset) / self.batch_size) - requests_files = [ - f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches) - ] + requests_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)] async def create_all_request_files(): tasks = [ @@ -157,11 +145,7 @@ async def create_all_request_files(): run_in_event_loop(create_all_request_files()) else: - run_in_event_loop( - self.acreate_request_file( - dataset, prompt_formatter, requests_file - ) - ) + run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, requests_file)) return requests_files @@ -184,12 +168,8 @@ async def acreate_request_file( for idx, dataset_row in enumerate(dataset): dataset_row_idx = idx + start_idx # Get the generic request from the map function - request = prompt_formatter.create_generic_request( - dataset_row, dataset_row_idx - ) - await f.write( - json.dumps(request.model_dump(), default=str) + "\n" - ) + request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx) + await f.write(json.dumps(request.model_dump(), default=str) + "\n") logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.") def create_dataset_files( @@ -248,9 +228,7 @@ def create_dataset_files( with open(responses_file, "r") as f_in: for generic_response_string in f_in: total_responses_count += 1 - response = GenericResponse.model_validate_json( - generic_response_string - ) + response = GenericResponse.model_validate_json(generic_response_string) # response.response_errors is not None IFF response.response_message is None if response.response_errors is not None: @@ -261,10 +239,8 @@ def create_dataset_files( # Response message is a string, which is converted to a dict # The dict is then used to construct the response_format Pydantic model try: - response.response_message = ( - prompt_formatter.response_format( - **response.response_message - ) + response.response_message = prompt_formatter.response_format( + **response.response_message ) except ValidationError as e: schema_str = json.dumps( @@ -287,17 +263,13 @@ def create_dataset_files( response.response_message, ) except Exception as e: - logger.error( - f"Exception raised in your `parse_func`. {error_help}" - ) + logger.error(f"Exception raised in your `parse_func`. {error_help}") os.remove(dataset_file) raise e if not isinstance(dataset_rows, list): dataset_rows = [dataset_rows] else: - dataset_rows = [ - {"response": response.response_message} - ] + dataset_rows = [{"response": response.response_message}] for row in dataset_rows: if isinstance(row, BaseModel): @@ -317,9 +289,7 @@ def create_dataset_files( writer.write(row) - logger.info( - f"Read {total_responses_count} responses, {failed_responses_count} failed" - ) + logger.info(f"Read {total_responses_count} responses, {failed_responses_count} failed") if failed_responses_count == total_responses_count: os.remove(dataset_file) raise ValueError("All requests failed") @@ -345,7 +315,5 @@ def parse_response_message( f"Failed to parse response as JSON: {response_message}, skipping this response." ) response_message = None - response_errors = [ - f"Failed to parse response as JSON: {response_message}" - ] + response_errors = [f"Failed to parse response as JSON: {response_message}"] return response_message, response_errors diff --git a/src/bespokelabs/curator/request_processor/generic_request.py b/src/bespokelabs/curator/request_processor/generic_request.py index a407a12c..1fa23327 100644 --- a/src/bespokelabs/curator/request_processor/generic_request.py +++ b/src/bespokelabs/curator/request_processor/generic_request.py @@ -1,4 +1,5 @@ from typing import Any, Dict, List, Optional, Type + from pydantic import BaseModel """A generic request model for LLM API requests. diff --git a/src/bespokelabs/curator/request_processor/generic_response.py b/src/bespokelabs/curator/request_processor/generic_response.py index ef9b81c0..58471370 100644 --- a/src/bespokelabs/curator/request_processor/generic_response.py +++ b/src/bespokelabs/curator/request_processor/generic_response.py @@ -1,7 +1,9 @@ +import datetime from typing import Any, Dict, List, Optional + from pydantic import BaseModel, Field + from .generic_request import GenericRequest -import datetime """A generic response model for LLM API requests. @@ -23,12 +25,13 @@ class TokenUsage(BaseModel): """Token usage information for an API request. - + Attributes: prompt_tokens: Number of tokens in the prompt completion_tokens: Number of tokens in the completion total_tokens: Total number of tokens used """ + prompt_tokens: int completion_tokens: int total_tokens: int @@ -43,4 +46,4 @@ class GenericResponse(BaseModel): created_at: datetime.datetime finished_at: datetime.datetime token_usage: Optional[TokenUsage] = None - response_cost: Optional[float] = None \ No newline at end of file + response_cost: Optional[float] = None diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 1e0cdc76..e6289ed2 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -1,14 +1,16 @@ import asyncio +import datetime import json import logging import os from dataclasses import dataclass import aiofiles +import litellm from openai import AsyncOpenAI from openai.types import Batch from tqdm import tqdm -import datetime + from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.base_request_processor import ( @@ -18,7 +20,6 @@ parse_response_message, ) from bespokelabs.curator.request_processor.event_loop import run_in_event_loop -import litellm from bespokelabs.curator.request_processor.generic_response import TokenUsage logger = logging.getLogger(__name__) @@ -91,17 +92,13 @@ def get_rate_limits(self) -> dict: else: tpd = model_tpd[self.model] - logger.info( - f"Automatically set max_tokens_per_day to {tpd}, model: {self.model} " - ) + logger.info(f"Automatically set max_tokens_per_day to {tpd}, model: {self.model} ") rate_limits = {"max_tokens_per_day": tpd} return rate_limits - def create_api_specific_request( - self, generic_request: GenericRequest - ) -> dict: + def create_api_specific_request(self, generic_request: GenericRequest) -> dict: """ Creates a API-specific request body from a generic request body. @@ -188,9 +185,7 @@ async def asubmit_batch(self, batch_file: str) -> dict: ) # this let's you upload a file that is larger than 200MB and won't error, so we catch it above - batch_file_upload = await async_client.files.create( - file=file_content, purpose="batch" - ) + batch_file_upload = await async_client.files.create(file=file_content, purpose="batch") logger.info(f"File uploaded: {batch_file_upload}") @@ -202,9 +197,7 @@ async def asubmit_batch(self, batch_file: str) -> dict: "request_file_name": batch_file }, # for downloading the batch to similarly named responses file ) - logger.info( - f"Batch request submitted, received batch object: {batch_object}" - ) + logger.info(f"Batch request submitted, received batch object: {batch_object}") # Explicitly close the client. Otherwise we get something like # future: > await async_client.close() @@ -230,9 +223,7 @@ def run( Returns: Dataset: Completed dataset """ - requests_files = self.create_request_files( - dataset, working_dir, prompt_formatter - ) + requests_files = self.create_request_files(dataset, working_dir, prompt_formatter) batch_objects_file = f"{working_dir}/batch_objects.jsonl" # TODO(Ryan): we should have an easy way to cancel all batches in batch_objects.jsonl if the user realized they made a mistake @@ -244,10 +235,7 @@ def run( # upload requests files and submit batches # asyncio gather preserves order async def submit_all_batches(): - tasks = [ - self.asubmit_batch(requests_files[i]) - for i in range(len(requests_files)) - ] + tasks = [self.asubmit_batch(requests_files[i]) for i in range(len(requests_files))] return await asyncio.gather(*tasks) batch_objects = run_in_event_loop(submit_all_batches()) @@ -285,9 +273,7 @@ async def watch_batches(): run_in_event_loop(watch_batches()) - dataset = self.create_dataset_files( - working_dir, parse_func_hash, prompt_formatter - ) + dataset = self.create_dataset_files(working_dir, parse_func_hash, prompt_formatter) return dataset @@ -333,8 +319,7 @@ def __init__( self.batch_objects = [json.loads(line) for line in f] self.batch_ids = [obj["id"] for obj in self.batch_objects] self.batch_id_to_request_file_name = { - obj["id"]: obj["metadata"]["request_file_name"] - for obj in self.batch_objects + obj["id"]: obj["metadata"]["request_file_name"] for obj in self.batch_objects } self.check_interval = check_interval self.working_dir = working_dir @@ -392,18 +377,14 @@ async def check_batch_status(self, batch_id: str) -> Batch | None: logger.warning(f"Unknown batch status: {batch.status}") if batch_returned: - logger.info( - f"Batch {batch.id} returned with status: {batch.status}" - ) + logger.info(f"Batch {batch.id} returned with status: {batch.status}") self.tracker.n_returned_batches += 1 self.tracker.n_completed_returned_requests += n_completed_requests self.tracker.n_failed_returned_requests += n_failed_requests self.remaining_batch_ids.remove(batch.id) return batch else: - self.tracker.n_completed_in_progress_requests += ( - n_completed_requests - ) + self.tracker.n_completed_in_progress_requests += n_completed_requests self.tracker.n_failed_in_progress_requests += n_failed_requests return None @@ -426,8 +407,7 @@ async def watch(self) -> None: # check batch status also updates the tracker status_tasks = [ - self.check_batch_status(batch_id) - for batch_id in self.remaining_batch_ids + self.check_batch_status(batch_id) for batch_id in self.remaining_batch_ids ] batches_to_download = await asyncio.gather(*status_tasks) batches_to_download = filter(None, batches_to_download) @@ -447,10 +427,7 @@ async def watch(self) -> None: # Failed downloads return None and print any errors that occurred all_response_files.extend(await asyncio.gather(*download_tasks)) - if ( - self.tracker.n_returned_batches - < self.tracker.n_submitted_batches - ): + if self.tracker.n_returned_batches < self.tracker.n_submitted_batches: logger.debug( f"Batches returned: {self.tracker.n_returned_batches}/{self.tracker.n_submitted_batches} " f"Requests completed: {pbar.n}/{self.tracker.n_submitted_requests}" @@ -466,9 +443,7 @@ async def watch(self) -> None: "Please check the logs above and https://platform.openai.com/batches for errors." ) - async def download_batch_to_generic_responses_file( - self, batch: Batch - ) -> str | None: + async def download_batch_to_generic_responses_file(self, batch: Batch) -> str | None: """Download the result of a completed batch to file. Args: @@ -481,9 +456,7 @@ async def download_batch_to_generic_responses_file( file_content = await self.client.files.content(batch.output_file_id) elif batch.status == "failed" and batch.error_file_id: file_content = await self.client.files.content(batch.error_file_id) - logger.warning( - f"Batch {batch.id} failed\n. Errors will be parsed below." - ) + logger.warning(f"Batch {batch.id} failed\n. Errors will be parsed below.") elif batch.status == "failed" and not batch.error_file_id: errors = "\n".join([str(error) for error in batch.errors.data]) logger.error( @@ -514,7 +487,7 @@ async def download_batch_to_generic_responses_file( raw_response = json.loads(raw_response) request_idx = int(raw_response["custom_id"]) generic_request = generic_request_map[request_idx] - + # TODO(Ryan): Add more specific error handling if raw_response["response"]["status_code"] != 200: logger.warning( @@ -531,31 +504,33 @@ async def download_batch_to_generic_responses_file( created_at=request_creation_times[request_idx], finished_at=datetime.datetime.now(), token_usage=None, - response_cost=None + response_cost=None, ) else: response_body = raw_response["response"]["body"] choices = response_body["choices"] usage = response_body.get("usage", {}) - + token_usage = TokenUsage( prompt_tokens=usage.get("prompt_tokens", 0), completion_tokens=usage.get("completion_tokens", 0), - total_tokens=usage.get("total_tokens", 0) + total_tokens=usage.get("total_tokens", 0), ) - + # Calculate cost using litellm cost = litellm.completion_cost( model=generic_request.model, - prompt=str(generic_request.messages), # Convert messages to string for cost calculation - completion=choices[0]["message"]["content"] + prompt=str( + generic_request.messages + ), # Convert messages to string for cost calculation + completion=choices[0]["message"]["content"], ) response_message = choices[0]["message"]["content"] response_message, response_errors = parse_response_message( response_message, self.prompt_formatter.response_format ) - + generic_response = GenericResponse( response_message=response_message, response_errors=response_errors, @@ -565,10 +540,7 @@ async def download_batch_to_generic_responses_file( created_at=request_creation_times[request_idx], finished_at=datetime.datetime.now(), token_usage=token_usage, - response_cost=cost + response_cost=cost, ) - f.write( - json.dumps(generic_response.model_dump(), default=str) - + "\n" - ) + f.write(json.dumps(generic_response.model_dump(), default=str) + "\n") return response_file diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 4cc7f7e0..132ae01a 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -1,16 +1,17 @@ import asyncio +import datetime import json import logging import os import re +import resource import time from dataclasses import dataclass, field from functools import partial from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar -import resource -import datetime import aiohttp +import litellm import requests import tiktoken from tqdm import tqdm @@ -24,7 +25,6 @@ parse_response_message, ) from bespokelabs.curator.request_processor.event_loop import run_in_event_loop -import litellm from bespokelabs.curator.request_processor.generic_response import TokenUsage T = TypeVar("T") @@ -77,9 +77,7 @@ def get_rate_limits(self) -> dict: tpm = int(response.headers.get("x-ratelimit-limit-tokens", 0)) if not rpm or not tpm: - logger.warning( - "Failed to get rate limits from OpenAI API, using default values" - ) + logger.warning("Failed to get rate limits from OpenAI API, using default values") rpm = 30_000 tpm = 150_000_000 @@ -93,9 +91,7 @@ def get_rate_limits(self) -> dict: return rate_limits - def create_api_specific_request( - self, generic_request: GenericRequest - ) -> dict: + def create_api_specific_request(self, generic_request: GenericRequest) -> dict: """ Creates a API-specific request body from a generic request body. @@ -151,21 +147,16 @@ def run( Returns: Dataset: Completed dataset """ - generic_requests_files = self.create_request_files( - dataset, working_dir, prompt_formatter - ) + generic_requests_files = self.create_request_files(dataset, working_dir, prompt_formatter) generic_responses_files = [ - f"{working_dir}/responses_{i}.jsonl" - for i in range(len(generic_requests_files)) + f"{working_dir}/responses_{i}.jsonl" for i in range(len(generic_requests_files)) ] rate_limits = self.get_rate_limits() rpm = rate_limits["max_requests_per_minute"] tpm = rate_limits["max_tokens_per_minute"] - token_encoding_name = get_token_encoding_name( - prompt_formatter.model_name - ) + token_encoding_name = get_token_encoding_name(prompt_formatter.model_name) # NOTE(Ryan): If you wanted to do this on batches, you could run a for loop here about request_files. Although I don't recommend it because you are waiting for straggler requests to finish for each batch. # NOTE(Ryan): And if you wanted to do batches in parallel, you would have to divide rpm and tpm by the number of parallel batches. @@ -186,9 +177,7 @@ def run( ) ) - dataset = self.create_dataset_files( - working_dir, parse_func_hash, prompt_formatter - ) + dataset = self.create_dataset_files(working_dir, parse_func_hash, prompt_formatter) return dataset async def process_generic_requests_from_file( @@ -227,12 +216,8 @@ async def process_generic_requests_from_file( # initialize trackers queue_of_requests_to_retry = asyncio.Queue() - task_id_generator = ( - task_id_generator_function() - ) # generates integer IDs of 0, 1, 2, ... - status_tracker = ( - StatusTracker() - ) # single instance to track a collection of variables + task_id_generator = task_id_generator_function() # generates integer IDs of 0, 1, 2, ... + status_tracker = StatusTracker() # single instance to track a collection of variables next_request = None # variable to hold the next request to call # initialize available capacity counts @@ -248,9 +233,7 @@ async def process_generic_requests_from_file( if os.path.exists(save_filepath): if resume: # save all successfully completed requests to a temporary file, then overwrite the original file with the temporary file - logger.debug( - f"Resuming progress from existing file: {save_filepath}" - ) + logger.debug(f"Resuming progress from existing file: {save_filepath}") logger.debug( f"Removing all failed requests from {save_filepath} so they can be retried" ) @@ -268,16 +251,12 @@ async def process_generic_requests_from_file( ) num_previously_failed_requests += 1 else: - completed_request_ids.add( - response.generic_request.original_row_idx - ) + completed_request_ids.add(response.generic_request.original_row_idx) output_file.write(line) logger.info( f"Found {len(completed_request_ids)} completed requests and {num_previously_failed_requests} previously failed requests" ) - logger.info( - "Failed requests and remaining requests will now be processed." - ) + logger.info("Failed requests and remaining requests will now be processed.") os.replace(temp_filepath, save_filepath) elif resume_no_retry: logger.warning( @@ -287,9 +266,7 @@ async def process_generic_requests_from_file( with open(save_filepath, "r") as input_file, open( temp_filepath, "w" ) as output_file: - for line in tqdm( - input_file, desc="Processing existing requests" - ): + for line in tqdm(input_file, desc="Processing existing requests"): data = json.loads(line) if isinstance(data[1], list): # this means that the request failed and we have a list of errors @@ -319,9 +296,7 @@ async def process_generic_requests_from_file( # Count total number of requests total_requests = sum(1 for _ in open(generic_requests_filepath)) if total_requests == len(completed_request_ids): - logger.debug( - "All requests have already been completed so will just reuse cache." - ) + logger.debug("All requests have already been completed so will just reuse cache.") return # Create progress bar @@ -338,41 +313,28 @@ async def process_generic_requests_from_file( # get next request (if one is not already waiting for capacity) if next_request is None: if not queue_of_requests_to_retry.empty(): - next_request = ( - queue_of_requests_to_retry.get_nowait() - ) - logger.debug( - f"Retrying request {next_request.task_id}: {next_request}" - ) + next_request = queue_of_requests_to_retry.get_nowait() + logger.debug(f"Retrying request {next_request.task_id}: {next_request}") elif file_not_finished: try: # get new generic request - generic_request_json = json.loads( - next(generic_requests) - ) + generic_request_json = json.loads(next(generic_requests)) generic_request = GenericRequest.model_validate( generic_request_json ) request_idx = generic_request.original_row_idx # Skip requests we already have responses for - if ( - resume - and request_idx in completed_request_ids - ): + if resume and request_idx in completed_request_ids: logger.debug( f"Skipping already completed request {request_idx}" ) - status_tracker.num_tasks_already_completed += ( - 1 - ) + status_tracker.num_tasks_already_completed += 1 continue # Create API-specific request - api_specific_request_json = ( - self.create_api_specific_request( - generic_request - ) + api_specific_request_json = self.create_api_specific_request( + generic_request ) next_request = APIRequest( task_id=next(task_id_generator), @@ -457,16 +419,11 @@ async def process_generic_requests_from_file( # if a rate limit error was hit recently, pause to cool down seconds_since_rate_limit_error = ( - time.time() - - status_tracker.time_of_last_rate_limit_error + time.time() - status_tracker.time_of_last_rate_limit_error ) - if ( - seconds_since_rate_limit_error - < seconds_to_pause_after_rate_limit_error - ): + if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: remaining_seconds_to_pause = ( - seconds_to_pause_after_rate_limit_error - - seconds_since_rate_limit_error + seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error ) await asyncio.sleep(remaining_seconds_to_pause) # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago @@ -478,9 +435,7 @@ async def process_generic_requests_from_file( pbar.close() # after finishing, log final status - logger.info( - f"""Parallel processing complete. Results saved to {save_filepath}""" - ) + logger.info(f"""Parallel processing complete. Results saved to {save_filepath}""") logger.info(f"Status tracker: {status_tracker}") @@ -506,9 +461,7 @@ class StatusTracker: num_rate_limit_errors: int = 0 num_api_errors: int = 0 # excluding rate limit errors, counted above num_other_errors: int = 0 - time_of_last_rate_limit_error: int = ( - 0 # used to cool off after hitting rate limits - ) + time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits @dataclass @@ -543,17 +496,13 @@ async def call_api( ) as response: response = await response.json() if "error" in response: - logger.warning( - f"Request {self.task_id} failed with error {response['error']}" - ) + logger.warning(f"Request {self.task_id} failed with error {response['error']}") status_tracker.num_api_errors += 1 error = response if "rate limit" in response["error"].get("message", "").lower(): status_tracker.time_of_last_rate_limit_error = time.time() status_tracker.num_rate_limit_errors += 1 - status_tracker.num_api_errors -= ( - 1 # rate limit errors are counted separately - ) + status_tracker.num_api_errors -= 1 # rate limit errors are counted separately except ( Exception @@ -575,7 +524,7 @@ async def call_api( raw_response=None, generic_request=self.generic_request, created_at=self.created_at, - finished_at=datetime.datetime.now() + finished_at=datetime.datetime.now(), ) append_generic_response(generic_response, save_filepath) status_tracker.num_tasks_in_progress -= 1 @@ -593,13 +542,11 @@ async def call_api( token_usage = TokenUsage( prompt_tokens=usage.get("prompt_tokens", 0), completion_tokens=usage.get("completion_tokens", 0), - total_tokens=usage.get("total_tokens", 0) + total_tokens=usage.get("total_tokens", 0), ) - + # Calculate cost using litellm - cost = litellm.completion_cost( - completion_response=response - ) + cost = litellm.completion_cost(completion_response=response) generic_response = GenericResponse( response_message=response_message, @@ -610,7 +557,7 @@ async def call_api( created_at=self.created_at, finished_at=datetime.datetime.now(), token_usage=token_usage, - response_cost=cost + response_cost=cost, ) append_generic_response(generic_response, save_filepath) status_tracker.num_tasks_in_progress -= 1 @@ -629,9 +576,7 @@ def get_token_encoding_name(model: str) -> str: return "cl100k_base" -def get_rate_limits( - model: str, request_url: str, api_key: str -) -> Tuple[int, int]: +def get_rate_limits(model: str, request_url: str, api_key: str) -> Tuple[int, int]: """ Function to get rate limits for a given annotator. Makes a single request to openAI API and gets the rate limits from the response headers. These rate limits vary per model @@ -654,20 +599,14 @@ def get_rate_limits( json={"model": model, "messages": []}, ) # Extract rate limit information from headers - max_requests = int( - response.headers.get("x-ratelimit-limit-requests", 30_000) - ) - max_tokens = int( - response.headers.get("x-ratelimit-limit-tokens", 150_000_000) - ) + max_requests = int(response.headers.get("x-ratelimit-limit-requests", 30_000)) + max_tokens = int(response.headers.get("x-ratelimit-limit-tokens", 150_000_000)) elif "api.sambanova.ai" in request_url: # Send a dummy request to get rate limit information max_requests = 50 max_tokens = 100_000_000 else: - raise NotImplementedError( - f'Rate limits for API endpoint "{request_url}" not implemented' - ) + raise NotImplementedError(f'Rate limits for API endpoint "{request_url}" not implemented') return max_requests, max_tokens @@ -695,9 +634,7 @@ def api_endpoint_from_url(request_url: str) -> str: return match[1] # for Azure OpenAI deployment urls - match = re.search( - r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url - ) + match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url) if match: return match[1] @@ -707,9 +644,7 @@ def api_endpoint_from_url(request_url: str) -> str: elif "completions" in request_url: return "completions" else: - raise NotImplementedError( - f'API endpoint "{request_url}" not implemented in this script' - ) + raise NotImplementedError(f'API endpoint "{request_url}" not implemented in this script') def append_generic_response(data: GenericResponse, filename: str) -> None: @@ -746,9 +681,7 @@ def num_tokens_consumed_from_request( ) num_tokens += len(str(value)) // 4 if key == "name": # if there's a name, the role is omitted - num_tokens -= ( - 1 # role is always required and always 1 token - ) + num_tokens -= 1 # role is always required and always 1 token num_tokens += 2 # every reply is primed with assistant return num_tokens + completion_tokens # normal completions @@ -781,9 +714,7 @@ def num_tokens_consumed_from_request( ) # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) else: - raise NotImplementedError( - f'API endpoint "{api_endpoint}" not implemented in this script' - ) + raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') def task_id_generator_function(): diff --git a/src/bespokelabs/curator/viewer/__main__.py b/src/bespokelabs/curator/viewer/__main__.py index e57c63bd..062454a2 100644 --- a/src/bespokelabs/curator/viewer/__main__.py +++ b/src/bespokelabs/curator/viewer/__main__.py @@ -1,16 +1,16 @@ +import logging import os +import platform +import shutil +import socket import subprocess import sys -from pathlib import Path -from argparse import ArgumentParser +import tempfile +import time import webbrowser +from argparse import ArgumentParser from contextlib import closing -import socket -import logging -import time -import platform -import tempfile -import shutil +from pathlib import Path def get_viewer_path(): @@ -32,9 +32,7 @@ def ensure_dependencies(): print(f"Error installing dependencies: {e}") sys.exit(1) except FileNotFoundError: - print( - "Error: Node.js is not installed. Please install Node.js to run the viewer." - ) + print("Error: Node.js is not installed. Please install Node.js to run the viewer.") sys.exit(1) @@ -49,9 +47,7 @@ def _setup_logging(level): def check_node_installed(): """Check if Node.js is installed and return version if found""" try: - result = subprocess.run( - ["node", "--version"], capture_output=True, text=True, check=True - ) + result = subprocess.run(["node", "--version"], capture_output=True, text=True, check=True) return result.stdout.strip() except (subprocess.CalledProcessError, FileNotFoundError): return None @@ -105,22 +101,16 @@ def main(): server_file = os.path.join(viewer_path, "server.js") if not os.path.exists(os.path.join(static_dir, ".next")): - print( - "Error: Next.js build artifacts not found. The package may not be built correctly." - ) + print("Error: Next.js build artifacts not found. The package may not be built correctly.") sys.exit(1) try: - subprocess.run( - ["node", server_file], cwd=viewer_path, env=env, check=True - ) + subprocess.run(["node", server_file], cwd=viewer_path, env=env, check=True) except subprocess.CalledProcessError as e: print(f"Error starting Next.js server: {e}") sys.exit(1) except FileNotFoundError: - print( - "Error: Node.js is not installed. Please install Node.js to run the viewer." - ) + print("Error: Node.js is not installed. Please install Node.js to run the viewer.") sys.exit(1) diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 00000000..73803465 --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,115 @@ +from datasets import Dataset + +from bespokelabs.curator import Prompter + + +def test_same_value_caching(tmp_path): + """Test that using the same value multiple times uses cache.""" + values = [] + + # Test with same value multiple times + for _ in range(3): + + def prompt_func(): + return f"Say '1'. Do not explain." + + prompter = Prompter( + prompt_func=prompt_func, + model_name="gpt-4o-mini", + ) + result = prompter(working_dir=str(tmp_path)) + values.append(result.to_pandas().iloc[0]["response"]) + + # Count cache directories, excluding metadata.db + cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"] + assert len(cache_dirs) == 1, f"Expected 1 cache directory but found {len(cache_dirs)}" + assert values == ["1", "1", "1"], "Same value should produce same results" + + +def test_different_values_caching(tmp_path): + """Test that using different values creates different cache entries.""" + values = [] + + # Test with different values + for x in [1, 2, 3]: + + def prompt_func(): + return f"Say '{x}'. Do not explain." + + prompter = Prompter( + prompt_func=prompt_func, + model_name="gpt-4o-mini", + ) + result = prompter(working_dir=str(tmp_path)) + values.append(result.to_pandas().iloc[0]["response"]) + + # Count cache directories, excluding metadata.db + cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"] + assert len(cache_dirs) == 3, f"Expected 3 cache directories but found {len(cache_dirs)}" + assert values == ["1", "2", "3"], "Different values should produce different results" + + +def test_same_dataset_caching(tmp_path): + """Test that using the same dataset multiple times uses cache.""" + dataset = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}]) + prompter = Prompter( + prompt_func=lambda x: x["instruction"], + model_name="gpt-4o-mini", + ) + + result = prompter(dataset=dataset, working_dir=str(tmp_path)) + assert result.to_pandas().iloc[0]["response"] == "1" + + result = prompter(dataset=dataset, working_dir=str(tmp_path)) + assert result.to_pandas().iloc[0]["response"] == "1" + + # Count cache directories, excluding metadata.db + cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"] + assert len(cache_dirs) == 1, f"Expected 1 cache directory but found {len(cache_dirs)}" + + +def test_different_dataset_caching(tmp_path): + """Test that using different datasets creates different cache entries.""" + dataset1 = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}]) + dataset2 = Dataset.from_list([{"instruction": "Say '2'. Do not explain."}]) + prompter = Prompter( + prompt_func=lambda x: x["instruction"], + model_name="gpt-4o-mini", + ) + + result = prompter(dataset=dataset1, working_dir=str(tmp_path)) + assert result.to_pandas().iloc[0]["response"] == "1" + + result = prompter(dataset=dataset2, working_dir=str(tmp_path)) + assert result.to_pandas().iloc[0]["response"] == "2" + + # Count cache directories, excluding metadata.db + cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"] + assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}" + + +def test_nested_call_caching(tmp_path): + """Test that changing a nested upstream function invalidates the cache.""" + + def value_generator(): + return 1 + + def prompt_func(): + return f"Say '{value_generator()}'. Do not explain." + + prompter = Prompter( + prompt_func=prompt_func, + model_name="gpt-4o-mini", + ) + result = prompter(working_dir=str(tmp_path)) + assert result.to_pandas().iloc[0]["response"] == "1" + + def value_generator(): + return 2 + + result = prompter(working_dir=str(tmp_path)) + assert result.to_pandas().iloc[0]["response"] == "2" + + # Count cache directories, excluding metadata.db + cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"] + assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}" diff --git a/tests/test_install_ui.py b/tests/test_install_ui.py index b78c5d6d..2ef39b29 100644 --- a/tests/test_install_ui.py +++ b/tests/test_install_ui.py @@ -1,17 +1,19 @@ """Test script for installation UI.""" -import os -import sys + import argparse import importlib.util +import os +import sys + def import_install_ui(): """Import just the install_ui module without importing the whole package.""" # Get the absolute path to install_ui.py install_ui_path = os.path.join( os.path.dirname(os.path.dirname(__file__)), # Go up one level since we're in tests/ - "src/bespokelabs/curator/install_ui.py" + "src/bespokelabs/curator/install_ui.py", ) - + # Import the module directly from file spec = importlib.util.spec_from_file_location("install_ui", install_ui_path) module = importlib.util.module_from_spec(spec) @@ -19,25 +21,27 @@ def import_install_ui(): spec.loader.exec_module(module) return module + def main(): """Run the test script with command line arguments.""" - parser = argparse.ArgumentParser(description='Test the installation UI.') + parser = argparse.ArgumentParser(description="Test the installation UI.") parser.add_argument( - '--scenario', - choices=['success', 'error'], - default='success', - help='Which scenario to test (success or error)' + "--scenario", + choices=["success", "error"], + default="success", + help="Which scenario to test (success or error)", ) args = parser.parse_args() - + # Import just the install_ui module install_ui = import_install_ui() - + # Run the enhanced install based on scenario - if args.scenario == 'success': + if args.scenario == "success": install_ui.enhanced_install("bespokelabs-curator") else: install_ui.enhanced_install("nonexistent-package-12345") + if __name__ == "__main__": main()