From 5380c9611441a57cbadb2e01f4ccad49cc544751 Mon Sep 17 00:00:00 2001
From: Trung Vu <vu.trung.96@gmail.com>
Date: Thu, 21 Nov 2024 22:46:09 +0000
Subject: [PATCH] add isort and black and github workflows

---
 .github/workflows/lint.yaml                   |  30 ++++
 examples/distill.py                           |   6 +-
 examples/poem.py                              |  10 +-
 poetry.lock                                   |  18 +-
 pyproject.toml                                |   3 +-
 src/bespokelabs/__init__.py                   |   4 +-
 src/bespokelabs/curator/__init__.py           |   2 +-
 src/bespokelabs/curator/dataset.py            |  34 ++--
 src/bespokelabs/curator/install_ui.py         |  62 ++++---
 .../curator/prompter/prompt_formatter.py      |  19 +-
 src/bespokelabs/curator/prompter/prompter.py  |  42 ++---
 .../base_request_processor.py                 |  68 ++------
 .../request_processor/generic_request.py      |   1 +
 .../request_processor/generic_response.py     |   9 +-
 .../openai_batch_request_processor.py         |  93 ++++------
 .../openai_online_request_processor.py        | 162 +++++-------------
 src/bespokelabs/curator/viewer/__main__.py    |  36 ++--
 tests/test_install_ui.py                      |  28 +--
 18 files changed, 256 insertions(+), 371 deletions(-)
 create mode 100644 .github/workflows/lint.yaml

diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
new file mode 100644
index 00000000..f8d46293
--- /dev/null
+++ b/.github/workflows/lint.yaml
@@ -0,0 +1,30 @@
+name: Python Linting
+
+on: [push, pull_request]
+
+jobs:
+  PythonLinting:
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        project: [bespoke]  # Add other projects here
+
+    steps:
+    - uses: actions/checkout@v4
+    - name: Set up Python
+      uses: actions/setup-python@v5
+      with:
+        python-version: '3.11'
+    - name: Install dependencies
+      run: |
+        cd ${{ matrix.project }}
+        pip install poetry
+        poetry install
+    - name: Run black
+      run: |
+        cd ${{ matrix.project }}
+        poetry run black --check .
+    - name: Run isort
+      run: |
+        cd ${{ matrix.project }}
+        poetry run isort --check .
\ No newline at end of file
diff --git a/examples/distill.py b/examples/distill.py
index 7fc785cb..20cfd53b 100644
--- a/examples/distill.py
+++ b/examples/distill.py
@@ -1,7 +1,9 @@
-from bespokelabs import curator
-from datasets import load_dataset
 import logging
 
+from datasets import load_dataset
+
+from bespokelabs import curator
+
 dataset = load_dataset("allenai/WildChat", split="train")
 dataset = dataset.select(range(3_000))
 
diff --git a/examples/poem.py b/examples/poem.py
index 5697e5e2..ffb8c5a5 100644
--- a/examples/poem.py
+++ b/examples/poem.py
@@ -2,10 +2,12 @@
 
 We generate 10 diverse topics and then generate 2 poems for each topic."""
 
-from bespokelabs import curator
+from typing import List
+
 from datasets import Dataset
 from pydantic import BaseModel, Field
-from typing import List
+
+from bespokelabs import curator
 
 
 # We use Pydantic and structured outputs to define the format of the response.
@@ -41,9 +43,7 @@ class Poems(BaseModel):
     model_name="gpt-4o-mini",
     response_format=Poems,
     # `row` is the input row, and `poems` is the Poems class which is parsed from the structured output from the LLM.
-    parse_func=lambda row, poems: [
-        {"topic": row["topic"], "poem": p} for p in poems.poems_list
-    ],
+    parse_func=lambda row, poems: [{"topic": row["topic"], "poem": p} for p in poems.poems_list],
 )
 
 # We apply the prompter to the topics dataset.
diff --git a/poetry.lock b/poetry.lock
index 98f2f8b1..0af8eb19 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
 
 [[package]]
 name = "aiofiles"
@@ -1135,6 +1135,20 @@ mistralai = ["mistralai (>=1.0.3,<2.0.0)"]
 test-docs = ["anthropic (>=0.36.2,<0.38.0)", "cohere (>=5.1.8,<6.0.0)", "diskcache (>=5.6.3,<6.0.0)", "fastapi (>=0.109.2,<0.116.0)", "groq (>=0.4.2,<0.12.0)", "litellm (>=1.35.31,<2.0.0)", "mistralai (>=1.0.3,<2.0.0)", "pandas (>=2.2.0,<3.0.0)", "pydantic_extra_types (>=2.6.0,<3.0.0)", "redis (>=5.0.1,<6.0.0)", "tabulate (>=0.9.0,<0.10.0)"]
 vertexai = ["google-cloud-aiplatform (>=1.53.0,<2.0.0)", "jsonref (>=1.1.0,<2.0.0)"]
 
+[[package]]
+name = "isort"
+version = "5.13.2"
+description = "A Python utility / library to sort Python imports."
+optional = false
+python-versions = ">=3.8.0"
+files = [
+    {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"},
+    {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"},
+]
+
+[package.extras]
+colors = ["colorama (>=0.4.6)"]
+
 [[package]]
 name = "jaraco-classes"
 version = "3.4.0"
@@ -3575,4 +3589,4 @@ type = ["pytest-mypy"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "f6b5a294e6105fa990fee6139aee98bd03335063a2932f71e152f5de2b599074"
+content-hash = "3604f19ac9d9dd28454528f2623f2b638bbd985d12810f4d99934d2bd11a3294"
diff --git a/pyproject.toml b/pyproject.toml
index 6f5d597a..0e622361 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,6 +34,7 @@ tiktoken = "^0.8.0"
 nest-asyncio = "^1.6.0"
 rich = "^13.7.0"
 litellm = "^1.52.11"
+isort = "^5.13.2"
 
 [tool.poetry.group.dev.dependencies]
 black = "^24.2.0"
@@ -47,4 +48,4 @@ build-backend = "poetry.core.masonry.api"
 curator-viewer = "bespokelabs.curator.viewer.__main__:main"
 
 [tool.black]
-line-length = 80
+line-length = 100
diff --git a/src/bespokelabs/__init__.py b/src/bespokelabs/__init__.py
index e89e45ee..f7b99017 100644
--- a/src/bespokelabs/__init__.py
+++ b/src/bespokelabs/__init__.py
@@ -3,9 +3,7 @@
 logger = logging.getLogger("bespokelabs.curator")
 
 handler = logging.StreamHandler()
-formatter = logging.Formatter(
-    "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
-)
+formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
 handler.setFormatter(formatter)
 logger.addHandler(handler)
 logger.setLevel(logging.WARNING)
diff --git a/src/bespokelabs/curator/__init__.py b/src/bespokelabs/curator/__init__.py
index 37ec5dbb..bb0b7aa2 100644
--- a/src/bespokelabs/curator/__init__.py
+++ b/src/bespokelabs/curator/__init__.py
@@ -1,2 +1,2 @@
-from .prompter.prompter import Prompter
 from .dataset import Dataset
+from .prompter.prompter import Prompter
diff --git a/src/bespokelabs/curator/dataset.py b/src/bespokelabs/curator/dataset.py
index 56b6c63d..6787de20 100644
--- a/src/bespokelabs/curator/dataset.py
+++ b/src/bespokelabs/curator/dataset.py
@@ -1,19 +1,17 @@
+import glob
 import json
 import logging
 import os
-import glob
+from typing import Any, Dict, Iterable, Iterator, List, TypeVar
 
 import pandas as pd
-
-from pydantic import BaseModel
 from datasets import Dataset as HFDataset
 from datasets.arrow_writer import ArrowWriter, SchemaInferenceError
-from typing import Any, Dict, Iterable, Iterator, List, TypeVar
+from pydantic import BaseModel
 
 from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
-from bespokelabs.curator.request_processor.generic_response import (
-    GenericResponse,
-)
+from bespokelabs.curator.request_processor.generic_response import \
+    GenericResponse
 
 T = TypeVar("T")
 
@@ -33,9 +31,7 @@ def from_iterable(iterable: Iterable[Dict[str, Any] | BaseModel]):
         return Dataset(iterable=iterable)
 
     def from_working_dir(working_dir: str, prompt_formatter: PromptFormatter):
-        return Dataset(
-            working_dir=working_dir, prompt_formatter=prompt_formatter
-        )
+        return Dataset(working_dir=working_dir, prompt_formatter=prompt_formatter)
 
     def __iter__(self) -> Iterator[Dict[str, Any] | BaseModel]:
         if self.iterable is not None:
@@ -48,13 +44,9 @@ def __iter__(self) -> Iterator[Dict[str, Any] | BaseModel]:
             for line in open(response_file, "r"):
                 response = GenericResponse.model_validate_json(line)
                 if self.prompt_formatter.response_format:
-                    response.response = self.prompt_formatter.response_format(
-                        **response.response
-                    )
+                    response.response = self.prompt_formatter.response_format(**response.response)
                 if self.prompt_formatter.parse_func:
-                    response = self.prompt_formatter.parse_func(
-                        response.row, response.response
-                    )
+                    response = self.prompt_formatter.parse_func(response.row, response.response)
                 else:
                     response = [response.response]
 
@@ -97,10 +89,8 @@ def to_huggingface(self, in_memory: bool = False) -> None:
                         total_responses_count += 1
                         response = GenericResponse.model_validate_json(line)
                         if self.prompt_formatter.response_format:
-                            response.response = (
-                                self.prompt_formatter.response_format(
-                                    **response.response
-                                )
+                            response.response = self.prompt_formatter.response_format(
+                                **response.response
                             )
 
                         if response is None:
@@ -119,9 +109,7 @@ def to_huggingface(self, in_memory: bool = False) -> None:
                                 row = row.model_dump()
                             writer.write(row)
 
-            logging.info(
-                f"Read {total_responses_count} responses, {failed_responses_count} failed"
-            )
+            logging.info(f"Read {total_responses_count} responses, {failed_responses_count} failed")
             logging.info("Finalizing writer")
 
             if failed_responses_count == total_responses_count:
diff --git a/src/bespokelabs/curator/install_ui.py b/src/bespokelabs/curator/install_ui.py
index b526ed60..746e67f7 100644
--- a/src/bespokelabs/curator/install_ui.py
+++ b/src/bespokelabs/curator/install_ui.py
@@ -4,22 +4,23 @@
 It includes progress tracking, status updates, and a polished success message.
 """
 
