Skip to content

Commit

Permalink
Merge pull request #271 from bespokelabsai/ryanm/update-tests
Browse files Browse the repository at this point in the history
Update pytests
  • Loading branch information
RyanMarten authored Dec 17, 2024
2 parents 125a83f + 38649ca commit 5013908
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 88 deletions.
2 changes: 1 addition & 1 deletion build_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def nextjs_build():
def run_pytest():
print("Running pytest")
try:
run_command("pytest", cwd="tests")
run_command("pytest")
except subprocess.CalledProcessError:
print("Pytest failed. Aborting build.")
sys.exit(1)
Expand Down
3 changes: 3 additions & 0 deletions src/bespokelabs/curator/llm/prompt_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def response_to_response_format(self, response_message: str | dict) -> Optional[
"""
# Response message is a string, which is converted to a dict
# The dict is then used to construct the response_format Pydantic model
if self.response_format is None:
return response_message

try:
# First try to parse the response message as JSON
if isinstance(response_message, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def run(
parse_func_hash: str,
prompt_formatter: PromptFormatter,
) -> Dataset:
# load from already completed dataset
output_dataset = self.attempt_loading_cached_dataset(working_dir, parse_func_hash)
if output_dataset is not None:
return output_dataset

"""Run completions using the online API with async processing."""
logger.info(f"Running {self.__class__.__name__} completions with model: {self.model}")

Expand Down
23 changes: 11 additions & 12 deletions src/bespokelabs/curator/request_processor/base_request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,19 +315,18 @@ def create_dataset_files(
failed_responses_count += 1
continue

if prompt_formatter.response_format:
try:
response.response_message = (
self.prompt_formatter.response_to_response_format(
response.response_message
)
)
except (json.JSONDecodeError, ValidationError) as e:
logger.warning(
"Skipping response due to error parsing response message into response format"
try:
response.response_message = (
self.prompt_formatter.response_to_response_format(
response.response_message
)
failed_responses_count += 1
continue
)
except (json.JSONDecodeError, ValidationError) as e:
logger.warning(
"Skipping response due to error parsing response message into response format"
)
failed_responses_count += 1
continue

# parse_func can return a single row or a list of rows
if prompt_formatter.parse_func:
Expand Down
1 change: 1 addition & 0 deletions tests/batch/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""


@pytest.mark.skip(reason="Temporarily disabled, need to add mocking")
@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-resume"))
@pytest.mark.usefixtures("prepare_test_cache")
def test_batch_resume():
Expand Down
1 change: 1 addition & 0 deletions tests/batch/test_switch_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""


@pytest.mark.skip(reason="Temporarily disabled, need to add mocking")
@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-switch-keys"))
@pytest.mark.usefixtures("prepare_test_cache")
def test_batch_switch_keys():
Expand Down
30 changes: 8 additions & 22 deletions tests/cache/different_files/one.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
from bespokelabs.curator import LLM
from datasets import Dataset
import logging
import argparse

logger = logging.getLogger("bespokelabs.curator")
logger.setLevel(logging.INFO)


def main(delete_cache: bool = False):
dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})
dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})

prompter = LLM(
prompt_func=lambda row: row["prompt"],
model_name="gpt-4o-mini",
response_format=None,
delete_cache=delete_cache,
)
prompter = LLM(
prompt_func=lambda row: row["prompt"],
model_name="gpt-4o-mini",
response_format=None,
)

dataset = prompter(dataset)
print(dataset.to_pandas())


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run prompter with cache control")
parser.add_argument(
"--delete-cache",
action="store_true",
help="Delete the cache before running",
)
args = parser.parse_args()
main(delete_cache=args.delete_cache)
dataset = prompter(dataset)
print(dataset.to_pandas())
30 changes: 8 additions & 22 deletions tests/cache/different_files/two.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
from bespokelabs.curator import LLM
from datasets import Dataset
import logging
import argparse

logger = logging.getLogger("bespokelabs.curator")
logger.setLevel(logging.INFO)


def main(delete_cache: bool = False):
dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})
dataset = Dataset.from_dict({"prompt": ["just say 'hi'"] * 3})

prompter = LLM(
prompt_func=lambda row: row["prompt"],
model_name="gpt-4o-mini",
response_format=None,
delete_cache=delete_cache,
)
prompter = LLM(
prompt_func=lambda row: row["prompt"],
model_name="gpt-4o-mini",
response_format=None,
)

dataset = prompter(dataset)
print(dataset.to_pandas())


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run prompter with cache control")
parser.add_argument(
"--delete-cache",
action="store_true",
help="Delete the cache before running",
)
args = parser.parse_args()
main(delete_cache=args.delete_cache)
dataset = prompter(dataset)
print(dataset.to_pandas())
9 changes: 3 additions & 6 deletions tests/cache/test_different_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@ def test_cache_behavior():

# Run one.py twice and check for cache behavior
print("RUNNING ONE.PY")
output1, _ = run_script(["python", "tests/cache_tests/different_files/one.py"])
print(output1)
output1, _ = run_script(["python", "tests/cache/different_files/one.py"])
assert cache_hit_log not in output1, "First run of one.py should not hit cache"

print("RUNNING ONE.PY AGAIN")
output2, _ = run_script(["python", "tests/cache_tests/different_files/one.py"])
print(output2)
output2, _ = run_script(["python", "tests/cache/different_files/one.py"])
assert cache_hit_log in output2, "Second run of one.py should hit cache"

# Run two.py and check for cache behavior
print("RUNNING TWO.PY")
output3, _ = run_script(["python", "tests/cache_tests/different_files/two.py"])
print(output3)
output3, _ = run_script(["python", "tests/cache/different_files/two.py"])
assert cache_hit_log in output3, "First run of two.py should hit cache"
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pytest


def pytest_configure(config):
config.addinivalue_line("markers", "cache_dir(path): mark test to use specific cache directory")
2 changes: 1 addition & 1 deletion tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_function_hash_dir_change():
import tempfile
from pathlib import Path

from bespokelabs.curator.prompter.llm import _get_function_hash
from bespokelabs.curator.llm.llm import _get_function_hash

# Set up logging to write to a file in the current directory
debug_log = Path("function_debug.log")
Expand Down
57 changes: 33 additions & 24 deletions tests/test_litellm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,40 @@

@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-models"))
@pytest.mark.usefixtures("prepare_test_cache")
def test_litellm_models():
class TestLiteLLMModels:
@pytest.fixture(autouse=True)
def check_environment(self):
env = os.environ.copy()
required_keys = [
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"GEMINI_API_KEY",
"TOGETHER_API_KEY",
]
for key in required_keys:
assert key in env, f"{key} must be set"

env = os.environ.copy()
assert "ANTHROPIC_API_KEY" in env, "ANTHROPIC_API_KEY must be set"
assert "OPENAI_API_KEY" in env, "OPENAI_API_KEY must be set"
assert "GEMINI_API_KEY" in env, "GEMINI_API_KEY must be set"
assert "TOGETHER_API_KEY" in env, "TOGETHER_API_KEY must be set"

models_list = [
"claude-3-5-sonnet-20240620", # https://docs.litellm.ai/docs/providers/anthropic # anthropic has a different hidden param tokens structure.
"claude-3-5-haiku-20241022",
"claude-3-haiku-20240307",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"gpt-4o-mini", # https://docs.litellm.ai/docs/providers/openai
"gpt-4o-2024-08-06",
"gpt-4-0125-preview",
"gpt-3.5-turbo-1106",
"gemini/gemini-1.5-flash", # https://docs.litellm.ai/docs/providers/gemini; https://ai.google.dev/gemini-api/docs/models # 20-30 iter/s
"gemini/gemini-1.5-pro", # 20-30 iter/s
"together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", # https://docs.together.ai/docs/serverless-models
"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
]

for model in models_list:
@pytest.mark.parametrize(
"model",
[
pytest.param("claude-3-5-sonnet-20240620", id="claude-3-5-sonnet"),
pytest.param("claude-3-5-haiku-20241022", id="claude-3-5-haiku"),
pytest.param("claude-3-haiku-20240307", id="claude-3-haiku"),
pytest.param("claude-3-opus-20240229", id="claude-3-opus"),
pytest.param("claude-3-sonnet-20240229", id="claude-3-sonnet"),
pytest.param("gpt-4o-mini", id="gpt-4-mini"),
pytest.param("gpt-4o-2024-08-06", id="gpt-4"),
pytest.param("gpt-4-0125-preview", id="gpt-4-preview"),
pytest.param("gpt-3.5-turbo-1106", id="gpt-3.5"),
pytest.param("gemini/gemini-1.5-flash", id="gemini-flash"),
pytest.param("gemini/gemini-1.5-pro", id="gemini-pro"),
pytest.param("together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", id="llama-8b"),
pytest.param(
"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", id="llama-70b"
),
],
)
def test_model(self, model):
print(f"\n\n========== TESTING {model} ==========\n\n")
logger = logging.getLogger("bespokelabs.curator")
logger.setLevel(logging.DEBUG)
Expand Down

0 comments on commit 5013908

Please sign in to comment.