Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
vutrung96 committed Nov 21, 2024
1 parent 34bf3d7 commit 18c6cbc
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from datasets import Dataset

from bespokelabs.curator import Prompter


def test_same_value_caching(tmp_path):
"""Test that using the same value multiple times uses cache."""
values = []

# Test with same value multiple times
for _ in range(3):
def prompt_func():
return f"Say '1'. Do not explain."

prompter = Prompter(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
result = prompter(working_dir=str(tmp_path))
values.append(result.to_pandas().iloc[0]["response"])

# Count cache directories, excluding metadata.db
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
assert len(cache_dirs) == 1, f"Expected 1 cache directory but found {len(cache_dirs)}"
assert values == ['1', '1', '1'], "Same value should produce same results"


def test_different_values_caching(tmp_path):
"""Test that using different values creates different cache entries."""
values = []

# Test with different values
for x in [1, 2, 3]:
def prompt_func():
return f"Say '{x}'. Do not explain."

prompter = Prompter(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
result = prompter(working_dir=str(tmp_path))
values.append(result.to_pandas().iloc[0]["response"])

# Count cache directories, excluding metadata.db
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
assert len(cache_dirs) == 3, f"Expected 3 cache directories but found {len(cache_dirs)}"
assert values == ['1', '2', '3'], "Different values should produce different results"

def test_same_dataset_caching(tmp_path):
"""Test that using the same dataset multiple times uses cache."""
dataset = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}])
prompter = Prompter(
prompt_func=lambda x: x["instruction"],
model_name="gpt-4o-mini",
)

result = prompter(dataset=dataset, working_dir=str(tmp_path))
assert result.to_pandas().iloc[0]["response"] == "1"

result = prompter(dataset=dataset, working_dir=str(tmp_path))
assert result.to_pandas().iloc[0]["response"] == "1"

# Count cache directories, excluding metadata.db
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
assert len(cache_dirs) == 1, f"Expected 1 cache directory but found {len(cache_dirs)}"


def test_different_dataset_caching(tmp_path):
"""Test that using different datasets creates different cache entries."""
dataset1 = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}])
dataset2 = Dataset.from_list([{"instruction": "Say '2'. Do not explain."}])
prompter = Prompter(
prompt_func=lambda x: x["instruction"],
model_name="gpt-4o-mini",
)

result = prompter(dataset=dataset1, working_dir=str(tmp_path))
assert result.to_pandas().iloc[0]["response"] == "1"

result = prompter(dataset=dataset2, working_dir=str(tmp_path))
assert result.to_pandas().iloc[0]["response"] == "2"

# Count cache directories, excluding metadata.db
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}"

0 comments on commit 18c6cbc

Please sign in to comment.