-import sys
 import subprocess
-from typing import Optional, Tuple
+import sys
 from dataclasses import dataclass
 from enum import Enum
+from typing import Optional, Tuple
 
 from rich.console import Console
-from rich.text import Text
 from rich.live import Live
-from rich.spinner import Spinner
 from rich.panel import Panel
 from rich.progress import ProgressBar
+from rich.spinner import Spinner
+from rich.text import Text
 
 
 class InstallationStage(Enum):
     """Enum representing different stages of the installation process."""
+
     PREPARING = ("Preparing your environment...", 0.0)
     COLLECTING = ("Downloading packages...", 0.2)
     DOWNLOADING = ("Downloading packages...", 0.4)
@@ -35,9 +36,10 @@ def __init__(self, message: str, progress: float):
 @dataclass
 class InstallationUI:
     """Class to manage the installation UI components and styling."""
+
     package_name: str
     console: Console = Console()
-    
+
     def create_progress_bar(self, completed: float = 0) -> Text:
         """Create a stylish progress bar with the given completion percentage."""
         width = 40
@@ -65,25 +67,33 @@ def create_loading_text(self, stage: InstallationStage, progress: float) -> Text
             ("Your synthetic data journey begins in moments", "dim white"),
             self.create_progress_bar(progress),
             ("\n ", ""),
-            (stage.message, "italic dim white")
+            (stage.message, "italic dim white"),
         )
 
     def create_success_text(self) -> Text:
         """Create the success message with links."""
         text = Text()
         text.append("✨ Curator installed successfully!\n\n", style="bold green")
-        text.append("Start building production-ready synthetic data pipelines:\n\n", style="dim white")
+        text.append(
+            "Start building production-ready synthetic data pipelines:\n\n", style="dim white"
+        )
         text.append("   📚 ", style="")
         text.append("docs.bespokelabs.ai", style="dim cyan link https://docs.bespokelabs.ai")
         text.append("\n   📦 ", style="")
-        text.append("github.com/bespokelabsai/curator", style="dim cyan link https://github.com/bespokelabsai/curator")
+        text.append(
+            "github.com/bespokelabsai/curator",
+            style="dim cyan link https://github.com/bespokelabsai/curator",
+        )
         text.append("\n   💬 ", style="")
-        text.append("discord.gg/KqpXvpzVBS", style="dim cyan link https://discord.com/invite/KqpXvpzVBS")
+        text.append(
+            "discord.gg/KqpXvpzVBS", style="dim cyan link https://discord.com/invite/KqpXvpzVBS"
+        )
         return text
 
 
 class PackageInstaller:
     """Class to handle the package installation process."""
+
     def __init__(self, package_name: str, version: Optional[str] = None):
         self.package_spec = f"{package_name}=={version}" if version else package_name
         self.ui = InstallationUI(package_name)
@@ -96,13 +106,13 @@ def run_pip_install(self) -> subprocess.Popen:
             stderr=subprocess.PIPE,
             text=True,
             bufsize=1,
-            universal_newlines=True
+            universal_newlines=True,
         )
 
     def parse_pip_output(self, line: str) -> Tuple[InstallationStage, float]:
         """Parse pip output to determine installation stage and progress."""
         line = line.strip().lower()
-        
+
         if "collecting" in line:
             return InstallationStage.COLLECTING, InstallationStage.COLLECTING.progress
         elif "downloading" in line:
