Skip to content

Commit 94111a4

Browse files
committed
add nested calling test
1 parent 82ced1a commit 94111a4

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

src/bespokelabs/curator/prompter/prompt_formatter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pydantic import BaseModel
55

66
from bespokelabs.curator.request_processor.generic_request import GenericRequest
7+
78
T = TypeVar("T")
89

910

tests/test_caching.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,21 @@ def test_same_value_caching(tmp_path):
99

1010
# Test with same value multiple times
1111
for _ in range(3):
12+
1213
def prompt_func():
1314
return f"Say '1'. Do not explain."
14-
15+
1516
prompter = Prompter(
1617
prompt_func=prompt_func,
1718
model_name="gpt-4o-mini",
1819
)
1920
result = prompter(working_dir=str(tmp_path))
2021
values.append(result.to_pandas().iloc[0]["response"])
21-
22+
2223
# Count cache directories, excluding metadata.db
2324
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
2425
assert len(cache_dirs) == 1, f"Expected 1 cache directory but found {len(cache_dirs)}"
25-
assert values == ['1', '1', '1'], "Same value should produce same results"
26+
assert values == ["1", "1", "1"], "Same value should produce same results"
2627

2728

2829
def test_different_values_caching(tmp_path):
@@ -31,9 +32,10 @@ def test_different_values_caching(tmp_path):
3132

3233
# Test with different values
3334
for x in [1, 2, 3]:
35+
3436
def prompt_func():
3537
return f"Say '{x}'. Do not explain."
36-
38+
3739
prompter = Prompter(
3840
prompt_func=prompt_func,
3941
model_name="gpt-4o-mini",
@@ -44,7 +46,8 @@ def prompt_func():
4446
# Count cache directories, excluding metadata.db
4547
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
4648
assert len(cache_dirs) == 3, f"Expected 3 cache directories but found {len(cache_dirs)}"
47-
assert values == ['1', '2', '3'], "Different values should produce different results"
49+
assert values == ["1", "2", "3"], "Different values should produce different results"
50+
4851

4952
def test_same_dataset_caching(tmp_path):
5053
"""Test that using the same dataset multiple times uses cache."""
@@ -53,7 +56,7 @@ def test_same_dataset_caching(tmp_path):
5356
prompt_func=lambda x: x["instruction"],
5457
model_name="gpt-4o-mini",
5558
)
56-
59+
5760
result = prompter(dataset=dataset, working_dir=str(tmp_path))
5861
assert result.to_pandas().iloc[0]["response"] == "1"
5962

@@ -62,7 +65,7 @@ def test_same_dataset_caching(tmp_path):
6265

6366
# Count cache directories, excluding metadata.db
6467
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
65-
assert len(cache_dirs) == 1, f"Expected 1 cache directory but found {len(cache_dirs)}"
68+
assert len(cache_dirs) == 1, f"Expected 1 cache directory but found {len(cache_dirs)}"
6669

6770

6871
def test_different_dataset_caching(tmp_path):
@@ -82,4 +85,31 @@ def test_different_dataset_caching(tmp_path):
8285

8386
# Count cache directories, excluding metadata.db
8487
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
85-
assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}"
88+
assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}"
89+
90+
91+
def test_nested_call_caching(tmp_path):
92+
"""Test that changing a nested upstream function invalidates the cache."""
93+
94+
def value_generator():
95+
return 1
96+
97+
def prompt_func():
98+
return f"Say '{value_generator()}'. Do not explain."
99+
100+
prompter = Prompter(
101+
prompt_func=prompt_func,
102+
model_name="gpt-4o-mini",
103+
)
104+
result = prompter(working_dir=str(tmp_path))
105+
assert result.to_pandas().iloc[0]["response"] == "1"
106+
107+
def value_generator():
108+
return 2
109+
110+
result = prompter(working_dir=str(tmp_path))
111+
assert result.to_pandas().iloc[0]["response"] == "2"
112+
113+
# Count cache directories, excluding metadata.db
114+
cache_dirs = [d for d in tmp_path.glob("*") if d.name != "metadata.db"]
115+
assert len(cache_dirs) == 2, f"Expected 2 cache directory but found {len(cache_dirs)}"

0 commit comments

Comments
 (0)