@@ -118,32 +128,30 @@ def parse_pip_output(self, line: str) -> Tuple[InstallationStage, float]:
             return InstallationStage.INSTALLING, InstallationStage.INSTALLING.progress
         elif "successfully installed" in line:
             return InstallationStage.FINALIZING, InstallationStage.FINALIZING.progress
-        
+
         return InstallationStage.PREPARING, InstallationStage.PREPARING.progress
 
     def install(self) -> None:
         """Execute the installation with progress tracking and UI updates."""
-        spinner = Spinner("dots2", text=self.ui.create_loading_text(InstallationStage.PREPARING, 0), style="green")
-        
-        with Live(
-            spinner, 
-            console=self.ui.console, 
-            refresh_per_second=30
-        ) as live:
+        spinner = Spinner(
+            "dots2", text=self.ui.create_loading_text(InstallationStage.PREPARING, 0), style="green"
+        )
+
+        with Live(spinner, console=self.ui.console, refresh_per_second=30) as live:
             try:
                 process = self.run_pip_install()
-                
+
                 while True:
                     output_line = process.stdout.readline()
-                    if output_line == '' and process.poll() is not None:
+                    if output_line == "" and process.poll() is not None:
                         break
-                    
+
                     stage, progress = self.parse_pip_output(output_line)
                     spinner.text = self.ui.create_loading_text(stage, progress)
-                
+
                 # Show completion
                 spinner.text = self.ui.create_loading_text(InstallationStage.COMPLETE, 1.0)
-                
+
                 if process.poll() == 0:
                     live.update(self.ui.create_success_text())
                 else:
@@ -151,19 +159,19 @@ def install(self) -> None:
                     error_text = Text(error, style="red")
                     live.update(error_text)
                     sys.exit(1)
-                    
+
             except Exception as e:
                 error_text = Text(f"Error: {str(e)}", style="red")
                 live.update(error_text)
                 sys.exit(1)
-        
+
         self.ui.console.print()
 
 
 def enhanced_install(package_name: str, version: Optional[str] = None) -> None:
     """
     Enhance pip installation with a professional progress UI.
-    
+
     Args:
         package_name: Name of the package to install
         version: Optional specific version to install
diff --git a/src/bespokelabs/curator/prompter/prompt_formatter.py b/src/bespokelabs/curator/prompter/prompt_formatter.py
index 40b26e2a..937fbc61 100644
--- a/src/bespokelabs/curator/prompter/prompt_formatter.py
+++ b/src/bespokelabs/curator/prompter/prompt_formatter.py
@@ -3,7 +3,8 @@
 
 from pydantic import BaseModel
 
-from bespokelabs.curator.request_processor.generic_request import GenericRequest
+from bespokelabs.curator.request_processor.generic_request import \
+    GenericRequest
 
 T = TypeVar("T")
 
@@ -25,9 +26,7 @@ class PromptFormatter:
     def __init__(
         self,
         model_name: str,
-        prompt_func: Callable[
-            [Union[Dict[str, Any], BaseModel]], Dict[str, str]
-        ],
+        prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]],
         parse_func: Optional[
             Callable[
                 [
@@ -44,9 +43,7 @@ def __init__(
         self.parse_func = parse_func
         self.response_format = response_format
 
-    def create_generic_request(
-        self, row: Dict[str, Any] | BaseModel, idx: int
-    ) -> GenericRequest:
+    def create_generic_request(self, row: Dict[str, Any] | BaseModel, idx: int) -> GenericRequest:
         """Format the request object based off Prompter attributes."""
         sig = inspect.signature(self.prompt_func)
         if len(sig.parameters) == 0:
@@ -54,9 +51,7 @@ def create_generic_request(
         elif len(sig.parameters) == 1:
             prompts = self.prompt_func(row)
         else:
-            raise ValueError(
-                f"Prompting function {self.prompt_func} must have 0 or 1 arguments."
-            )
+            raise ValueError(f"Prompting function {self.prompt_func} must have 0 or 1 arguments.")
 
         if isinstance(prompts, str):
             messages = [{"role": "user", "content": prompts}]
@@ -74,8 +69,6 @@ def create_generic_request(
             original_row=row,
             original_row_idx=idx,
             response_format=(
-                self.response_format.model_json_schema()
-                if self.response_format
-                else None
+                self.response_format.model_json_schema() if self.response_format else None
             ),
         )
diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py
index a8afa2da..6782d79a 100644
--- a/src/bespokelabs/curator/prompter/prompter.py
+++ b/src/bespokelabs/curator/prompter/prompter.py
@@ -1,26 +1,24 @@
 """Curator: Bespoke Labs Synthetic Data Generation Library."""
 
 import inspect
+import logging
 import os
 from datetime import datetime
-from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar, Union
+from typing import (Any, Callable, Dict, Iterable, Optional, Type, TypeVar,
+                    Union)
 
 from datasets import Dataset
 from pydantic import BaseModel
 from xxhash import xxh64
-import logging
 
 from bespokelabs.curator.db import MetadataDB
 from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
-from bespokelabs.curator.request_processor.base_request_processor import (
-    BaseRequestProcessor,
-)
-from bespokelabs.curator.request_processor.openai_batch_request_processor import (
-    OpenAIBatchRequestProcessor,
-)
-from bespokelabs.curator.request_processor.openai_online_request_processor import (
-    OpenAIOnlineRequestProcessor,
-)
+from bespokelabs.curator.request_processor.base_request_processor import \
+    BaseRequestProcessor
+from bespokelabs.curator.request_processor.openai_batch_request_processor import \
+    OpenAIBatchRequestProcessor
+from bespokelabs.curator.request_processor.openai_online_request_processor import \
+    OpenAIOnlineRequestProcessor
 
 _CURATOR_DEFAULT_CACHE_DIR = "~/.cache/curator"
 T = TypeVar("T")
@@ -34,9 +32,7 @@ class Prompter:
     def __init__(
         self,
         model_name: str,
-        prompt_func: Callable[
-            [Union[Dict[str, Any], BaseModel]], Dict[str, str]
-        ],
+        prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]],
         parse_func: Optional[
             Callable[
                 [
@@ -115,9 +111,7 @@ def __init__(
                 frequency_penalty=frequency_penalty,
             )
 
-    def __call__(
-        self, dataset: Optional[Iterable] = None, working_dir: str = None
-    ) -> Dataset:
+    def __call__(self, dataset: Optional[Iterable] = None, working_dir: str = None) -> Dataset:
         """
         Run completions on a dataset.
 
@@ -161,11 +155,7 @@ def _completions(
         else:
             curator_cache_dir = working_dir
 
-        dataset_hash = (
-            dataset._fingerprint
-            if dataset is not None
-            else xxh64("").hexdigest()
-        )
+        dataset_hash = dataset._fingerprint if dataset is not None else xxh64("").hexdigest()
 
         prompt_func_hash = _get_function_hash(self.prompt_formatter.prompt_func)
 
@@ -192,13 +182,9 @@ def _completions(
         metadata_db = MetadataDB(metadata_db_path)
 
         # Get the source code of the prompt function
-        prompt_func_source = _get_function_source(
-            self.prompt_formatter.prompt_func
-        )
+        prompt_func_source = _get_function_source(self.prompt_formatter.prompt_func)
         if self.prompt_formatter.parse_func is not None:
-            parse_func_source = _get_function_source(
-                self.prompt_formatter.parse_func
-            )
+            parse_func_source = _get_function_source(self.prompt_formatter.parse_func)
         else:
             parse_func_source = ""
 
diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py
index dcc344b7..f1b37cbf 100644
--- a/src/bespokelabs/curator/request_processor/base_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/base_request_processor.py
@@ -15,10 +15,10 @@
 
 from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
 from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
-from bespokelabs.curator.request_processor.generic_request import GenericRequest
-from bespokelabs.curator.request_processor.generic_response import (
-    GenericResponse,
-)
+from bespokelabs.curator.request_processor.generic_request import \
+    GenericRequest
+from bespokelabs.curator.request_processor.generic_response import \
+    GenericResponse
 
 logger = logging.getLogger(__name__)
 
@@ -42,9 +42,7 @@ def get_rate_limits(self) -> dict:
         pass
 
     @abstractmethod
-    def create_api_specific_request(
-        self, generic_request: GenericRequest
-    ) -> dict:
+    def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
         """
         Creates a API-specific request body from a GenericRequest.
 
@@ -115,9 +113,7 @@ def create_request_files(
                     num_jobs = i + 1
 
                 if num_jobs > 0:
-                    logger.info(
-                        f"There are {num_jobs} existing requests in {requests_files[0]}"
-                    )
+                    logger.info(f"There are {num_jobs} existing requests in {requests_files[0]}")
                     logger.info(
                         f"Example request in {requests_files[0]}:\n{json.dumps(first_job, default=str, indent=2)}"
                     )
@@ -129,19 +125,13 @@ def create_request_files(
 
         if dataset is None:
             with open(requests_file, "w") as f:
-                generic_request = prompt_formatter.create_generic_request(
-                    dict(), 0
-                )
-                f.write(
-                    json.dumps(generic_request.model_dump(), default=str) + "\n"
-                )
+                generic_request = prompt_formatter.create_generic_request(dict(), 0)
+                f.write(json.dumps(generic_request.model_dump(), default=str) + "\n")
             return requests_files
 
         if self.batch_size:
             num_batches = ceil(len(dataset) / self.batch_size)
-            requests_files = [
-                f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)
-            ]
+            requests_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)]
 
             async def create_all_request_files():
                 tasks = [
@@ -157,11 +147,7 @@ async def create_all_request_files():
 
             run_in_event_loop(create_all_request_files())
         else:
-            run_in_event_loop(
-                self.acreate_request_file(
-                    dataset, prompt_formatter, requests_file
-                )
-            )
+            run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, requests_file))
 
         return requests_files
 
@@ -184,12 +170,8 @@ async def acreate_request_file(
             for idx, dataset_row in enumerate(dataset):
                 dataset_row_idx = idx + start_idx
                 # Get the generic request from the map function
-                request = prompt_formatter.create_generic_request(
-                    dataset_row, dataset_row_idx
-                )
-                await f.write(
-                    json.dumps(request.model_dump(), default=str) + "\n"
-                )
+                request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx)
+                await f.write(json.dumps(request.model_dump(), default=str) + "\n")
         logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.")
 
     def create_dataset_files(
@@ -248,9 +230,7 @@ def create_dataset_files(
                 with open(responses_file, "r") as f_in:
                     for generic_response_string in f_in:
                         total_responses_count += 1
-                        response = GenericResponse.model_validate_json(
-                            generic_response_string
-                        )
+                        response = GenericResponse.model_validate_json(generic_response_string)
 
                         # response.response_errors is not None IFF response.response_message is None
                         if response.response_errors is not None:
@@ -261,10 +241,8 @@ def create_dataset_files(
                             # Response message is a string, which is converted to a dict
                             # The dict is then used to construct the response_format Pydantic model
                             try:
-                                response.response_message = (
-                                    prompt_formatter.response_format(
-                                        **response.response_message
-                                    )
+                                response.response_message = prompt_formatter.response_format(
+                                    **response.response_message
                                 )
                             except ValidationError as e:
                                 schema_str = json.dumps(
@@ -287,17 +265,13 @@ def create_dataset_files(
                                     response.response_message,
                                 )
                             except Exception as e:
-                                logger.error(
-                                    f"Exception raised in your `parse_func`. {error_help}"
-                                )
+                                logger.error(f"Exception raised in your `parse_func`. {error_help}")
                                 os.remove(dataset_file)
                                 raise e
                             if not isinstance(dataset_rows, list):
                                 dataset_rows = [dataset_rows]
                         else:
-                            dataset_rows = [
-                                {"response": response.response_message}
-                            ]
+                            dataset_rows = [{"response": response.response_message}]
 
                         for row in dataset_rows:
                             if isinstance(row, BaseModel):
@@ -317,9 +291,7 @@ def create_dataset_files(
 
                             writer.write(row)
 
-            logger.info(
-                f"Read {total_responses_count} responses, {failed_responses_count} failed"
-            )
+            logger.info(f"Read {total_responses_count} responses, {failed_responses_count} failed")
             if failed_responses_count == total_responses_count:
                 os.remove(dataset_file)
                 raise ValueError("All requests failed")
@@ -345,7 +317,5 @@ def parse_response_message(
                 f"Failed to parse response as JSON: {response_message}, skipping this response."
             )
             response_message = None
-            response_errors = [
-                f"Failed to parse response as JSON: {response_message}"
-            ]
+            response_errors = [f"Failed to parse response as JSON: {response_message}"]
     return response_message, response_errors
diff --git a/src/bespokelabs/curator/request_processor/generic_request.py b/src/bespokelabs/curator/request_processor/generic_request.py
index a407a12c..1fa23327 100644
--- a/src/bespokelabs/curator/request_processor/generic_request.py
+++ b/src/bespokelabs/curator/request_processor/generic_request.py
@@ -1,4 +1,5 @@
 from typing import Any, Dict, List, Optional, Type
+
 from pydantic import BaseModel
 
 """A generic request model for LLM API requests.
diff --git a/src/bespokelabs/curator/request_processor/generic_response.py b/src/bespokelabs/curator/request_processor/generic_response.py
index ef9b81c0..58471370 100644
--- a/src/bespokelabs/curator/request_processor/generic_response.py
+++ b/src/bespokelabs/curator/request_processor/generic_response.py
@@ -1,7 +1,9 @@
+import datetime
 from typing import Any, Dict, List, Optional
+
 from pydantic import BaseModel, Field
+
 from .generic_request import GenericRequest
-import datetime
 
 """A generic response model for LLM API requests.
 
@@ -23,12 +25,13 @@
 
 class TokenUsage(BaseModel):
     """Token usage information for an API request.
-    
+
     Attributes:
         prompt_tokens: Number of tokens in the prompt
         completion_tokens: Number of tokens in the completion
         total_tokens: Total number of tokens used
     """
+
     prompt_tokens: int
     completion_tokens: int
     total_tokens: int
@@ -43,4 +46,4 @@ class GenericResponse(BaseModel):
     created_at: datetime.datetime
     finished_at: datetime.datetime
     token_usage: Optional[TokenUsage] = None
-    response_cost: Optional[float] = None
\ No newline at end of file
+    response_cost: Optional[float] = None
diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py
index 1e0cdc76..81ec139e 100644
--- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py
@@ -1,24 +1,22 @@
 import asyncio
+import datetime
 import json
 import logging
 import os
 from dataclasses import dataclass
 
 import aiofiles
+import litellm
 from openai import AsyncOpenAI
 from openai.types import Batch
 from tqdm import tqdm
-import datetime
+
 from bespokelabs.curator.dataset import Dataset
 from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
 from bespokelabs.curator.request_processor.base_request_processor import (
-    BaseRequestProcessor,
-    GenericRequest,
-    GenericResponse,
-    parse_response_message,
-)
+    BaseRequestProcessor, GenericRequest, GenericResponse,
+    parse_response_message)
 from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
-import litellm
 from bespokelabs.curator.request_processor.generic_response import TokenUsage
 
 logger = logging.getLogger(__name__)
@@ -91,17 +89,13 @@ def get_rate_limits(self) -> dict:
         else:
             tpd = model_tpd[self.model]
 
-        logger.info(
-            f"Automatically set max_tokens_per_day to {tpd}, model: {self.model} "
-        )
+        logger.info(f"Automatically set max_tokens_per_day to {tpd}, model: {self.model} ")
 
         rate_limits = {"max_tokens_per_day": tpd}
 
         return rate_limits
 
-    def create_api_specific_request(
-        self, generic_request: GenericRequest
-    ) -> dict:
+    def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
         """
         Creates a API-specific request body from a generic request body.
 
@@ -188,9 +182,7 @@ async def asubmit_batch(self, batch_file: str) -> dict:
             )
 
         # this let's you upload a file that is larger than 200MB and won't error, so we catch it above
-        batch_file_upload = await async_client.files.create(
-            file=file_content, purpose="batch"
-        )
+        batch_file_upload = await async_client.files.create(file=file_content, purpose="batch")
 
         logger.info(f"File uploaded: {batch_file_upload}")
 
@@ -202,9 +194,7 @@ async def asubmit_batch(self, batch_file: str) -> dict:
                 "request_file_name": batch_file
             },  # for downloading the batch to similarly named responses file
         )
-        logger.info(
-            f"Batch request submitted, received batch object: {batch_object}"
-        )
+        logger.info(f"Batch request submitted, received batch object: {batch_object}")
         # Explicitly close the client. Otherwise we get something like
         # future: <Task finished name='Task-46' coro=<AsyncClient.aclose() done ... >>
         await async_client.close()
@@ -230,9 +220,7 @@ def run(
         Returns:
             Dataset: Completed dataset
         """
-        requests_files = self.create_request_files(
-            dataset, working_dir, prompt_formatter
-        )
+        requests_files = self.create_request_files(dataset, working_dir, prompt_formatter)
         batch_objects_file = f"{working_dir}/batch_objects.jsonl"
 
         # TODO(Ryan): we should have an easy way to cancel all batches in batch_objects.jsonl if the user realized they made a mistake
@@ -244,10 +232,7 @@ def run(
             # upload requests files and submit batches
             # asyncio gather preserves order
             async def submit_all_batches():
-                tasks = [
-                    self.asubmit_batch(requests_files[i])
-                    for i in range(len(requests_files))
-                ]
+                tasks = [self.asubmit_batch(requests_files[i]) for i in range(len(requests_files))]
                 return await asyncio.gather(*tasks)
 
             batch_objects = run_in_event_loop(submit_all_batches())
@@ -285,9 +270,7 @@ async def watch_batches():
 
         run_in_event_loop(watch_batches())
 
-        dataset = self.create_dataset_files(
-            working_dir, parse_func_hash, prompt_formatter
-        )
+        dataset = self.create_dataset_files(working_dir, parse_func_hash, prompt_formatter)
 
         return dataset
 
@@ -333,8 +316,7 @@ def __init__(
             self.batch_objects = [json.loads(line) for line in f]
         self.batch_ids = [obj["id"] for obj in self.batch_objects]
         self.batch_id_to_request_file_name = {
-            obj["id"]: obj["metadata"]["request_file_name"]
-            for obj in self.batch_objects
+            obj["id"]: obj["metadata"]["request_file_name"] for obj in self.batch_objects
         }
         self.check_interval = check_interval
         self.working_dir = working_dir
@@ -392,18 +374,14 @@ async def check_batch_status(self, batch_id: str) -> Batch | None:
                 logger.warning(f"Unknown batch status: {batch.status}")
 
         if batch_returned:
-            logger.info(
-                f"Batch {batch.id} returned with status: {batch.status}"
-            )
+            logger.info(f"Batch {batch.id} returned with status: {batch.status}")
             self.tracker.n_returned_batches += 1
             self.tracker.n_completed_returned_requests += n_completed_requests
             self.tracker.n_failed_returned_requests += n_failed_requests
             self.remaining_batch_ids.remove(batch.id)
             return batch
         else:
-            self.tracker.n_completed_in_progress_requests += (
-                n_completed_requests
-            )
+            self.tracker.n_completed_in_progress_requests += n_completed_requests
             self.tracker.n_failed_in_progress_requests += n_failed_requests
             return None
 
@@ -426,8 +404,7 @@ async def watch(self) -> None:
 
             # check batch status also updates the tracker
             status_tasks = [
-                self.check_batch_status(batch_id)
-                for batch_id in self.remaining_batch_ids
+                self.check_batch_status(batch_id) for batch_id in self.remaining_batch_ids
             ]
             batches_to_download = await asyncio.gather(*status_tasks)
             batches_to_download = filter(None, batches_to_download)
@@ -447,10 +424,7 @@ async def watch(self) -> None:
             # Failed downloads return None and print any errors that occurred
             all_response_files.extend(await asyncio.gather(*download_tasks))
 
-            if (
-                self.tracker.n_returned_batches
-                < self.tracker.n_submitted_batches
-            ):
+            if self.tracker.n_returned_batches < self.tracker.n_submitted_batches:
                 logger.debug(
                     f"Batches returned: {self.tracker.n_returned_batches}/{self.tracker.n_submitted_batches} "
                     f"Requests completed: {pbar.n}/{self.tracker.n_submitted_requests}"
@@ -466,9 +440,7 @@ async def watch(self) -> None:
                 "Please check the logs above and https://platform.openai.com/batches for errors."
             )
 
-    async def download_batch_to_generic_responses_file(
-        self, batch: Batch
-    ) -> str | None:
+    async def download_batch_to_generic_responses_file(self, batch: Batch) -> str | None:
         """Download the result of a completed batch to file.
 
         Args:
@@ -481,9 +453,7 @@ async def download_batch_to_generic_responses_file(
             file_content = await self.client.files.content(batch.output_file_id)
         elif batch.status == "failed" and batch.error_file_id:
             file_content = await self.client.files.content(batch.error_file_id)
-            logger.warning(
-                f"Batch {batch.id} failed\n. Errors will be parsed below."
-            )
+            logger.warning(f"Batch {batch.id} failed\n. Errors will be parsed below.")
         elif batch.status == "failed" and not batch.error_file_id:
             errors = "\n".join([str(error) for error in batch.errors.data])
             logger.error(
@@ -514,7 +484,7 @@ async def download_batch_to_generic_responses_file(
                 raw_response = json.loads(raw_response)
                 request_idx = int(raw_response["custom_id"])
                 generic_request = generic_request_map[request_idx]
-                
+
                 # TODO(Ryan): Add more specific error handling
                 if raw_response["response"]["status_code"] != 200:
                     logger.warning(
@@ -531,31 +501,33 @@ async def download_batch_to_generic_responses_file(
                         created_at=request_creation_times[request_idx],
                         finished_at=datetime.datetime.now(),
                         token_usage=None,
-                        response_cost=None
+                        response_cost=None,
                     )
                 else:
                     response_body = raw_response["response"]["body"]
                     choices = response_body["choices"]
                     usage = response_body.get("usage", {})
-                    
+
                     token_usage = TokenUsage(
                         prompt_tokens=usage.get("prompt_tokens", 0),
                         completion_tokens=usage.get("completion_tokens", 0),
-                        total_tokens=usage.get("total_tokens", 0)
+                        total_tokens=usage.get("total_tokens", 0),
                     )
-                    
+
                     # Calculate cost using litellm
                     cost = litellm.completion_cost(
                         model=generic_request.model,
-                        prompt=str(generic_request.messages),  # Convert messages to string for cost calculation
-                        completion=choices[0]["message"]["content"]
+                        prompt=str(
+                            generic_request.messages
+                        ),  # Convert messages to string for cost calculation
+                        completion=choices[0]["message"]["content"],
                     )
 
                     response_message = choices[0]["message"]["content"]
                     response_message, response_errors = parse_response_message(
                         response_message, self.prompt_formatter.response_format
                     )
-                    
+
                     generic_response = GenericResponse(
                         response_message=response_message,
                         response_errors=response_errors,
@@ -565,10 +537,7 @@ async def download_batch_to_generic_responses_file(
                         created_at=request_creation_times[request_idx],
                         finished_at=datetime.datetime.now(),
                         token_usage=token_usage,
-                        response_cost=cost
+                        response_cost=cost,
                     )
-                f.write(
-                    json.dumps(generic_response.model_dump(), default=str)
-                    + "\n"
-                )
+                f.write(json.dumps(generic_response.model_dump(), default=str) + "\n")
         return response_file
diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py
index 4cc7f7e0..f20d3d1f 100644
--- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py
+++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py
@@ -1,16 +1,17 @@
 import asyncio
+import datetime
 import json
 import logging
 import os
 import re
+import resource
 import time
 from dataclasses import dataclass, field
 from functools import partial
 from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar
-import resource
-import datetime
 
 import aiohttp
+import litellm
 import requests
 import tiktoken
 from tqdm import tqdm
@@ -18,13 +19,9 @@
 from bespokelabs.curator.dataset import Dataset
 from bespokelabs.curator.prompter.prompter import PromptFormatter
 from bespokelabs.curator.request_processor.base_request_processor import (
-    BaseRequestProcessor,
-    GenericRequest,
-    GenericResponse,
-    parse_response_message,
-)
+    BaseRequestProcessor, GenericRequest, GenericResponse,
+    parse_response_message)
 from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
-import litellm
 from bespokelabs.curator.request_processor.generic_response import TokenUsage
 
 T = TypeVar("T")
@@ -77,9 +74,7 @@ def get_rate_limits(self) -> dict:
         tpm = int(response.headers.get("x-ratelimit-limit-tokens", 0))
 
         if not rpm or not tpm:
-            logger.warning(
-                "Failed to get rate limits from OpenAI API, using default values"
-            )
+            logger.warning("Failed to get rate limits from OpenAI API, using default values")
             rpm = 30_000
             tpm = 150_000_000
 
@@ -93,9 +88,7 @@ def get_rate_limits(self) -> dict:
 
         return rate_limits
 
-    def create_api_specific_request(
-        self, generic_request: GenericRequest
-    ) -> dict:
+    def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
         """
         Creates a API-specific request body from a generic request body.
 
@@ -151,21 +144,16 @@ def run(
         Returns:
             Dataset: Completed dataset
         """
-        generic_requests_files = self.create_request_files(
-            dataset, working_dir, prompt_formatter
-        )
+        generic_requests_files = self.create_request_files(dataset, working_dir, prompt_formatter)
         generic_responses_files = [
-            f"{working_dir}/responses_{i}.jsonl"
-            for i in range(len(generic_requests_files))
+            f"{working_dir}/responses_{i}.jsonl" for i in range(len(generic_requests_files))
         ]
 
         rate_limits = self.get_rate_limits()
         rpm = rate_limits["max_requests_per_minute"]
         tpm = rate_limits["max_tokens_per_minute"]
 
-        token_encoding_name = get_token_encoding_name(
-            prompt_formatter.model_name
-        )
+        token_encoding_name = get_token_encoding_name(prompt_formatter.model_name)
 
         # NOTE(Ryan): If you wanted to do this on batches, you could run a for loop here about request_files. Although I don't recommend it because you are waiting for straggler requests to finish for each batch.
         # NOTE(Ryan): And if you wanted to do batches in parallel, you would have to divide rpm and tpm by the number of parallel batches.
@@ -186,9 +174,7 @@ def run(
                 )
             )
 
-        dataset = self.create_dataset_files(
-            working_dir, parse_func_hash, prompt_formatter
-        )
+        dataset = self.create_dataset_files(working_dir, parse_func_hash, prompt_formatter)
         return dataset
 
     async def process_generic_requests_from_file(
@@ -227,12 +213,8 @@ async def process_generic_requests_from_file(
 
         # initialize trackers
         queue_of_requests_to_retry = asyncio.Queue()
-        task_id_generator = (
-            task_id_generator_function()
-        )  # generates integer IDs of 0, 1, 2, ...
-        status_tracker = (
-            StatusTracker()
-        )  # single instance to track a collection of variables
+        task_id_generator = task_id_generator_function()  # generates integer IDs of 0, 1, 2, ...
+        status_tracker = StatusTracker()  # single instance to track a collection of variables
         next_request = None  # variable to hold the next request to call
 
         # initialize available capacity counts
@@ -248,9 +230,7 @@ async def process_generic_requests_from_file(
         if os.path.exists(save_filepath):
             if resume:
                 # save all successfully completed requests to a temporary file, then overwrite the original file with the temporary file
-                logger.debug(
-                    f"Resuming progress from existing file: {save_filepath}"
-                )
+                logger.debug(f"Resuming progress from existing file: {save_filepath}")
                 logger.debug(
                     f"Removing all failed requests from {save_filepath} so they can be retried"
                 )
@@ -268,16 +248,12 @@ async def process_generic_requests_from_file(
                             )
                             num_previously_failed_requests += 1
                         else:
-                            completed_request_ids.add(
-                                response.generic_request.original_row_idx
-                            )
+                            completed_request_ids.add(response.generic_request.original_row_idx)
                             output_file.write(line)
                 logger.info(
                     f"Found {len(completed_request_ids)} completed requests and {num_previously_failed_requests} previously failed requests"
                 )
-                logger.info(
-                    "Failed requests and remaining requests will now be processed."
-                )
+                logger.info("Failed requests and remaining requests will now be processed.")
                 os.replace(temp_filepath, save_filepath)
             elif resume_no_retry:
                 logger.warning(
@@ -287,9 +263,7 @@ async def process_generic_requests_from_file(
                 with open(save_filepath, "r") as input_file, open(
                     temp_filepath, "w"
                 ) as output_file:
-                    for line in tqdm(
-                        input_file, desc="Processing existing requests"
-                    ):
+                    for line in tqdm(input_file, desc="Processing existing requests"):
                         data = json.loads(line)
                         if isinstance(data[1], list):
                             # this means that the request failed and we have a list of errors
@@ -319,9 +293,7 @@ async def process_generic_requests_from_file(
             # Count total number of requests
             total_requests = sum(1 for _ in open(generic_requests_filepath))
             if total_requests == len(completed_request_ids):
-                logger.debug(
-                    "All requests have already been completed so will just reuse cache."
-                )
+                logger.debug("All requests have already been completed so will just reuse cache.")
                 return
 
             # Create progress bar
@@ -338,41 +310,28 @@ async def process_generic_requests_from_file(
                     # get next request (if one is not already waiting for capacity)
                     if next_request is None:
                         if not queue_of_requests_to_retry.empty():
-                            next_request = (
-                                queue_of_requests_to_retry.get_nowait()
-                            )
-                            logger.debug(
-                                f"Retrying request {next_request.task_id}: {next_request}"
-                            )
+                            next_request = queue_of_requests_to_retry.get_nowait()
+                            logger.debug(f"Retrying request {next_request.task_id}: {next_request}")
                         elif file_not_finished:
                             try:
                                 # get new generic request
-                                generic_request_json = json.loads(
-                                    next(generic_requests)
-                                )
+                                generic_request_json = json.loads(next(generic_requests))
                                 generic_request = GenericRequest.model_validate(
                                     generic_request_json
                                 )
                                 request_idx = generic_request.original_row_idx
 
                                 # Skip requests we already have responses for
-                                if (
-                                    resume
-                                    and request_idx in completed_request_ids
-                                ):
+                                if resume and request_idx in completed_request_ids:
                                     logger.debug(
                                         f"Skipping already completed request {request_idx}"
                                     )
-                                    status_tracker.num_tasks_already_completed += (
-                                        1
-                                    )
+                                    status_tracker.num_tasks_already_completed += 1
                                     continue
 
                                 # Create API-specific request
-                                api_specific_request_json = (
-                                    self.create_api_specific_request(
-                                        generic_request
-                                    )
+                                api_specific_request_json = self.create_api_specific_request(
+                                    generic_request
                                 )
                                 next_request = APIRequest(
                                     task_id=next(task_id_generator),
@@ -457,16 +416,11 @@ async def process_generic_requests_from_file(
 
                     # if a rate limit error was hit recently, pause to cool down
                     seconds_since_rate_limit_error = (
-                        time.time()
-                        - status_tracker.time_of_last_rate_limit_error
+                        time.time() - status_tracker.time_of_last_rate_limit_error
                     )
-                    if (
-                        seconds_since_rate_limit_error
-                        < seconds_to_pause_after_rate_limit_error
-                    ):
+                    if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
                         remaining_seconds_to_pause = (
-                            seconds_to_pause_after_rate_limit_error
-                            - seconds_since_rate_limit_error
+                            seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
                         )
                         await asyncio.sleep(remaining_seconds_to_pause)
                         # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
@@ -478,9 +432,7 @@ async def process_generic_requests_from_file(
             pbar.close()
 
             # after finishing, log final status
-            logger.info(
-                f"""Parallel processing complete. Results saved to {save_filepath}"""
-            )
+            logger.info(f"""Parallel processing complete. Results saved to {save_filepath}""")
 
             logger.info(f"Status tracker: {status_tracker}")
 
@@ -506,9 +458,7 @@ class StatusTracker:
     num_rate_limit_errors: int = 0
     num_api_errors: int = 0  # excluding rate limit errors, counted above
     num_other_errors: int = 0
-    time_of_last_rate_limit_error: int = (
-        0  # used to cool off after hitting rate limits
-    )
+    time_of_last_rate_limit_error: int = 0  # used to cool off after hitting rate limits
 
 
 @dataclass
@@ -543,17 +493,13 @@ async def call_api(
             ) as response:
                 response = await response.json()
             if "error" in response:
-                logger.warning(
-                    f"Request {self.task_id} failed with error {response['error']}"
-                )
+                logger.warning(f"Request {self.task_id} failed with error {response['error']}")
                 status_tracker.num_api_errors += 1
                 error = response
                 if "rate limit" in response["error"].get("message", "").lower():
                     status_tracker.time_of_last_rate_limit_error = time.time()
                     status_tracker.num_rate_limit_errors += 1
-                    status_tracker.num_api_errors -= (
-                        1  # rate limit errors are counted separately
-                    )
+                    status_tracker.num_api_errors -= 1  # rate limit errors are counted separately
 
         except (
             Exception
@@ -575,7 +521,7 @@ async def call_api(
                     raw_response=None,
                     generic_request=self.generic_request,
                     created_at=self.created_at,
-                    finished_at=datetime.datetime.now()
+                    finished_at=datetime.datetime.now(),
                 )
                 append_generic_response(generic_response, save_filepath)
                 status_tracker.num_tasks_in_progress -= 1
@@ -593,13 +539,11 @@ async def call_api(
             token_usage = TokenUsage(
                 prompt_tokens=usage.get("prompt_tokens", 0),
                 completion_tokens=usage.get("completion_tokens", 0),
-                total_tokens=usage.get("total_tokens", 0)
+                total_tokens=usage.get("total_tokens", 0),
             )
-            
+
             # Calculate cost using litellm
-            cost = litellm.completion_cost(
-                completion_response=response
-            )
+            cost = litellm.completion_cost(completion_response=response)
 
             generic_response = GenericResponse(
                 response_message=response_message,
@@ -610,7 +554,7 @@ async def call_api(
                 created_at=self.created_at,
                 finished_at=datetime.datetime.now(),
                 token_usage=token_usage,
-                response_cost=cost
+                response_cost=cost,
             )
             append_generic_response(generic_response, save_filepath)
             status_tracker.num_tasks_in_progress -= 1
@@ -629,9 +573,7 @@ def get_token_encoding_name(model: str) -> str:
         return "cl100k_base"
 
 
-def get_rate_limits(
-    model: str, request_url: str, api_key: str
-) -> Tuple[int, int]:
+def get_rate_limits(model: str, request_url: str, api_key: str) -> Tuple[int, int]:
     """
     Function to get rate limits for a given annotator. Makes a single request to openAI API
     and gets the rate limits from the response headers. These rate limits vary per model
@@ -654,20 +596,14 @@ def get_rate_limits(
             json={"model": model, "messages": []},
         )
         # Extract rate limit information from headers
-        max_requests = int(
-            response.headers.get("x-ratelimit-limit-requests", 30_000)
-        )
-        max_tokens = int(
-            response.headers.get("x-ratelimit-limit-tokens", 150_000_000)
-        )
+        max_requests = int(response.headers.get("x-ratelimit-limit-requests", 30_000))
+        max_tokens = int(response.headers.get("x-ratelimit-limit-tokens", 150_000_000))
     elif "api.sambanova.ai" in request_url:
         # Send a dummy request to get rate limit information
         max_requests = 50
         max_tokens = 100_000_000
     else:
-        raise NotImplementedError(
-            f'Rate limits for API endpoint "{request_url}" not implemented'
-        )
+        raise NotImplementedError(f'Rate limits for API endpoint "{request_url}" not implemented')
 
     return max_requests, max_tokens
 
@@ -695,9 +631,7 @@ def api_endpoint_from_url(request_url: str) -> str:
         return match[1]
 
     # for Azure OpenAI deployment urls
-    match = re.search(
-        r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url
-    )
+    match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
     if match:
         return match[1]
 
@@ -707,9 +641,7 @@ def api_endpoint_from_url(request_url: str) -> str:
     elif "completions" in request_url:
         return "completions"
     else:
-        raise NotImplementedError(
-            f'API endpoint "{request_url}" not implemented in this script'
-        )
+        raise NotImplementedError(f'API endpoint "{request_url}" not implemented in this script')
 
 
 def append_generic_response(data: GenericResponse, filename: str) -> None:
@@ -746,9 +678,7 @@ def num_tokens_consumed_from_request(
                         )
                         num_tokens += len(str(value)) // 4
                     if key == "name":  # if there's a name, the role is omitted
-                        num_tokens -= (
-                            1  # role is always required and always 1 token
-                        )
+                        num_tokens -= 1  # role is always required and always 1 token
             num_tokens += 2  # every reply is primed with <im_start>assistant
             return num_tokens + completion_tokens
         # normal completions
@@ -781,9 +711,7 @@ def num_tokens_consumed_from_request(
             )
     # more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
     else:
-        raise NotImplementedError(
-            f'API endpoint "{api_endpoint}" not implemented in this script'
-        )
+        raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script')
 
 
 def task_id_generator_function():
diff --git a/src/bespokelabs/curator/viewer/__main__.py b/src/bespokelabs/curator/viewer/__main__.py
index e57c63bd..062454a2 100644
--- a/src/bespokelabs/curator/viewer/__main__.py
+++ b/src/bespokelabs/curator/viewer/__main__.py
@@ -1,16 +1,16 @@
+import logging
 import os
+import platform
+import shutil
+import socket
 import subprocess
 import sys
-from pathlib import Path
-from argparse import ArgumentParser
+import tempfile
+import time
 import webbrowser
+from argparse import ArgumentParser
 from contextlib import closing
-import socket
-import logging
-import time
-import platform
-import tempfile
-import shutil
+from pathlib import Path
 
 
 def get_viewer_path():
@@ -32,9 +32,7 @@ def ensure_dependencies():
             print(f"Error installing dependencies: {e}")
             sys.exit(1)
         except FileNotFoundError:
-            print(
-                "Error: Node.js is not installed. Please install Node.js to run the viewer."
-            )
+            print("Error: Node.js is not installed. Please install Node.js to run the viewer.")
             sys.exit(1)
 
 
@@ -49,9 +47,7 @@ def _setup_logging(level):
 def check_node_installed():
     """Check if Node.js is installed and return version if found"""
     try:
-        result = subprocess.run(
-            ["node", "--version"], capture_output=True, text=True, check=True
-        )
+        result = subprocess.run(["node", "--version"], capture_output=True, text=True, check=True)
         return result.stdout.strip()
     except (subprocess.CalledProcessError, FileNotFoundError):
         return None
@@ -105,22 +101,16 @@ def main():
     server_file = os.path.join(viewer_path, "server.js")
 
     if not os.path.exists(os.path.join(static_dir, ".next")):
-        print(
-            "Error: Next.js build artifacts not found. The package may not be built correctly."
-        )
+        print("Error: Next.js build artifacts not found. The package may not be built correctly.")
         sys.exit(1)
 
     try:
-        subprocess.run(
-            ["node", server_file], cwd=viewer_path, env=env, check=True
-        )
+        subprocess.run(["node", server_file], cwd=viewer_path, env=env, check=True)
     except subprocess.CalledProcessError as e:
         print(f"Error starting Next.js server: {e}")
         sys.exit(1)
     except FileNotFoundError:
-        print(
-            "Error: Node.js is not installed. Please install Node.js to run the viewer."
-        )
+        print("Error: Node.js is not installed. Please install Node.js to run the viewer.")
         sys.exit(1)
 
 
diff --git a/tests/test_install_ui.py b/tests/test_install_ui.py
index b78c5d6d..2ef39b29 100644
--- a/tests/test_install_ui.py
+++ b/tests/test_install_ui.py
@@ -1,17 +1,19 @@
 """Test script for installation UI."""
-import os
-import sys
+
 import argparse
 import importlib.util
+import os
+import sys
+
 
 def import_install_ui():
     """Import just the install_ui module without importing the whole package."""
     # Get the absolute path to install_ui.py
     install_ui_path = os.path.join(
         os.path.dirname(os.path.dirname(__file__)),  # Go up one level since we're in tests/
-        "src/bespokelabs/curator/install_ui.py"
+        "src/bespokelabs/curator/install_ui.py",
     )
-    
+
     # Import the module directly from file
     spec = importlib.util.spec_from_file_location("install_ui", install_ui_path)
     module = importlib.util.module_from_spec(spec)
@@ -19,25 +21,27 @@ def import_install_ui():
     spec.loader.exec_module(module)
     return module
 
+
 def main():
     """Run the test script with command line arguments."""
-    parser = argparse.ArgumentParser(description='Test the installation UI.')
+    parser = argparse.ArgumentParser(description="Test the installation UI.")
     parser.add_argument(
-        '--scenario', 
-        choices=['success', 'error'],
-        default='success',
-        help='Which scenario to test (success or error)'
+        "--scenario",
+        choices=["success", "error"],
+        default="success",
+        help="Which scenario to test (success or error)",
     )
     args = parser.parse_args()
-    
+
     # Import just the install_ui module
     install_ui = import_install_ui()
-    
+
     # Run the enhanced install based on scenario
-    if args.scenario == 'success':
+    if args.scenario == "success":
         install_ui.enhanced_install("bespokelabs-curator")
     else:
         install_ui.enhanced_install("nonexistent-package-12345")
 
+
 if __name__ == "__main__":
     main()